Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,7 +26,6 @@ class SyntheticDataGenerator:
|
|
| 26 |
self.original_data = None
|
| 27 |
|
| 28 |
def initialize_mostly_ai(self) -> Tuple[bool, str]:
|
| 29 |
-
"""Initialize Mostly AI SDK"""
|
| 30 |
if not MOSTLY_AI_AVAILABLE:
|
| 31 |
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
|
| 32 |
try:
|
|
@@ -44,7 +43,6 @@ class SyntheticDataGenerator:
|
|
| 44 |
batch_size: int = 32,
|
| 45 |
value_protection: bool = True,
|
| 46 |
) -> Tuple[bool, str]:
|
| 47 |
-
"""Train the synthetic data generator"""
|
| 48 |
if not self.mostly:
|
| 49 |
return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
|
| 50 |
try:
|
|
@@ -63,14 +61,12 @@ class SyntheticDataGenerator:
|
|
| 63 |
}
|
| 64 |
]
|
| 65 |
}
|
| 66 |
-
|
| 67 |
self.generator = self.mostly.train(config=train_config)
|
| 68 |
return True, f"Training completed successfully. Model name: {name}"
|
| 69 |
except Exception as e:
|
| 70 |
return False, f"Training failed with error: {str(e)}"
|
| 71 |
|
| 72 |
def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
| 73 |
-
"""Generate synthetic data"""
|
| 74 |
if not self.generator:
|
| 75 |
return None, "No trained generator available. Please train a model first."
|
| 76 |
try:
|
|
@@ -82,27 +78,28 @@ class SyntheticDataGenerator:
|
|
| 82 |
|
| 83 |
def get_quality_report_file(self) -> Optional[str]:
|
| 84 |
"""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
"""
|
| 88 |
if not self.generator:
|
| 89 |
return None
|
| 90 |
try:
|
| 91 |
rep = self.generator.reports(display=False)
|
| 92 |
|
| 93 |
-
#
|
| 94 |
if isinstance(rep, str) and rep.endswith(".zip") and os.path.exists(rep):
|
| 95 |
return rep
|
| 96 |
|
| 97 |
-
#
|
| 98 |
for attr in ("archive_path", "zip_path", "path", "file_path"):
|
| 99 |
if hasattr(rep, attr):
|
| 100 |
p = getattr(rep, attr)
|
| 101 |
if isinstance(p, str) and os.path.exists(p):
|
| 102 |
return p
|
| 103 |
|
| 104 |
-
#
|
| 105 |
-
|
|
|
|
| 106 |
if hasattr(rep, "save"):
|
| 107 |
try:
|
| 108 |
rep.save(target_zip)
|
|
@@ -118,8 +115,8 @@ class SyntheticDataGenerator:
|
|
| 118 |
except Exception:
|
| 119 |
pass
|
| 120 |
|
| 121 |
-
#
|
| 122 |
-
target_txt = "/
|
| 123 |
with open(target_txt, "w", encoding="utf-8") as f:
|
| 124 |
f.write(str(rep))
|
| 125 |
return target_txt
|
|
@@ -128,21 +125,12 @@ class SyntheticDataGenerator:
|
|
| 128 |
return None
|
| 129 |
|
| 130 |
def estimate_memory_usage(self, df: pd.DataFrame) -> str:
|
| 131 |
-
"""Estimate memory usage for the dataset"""
|
| 132 |
if df is None or df.empty:
|
| 133 |
return "No data available to analyze."
|
| 134 |
-
|
| 135 |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
|
| 136 |
rows, cols = len(df), len(df.columns)
|
| 137 |
estimated_training_mb = memory_mb * 4
|
| 138 |
-
|
| 139 |
-
if memory_mb < 100:
|
| 140 |
-
status = "Good"
|
| 141 |
-
elif memory_mb < 500:
|
| 142 |
-
status = "Large"
|
| 143 |
-
else:
|
| 144 |
-
status = "Very Large"
|
| 145 |
-
|
| 146 |
return f"""
|
| 147 |
Memory Usage Estimate:
|
| 148 |
- Data size: {memory_mb:.1f} MB
|
|
@@ -152,10 +140,12 @@ Memory Usage Estimate:
|
|
| 152 |
""".strip()
|
| 153 |
|
| 154 |
|
| 155 |
-
#
|
| 156 |
generator = SyntheticDataGenerator()
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
# ----
|
| 159 |
def initialize_sdk() -> str:
|
| 160 |
ok, msg = generator.initialize_mostly_ai()
|
| 161 |
return ("Success: " if ok else "Error: ") + msg
|
|
@@ -178,62 +168,53 @@ def train_model(
|
|
| 178 |
|
| 179 |
|
| 180 |
def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
-
def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame)
|
| 187 |
if original_df is None or synthetic_df is None:
|
| 188 |
return None
|
| 189 |
-
|
| 190 |
numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
|
| 191 |
if not numeric_cols:
|
| 192 |
return None
|
| 193 |
-
|
| 194 |
n_cols = min(3, len(numeric_cols))
|
| 195 |
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
|
| 196 |
-
|
| 197 |
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols])
|
| 198 |
-
|
| 199 |
for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
|
| 200 |
row = i // n_cols + 1
|
| 201 |
col_idx = i % n_cols + 1
|
| 202 |
-
|
| 203 |
-
fig.add_trace(
|
| 204 |
-
go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20),
|
| 205 |
-
row=row,
|
| 206 |
-
col=col_idx,
|
| 207 |
-
)
|
| 208 |
-
fig.add_trace(
|
| 209 |
-
go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20),
|
| 210 |
-
row=row,
|
| 211 |
-
col=col_idx,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
|
| 215 |
return fig
|
| 216 |
|
| 217 |
|
| 218 |
-
def download_csv(df: pd.DataFrame) -> Optional[str]:
|
| 219 |
-
if df is None or df.empty:
|
| 220 |
-
return None
|
| 221 |
-
path = "/mnt/data/synthetic_data.csv"
|
| 222 |
-
df.to_csv(path, index=False)
|
| 223 |
-
return path
|
| 224 |
-
|
| 225 |
-
|
| 226 |
# ---- UI ----
|
| 227 |
def create_interface():
|
| 228 |
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
|
| 229 |
-
# Header image
|
| 230 |
gr.Image(
|
| 231 |
value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
|
| 232 |
show_label=False,
|
| 233 |
elem_id="header-image",
|
| 234 |
)
|
| 235 |
|
| 236 |
-
# README
|
| 237 |
gr.Markdown(
|
| 238 |
"""
|
| 239 |
# Synthetic Data SDK by MOSTLY AI Demo Space
|
|
@@ -289,6 +270,7 @@ def create_interface():
|
|
| 289 |
train_status = gr.Textbox(label="Training Status", interactive=False)
|
| 290 |
|
| 291 |
with gr.Row():
|
|
|
|
| 292 |
get_report_btn = gr.DownloadButton("Get Quality Report", variant="secondary")
|
| 293 |
|
| 294 |
with gr.Tab("Generate Data"):
|
|
@@ -302,10 +284,11 @@ def create_interface():
|
|
| 302 |
|
| 303 |
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
|
| 304 |
with gr.Row():
|
|
|
|
| 305 |
download_btn = gr.DownloadButton("Download CSV", variant="secondary")
|
| 306 |
comparison_plot = gr.Plot(label="Data Comparison")
|
| 307 |
|
| 308 |
-
# ----
|
| 309 |
init_btn.click(initialize_sdk, outputs=[init_status])
|
| 310 |
|
| 311 |
train_btn.click(
|
|
@@ -314,21 +297,18 @@ def create_interface():
|
|
| 314 |
outputs=[train_status],
|
| 315 |
)
|
| 316 |
|
| 317 |
-
#
|
| 318 |
-
get_report_btn.click(generator.get_quality_report_file, outputs=
|
| 319 |
|
| 320 |
-
# Generate data
|
| 321 |
generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status])
|
| 322 |
|
| 323 |
-
# Update CSV DownloadButton whenever synthetic data changes
|
| 324 |
-
synthetic_data.change(download_csv, inputs=[synthetic_data], outputs=[download_btn])
|
| 325 |
-
|
| 326 |
# Build comparison plot when both datasets are available
|
| 327 |
-
synthetic_data.change(
|
| 328 |
-
|
| 329 |
-
)
|
|
|
|
| 330 |
|
| 331 |
-
#
|
| 332 |
def process_uploaded_file(file):
|
| 333 |
if file is None:
|
| 334 |
return None, "No file uploaded.", gr.update(visible=False)
|
|
|
|
| 26 |
self.original_data = None
|
| 27 |
|
| 28 |
def initialize_mostly_ai(self) -> Tuple[bool, str]:
|
|
|
|
| 29 |
if not MOSTLY_AI_AVAILABLE:
|
| 30 |
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
|
| 31 |
try:
|
|
|
|
| 43 |
batch_size: int = 32,
|
| 44 |
value_protection: bool = True,
|
| 45 |
) -> Tuple[bool, str]:
|
|
|
|
| 46 |
if not self.mostly:
|
| 47 |
return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
|
| 48 |
try:
|
|
|
|
| 61 |
}
|
| 62 |
]
|
| 63 |
}
|
|
|
|
| 64 |
self.generator = self.mostly.train(config=train_config)
|
| 65 |
return True, f"Training completed successfully. Model name: {name}"
|
| 66 |
except Exception as e:
|
| 67 |
return False, f"Training failed with error: {str(e)}"
|
| 68 |
|
| 69 |
def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
|
|
|
| 70 |
if not self.generator:
|
| 71 |
return None, "No trained generator available. Please train a model first."
|
| 72 |
try:
|
|
|
|
| 78 |
|
| 79 |
def get_quality_report_file(self) -> Optional[str]:
|
| 80 |
"""
|
| 81 |
+
Build/export the quality report and return a file path for immediate download.
|
| 82 |
+
Uses /tmp for Spaces; tries ZIP, falls back to TXT.
|
| 83 |
"""
|
| 84 |
if not self.generator:
|
| 85 |
return None
|
| 86 |
try:
|
| 87 |
rep = self.generator.reports(display=False)
|
| 88 |
|
| 89 |
+
# If a string path to a .zip is returned
|
| 90 |
if isinstance(rep, str) and rep.endswith(".zip") and os.path.exists(rep):
|
| 91 |
return rep
|
| 92 |
|
| 93 |
+
# If object exposes a path-like attribute
|
| 94 |
for attr in ("archive_path", "zip_path", "path", "file_path"):
|
| 95 |
if hasattr(rep, attr):
|
| 96 |
p = getattr(rep, attr)
|
| 97 |
if isinstance(p, str) and os.path.exists(p):
|
| 98 |
return p
|
| 99 |
|
| 100 |
+
# Try saving/exporting
|
| 101 |
+
os.makedirs("/tmp", exist_ok=True)
|
| 102 |
+
target_zip = "/tmp/quality_report.zip"
|
| 103 |
if hasattr(rep, "save"):
|
| 104 |
try:
|
| 105 |
rep.save(target_zip)
|
|
|
|
| 115 |
except Exception:
|
| 116 |
pass
|
| 117 |
|
| 118 |
+
# Fallback: stringify into TXT
|
| 119 |
+
target_txt = "/tmp/quality_report.txt"
|
| 120 |
with open(target_txt, "w", encoding="utf-8") as f:
|
| 121 |
f.write(str(rep))
|
| 122 |
return target_txt
|
|
|
|
| 125 |
return None
|
| 126 |
|
| 127 |
def estimate_memory_usage(self, df: pd.DataFrame) -> str:
|
|
|
|
| 128 |
if df is None or df.empty:
|
| 129 |
return "No data available to analyze."
|
|
|
|
| 130 |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
|
| 131 |
rows, cols = len(df), len(df.columns)
|
| 132 |
estimated_training_mb = memory_mb * 4
|
| 133 |
+
status = "Good" if memory_mb < 100 else ("Large" if memory_mb < 500 else "Very Large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
return f"""
|
| 135 |
Memory Usage Estimate:
|
| 136 |
- Data size: {memory_mb:.1f} MB
|
|
|
|
| 140 |
""".strip()
|
| 141 |
|
| 142 |
|
| 143 |
+
# App state
|
| 144 |
generator = SyntheticDataGenerator()
|
| 145 |
+
_last_synth_df: Optional[pd.DataFrame] = None # store latest synthetic DF for download
|
| 146 |
+
|
| 147 |
|
| 148 |
+
# ---- Gradio wrappers ----
|
| 149 |
def initialize_sdk() -> str:
|
| 150 |
ok, msg = generator.initialize_mostly_ai()
|
| 151 |
return ("Success: " if ok else "Error: ") + msg
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
| 171 |
+
global _last_synth_df
|
| 172 |
+
synth_df, message = generator.generate_synthetic_data(size)
|
| 173 |
+
if synth_df is not None:
|
| 174 |
+
_last_synth_df = synth_df.copy()
|
| 175 |
+
return synth_df, f"Success: {message}"
|
| 176 |
+
else:
|
| 177 |
+
return None, f"Error: {message}"
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def download_csv_now() -> Optional[str]:
|
| 181 |
+
"""Write the most recent synthetic DF to /tmp and return the path for direct download."""
|
| 182 |
+
global _last_synth_df
|
| 183 |
+
if _last_synth_df is None or _last_synth_df.empty:
|
| 184 |
+
return None
|
| 185 |
+
os.makedirs("/tmp", exist_ok=True)
|
| 186 |
+
path = "/tmp/synthetic_data.csv"
|
| 187 |
+
_last_synth_df.to_csv(path, index=False)
|
| 188 |
+
return path
|
| 189 |
|
| 190 |
|
| 191 |
+
def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame):
|
| 192 |
if original_df is None or synthetic_df is None:
|
| 193 |
return None
|
|
|
|
| 194 |
numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
|
| 195 |
if not numeric_cols:
|
| 196 |
return None
|
|
|
|
| 197 |
n_cols = min(3, len(numeric_cols))
|
| 198 |
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
|
|
|
|
| 199 |
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols])
|
|
|
|
| 200 |
for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
|
| 201 |
row = i // n_cols + 1
|
| 202 |
col_idx = i % n_cols + 1
|
| 203 |
+
fig.add_trace(go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
|
| 204 |
+
fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
|
| 206 |
return fig
|
| 207 |
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# ---- UI ----
|
| 210 |
def create_interface():
|
| 211 |
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
|
|
|
|
| 212 |
gr.Image(
|
| 213 |
value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
|
| 214 |
show_label=False,
|
| 215 |
elem_id="header-image",
|
| 216 |
)
|
| 217 |
|
|
|
|
| 218 |
gr.Markdown(
|
| 219 |
"""
|
| 220 |
# Synthetic Data SDK by MOSTLY AI Demo Space
|
|
|
|
| 270 |
train_status = gr.Textbox(label="Training Status", interactive=False)
|
| 271 |
|
| 272 |
with gr.Row():
|
| 273 |
+
# This download button calls a function that returns a file path β download starts immediately
|
| 274 |
get_report_btn = gr.DownloadButton("Get Quality Report", variant="secondary")
|
| 275 |
|
| 276 |
with gr.Tab("Generate Data"):
|
|
|
|
| 284 |
|
| 285 |
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
|
| 286 |
with gr.Row():
|
| 287 |
+
# Same pattern: click β function returns the CSV path β immediate download
|
| 288 |
download_btn = gr.DownloadButton("Download CSV", variant="secondary")
|
| 289 |
comparison_plot = gr.Plot(label="Data Comparison")
|
| 290 |
|
| 291 |
+
# ---- Events ----
|
| 292 |
init_btn.click(initialize_sdk, outputs=[init_status])
|
| 293 |
|
| 294 |
train_btn.click(
|
|
|
|
| 297 |
outputs=[train_status],
|
| 298 |
)
|
| 299 |
|
| 300 |
+
# IMPORTANT: For DownloadButton, do NOT specify outputs β the returned path is auto-downloaded.
|
| 301 |
+
get_report_btn.click(generator.get_quality_report_file, inputs=None, outputs=None)
|
| 302 |
|
|
|
|
| 303 |
generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status])
|
| 304 |
|
|
|
|
|
|
|
|
|
|
| 305 |
# Build comparison plot when both datasets are available
|
| 306 |
+
synthetic_data.change(create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot])
|
| 307 |
+
|
| 308 |
+
# CSV download: return a path from the click handler (no outputs)
|
| 309 |
+
download_btn.click(download_csv_now, inputs=None, outputs=None)
|
| 310 |
|
| 311 |
+
# File upload handler
|
| 312 |
def process_uploaded_file(file):
|
| 313 |
if file is None:
|
| 314 |
return None, "No file uploaded.", gr.update(visible=False)
|