-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
70 lines (53 loc) · 1.93 KB
/
generate.py
File metadata and controls
70 lines (53 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# generate.py
import torch
from model import ChoraleLSTM
from bach_dataset import build_vocab, load_chorales_soprano
from music21 import stream, note
import os
SEQ_LEN = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def generate(model, seed, length, idx2tok):
model.eval()
result = seed[:]
for _ in range(length):
inp = torch.tensor([result[-SEQ_LEN:]], device=DEVICE)
with torch.no_grad():
logits = model(inp)
probs = torch.softmax(logits[0], dim=0)
next_token = torch.multinomial(probs, 1).item()
result.append(next_token)
return [idx2tok[i] for i in result]
def sequence_to_stream(seq, quarterLength=0.25):
s = stream.Stream()
for p in seq:
if isinstance(p, str) and p.startswith("KEY_"):
continue # Skip key tokens
if p == -1:
n = note.Rest(quarterLength=quarterLength)
else:
n = note.Note(midi=p, quarterLength=quarterLength)
s.append(n)
return s
def main():
sequences = load_chorales_soprano()
tok2idx, idx2tok = build_vocab(sequences)
model = ChoraleLSTM(len(tok2idx))
model.load_state_dict(torch.load(
"chorale_lstm_with_keys.pt", map_location=DEVICE))
model.to(DEVICE)
# Choose a key token that exists in your vocab
key_token = "KEY_C_MAJOR"
if key_token not in tok2idx:
raise ValueError(f"{key_token} not in vocab!")
seed_seq = [tok2idx[key_token]]
# Optionally follow with real pitches (e.g. from training data)
# skip the key token in the source
seed_seq += [tok2idx[p] for p in sequences[0][1:SEQ_LEN]]
output = generate(model, seed_seq, 100, idx2tok)
s = sequence_to_stream(output)
os.makedirs("output", exist_ok=True)
s.write('midi', fp='output/generated.mid')
s.write('musicxml', fp='output/generated.xml')
s.show('midi') # or s.show() to open notation view
if __name__ == "__main__":
main()