import gradio as gr from transformers import pipeline import torch import os import spaces # Load the model pipeline pipe = pipeline( task="multimodal_mt", model="BSC-LT/salamandra-TAV-7b", trust_remote_code=True, token=os.environ.get("HF_TOKEN"), device_map="auto", torch_dtype=torch.float16, ) # Define the languages for the dropdowns LANGUAGES = { "autodetect": "Autodetect", "en": "English", "es": "Spanish", "ca": "Catalan", "pt": "Portuguese", "gl": "Galician", "eu": "Basque", } # Invert the dictionary for easy lookup LANG_TO_NAME = {v: k for k, v in LANGUAGES.items()} @spaces.GPU def process_audio(audio, source_lang_name, target_lang_name): """ Processes the audio input to perform speech-to-text translation or transcription. """ if audio is None: return "Please provide an audio file or record one.", "" if target_lang_name is None: return "Please select a target language.", "" source_lang = LANG_TO_NAME.get(source_lang_name) target_lang = LANG_TO_NAME.get(target_lang_name) generation_kwargs = {"beam_size": 5, "max_new_tokens": 100} asr_kwargs = {"mode": "asr", "return_chat_history": True, **generation_kwargs} if source_lang != "autodetect": asr_kwargs["src_lang"] = source_lang_name history = pipe(audio, **asr_kwargs) # If source and target languages are the same, we're done (transcription) if source_lang == target_lang: text = history.get_assistant_messages()[-1] else: # Text-to-text translation step t2tt_kwargs = { "mode": "t2tt", "tgt_lang": target_lang_name, "return_chat_history": True, **generation_kwargs } if source_lang != "autodetect": t2tt_kwargs["src_lang"] = source_lang_name history = pipe(history, **t2tt_kwargs) text = history.get_assistant_messages()[-1] detected_language = "" if source_lang == "autodetect": # Language identification step lang_history = pipe(history, mode="lid", return_chat_history=True, **generation_kwargs) detected_language = lang_history.get_assistant_messages()[-1] return text, detected_language # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# SalamandraTAV: Speech-to-Text Translation Demo") gr.Markdown( "A multilingual model for Speech-to-Text Translation (S2TT) and Automatic Speech Recognition (ASR) for Iberian languages." ) with gr.Row(): with gr.Column(): audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input") with gr.Row(): source_lang_dropdown = gr.Dropdown( choices=list(LANGUAGES.values()), value=LANGUAGES["autodetect"], label="Source Language (Optional)", ) target_lang_dropdown = gr.Dropdown( choices=[lang for key, lang in LANGUAGES.items() if key != "autodetect"], label="Target Language (Required)", ) submit_button = gr.Button("Translate/Transcribe") with gr.Column(): output_text = gr.Textbox(label="Output", lines=10, interactive=False) detected_lang_output = gr.Textbox(label="Detected Source Language", interactive=False) submit_button.click( fn=process_audio, inputs=[audio_input, source_lang_dropdown, target_lang_dropdown], outputs=[output_text, detected_lang_output], ) gr.Markdown("## Examples") gr.Examples( examples=[ [ "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/orig/127389__acclivity__thetimehascome.wav", LANGUAGES["en"], LANGUAGES["es"], ], [ "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/orig/127389__acclivity__thetimehascome.wav", LANGUAGES["en"], LANGUAGES["en"], ], ], inputs=[audio_input, source_lang_dropdown, target_lang_dropdown], outputs=[output_text, detected_lang_output], fn=process_audio, ) if __name__ == "__main__": demo.launch()