chenguittiMaroua commited on
Commit
23e0e9b
·
verified ·
1 Parent(s): d69a8fd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +54 -83
main.py CHANGED
@@ -131,73 +131,45 @@ def get_summarizer():
131
 
132
 
133
 
134
- from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
135
- import torch
136
-
137
- # Model options (ordered by preference)
138
- QA_MODELS = [
139
- {"name": "google/flan-t5-small", "max_length": 512},
140
- {"name": "facebook/bart-large-cnn", "max_length": 1024}
141
  ]
142
 
143
- class QASystem:
144
  def __init__(self):
145
  self.model = None
146
- self.tokenizer = None
147
- self.current_model = None
148
  self.device = 0 if torch.cuda.is_available() else -1
149
-
150
- def load_model(self):
151
- for model_info in QA_MODELS:
 
152
  try:
153
- logger.info(f"Loading model: {model_info['name']}")
154
 
155
- self.tokenizer = AutoTokenizer.from_pretrained(model_info["name"])
156
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
157
- model_info["name"],
158
- device_map="auto",
159
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
160
  )
161
- self.current_model = model_info
162
- logger.info(f"Successfully loaded {model_info['name']}")
163
  return True
164
 
165
  except Exception as e:
166
- logger.warning(f"Failed to load {model_info['name']}: {str(e)}")
167
  continue
168
 
169
  logger.error("All model loading attempts failed")
170
  return False
171
 
172
- def generate_answer(self, question: str, context: Optional[str] = None):
173
- try:
174
- if context:
175
- input_text = f"question: {question} context: {context[:2000]}"
176
- else:
177
- input_text = f"question: {question}"
178
-
179
- inputs = self.tokenizer(
180
- input_text,
181
- return_tensors="pt",
182
- truncation=True,
183
- max_length=self.current_model["max_length"]
184
- ).to(self.device)
185
-
186
- outputs = self.model.generate(
187
- **inputs,
188
- max_new_tokens=200,
189
- num_beams=4,
190
- early_stopping=True
191
- )
192
-
193
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
194
-
195
- except Exception as e:
196
- logger.error(f"Generation failed: {str(e)}")
197
- raise
198
-
199
- # Initialize QA system
200
- qa_system = QASystem()
201
 
202
 
203
 
@@ -891,59 +863,58 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
891
  from typing import Optional
892
 
893
  @app.post("/qa")
894
- async def question_answering(
895
  question: str = Form(...),
896
- file: Optional[UploadFile] = File(None),
897
- language: str = Form("en")
898
  ):
899
- # Initialize model if not loaded
900
- if not qa_system.model:
901
- if not qa_system.load_model():
902
  raise HTTPException(
903
- 500,
904
  detail={
905
- "error": "System initialization failed",
906
- "tried_models": [m["name"] for m in QA_MODELS],
907
- "suggestion": "Check logs for loading errors"
908
  }
909
  )
910
-
911
  try:
912
- # Process file if provided
913
  context = None
914
  if file:
915
- try:
916
- file_ext, content = await process_uploaded_file(file)
917
- context = extract_text(content, file_ext)
918
- context = re.sub(r'\s+', ' ', context).strip()[:3000]
919
- except Exception as e:
920
- logger.error(f"File processing failed: {str(e)}")
921
- raise HTTPException(422, detail=f"File processing error: {str(e)}")
922
-
923
- # Generate answer
924
  try:
925
- answer = qa_system.generate_answer(question, context)
 
 
 
 
 
 
926
 
927
  return {
928
  "question": question,
929
- "answer": answer,
930
- "model": qa_system.current_model["name"],
931
- "source": "document" if context else "general",
932
- "language": language
933
  }
934
 
935
  except Exception as e:
936
- logger.error(f"Answer generation failed: {str(e)}")
937
  raise HTTPException(
938
- 500,
939
  detail={
940
  "error": "Answer generation failed",
941
- "model": qa_system.current_model["name"],
942
- "input_length": len(question) + (len(context) if context else 0),
943
- "suggestion": "Try simplifying your question or reducing document size"
944
  }
945
  )
946
-
947
  except HTTPException:
948
  raise
949
  except Exception as e:
 
131
 
132
 
133
 
134
+ MODEL_CHOICES = [
135
+ "patrickvonplaten/t5-tiny-random", # Tiny test model (always works)
136
+ "google/flan-t5-small", # 300MB
137
+ "google/flan-t5-base", # 900MB
138
+ "facebook/bart-large-cnn" # 1.6GB
 
 
139
  ]
140
 
141
+ class QAService:
142
  def __init__(self):
143
  self.model = None
144
+ self.model_name = None
 
145
  self.device = 0 if torch.cuda.is_available() else -1
146
+
147
+ def initialize(self):
148
+ """Try loading models until one succeeds"""
149
+ for model_name in MODEL_CHOICES:
150
  try:
151
+ logger.info(f"Attempting to load {model_name}")
152
 
153
+ # Lightweight pipeline initialization
154
+ self.model = pipeline(
155
+ "text2text-generation",
156
+ model=model_name,
157
+ device=self.device,
158
+ torch_dtype=torch.float16 if self.device == 0 else torch.float32
159
  )
160
+ self.model_name = model_name
161
+ logger.info(f"Successfully loaded {model_name}")
162
  return True
163
 
164
  except Exception as e:
165
+ logger.warning(f"Failed to load {model_name}: {str(e)}")
166
  continue
167
 
168
  logger.error("All model loading attempts failed")
169
  return False
170
 
171
+ # Global service instance
172
+ qa_service = QAService()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
 
175
 
 
863
  from typing import Optional
864
 
865
  @app.post("/qa")
866
+ async def handle_qa_request(
867
  question: str = Form(...),
868
+ file: Optional[UploadFile] = File(None)
 
869
  ):
870
+ # Initialize service if needed
871
+ if not qa_service.model:
872
+ if not qa_service.initialize():
873
  raise HTTPException(
874
+ status_code=500,
875
  detail={
876
+ "error": "System unavailable",
877
+ "status": "Model initialization failed",
878
+ "recovery_suggestion": "Retry in 30 seconds or contact support"
879
  }
880
  )
881
+
882
  try:
883
+ # Process input
884
  context = None
885
  if file:
886
+ file_ext, content = await process_uploaded_file(file)
887
+ context = extract_text(content, file_ext)[:2000] # Strict limit
888
+
889
+ # Generate response
 
 
 
 
 
890
  try:
891
+ input_text = f"question: {question}" + (f" context: {context}" if context else "")
892
+ result = qa_service.model(
893
+ input_text,
894
+ max_length=150,
895
+ num_beams=2,
896
+ early_stopping=True
897
+ )
898
 
899
  return {
900
  "question": question,
901
+ "answer": result[0]["generated_text"],
902
+ "model": qa_service.model_name,
903
+ "context_used": bool(context)
 
904
  }
905
 
906
  except Exception as e:
907
+ logger.error(f"Generation failed: {str(e)}")
908
  raise HTTPException(
909
+ status_code=500,
910
  detail={
911
  "error": "Answer generation failed",
912
+ "model": qa_service.model_name,
913
+ "input_size": len(input_text) if 'input_text' in locals() else None,
914
+ "suggestion": "Simplify your question or reduce document size"
915
  }
916
  )
917
+
918
  except HTTPException:
919
  raise
920
  except Exception as e: