Update app.py
Browse files
app.py
CHANGED
|
@@ -220,7 +220,6 @@ if __name__ == "__main__":
|
|
| 220 |
print('Loading MidiCaps dataset...')
|
| 221 |
|
| 222 |
mc_dataset = load_dataset("amaai-lab/MidiCaps")
|
| 223 |
-
# mc_fnames = [f['location'].split('/')[-1].split('.mid')[0] for f in mc_dataset['train']]
|
| 224 |
print('=' * 70)
|
| 225 |
|
| 226 |
print('Loading files list...')
|
|
@@ -231,12 +230,12 @@ if __name__ == "__main__":
|
|
| 231 |
|
| 232 |
print('Loading MIDI corpus embeddings...')
|
| 233 |
|
| 234 |
-
corpus_embeddings = np.load('MIDI_corpus_embeddings_all-
|
| 235 |
print('Done!')
|
| 236 |
print('=' * 70)
|
| 237 |
|
| 238 |
print('Loading Sentence Transformer model...')
|
| 239 |
-
model = SentenceTransformer('all-
|
| 240 |
print('Done!')
|
| 241 |
print('=' * 70)
|
| 242 |
|
|
|
|
| 220 |
print('Loading MidiCaps dataset...')
|
| 221 |
|
| 222 |
mc_dataset = load_dataset("amaai-lab/MidiCaps")
|
|
|
|
| 223 |
print('=' * 70)
|
| 224 |
|
| 225 |
print('Loading files list...')
|
|
|
|
| 230 |
|
| 231 |
print('Loading MIDI corpus embeddings...')
|
| 232 |
|
| 233 |
+
corpus_embeddings = np.load('MIDI_corpus_embeddings_all-mpnet-base-v2.npz')['data']
|
| 234 |
print('Done!')
|
| 235 |
print('=' * 70)
|
| 236 |
|
| 237 |
print('Loading Sentence Transformer model...')
|
| 238 |
+
model = SentenceTransformer('all-mpnet-base-v2')
|
| 239 |
print('Done!')
|
| 240 |
print('=' * 70)
|
| 241 |
|