import os import spaces import pickle import subprocess import torch import torch.nn as nn import gradio as gr from dataclasses import asdict from transformers import T5Tokenizer from huggingface_hub import hf_hub_download from time import time_ns from uuid import uuid4 from transformer_model import Transformer from pyharp.core import ModelCard, build_endpoint from pyharp.labels import LabelList # Model/artifacts from HF Hub REPO_ID = "amaai-lab/text2midi" MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl") # Optional, only if you later add WAV preview: SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2") # (Optional) MIDI -> WAV def save_wav(midi_path: str) -> str: directory = os.path.dirname(midi_path) or "." stem = os.path.splitext(os.path.basename(midi_path))[0] midi_filepath = os.path.join(directory, f"{stem}.mid") wav_filepath = os.path.join(directory, f"{stem}.wav") cmd = ( f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell " f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null" ) subprocess.run(cmd, shell=True, check=False) return wav_filepath # Helpers def _unique_path(ext: str) -> str: """Create a unique file path in /tmp to avoid naming collisions.""" return os.path.join("/tmp", f"t2m_{time_ns()}_{uuid4().hex[:8]}{ext}") # Core Text -> MIDI def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str: device = "cuda" if torch.cuda.is_available() else "cpu" # Load REMI vocab/tokenizer (pickle dict used by the provided model) with open(TOKENIZER_PATH, "rb") as f: r_tokenizer = pickle.load(f) vocab_size = len(r_tokenizer) model = Transformer( vocab_size, # vocab size 768, # d_model 8, # nhead 2048, # dim_feedforward 18, # nlayers 1024, # max_seq_len False, # use_rotary 8, # rotary_dim device=device # device ) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.eval() tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device) attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device) with torch.no_grad(): output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature) output_list = output[0].tolist() generated_midi = r_tokenizer.decode(output_list) midi_path = _unique_path(".mid") generated_midi.dump_midi(midi_path) return midi_path # HARP process function # Return JSON first, MIDI second @spaces.GPU(duration=120) def process_fn(prompt: str, temperature: float, max_length: int): try: midi_path = generate_midi(prompt, float(temperature), int(max_length)) labels = LabelList() # add MidiLabel entries here if you have metadata return asdict(labels), midi_path except Exception as e: # On error: return JSON with error message, and no file return {"message": f"Error: {e}"}, None # HARP Model Card model_card = ModelCard( name="Text2MIDI Generation", description=( "Turn your musical ideas into playable MIDI notes. \n" "Input: Describe what you'd like to hear. For example: a gentle piano lullaby with soft strings. \n" "Output: This model will generate a matching MIDI sequence for playback or editing. \n" "Use the sliders to control the amount of creativity and length." ), author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans", tags=["text-to-music", "midi", "generation"] ) # Gradio + HARP UI with gr.Blocks() as demo: gr.Markdown("## 🎶 text2midi") # Inputs prompt_in = gr.Textbox( label="Describe Your Music", info="Type a short phrase like 'calm piano with flowing arpeggios' ", ).harp_required(True) temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Creativity", info=( "Adjusts how much freedom the model takes while composing.\n" "Lower = safer and more predictable (structured), " "Higher = more varied and expressive." ), interactive=True) maxlen_in = gr.Slider(minimum=500, maximum=1500, step=100, value=500, label="Composition Length", info=( "Determines how long the generated piece is in musical tokens.\n" "Higher values produce longer phrases (roughly more measures of music).") ) # Outputs (JSON FIRST for HARP, then MIDI) labels_out = gr.JSON(label="Labels / Metadata") midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath") # Build HARP endpoint _ = build_endpoint( model_card=model_card, input_components=[prompt_in, temperature_in, maxlen_in], output_components=[labels_out, midi_out], # JSON first process_fn=process_fn ) # Launch App demo.launch(share=True, show_error=True, debug=True)