chenguittiMaroua commited on
Commit
29d0793
·
verified ·
1 Parent(s): cbfdbdf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +112 -22
main.py CHANGED
@@ -30,7 +30,45 @@ from fastapi import Request
30
  from pathlib import Path
31
  from fastapi.staticfiles import StaticFiles
32
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Initialize rate limiter
35
  limiter = Limiter(key_func=get_remote_address)
36
 
@@ -122,20 +160,37 @@ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
122
  return file_ext, content
123
 
124
  def extract_text(content: bytes, file_ext: str) -> str:
125
- """Extract text from various file formats with enhanced support"""
126
  try:
127
  if file_ext == "docx":
128
  doc = Document(io.BytesIO(content))
129
  return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
130
 
131
  elif file_ext in {"xlsx", "xls"}:
132
- df = pd.read_excel(io.BytesIO(content), sheet_name=None)
 
 
 
 
 
 
 
 
 
133
  all_text = []
134
  for sheet_name, sheet_data in df.items():
135
  sheet_text = []
 
136
  for column in sheet_data.columns:
137
- sheet_text.extend(sheet_data[column].dropna().astype(str).tolist())
 
 
 
 
 
 
138
  all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
 
139
  return "\n\n".join(all_text)
140
 
141
  elif file_ext == "pptx":
@@ -168,8 +223,8 @@ def extract_text(content: bytes, file_ext: str) -> str:
168
  raise ValueError("Could not extract text or caption from image")
169
 
170
  except Exception as e:
171
- logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
172
- raise HTTPException(422, f"Failed to extract text from {file_ext} file")
173
 
174
  # Visualization Models
175
  class VisualizationRequest(BaseModel):
@@ -213,47 +268,82 @@ def validate_matplotlib_style(style: str) -> str:
213
 
214
 
215
  def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
216
- """Generate Python code for visualization based on request parameters"""
217
  # Validate style
218
  valid_style = validate_matplotlib_style(request.style)
219
 
 
 
 
220
  code_lines = [
221
  "import matplotlib.pyplot as plt",
222
  "import seaborn as sns",
223
  "import pandas as pd",
 
 
 
 
 
224
  "",
225
- "# Data preparation",
226
- f"df = pd.DataFrame({df.to_dict(orient='list')})",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  ]
228
 
229
- # Apply filters if specified
230
  if request.filters:
231
  filter_conditions = []
232
  for column, condition in request.filters.items():
233
  if isinstance(condition, dict):
234
  if 'min' in condition and 'max' in condition:
235
- filter_conditions.append(f"(df['{column}'] >= {condition['min']}) & (df['{column}'] <= {condition['max']})")
 
 
 
 
236
  elif 'values' in condition:
237
  values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
238
- filter_conditions.append(f"df['{column}'].isin([{values}])")
 
 
 
239
  else:
240
- filter_conditions.append(f"df['{column}'] == {repr(condition)}")
 
 
 
241
 
242
  if filter_conditions:
243
  code_lines.extend([
244
  "",
245
- "# Apply filters",
246
- f"df = df[{' & '.join(filter_conditions)}]"
247
  ])
248
 
249
  code_lines.extend([
250
  "",
251
- "# Visualization",
252
  f"plt.style.use('{valid_style}')",
253
  f"plt.figure(figsize=(10, 6))"
254
  ])
255
 
256
- # Chart type specific code
257
  if request.chart_type == "line":
258
  if request.hue_column:
259
  code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
@@ -270,15 +360,16 @@ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest)
270
  else:
271
  code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
272
  elif request.chart_type == "histogram":
273
- code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
274
  elif request.chart_type == "boxplot":
275
  if request.hue_column:
276
- code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
277
  else:
278
- code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}')")
279
  elif request.chart_type == "heatmap":
280
- code_lines.append(f"corr = df.corr()")
281
- code_lines.append(f"sns.heatmap(corr, annot=True, cmap='coolwarm')")
 
282
  else:
283
  raise ValueError(f"Unsupported chart type: {request.chart_type}")
284
 
@@ -296,7 +387,6 @@ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest)
296
  ])
297
 
298
  return "\n".join(code_lines)
299
-
300
  def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest:
301
  """Convert natural language prompt to visualization parameters"""
302
  prompt = prompt.lower()
 
30
  from pathlib import Path
31
  from fastapi.staticfiles import StaticFiles
32
 
33
+ # main.py
34
 
35
+ # Standard library imports
36
+ import io
37
+ import re
38
+ import logging
39
+ import tempfile
40
+ import base64
41
+ import warnings
42
+ from typing import Tuple, Optional
43
+ from pathlib import Path
44
+
45
+ # Third-party imports
46
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
47
+ from fastapi.middleware.cors import CORSMiddleware
48
+ from fastapi.responses import JSONResponse, HTMLResponse
49
+ from transformers import pipeline
50
+ import fitz # PyMuPDF
51
+ from PIL import Image
52
+ import pandas as pd
53
+ import uvicorn
54
+ from docx import Document
55
+ from pptx import Presentation
56
+ import pytesseract
57
+ from slowapi import Limiter
58
+ from slowapi.util import get_remote_address
59
+ from slowapi.errors import RateLimitExceeded
60
+ from slowapi.middleware import SlowAPIMiddleware
61
+ import matplotlib.pyplot as plt
62
+ import seaborn as sns
63
+ from pydantic import BaseModel
64
+ import traceback
65
+ import ast
66
+ from openpyxl import Workbook
67
+
68
+ # Suppress openpyxl warnings
69
+ warnings.filterwarnings("ignore", category=UserWarning, module="openpyxl")
70
+
71
+ # Rest of your code (app setup, routes, etc.)...
72
  # Initialize rate limiter
