asigalov61 commited on
Commit
2844146
·
verified ·
1 Parent(s): c094d1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -54
app.py CHANGED
@@ -269,6 +269,31 @@ def check_seq(seq, mel_len, mel_ptcs):
269
  #==================================================================================
270
 
271
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  def Generate_Accompaniment(input_midi,
273
  input_melody,
274
  melody_patch,
@@ -278,54 +303,6 @@ def Generate_Accompaniment(input_midi,
278
 
279
  #===============================================================================
280
 
281
- def generate_full_seq(input_seq,
282
- max_toks=3072,
283
- temperature=0.9,
284
- top_k_value=15,
285
- verbose=True
286
- ):
287
-
288
- seq_abs_run_time = sum([t for t in input_seq if t < 128])
289
-
290
- cur_time = 0
291
-
292
- full_seq = copy.deepcopy(input_seq)
293
-
294
- toks_counter = 0
295
-
296
- while cur_time <= seq_abs_run_time+32:
297
-
298
- if verbose:
299
- if toks_counter % 128 == 0:
300
- print('Generated', toks_counter, 'tokens')
301
-
302
- x = torch.LongTensor(full_seq).cuda()
303
-
304
- with ctx:
305
- out = model.generate(x,
306
- 1,
307
- filter_logits_fn=top_k,
308
- filter_kwargs={'k': top_k_value},
309
- temperature=temperature,
310
- return_prime=False,
311
- verbose=False)
312
-
313
- y = out.tolist()[0][0]
314
-
315
- if y < 128:
316
- cur_time += y
317
-
318
- full_seq.append(y)
319
-
320
- toks_counter += 1
321
-
322
- if toks_counter == max_toks:
323
- return full_seq
324
-
325
- return full_seq
326
-
327
- #===============================================================================
328
-
329
  print('=' * 70)
330
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
331
  start_time = reqtime.time()
@@ -353,7 +330,7 @@ def Generate_Accompaniment(input_midi,
353
 
354
  if input_midi:
355
  inp_mel = 'Custom MIDI'
356
- score, score_list = load_midi(input_midi.name, melody_patch, use_nth_note)
357
 
358
  else:
359
  mel_list = [m[0].lower() for m in popular_hook_melodies]
@@ -366,8 +343,9 @@ def Generate_Accompaniment(input_midi,
366
  break
367
 
368
  score = popular_hook_melodies[[m[0] for m in popular_hook_melodies].index(inp_mel)][1]
369
- score_list = [[[score[i]], score[i+1:i+3]] for i in range(0, len(score)-3, 3)]
370
-
 
371
  print('Selected melody:', inp_mel)
372
 
373
  print('Sample score events', score[:12])
@@ -382,12 +360,12 @@ def Generate_Accompaniment(input_midi,
382
 
383
  #==================================================================
384
 
385
- start_score_seq = [1792] + score + [1793]
386
 
387
  #==================================================================
388
 
389
- input_seq = generate_full_seq(start_score_seq,
390
- max_toks=MAX_GEN_TOKS,
391
  temperature=model_temperature,
392
  top_k_value=model_sampling_top_k,
393
  )
 
269
  #==================================================================================
270
 
271
  @spaces.GPU
272
+ def generate_full_seq(score_seq,
273
+ score_len,
274
+ temperature=0.9,
275
+ top_k_value=15,
276
+ num_batches=64,
277
+ verbose=True
278
+ ):
279
+
280
+ x = torch.LongTensor([score_seq] * num_batches).cuda()
281
+
282
+ with ctx:
283
+ out = model.generate(x,
284
+ 32*score_len,
285
+ filter_logits_fn=top_k,
286
+ filter_kwargs={'k': top_k_value},
287
+ temperature=temperature,
288
+ return_prime=False,
289
+ verbose=verbose)
290
+
291
+ output = out.tolist()
292
+
293
+ return output
294
+
295
+ #==================================================================================
296
+
297
  def Generate_Accompaniment(input_midi,
298
  input_melody,
299
  melody_patch,
 
303
 
304
  #===============================================================================
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  print('=' * 70)
307
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
308
  start_time = reqtime.time()
 
330
 
331
  if input_midi:
332
  inp_mel = 'Custom MIDI'
333
+ score, score_len, score_ptcs = load_midi(input_midi.name, melody_patch, use_nth_note)
334
 
335
  else:
336
  mel_list = [m[0].lower() for m in popular_hook_melodies]
 
343
  break
344
 
345
  score = popular_hook_melodies[[m[0] for m in popular_hook_melodies].index(inp_mel)][1]
346
+ score_len = len(score) // 3
347
+ score_ptcs = [t-256 for t in score if t > 256]
348
+
349
  print('Selected melody:', inp_mel)
350
 
351
  print('Sample score events', score[:12])
 
360
 
361
  #==================================================================
362
 
363
+ score_seq = [1792] + score + [1793]
364
 
365
  #==================================================================
366
 
367
+ input_seq = generate_full_seq(score_seq,
368
+ score_len,
369
  temperature=model_temperature,
370
  top_k_value=model_sampling_top_k,
371
  )