chenguittiMaroua commited on
Commit
eca39b4
·
verified ·
1 Parent(s): 56298de

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +68 -49
main.py CHANGED
@@ -838,73 +838,92 @@ async def visualize_with_natural_language(
838
  style: str = Form("seaborn-v0_8")
839
  ):
840
  try:
841
- # Read and process Excel file
842
  content = await file.read()
843
  df = pd.read_excel(BytesIO(content))
844
 
845
- # Automatic type conversion
846
  for col in df.columns:
847
- # Try converting to numeric first
848
  df[col] = pd.to_numeric(df[col], errors='ignore')
849
- # Then try dates
850
  try:
851
  df[col] = pd.to_datetime(df[col], errors='ignore')
852
  except:
853
  pass
 
 
854
 
855
  # Generate visualization request
856
  vis_request = interpret_natural_language(prompt, df.columns.tolist())
 
 
857
 
858
- # Create plot with type-safe data
859
  plt.style.use(style)
860
  fig, ax = plt.subplots(figsize=(10, 6))
861
 
862
- if vis_request.chart_type == "heatmap":
863
- numeric_df = df.select_dtypes(include=['number'])
864
- sns.heatmap(numeric_df.corr(), annot=True)
865
- else:
866
- # Type-safe plotting
867
- plot_data = df.copy()
868
- if vis_request.x_column:
869
- plot_data[vis_request.x_column] = pd.to_numeric(
870
- plot_data[vis_request.x_column],
871
- errors='ignore'
872
- )
873
- if vis_request.y_column:
874
- plot_data[vis_request.y_column] = pd.to_numeric(
875
- plot_data[vis_request.y_column],
876
- errors='ignore'
877
- )
878
-
879
- if vis_request.chart_type == "line":
880
- sns.lineplot(
881
- data=plot_data,
882
- x=vis_request.x_column,
883
- y=vis_request.y_column,
884
- hue=vis_request.hue_column
885
- )
886
- elif vis_request.chart_type == "bar":
887
- sns.barplot(
888
- data=plot_data,
889
- x=vis_request.x_column,
890
- y=vis_request.y_column,
891
- hue=vis_request.hue_column
892
- )
893
- # Add other chart types similarly...
894
-
895
- plt.title(vis_request.title)
896
- buffer = BytesIO()
897
- plt.savefig(buffer, format='png')
898
- plt.close()
899
-
900
- return {
901
- "image": base64.b64encode(buffer.getvalue()).decode('utf-8'),
902
- "chart_type": vis_request.chart_type,
903
- "columns": list(df.columns)
904
- }
905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
  except Exception as e:
907
- raise HTTPException(400, f"Visualization failed: {str(e)}")
908
 
909
 
910
  @app.get("/visualize/styles")
 
838
  style: str = Form("seaborn-v0_8")
839
  ):
840
  try:
841
+ # Read Excel file
842
  content = await file.read()
843
  df = pd.read_excel(BytesIO(content))
844
 
845
+ # Clean and convert data types
846
  for col in df.columns:
847
+ # First try numeric conversion
848
  df[col] = pd.to_numeric(df[col], errors='ignore')
849
+ # Then try datetime conversion
850
  try:
851
  df[col] = pd.to_datetime(df[col], errors='ignore')
852
  except:
853
  pass
854
+ # Finally clean any remaining strings
855
+ df[col] = df[col].astype(str).str.strip().replace('nan', np.nan)
856
 
857
  # Generate visualization request
858
  vis_request = interpret_natural_language(prompt, df.columns.tolist())
859
+ if not vis_request:
860
+ raise HTTPException(400, "Could not interpret visualization request")
861
 
862
+ # Create visualization with type safety
863
  plt.style.use(style)
864
  fig, ax = plt.subplots(figsize=(10, 6))
865
 
866
+ try:
867
+ if vis_request.chart_type == "heatmap":
868
+ numeric_df = df.select_dtypes(include=['number'])
869
+ if numeric_df.empty:
870
+ raise ValueError("No numeric columns for heatmap")
871
+ sns.heatmap(numeric_df.corr(), annot=True)
872
+ else:
873
+ # Ensure numeric data for plotting
874
+ plot_data = df.copy()
875
+ if vis_request.x_column:
876
+ plot_data[vis_request.x_column] = pd.to_numeric(
877
+ plot_data[vis_request.x_column],
878
+ errors='coerce'
879
+ )
880
+ if vis_request.y_column:
881
+ plot_data[vis_request.y_column] = pd.to_numeric(
882
+ plot_data[vis_request.y_column],
883
+ errors='coerce'
884
+ )
885
+
886
+ # Remove rows with missing numeric data
887
+ plot_data = plot_data.dropna()
888
+
889
+ if vis_request.chart_type == "line":
890
+ sns.lineplot(
891
+ data=plot_data,
892
+ x=vis_request.x_column,
893
+ y=vis_request.y_column,
894
+ hue=vis_request.hue_column
895
+ )
896
+ elif vis_request.chart_type == "bar":
897
+ sns.barplot(
898
+ data=plot_data,
899
+ x=vis_request.x_column,
900
+ y=vis_request.y_column,
901
+ hue=vis_request.hue_column
902
+ )
903
+ # Add other chart types as needed...
 
 
 
 
 
904
 
905
+ plt.title(vis_request.title)
906
+ buffer = BytesIO()
907
+ plt.savefig(buffer, format='png', bbox_inches='tight')
908
+ plt.close()
909
+
910
+ return {
911
+ "status": "success",
912
+ "image": base64.b64encode(buffer.getvalue()).decode('utf-8'),
913
+ "chart_type": vis_request.chart_type,
914
+ "columns": list(df.columns),
915
+ "x_column": vis_request.x_column,
916
+ "y_column": vis_request.y_column,
917
+ "hue_column": vis_request.hue_column
918
+ }
919
+
920
+ except Exception as e:
921
+ raise HTTPException(400, f"Plotting error: {str(e)}")
922
+
923
+ except HTTPException:
924
+ raise
925
  except Exception as e:
926
+ raise HTTPException(500, f"Server error: {str(e)}")
927
 
928
 
929
  @app.get("/visualize/styles")