samyhusy commited on
Commit
48e3908
·
verified ·
1 Parent(s): 51c8816

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -25
app.py CHANGED
@@ -4,33 +4,37 @@ import torch
4
 
5
  # Load the model and tokenizer
6
  model_name = "jbochi/madlad400-3b-mt"
7
-
8
- @gr.cache_resource()
9
- def load_model():
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForSeq2SeqLM.from_pretrained(
12
- model_name,
13
- torch_dtype=torch.float16,
14
- device_map="auto"
15
- )
16
- return tokenizer, model
17
-
18
- tokenizer, model = load_model()
19
 
20
  def translate_text(text, source_lang, target_lang):
21
  """
22
  Translate text between English and Persian using MADLAD-400-3B
23
  """
24
- lang_codes = {"English": "en", "Persian": "fa"}
 
 
 
 
 
25
  source_code = lang_codes[source_lang]
26
  target_code = lang_codes[target_lang]
27
 
 
28
  prompt = f"<2{target_code}> {text}"
29
 
30
  try:
 
31
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
 
32
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
33
 
 
34
  with torch.no_grad():
35
  outputs = model.generate(
36
  **inputs,
@@ -38,45 +42,55 @@ def translate_text(text, source_lang, target_lang):
38
  num_beams=5,
39
  early_stopping=True,
40
  no_repeat_ngram_size=3,
 
41
  )
42
 
 
43
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
44
  return translated_text
45
 
46
  except Exception as e:
47
  return f"Error during translation: {str(e)}"
48
 
49
- with gr.Blocks(title="English-Persian Translator", theme=gr.themes.Soft()) as demo:
50
- gr.Markdown("# 🌍 English-Persian Translator")
51
- gr.Markdown("Translate text between English and Persian using MADLAD-400-3B model")
 
 
 
 
 
 
 
52
 
53
  with gr.Row():
54
  with gr.Column():
55
  source_lang = gr.Dropdown(
56
  choices=["English", "Persian"],
57
  value="English",
58
- label="From"
59
  )
60
  input_text = gr.Textbox(
61
- lines=4,
62
  placeholder="Enter text to translate...",
63
  label="Input Text"
64
  )
 
65
 
66
  with gr.Column():
67
  target_lang = gr.Dropdown(
68
  choices=["Persian", "English"],
69
  value="Persian",
70
- label="To"
71
  )
72
  output_text = gr.Textbox(
73
- lines=4,
74
  label="Translated Text",
75
  interactive=False
76
  )
77
 
78
- translate_btn = gr.Button("Translate ✨", variant="primary")
79
-
80
  gr.Examples(
81
  examples=[
82
  ["Hello, how are you today?", "English", "Persian"],
@@ -86,19 +100,31 @@ with gr.Blocks(title="English-Persian Translator", theme=gr.themes.Soft()) as de
86
  ],
87
  inputs=[input_text, source_lang, target_lang],
88
  outputs=output_text,
89
- fn=translate_text
 
90
  )
91
 
 
92
  translate_btn.click(
93
  fn=translate_text,
94
  inputs=[input_text, source_lang, target_lang],
95
  outputs=output_text
96
  )
97
 
 
98
  def update_target_lang(source_lang):
99
  return "Persian" if source_lang == "English" else "English"
100
 
101
- source_lang.change(update_target_lang, source_lang, target_lang)
 
 
 
 
102
 
103
  if __name__ == "__main__":
104
- demo.launch()
 
 
 
 
 
 
4
 
5
  # Load the model and tokenizer
6
  model_name = "jbochi/madlad400-3b-mt"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(
9
+ model_name,
10
+ torch_dtype=torch.float16, # Use float16 to reduce memory usage
11
+ device_map="auto"
12
+ )
 
 
 
 
 
 
13
 
14
  def translate_text(text, source_lang, target_lang):
15
  """
16
  Translate text between English and Persian using MADLAD-400-3B
17
  """
18
+ # Define language codes for the model
19
+ lang_codes = {
20
+ "English": "en",
21
+ "Persian": "fa"
22
+ }
23
+
24
  source_code = lang_codes[source_lang]
25
  target_code = lang_codes[target_lang]
26
 
27
+ # Create the translation prompt in the format the model expects
28
  prompt = f"<2{target_code}> {text}"
29
 
30
  try:
31
+ # Tokenize input
32
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
33
+
34
+ # Move inputs to the same device as model
35
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
36
 
37
+ # Generate translation
38
  with torch.no_grad():
39
  outputs = model.generate(
40
  **inputs,
 
42
  num_beams=5,
43
  early_stopping=True,
44
  no_repeat_ngram_size=3,
45
+ length_penalty=1.0
46
  )
47
 
48
+ # Decode the output
49
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+
51
  return translated_text
52
 
53
  except Exception as e:
54
  return f"Error during translation: {str(e)}"
55
 
56
+ # Create the Gradio interface
57
+ with gr.Blocks(title="English-Persian Translator") as demo:
58
+ gr.Markdown(
59
+ """
60
+ # 🌍 English-Persian Translator
61
+ **Powered by MADLAD-400-3B Model**
62
+
63
+ Translate text between English and Persian using the state-of-the-art MADLAD-400 model.
64
+ """
65
+ )
66
 
67
  with gr.Row():
68
  with gr.Column():
69
  source_lang = gr.Dropdown(
70
  choices=["English", "Persian"],
71
  value="English",
72
+ label="Source Language"
73
  )
74
  input_text = gr.Textbox(
75
+ lines=5,
76
  placeholder="Enter text to translate...",
77
  label="Input Text"
78
  )
79
+ translate_btn = gr.Button("Translate", variant="primary")
80
 
81
  with gr.Column():
82
  target_lang = gr.Dropdown(
83
  choices=["Persian", "English"],
84
  value="Persian",
85
+ label="Target Language"
86
  )
87
  output_text = gr.Textbox(
88
+ lines=5,
89
  label="Translated Text",
90
  interactive=False
91
  )
92
 
93
+ # Examples
 
94
  gr.Examples(
95
  examples=[
96
  ["Hello, how are you today?", "English", "Persian"],
 
100
  ],
101
  inputs=[input_text, source_lang, target_lang],
102
  outputs=output_text,
103
+ fn=translate_text,
104
+ cache_examples=False
105
  )
106
 
107
+ # Connect the button
108
  translate_btn.click(
109
  fn=translate_text,
110
  inputs=[input_text, source_lang, target_lang],
111
  outputs=output_text
112
  )
113
 
114
+ # Auto-update target language based on source selection
115
  def update_target_lang(source_lang):
116
  return "Persian" if source_lang == "English" else "English"
117
 
118
+ source_lang.change(
119
+ fn=update_target_lang,
120
+ inputs=source_lang,
121
+ outputs=target_lang
122
+ )
123
 
124
  if __name__ == "__main__":
125
+ # Launch the app
126
+ demo.launch(
127
+ server_name="0.0.0.0", # Allow external access
128
+ share=False, # Set to True to get a public URL
129
+ debug=True
130
+ )