Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -269,6 +269,31 @@ def check_seq(seq, mel_len, mel_ptcs):
|
|
| 269 |
#==================================================================================
|
| 270 |
|
| 271 |
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
def Generate_Accompaniment(input_midi,
|
| 273 |
input_melody,
|
| 274 |
melody_patch,
|
|
@@ -278,54 +303,6 @@ def Generate_Accompaniment(input_midi,
|
|
| 278 |
|
| 279 |
#===============================================================================
|
| 280 |
|
| 281 |
-
def generate_full_seq(input_seq,
|
| 282 |
-
max_toks=3072,
|
| 283 |
-
temperature=0.9,
|
| 284 |
-
top_k_value=15,
|
| 285 |
-
verbose=True
|
| 286 |
-
):
|
| 287 |
-
|
| 288 |
-
seq_abs_run_time = sum([t for t in input_seq if t < 128])
|
| 289 |
-
|
| 290 |
-
cur_time = 0
|
| 291 |
-
|
| 292 |
-
full_seq = copy.deepcopy(input_seq)
|
| 293 |
-
|
| 294 |
-
toks_counter = 0
|
| 295 |
-
|
| 296 |
-
while cur_time <= seq_abs_run_time+32:
|
| 297 |
-
|
| 298 |
-
if verbose:
|
| 299 |
-
if toks_counter % 128 == 0:
|
| 300 |
-
print('Generated', toks_counter, 'tokens')
|
| 301 |
-
|
| 302 |
-
x = torch.LongTensor(full_seq).cuda()
|
| 303 |
-
|
| 304 |
-
with ctx:
|
| 305 |
-
out = model.generate(x,
|
| 306 |
-
1,
|
| 307 |
-
filter_logits_fn=top_k,
|
| 308 |
-
filter_kwargs={'k': top_k_value},
|
| 309 |
-
temperature=temperature,
|
| 310 |
-
return_prime=False,
|
| 311 |
-
verbose=False)
|
| 312 |
-
|
| 313 |
-
y = out.tolist()[0][0]
|
| 314 |
-
|
| 315 |
-
if y < 128:
|
| 316 |
-
cur_time += y
|
| 317 |
-
|
| 318 |
-
full_seq.append(y)
|
| 319 |
-
|
| 320 |
-
toks_counter += 1
|
| 321 |
-
|
| 322 |
-
if toks_counter == max_toks:
|
| 323 |
-
return full_seq
|
| 324 |
-
|
| 325 |
-
return full_seq
|
| 326 |
-
|
| 327 |
-
#===============================================================================
|
| 328 |
-
|
| 329 |
print('=' * 70)
|
| 330 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 331 |
start_time = reqtime.time()
|
|
@@ -353,7 +330,7 @@ def Generate_Accompaniment(input_midi,
|
|
| 353 |
|
| 354 |
if input_midi:
|
| 355 |
inp_mel = 'Custom MIDI'
|
| 356 |
-
score,
|
| 357 |
|
| 358 |
else:
|
| 359 |
mel_list = [m[0].lower() for m in popular_hook_melodies]
|
|
@@ -366,8 +343,9 @@ def Generate_Accompaniment(input_midi,
|
|
| 366 |
break
|
| 367 |
|
| 368 |
score = popular_hook_melodies[[m[0] for m in popular_hook_melodies].index(inp_mel)][1]
|
| 369 |
-
|
| 370 |
-
|
|
|
|
| 371 |
print('Selected melody:', inp_mel)
|
| 372 |
|
| 373 |
print('Sample score events', score[:12])
|
|
@@ -382,12 +360,12 @@ def Generate_Accompaniment(input_midi,
|
|
| 382 |
|
| 383 |
#==================================================================
|
| 384 |
|
| 385 |
-
|
| 386 |
|
| 387 |
#==================================================================
|
| 388 |
|
| 389 |
-
input_seq = generate_full_seq(
|
| 390 |
-
|
| 391 |
temperature=model_temperature,
|
| 392 |
top_k_value=model_sampling_top_k,
|
| 393 |
)
|
|
|
|
| 269 |
#==================================================================================
|
| 270 |
|
| 271 |
@spaces.GPU
|
| 272 |
+
def generate_full_seq(score_seq,
|
| 273 |
+
score_len,
|
| 274 |
+
temperature=0.9,
|
| 275 |
+
top_k_value=15,
|
| 276 |
+
num_batches=64,
|
| 277 |
+
verbose=True
|
| 278 |
+
):
|
| 279 |
+
|
| 280 |
+
x = torch.LongTensor([score_seq] * num_batches).cuda()
|
| 281 |
+
|
| 282 |
+
with ctx:
|
| 283 |
+
out = model.generate(x,
|
| 284 |
+
32*score_len,
|
| 285 |
+
filter_logits_fn=top_k,
|
| 286 |
+
filter_kwargs={'k': top_k_value},
|
| 287 |
+
temperature=temperature,
|
| 288 |
+
return_prime=False,
|
| 289 |
+
verbose=verbose)
|
| 290 |
+
|
| 291 |
+
output = out.tolist()
|
| 292 |
+
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
#==================================================================================
|
| 296 |
+
|
| 297 |
def Generate_Accompaniment(input_midi,
|
| 298 |
input_melody,
|
| 299 |
melody_patch,
|
|
|
|
| 303 |
|
| 304 |
#===============================================================================
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
print('=' * 70)
|
| 307 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 308 |
start_time = reqtime.time()
|
|
|
|
| 330 |
|
| 331 |
if input_midi:
|
| 332 |
inp_mel = 'Custom MIDI'
|
| 333 |
+
score, score_len, score_ptcs = load_midi(input_midi.name, melody_patch, use_nth_note)
|
| 334 |
|
| 335 |
else:
|
| 336 |
mel_list = [m[0].lower() for m in popular_hook_melodies]
|
|
|
|
| 343 |
break
|
| 344 |
|
| 345 |
score = popular_hook_melodies[[m[0] for m in popular_hook_melodies].index(inp_mel)][1]
|
| 346 |
+
score_len = len(score) // 3
|
| 347 |
+
score_ptcs = [t-256 for t in score if t > 256]
|
| 348 |
+
|
| 349 |
print('Selected melody:', inp_mel)
|
| 350 |
|
| 351 |
print('Sample score events', score[:12])
|
|
|
|
| 360 |
|
| 361 |
#==================================================================
|
| 362 |
|
| 363 |
+
score_seq = [1792] + score + [1793]
|
| 364 |
|
| 365 |
#==================================================================
|
| 366 |
|
| 367 |
+
input_seq = generate_full_seq(score_seq,
|
| 368 |
+
score_len,
|
| 369 |
temperature=model_temperature,
|
| 370 |
top_k_value=model_sampling_top_k,
|
| 371 |
)
|