Spaces:
Sleeping
Sleeping
File size: 3,771 Bytes
bc56659 48e3908 f1df127 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 f1df127 48e3908 f1df127 48e3908 f1df127 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 bc56659 48e3908 f1df127 48e3908 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Load the model and tokenizer
model_name = "jbochi/madlad400-3b-mt"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
def translate_text(text, source_lang, target_lang):
"""
Translate text between English and Persian using MADLAD-400-3B
"""
# Define language codes for the model
lang_codes = {
"English": "en",
"Persian": "fa"
}
source_code = lang_codes[source_lang]
target_code = lang_codes[target_lang]
# Create the translation prompt in the format the model expects
prompt = f"<2{target_code}> {text}"
try:
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate translation
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=512,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=3,
length_penalty=1.0
)
# Decode the output
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
except Exception as e:
return f"Error during translation: {str(e)}"
# Gradio interface
with gr.Blocks(title="English-Persian Translator") as demo:
gr.Markdown(
"""
# 🌍 English-Persian Translator
**Author: Saman Zeitouinan**
Translate text between English and Persian.
"""
)
with gr.Row():
with gr.Column():
source_lang = gr.Dropdown(
choices=["English", "Persian"],
value="English",
label="Source Language"
)
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to translate...",
label="Input Text"
)
translate_btn = gr.Button("Translate", variant="primary")
with gr.Column():
target_lang = gr.Dropdown(
choices=["Persian", "English"],
value="Persian",
label="Target Language"
)
output_text = gr.Textbox(
lines=5,
label="Translated Text",
interactive=False
)
# Examples
gr.Examples(
examples=[
["Hello, how are you today?", "English", "Persian"],
["What is your name?", "English", "Persian"],
["سلام، حالتون چطوره؟", "Persian", "English"],
["امروز هوا خوب است", "Persian", "English"]
],
inputs=[input_text, source_lang, target_lang],
outputs=output_text,
fn=translate_text,
cache_examples=False
)
# Connect the button
translate_btn.click(
fn=translate_text,
inputs=[input_text, source_lang, target_lang],
outputs=output_text
)
# Auto-update target language based on source selection
def update_target_lang(source_lang):
return "Persian" if source_lang == "English" else "English"
source_lang.change(
fn=update_target_lang,
inputs=source_lang,
outputs=target_lang
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
share=False,
debug=True
) |