Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -838,73 +838,92 @@ async def visualize_with_natural_language(
|
|
| 838 |
style: str = Form("seaborn-v0_8")
|
| 839 |
):
|
| 840 |
try:
|
| 841 |
-
# Read
|
| 842 |
content = await file.read()
|
| 843 |
df = pd.read_excel(BytesIO(content))
|
| 844 |
|
| 845 |
-
#
|
| 846 |
for col in df.columns:
|
| 847 |
-
#
|
| 848 |
df[col] = pd.to_numeric(df[col], errors='ignore')
|
| 849 |
-
# Then try
|
| 850 |
try:
|
| 851 |
df[col] = pd.to_datetime(df[col], errors='ignore')
|
| 852 |
except:
|
| 853 |
pass
|
|
|
|
|
|
|
| 854 |
|
| 855 |
# Generate visualization request
|
| 856 |
vis_request = interpret_natural_language(prompt, df.columns.tolist())
|
|
|
|
|
|
|
| 857 |
|
| 858 |
-
# Create
|
| 859 |
plt.style.use(style)
|
| 860 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 861 |
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
return {
|
| 901 |
-
"image": base64.b64encode(buffer.getvalue()).decode('utf-8'),
|
| 902 |
-
"chart_type": vis_request.chart_type,
|
| 903 |
-
"columns": list(df.columns)
|
| 904 |
-
}
|
| 905 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 906 |
except Exception as e:
|
| 907 |
-
raise HTTPException(
|
| 908 |
|
| 909 |
|
| 910 |
@app.get("/visualize/styles")
|
|
|
|
| 838 |
style: str = Form("seaborn-v0_8")
|
| 839 |
):
|
| 840 |
try:
|
| 841 |
+
# Read Excel file
|
| 842 |
content = await file.read()
|
| 843 |
df = pd.read_excel(BytesIO(content))
|
| 844 |
|
| 845 |
+
# Clean and convert data types
|
| 846 |
for col in df.columns:
|
| 847 |
+
# First try numeric conversion
|
| 848 |
df[col] = pd.to_numeric(df[col], errors='ignore')
|
| 849 |
+
# Then try datetime conversion
|
| 850 |
try:
|
| 851 |
df[col] = pd.to_datetime(df[col], errors='ignore')
|
| 852 |
except:
|
| 853 |
pass
|
| 854 |
+
# Finally clean any remaining strings
|
| 855 |
+
df[col] = df[col].astype(str).str.strip().replace('nan', np.nan)
|
| 856 |
|
| 857 |
# Generate visualization request
|
| 858 |
vis_request = interpret_natural_language(prompt, df.columns.tolist())
|
| 859 |
+
if not vis_request:
|
| 860 |
+
raise HTTPException(400, "Could not interpret visualization request")
|
| 861 |
|
| 862 |
+
# Create visualization with type safety
|
| 863 |
plt.style.use(style)
|
| 864 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 865 |
|
| 866 |
+
try:
|
| 867 |
+
if vis_request.chart_type == "heatmap":
|
| 868 |
+
numeric_df = df.select_dtypes(include=['number'])
|
| 869 |
+
if numeric_df.empty:
|
| 870 |
+
raise ValueError("No numeric columns for heatmap")
|
| 871 |
+
sns.heatmap(numeric_df.corr(), annot=True)
|
| 872 |
+
else:
|
| 873 |
+
# Ensure numeric data for plotting
|
| 874 |
+
plot_data = df.copy()
|
| 875 |
+
if vis_request.x_column:
|
| 876 |
+
plot_data[vis_request.x_column] = pd.to_numeric(
|
| 877 |
+
plot_data[vis_request.x_column],
|
| 878 |
+
errors='coerce'
|
| 879 |
+
)
|
| 880 |
+
if vis_request.y_column:
|
| 881 |
+
plot_data[vis_request.y_column] = pd.to_numeric(
|
| 882 |
+
plot_data[vis_request.y_column],
|
| 883 |
+
errors='coerce'
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
# Remove rows with missing numeric data
|
| 887 |
+
plot_data = plot_data.dropna()
|
| 888 |
+
|
| 889 |
+
if vis_request.chart_type == "line":
|
| 890 |
+
sns.lineplot(
|
| 891 |
+
data=plot_data,
|
| 892 |
+
x=vis_request.x_column,
|
| 893 |
+
y=vis_request.y_column,
|
| 894 |
+
hue=vis_request.hue_column
|
| 895 |
+
)
|
| 896 |
+
elif vis_request.chart_type == "bar":
|
| 897 |
+
sns.barplot(
|
| 898 |
+
data=plot_data,
|
| 899 |
+
x=vis_request.x_column,
|
| 900 |
+
y=vis_request.y_column,
|
| 901 |
+
hue=vis_request.hue_column
|
| 902 |
+
)
|
| 903 |
+
# Add other chart types as needed...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
|
| 905 |
+
plt.title(vis_request.title)
|
| 906 |
+
buffer = BytesIO()
|
| 907 |
+
plt.savefig(buffer, format='png', bbox_inches='tight')
|
| 908 |
+
plt.close()
|
| 909 |
+
|
| 910 |
+
return {
|
| 911 |
+
"status": "success",
|
| 912 |
+
"image": base64.b64encode(buffer.getvalue()).decode('utf-8'),
|
| 913 |
+
"chart_type": vis_request.chart_type,
|
| 914 |
+
"columns": list(df.columns),
|
| 915 |
+
"x_column": vis_request.x_column,
|
| 916 |
+
"y_column": vis_request.y_column,
|
| 917 |
+
"hue_column": vis_request.hue_column
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
except Exception as e:
|
| 921 |
+
raise HTTPException(400, f"Plotting error: {str(e)}")
|
| 922 |
+
|
| 923 |
+
except HTTPException:
|
| 924 |
+
raise
|
| 925 |
except Exception as e:
|
| 926 |
+
raise HTTPException(500, f"Server error: {str(e)}")
|
| 927 |
|
| 928 |
|
| 929 |
@app.get("/visualize/styles")
|