Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -798,71 +798,168 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
|
|
| 798 |
logger.error(f"Summarization failed: {str(e)}", exc_info=True)
|
| 799 |
raise HTTPException(500, "Document summarization failed")
|
| 800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
@app.post("/qa")
|
| 802 |
@limiter.limit("5/minute")
|
| 803 |
-
async def
|
| 804 |
request: Request,
|
| 805 |
-
file: UploadFile = File(
|
|
|
|
| 806 |
question: str = Form(...),
|
| 807 |
-
language: str = Form("
|
| 808 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
try:
|
| 810 |
-
|
| 811 |
-
text =
|
| 812 |
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
# Clean and truncate text
|
| 817 |
-
text = re.sub(r'\s+', ' ', text).strip()[:5000]
|
| 818 |
-
|
| 819 |
-
# Theme detection
|
| 820 |
-
theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
|
| 821 |
-
if any(kw in question.lower() for kw in theme_keywords):
|
| 822 |
-
try:
|
| 823 |
-
summarizer = get_summarizer()
|
| 824 |
-
summary_output = summarizer(
|
| 825 |
-
text,
|
| 826 |
-
max_length=min(100, len(text)//4),
|
| 827 |
-
min_length=30,
|
| 828 |
-
do_sample=False,
|
| 829 |
-
truncation=True
|
| 830 |
-
)
|
| 831 |
-
|
| 832 |
-
theme = summary_output[0].get("summary_text", text[:200] + "...")
|
| 833 |
-
return {
|
| 834 |
-
"question": question,
|
| 835 |
-
"answer": f"Le document traite principalement de : {theme}",
|
| 836 |
-
"confidence": 0.95,
|
| 837 |
-
"language": language
|
| 838 |
-
}
|
| 839 |
-
except Exception:
|
| 840 |
-
theme = text[:200] + ("..." if len(text) > 200 else "")
|
| 841 |
-
return {
|
| 842 |
-
"question": question,
|
| 843 |
-
"answer": f"D'après le document : {theme}",
|
| 844 |
-
"confidence": 0.7,
|
| 845 |
-
"language": language,
|
| 846 |
-
"warning": "theme_summary_fallback"
|
| 847 |
-
}
|
| 848 |
-
|
| 849 |
-
# Standard QA
|
| 850 |
-
qa = get_qa_model()
|
| 851 |
-
result = qa(question=question, context=text[:3000])
|
| 852 |
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
|
| 860 |
except HTTPException:
|
| 861 |
raise
|
| 862 |
except Exception as e:
|
| 863 |
-
logger.error(f"QA
|
| 864 |
raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
|
| 865 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
|
| 867 |
@app.post("/visualize/natural")
|
| 868 |
async def natural_language_visualization(
|
|
|
|
| 798 |
logger.error(f"Summarization failed: {str(e)}", exc_info=True)
|
| 799 |
raise HTTPException(500, "Document summarization failed")
|
| 800 |
|
| 801 |
+
from typing import Optional
|
| 802 |
+
import re
|
| 803 |
+
from fastapi import HTTPException
|
| 804 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 805 |
+
|
| 806 |
+
executor = ThreadPoolExecutor(max_workers=4)
|
| 807 |
+
|
| 808 |
@app.post("/qa")
|
| 809 |
@limiter.limit("5/minute")
|
| 810 |
+
async def universal_question_answering(
|
| 811 |
request: Request,
|
| 812 |
+
file: UploadFile = File(None),
|
| 813 |
+
text_input: str = Form(None),
|
| 814 |
question: str = Form(...),
|
| 815 |
+
language: str = Form("en")
|
| 816 |
):
|
| 817 |
+
"""
|
| 818 |
+
Universal QA endpoint that handles:
|
| 819 |
+
- Any file type (PPTX, XLSX, DOCX, PDF, Images)
|
| 820 |
+
- Direct text input
|
| 821 |
+
- Any question type (factual, thematic, analytical)
|
| 822 |
+
- Multiple languages
|
| 823 |
+
"""
|
| 824 |
try:
|
| 825 |
+
# Step 1: Extract and preprocess content
|
| 826 |
+
text = await extract_content(file, text_input)
|
| 827 |
|
| 828 |
+
# Step 2: Classify question type
|
| 829 |
+
question_type = classify_question(question, language)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
|
| 831 |
+
# Step 3: Process based on question type
|
| 832 |
+
if question_type == "theme":
|
| 833 |
+
return await handle_theme(text, question, language)
|
| 834 |
+
elif question_type == "summary":
|
| 835 |
+
return await handle_summary(text, question, language)
|
| 836 |
+
elif question_type == "fact":
|
| 837 |
+
return await handle_factual(text, question, language)
|
| 838 |
+
elif question_type == "list":
|
| 839 |
+
return await handle_list(text, question, language)
|
| 840 |
+
elif question_type == "comparison":
|
| 841 |
+
return await handle_comparison(text, question, language)
|
| 842 |
+
else:
|
| 843 |
+
return await handle_general(text, question, language)
|
| 844 |
|
| 845 |
except HTTPException:
|
| 846 |
raise
|
| 847 |
except Exception as e:
|
| 848 |
+
logger.error(f"QA failed: {str(e)}", exc_info=True)
|
| 849 |
raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
|
| 850 |
|
| 851 |
+
async def extract_content(file: Optional[UploadFile], text_input: Optional[str]) -> str:
|
| 852 |
+
"""Extract and preprocess content from file or direct text"""
|
| 853 |
+
if file:
|
| 854 |
+
file_ext, content = await process_uploaded_file(file)
|
| 855 |
+
loop = asyncio.get_event_loop()
|
| 856 |
+
text = await loop.run_in_executor(executor, extract_text, content, file_ext)
|
| 857 |
+
elif text_input:
|
| 858 |
+
text = text_input
|
| 859 |
+
else:
|
| 860 |
+
raise HTTPException(400, "Either file or text_input must be provided")
|
| 861 |
+
|
| 862 |
+
if not text.strip():
|
| 863 |
+
raise HTTPException(400, "No extractable content found")
|
| 864 |
+
|
| 865 |
+
# Advanced cleaning preserving structure
|
| 866 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 867 |
+
return smart_truncate(text, 15000) # Increased context window
|
| 868 |
+
|
| 869 |
+
def classify_question(question: str, language: str) -> str:
|
| 870 |
+
"""Determine question type using keyword matching and ML"""
|
| 871 |
+
question_lower = question.lower()
|
| 872 |
+
|
| 873 |
+
# Theme detection
|
| 874 |
+
theme_keywords = {
|
| 875 |
+
"en": ["theme", "main topic", "about", "subject"],
|
| 876 |
+
"fr": ["thème", "sujet principal", "parle de"],
|
| 877 |
+
"es": ["tema", "asunto principal"]
|
| 878 |
+
}
|
| 879 |
+
if any(kw in question_lower for kw in theme_keywords.get(language, theme_keywords["en"])):
|
| 880 |
+
return "theme"
|
| 881 |
+
|
| 882 |
+
# Summary detection
|
| 883 |
+
summary_keywords = {
|
| 884 |
+
"en": ["summarize", "overview", "brief"],
|
| 885 |
+
"fr": ["résumer", "aperçu"],
|
| 886 |
+
"es": ["resumir", "resumen"]
|
| 887 |
+
}
|
| 888 |
+
if any(kw in question_lower for kw in summary_keywords.get(language, summary_keywords["en"])):
|
| 889 |
+
return "summary"
|
| 890 |
+
|
| 891 |
+
# Factual questions
|
| 892 |
+
factual_keywords = ["what", "when", "who", "which", "where", "quoi", "quand", "qui"]
|
| 893 |
+
if any(question_lower.startswith(kw) for kw in factual_keywords):
|
| 894 |
+
return "fact"
|
| 895 |
+
|
| 896 |
+
# List questions
|
| 897 |
+
list_keywords = ["list", "examples", "name all", "énumérer"]
|
| 898 |
+
if any(kw in question_lower for kw in list_keywords):
|
| 899 |
+
return "list"
|
| 900 |
+
|
| 901 |
+
# Comparison questions
|
| 902 |
+
comparison_keywords = ["compare", "difference", "contrast", "comparer"]
|
| 903 |
+
if any(kw in question_lower for kw in comparison_keywords):
|
| 904 |
+
return "comparison"
|
| 905 |
+
|
| 906 |
+
return "general"
|
| 907 |
+
|
| 908 |
+
async def handle_theme(text: str, question: str, language: str) -> dict:
|
| 909 |
+
"""Handle theme/topic questions"""
|
| 910 |
+
summarizer = get_summarizer()
|
| 911 |
+
summary = await asyncio.get_event_loop().run_in_executor(
|
| 912 |
+
executor,
|
| 913 |
+
lambda: summarizer(
|
| 914 |
+
text,
|
| 915 |
+
max_length=150,
|
| 916 |
+
min_length=50,
|
| 917 |
+
do_sample=False
|
| 918 |
+
)[0]["summary_text"]
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
responses = {
|
| 922 |
+
"en": f"The main theme is: {summary}",
|
| 923 |
+
"fr": f"Le thème principal est : {summary}",
|
| 924 |
+
"es": f"El tema principal es: {summary}"
|
| 925 |
+
}
|
| 926 |
+
return format_response(question, responses.get(language, responses["en"]), 0.95)
|
| 927 |
+
|
| 928 |
+
async def handle_factual(text: str, question: str, language: str) -> dict:
|
| 929 |
+
"""Handle factual questions"""
|
| 930 |
+
qa = get_qa_model()
|
| 931 |
+
context = select_relevant_context(text, question)
|
| 932 |
+
|
| 933 |
+
result = await asyncio.get_event_loop().run_in_executor(
|
| 934 |
+
executor,
|
| 935 |
+
lambda: qa(question=question, context=context)
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
return format_response(question, result["answer"], result["score"])
|
| 939 |
+
|
| 940 |
+
async def handle_general(text: str, question: str, language: str) -> dict:
|
| 941 |
+
"""Handle any generic question"""
|
| 942 |
+
# First try standard QA
|
| 943 |
+
try:
|
| 944 |
+
qa_result = await handle_factual(text, question, language)
|
| 945 |
+
if qa_result["confidence"] > 0.7:
|
| 946 |
+
return qa_result
|
| 947 |
+
except Exception:
|
| 948 |
+
pass
|
| 949 |
+
|
| 950 |
+
# Fallback to summarization
|
| 951 |
+
return await handle_theme(text, question, language)
|
| 952 |
+
|
| 953 |
+
def format_response(question: str, answer: str, confidence: float) -> dict:
|
| 954 |
+
"""Standardize response format"""
|
| 955 |
+
return {
|
| 956 |
+
"question": question,
|
| 957 |
+
"answer": answer,
|
| 958 |
+
"confidence": float(confidence),
|
| 959 |
+
"type": "qa_response"
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
# Include other helper functions from previous implementation (smart_truncate, select_relevant_context, etc.)
|
| 963 |
|
| 964 |
@app.post("/visualize/natural")
|
| 965 |
async def natural_language_visualization(
|