chenguittiMaroua commited on
Commit
e4ab113
·
verified ·
1 Parent(s): bdb0847

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +58 -54
main.py CHANGED
@@ -874,71 +874,75 @@ async def question_answering(
874
  file: Optional[UploadFile] = File(None)
875
  ):
876
  if qa_pipeline is None:
877
- raise HTTPException(
878
- status_code=503,
879
- detail={
880
- "error": "QA system unavailable",
881
- "status": "No working model could be loaded",
882
- "recovery_suggestion": "Please try again later"
883
- }
884
- )
885
-
886
  try:
887
- # Process input
888
  context = None
889
  if file:
890
- try:
891
- _, content = await process_uploaded_file(file)
892
- context = extract_text(content, file.filename.split('.')[-1])[:1000] # Smaller context
893
- except Exception as e:
894
- logger.error(f"File processing failed: {str(e)}")
895
- raise HTTPException(422, detail=f"File error: {str(e)}")
896
-
897
- # Generate response - MODIFIED PROMPT HERE
898
- try:
899
- input_text = f"Réponds à cette question: {question}" # Changed prompt
900
- if context:
901
- input_text += f" en utilisant ce contexte: {context[:1000]}" # Added context
902
-
903
- result = qa_pipeline(
904
- input_text,
905
- max_length=100,
906
- num_beams=2,
907
- temperature=0.7,
908
- repetition_penalty=2.0
909
  )
910
 
911
- # ADDED CHECK FOR QUESTION REFORMULATION
912
- if "question:" in result[0]["generated_text"].lower():
913
- result[0]["generated_text"] = (
914
- "[Note: This is a question-generation model. For better answers, "
915
- "please use a QA model like google/flan-t5-base].\n"
916
- f"Reformulated question: {result[0]['generated_text']}"
917
- )
918
 
919
  return {
920
  "question": question,
921
- "answer": result[0]["generated_text"],
922
  "model": current_model,
923
- "context_used": context is not None
924
  }
925
-
926
- except Exception as e:
927
- logger.error(f"Generation failed: {str(e)}")
928
- raise HTTPException(
929
- status_code=500,
930
- detail={
931
- "error": "Answer generation failed",
932
- "model": current_model,
933
- "suggestion": "Try a simpler question or smaller document"
934
- }
935
- )
936
-
937
- except HTTPException:
938
- raise
 
 
 
 
 
 
 
 
 
 
 
 
939
  except Exception as e:
940
- logger.critical(f"Unexpected error: {str(e)}")
941
- raise HTTPException(500, "Internal server error")
942
 
943
 
944
 
 
874
  file: Optional[UploadFile] = File(None)
875
  ):
876
  if qa_pipeline is None:
877
+ raise HTTPException(503, detail="QA system unavailable")
878
+
 
 
 
 
 
 
 
879
  try:
880
+ # Process file if provided
881
  context = None
882
  if file:
883
+ _, content = await process_uploaded_file(file)
884
+ full_text = extract_text(content, file.filename.split('.')[-1])
885
+ context = re.sub(r'\s+', ' ', full_text).strip()[:2000] # Clean and limit context
886
+
887
+ # Special handling for theme questions
888
+ theme_keywords = ["thème", "theme", "sujet principal", "quoi le sujet", "de quoi ça parle"]
889
+ if any(kw in question.lower() for kw in theme_keywords):
890
+ if not context:
891
+ return {
892
+ "question": question,
893
+ "answer": "Aucun document fourni pour déterminer le thème",
894
+ "context_used": False
895
+ }
896
+
897
+ # Special prompt for theme detection
898
+ theme_prompt = (
899
+ "Extrait le thème principal en 1-2 phrases en français à partir de ce texte. "
900
+ "Sois concis et précis. Texte:\n" + context[:1500]
 
901
  )
902
 
903
+ theme_result = qa_pipeline(
904
+ theme_prompt,
905
+ max_length=150,
906
+ num_beams=2,
907
+ temperature=0.3, # Lower temperature for more focused answers
908
+ repetition_penalty=2.5
909
+ )
910
 
911
  return {
912
  "question": question,
913
+ "answer": theme_result[0]["generated_text"],
914
  "model": current_model,
915
+ "context_used": True
916
  }
917
+
918
+ # Standard QA handling
919
+ input_text = f"Réponds en français à: {question}"
920
+ if context:
921
+ input_text += f" en utilisant ce contexte: {context[:2000]}"
922
+
923
+ result = qa_pipeline(
924
+ input_text,
925
+ max_length=150,
926
+ num_beams=3,
927
+ temperature=0.7,
928
+ repetition_penalty=2.0
929
+ )
930
+
931
+ # Post-process answer
932
+ answer = result[0]["generated_text"]
933
+ if answer.lower().startswith(("question:", "réponse:")):
934
+ answer = answer.split(":", 1)[1].strip()
935
+
936
+ return {
937
+ "question": question,
938
+ "answer": answer,
939
+ "model": current_model,
940
+ "context_used": context is not None
941
+ }
942
+
943
  except Exception as e:
944
+ logger.error(f"Error: {str(e)}")
945
+ raise HTTPException(500, "Erreur de traitement")
946
 
947
 
948