73
  limiter = Limiter(key_func=get_remote_address)
74
 
 
160
  return file_ext, content
161
 
162
  def extract_text(content: bytes, file_ext: str) -> str:
163
+ """Extract text from various file formats with enhanced Excel support"""
164
  try:
165
  if file_ext == "docx":
166
  doc = Document(io.BytesIO(content))
167
  return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
168
 
169
  elif file_ext in {"xlsx", "xls"}:
170
+ # Improved Excel handling with better NaN and date support
171
+ df = pd.read_excel(
172
+ io.BytesIO(content),
173
+ sheet_name=None,
174
+ engine='openpyxl',
175
+ na_values=['', 'NA', 'N/A', 'NaN', 'null'],
176
+ keep_default_na=False,
177
+ parse_dates=True
178
+ )
179
+
180
  all_text = []
181
  for sheet_name, sheet_data in df.items():
182
  sheet_text = []
183
+ # Convert all data to string and handle special types
184
  for column in sheet_data.columns:
185
+ # Handle datetime columns
186
+ if pd.api.types.is_datetime64_any_dtype(sheet_data[column]):
187
+ sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S')
188
+ # Convert to string and clean
189
+ col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist()
190
+ sheet_text.extend([x for x in col_text if x.strip()])
191
+
192
  all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
193
+
194
  return "\n\n".join(all_text)
195
 
196
  elif file_ext == "pptx":
 
223
  raise ValueError("Could not extract text or caption from image")
224
 
225
  except Exception as e:
226
+ logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True)
227
+ raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}")
228
 
229
  # Visualization Models
230
  class VisualizationRequest(BaseModel):
 
268
 
269
 
270
  def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
271
+ """Generate Python code for visualization with enhanced NaN handling and type safety"""
272
  # Validate style
273
  valid_style = validate_matplotlib_style(request.style)
274
 
275
+ # Convert DataFrame to dict with proper NaN handling
276
+ df_dict = df.where(pd.notnull(df), None).to_dict(orient='list')
277
+
278
  code_lines = [
279
  "import matplotlib.pyplot as plt",
280
  "import seaborn as sns",
281
  "import pandas as pd",
282
+ "import numpy as np",
283
+ "",
284
+ "# Data preparation with NaN handling and type conversion",
285
+ f"raw_data = {df_dict}",
286
+ "df = pd.DataFrame(raw_data)",
287
  "",
288
+ "# Automatic type conversion and cleaning",
289
+ "for col in df.columns:",
290
+ " # Convert strings that should be numeric",
291
+ " if pd.api.types.is_string_dtype(df[col]):",
292
+ " try:",
293
+ " df[col] = pd.to_numeric(df[col])",
294
+ " continue",
295
+ " except (ValueError, TypeError):",
296
+ " pass",
297
+ " ",
298
+ " # Convert string dates to datetime",
299
+ " try:",
300
+ " df[col] = pd.to_datetime(df[col])",
301
+ " continue",
302
+ " except (ValueError, TypeError):",
303
+ " pass",
304
+ " ",
305
+ " # Clean remaining None/NaN values",
306
+ " df[col] = df[col].where(pd.notnull(df[col]), None)",
307
  ]
308
 
309
+ # Apply filters if specified (with enhanced safety)
310
  if request.filters:
311
  filter_conditions = []
312
  for column, condition in request.filters.items():
313
  if isinstance(condition, dict):
314
  if 'min' in condition and 'max' in condition:
315
+ filter_conditions.append(
316
+ f"(pd.notna(df['{column}']) & "
317
+ f"(df['{column}'] >= {condition['min']}) & "
318
+ f"(df['{column}'] <= {condition['max']})"
319
+ )
320
  elif 'values' in condition:
321
  values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
322
+ filter_conditions.append(
323
+ f"(pd.notna(df['{column}'])) & "
324
+ f"(df['{column}'].isin([{values}]))"
325
+ )
326
  else:
327
+ filter_conditions.append(
328
+ f"(pd.notna(df['{column}'])) & "
329
+ f"(df['{column}'] == {repr(condition)})"
330
+ )
331
 
332
  if filter_conditions:
333
  code_lines.extend([
334
  "",
335
+ "# Apply filters with NaN checking",
336
+ f"df = df[{' & '.join(filter_conditions)}].copy()"
337
  ])
338
 
339
  code_lines.extend([
340
  "",
341
+ "# Visualization setup",
342
  f"plt.style.use('{valid_style}')",
343
  f"plt.figure(figsize=(10, 6))"
344
  ])
345
 
346
+ # Chart type specific code (unchanged from your original)
347
  if request.chart_type == "line":
348
  if request.hue_column:
349
  code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
 
360
  else:
361
  code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
362
  elif request.chart_type == "histogram":
363
+ code_lines.append(f"plt.hist(df['{request.x_column}'].dropna(), bins=20)") # Added dropna()
364
  elif request.chart_type == "boxplot":
365
  if request.hue_column:
366
+ code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") # Added dropna()
367
  else:
368
+ code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}')") # Added dropna()
369
  elif request.chart_type == "heatmap":
370
+ code_lines.append("numeric_df = df.select_dtypes(include=[np.number])") # Filter numeric only
371
+ code_lines.append("corr = numeric_df.corr()")
372
+ code_lines.append("sns.heatmap(corr, annot=True, cmap='coolwarm')")
373
  else:
374
  raise ValueError(f"Unsupported chart type: {request.chart_type}")
375
 
 
387
  ])
388
 
389
  return "\n".join(code_lines)
 
390
  def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest:
391
  """Convert natural language prompt to visualization parameters"""
392
  prompt = prompt.lower()