# Environment setup from pathlib import Path import os import sys sys.path.append(str(Path(__file__).parent)) # FIXME add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315 if os.path.exists('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py'): file_lines = [] with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'r') as f: for line in f: file_lines.append(line.strip('\n')) file_lines[314] = file_lines[314].replace( "state = torch.load(f, map_location=torch.device(\"cpu\"))", "state = torch.load(f, map_location=torch.device(\"cpu\"), weights_only=False)" ) with open('/usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py', 'w') as f: for line in file_lines: f.write(line+'\n') print('[DEBUG] added weights_only=False') # Run import spaces import gradio as gr from zipfile import ZipFile from typing import Literal from huggingface_hub import snapshot_download from fireredtts.models.fireredtts import FireRedTTS # NOTE disable verbose INFO logs import logging httpx_logger = logging.getLogger("httpx") httpx_logger.setLevel(logging.WARNING) # NOTE Some launching setups # - install fairseq manually ("python -m pip install pip==24.0") # - manually add weights_only=False in /usr/local/lib/python3.10/site-packages/fairseq/checkpoint_utils.py#315 # ================================================ # FireRedTTS1s Model # ================================================ # Global model instance tts_flow: FireRedTTS = None tts_acollm: FireRedTTS = None def initiate_model(pretrained_dir: str): global tts_flow, tts_acollm if tts_flow is None: tts_flow = FireRedTTS( config_path='configs/config_24k_flow.json', pretrained_path=pretrained_dir, ) if tts_acollm is None: tts_acollm = FireRedTTS( config_path='configs/config_24k.json', pretrained_path=pretrained_dir, ) # ================================================ # Gradio # ================================================ # i18n _i18n_key2lang_dict = dict( # Title markdown title_md_desc=dict( en="FireRedTTS-1s πŸ”₯ Streamable TTS", zh="FireRedTTS-1s πŸ”₯ 可桁式TTS", ), # Decoder choice radio decoder_choice_label=dict( en="Decoder Choice", zh="解码器选择", ), decoder_choice_1=dict( en="Flow Matching", zh="Flow Matching", ), decoder_choice_2=dict( en="Acoustic LLM", zh="Acoustic LLM", ), # Speaker Prompt spk_prompt_audio_label=dict( en="Speaker Prompt Audio", zh="ε‚θ€ƒθ―­ιŸ³", ), spk_prompt_text_label=dict( en="Speaker Prompt Text", zh="ε‚θ€ƒθ―­ιŸ³ηš„ζ–‡ζœ¬", ), spk_prompt_text_placeholder=dict( en="Speaker Prompt Text", zh="ε‚θ€ƒθ―­ιŸ³ηš„ζ–‡ζœ¬", ), # Input textbox target_text_input_label=dict( en="Text To Synthesis", zh="εΎ…εˆζˆζ–‡ζœ¬", ), target_text_input_placeholder=dict( en="Text To Synthesis", zh="εΎ…εˆζˆζ–‡ζœ¬", ), # Generate button generate_btn_label=dict( en="Generate Audio", zh="合成", ), # Generated audio generated_audio_label=dict( en="Generated Audio", zh="εˆζˆηš„ιŸ³ι’‘", ), # Warining1: incomplete prompt info warn_incomplete_prompt=dict( en="Please provide prompt audio and text", zh="θ―·ζδΎ›θ―΄θ―δΊΊε‚θ€ƒθ―­ιŸ³δΈŽε‚θ€ƒζ–‡ζœ¬", ), # Warining2: invalid text for target text input warn_invalid_target_text=dict( en="Empty input text", zh="εΎ…εˆζˆζ–‡ζœ¬δΈΊη©Ί", ), ) global_lang: Literal['zh', 'en'] = 'zh' def i18n(key): global global_lang return _i18n_key2lang_dict[key][global_lang] def check_monologue_text(text:str, prefix:str=None)->bool: text = text.strip() # Check speaker tags if prefix is not None and (not text.startswith(prefix)): return False # Remove prefix if prefix is not None: text = text.removeprefix(prefix) text = text.strip() # If empty? if len(text) == 0: return False return True @spaces.GPU(duration=60) def synthesis_function( spk_prompt_audio: str, spk_prompt_text: str, target_text: str, decoder_choice: Literal[0, 1] = 0, # 0 means flow matching decoder ): global tts_flow, tts_acollm # Check prompt info spk_prompt_text = spk_prompt_text.strip() if spk_prompt_audio is None or spk_prompt_text == "": gr.Warning(message=i18n('warn_incomplete_prompt')) return None # Check target text target_text = target_text.strip() if target_text == "": gr.Warning(message=i18n('warn_invalid_target_text')) return None # Go synthesis if decoder_choice == 0: audio = tts_flow.synthesize( prompt_wav=spk_prompt_audio, prompt_text=spk_prompt_text, text=target_text, lang="zh", use_tn=True ) else: audio = tts_acollm.synthesize( prompt_wav=spk_prompt_audio, prompt_text=spk_prompt_text, text=target_text, lang="zh", use_tn=True ) return (24000, audio.detach().cpu().squeeze(0).numpy()) # UI rendering def render_interface()->gr.Blocks: with gr.Blocks(title="FireRedTTS-2", theme=gr.themes.Default()) as page: # ======================== UI ======================== # A large title title_desc = gr.Markdown(value="# {}".format(i18n('title_md_desc'))) with gr.Row(): lang_choice = gr.Radio( choices=['δΈ­ζ–‡', 'English'], value='δΈ­ζ–‡', label='Display Language/ζ˜Ύη€Ίθ―­θ¨€', type="index", interactive=True, ) decoder_choice = gr.Radio( choices=[i18n('decoder_choice_1'), i18n('decoder_choice_2')], value=i18n('decoder_choice_1'), label=i18n('decoder_choice_label'), type="index", interactive=True, ) with gr.Row(): # ==== Speaker Prompt ==== spk_prompt_text = gr.Textbox( label=i18n('spk_prompt_text_label'), placeholder=i18n('spk_prompt_text_placeholder'), lines=5, ) spk_prompt_audio = gr.Audio( label=i18n('spk_prompt_audio_label'), type="filepath", editable=False, interactive=True, ) # Audio component returns tmp audio path # ==== Target Text ==== target_text_input = gr.Textbox( label=i18n('target_text_input_label'), placeholder=i18n('target_text_input_placeholder'), lines=5, ) # Generate button generate_btn = gr.Button(value=i18n('generate_btn_label'), variant="primary", size="lg") # Long output audio generate_audio = gr.Audio( label=i18n('generated_audio_label'), interactive=False, ) # ======================== Action ======================== # Language action def _change_component_language(lang): global global_lang global_lang = ['zh', 'en'][lang] return [ # title_desc gr.update(value="# {}".format(i18n('title_md_desc'))), # decoder_choice gr.update(label=i18n('decoder_choice_label')), # spk_prompt_{audio,text} gr.update(label=i18n('spk_prompt_text_label'), placeholder=i18n('spk_prompt_text_placeholder')), gr.update(label=i18n('spk_prompt_audio_label')), # target_text_input gr.update(label=i18n('target_text_input_label'), placeholder=i18n('target_text_input_placeholder')), # generate_btn gr.update(value=i18n('generate_btn_label')), # generate_audio gr.update(label=i18n('generated_audio_label')), ] lang_choice.change( fn=_change_component_language, inputs=[lang_choice], outputs=[ title_desc, decoder_choice, spk_prompt_text, spk_prompt_audio, target_text_input, generate_btn, generate_audio, ] ) generate_btn.click( fn=synthesis_function, inputs=[spk_prompt_audio, spk_prompt_text, target_text_input, decoder_choice], outputs=[generate_audio] ) return page if __name__ == '__main__': # Download model snapshot_download(repo_id='FireRedTeam/FireRedTTS-1S', local_dir='pretrained_models/FireRedTTS-1S') # Unzip model, weights under "pretrained_models/FireRedTTS-1S/pretrained_models" with ZipFile('pretrained_models/FireRedTTS-1S/pretrained_models.zip', 'r') as zipf: zipf.extractall('pretrained_models/FireRedTTS-1S') # Init model initiate_model('pretrained_models/FireRedTTS-1S/pretrained_models') print('[INFO] model loaded') # UI page = render_interface() page.launch()