Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,12 +33,12 @@ class SyntheticDataGenerator:
|
|
| 33 |
|
| 34 |
def initialize_mostly_ai(self) -> Tuple[bool, str]:
|
| 35 |
if not MOSTLY_AI_AVAILABLE:
|
| 36 |
-
return False, "
|
| 37 |
try:
|
| 38 |
self.mostly = MostlyAI(local=True, local_port=8080)
|
| 39 |
-
return True, "
|
| 40 |
except Exception as e:
|
| 41 |
-
return False, f"
|
| 42 |
|
| 43 |
def train_generator(
|
| 44 |
self,
|
|
@@ -59,7 +59,7 @@ class SyntheticDataGenerator:
|
|
| 59 |
weight_decay: float = 0.0001,
|
| 60 |
) -> Tuple[bool, str]:
|
| 61 |
if not self.mostly:
|
| 62 |
-
return False, "
|
| 63 |
try:
|
| 64 |
self.original_data = data
|
| 65 |
train_config = {
|
|
@@ -86,33 +86,33 @@ class SyntheticDataGenerator:
|
|
| 86 |
]
|
| 87 |
}
|
| 88 |
self.generator = self.mostly.train(config=train_config)
|
| 89 |
-
return True, f"
|
| 90 |
except Exception as e:
|
| 91 |
-
return False, f"
|
| 92 |
|
| 93 |
def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
| 94 |
if not self.generator:
|
| 95 |
-
return None, "
|
| 96 |
try:
|
| 97 |
synthetic_data = self.mostly.generate(self.generator, size=size)
|
| 98 |
df = synthetic_data.data()
|
| 99 |
-
return df, f"
|
| 100 |
except Exception as e:
|
| 101 |
-
return None, f"
|
| 102 |
|
| 103 |
def estimate_memory_usage(self, df: pd.DataFrame) -> str:
|
| 104 |
if df is None or df.empty:
|
| 105 |
-
return "
|
| 106 |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
|
| 107 |
rows, cols = len(df), len(df.columns)
|
| 108 |
estimated_training_mb = memory_mb * 4
|
| 109 |
-
status = "
|
| 110 |
return f"""
|
| 111 |
-
|
| 112 |
-
-
|
| 113 |
-
-
|
| 114 |
-
-
|
| 115 |
-
-
|
| 116 |
""".strip()
|
| 117 |
|
| 118 |
|
|
@@ -123,7 +123,7 @@ _last_synth_df: Optional[pd.DataFrame] = None
|
|
| 123 |
|
| 124 |
def initialize_sdk() -> str:
|
| 125 |
ok, msg = generator.initialize_mostly_ai()
|
| 126 |
-
return ("
|
| 127 |
|
| 128 |
|
| 129 |
def train_model(
|
|
@@ -144,7 +144,7 @@ def train_model(
|
|
| 144 |
weight_decay: float,
|
| 145 |
) -> str:
|
| 146 |
if data is None or data.empty:
|
| 147 |
-
return "
|
| 148 |
|
| 149 |
# enforce backend caps regardless of ui inputs
|
| 150 |
epochs = min(int(epochs), MAX_EPOCHS)
|
|
@@ -167,7 +167,7 @@ def train_model(
|
|
| 167 |
dropout_rate=dropout_rate,
|
| 168 |
weight_decay=weight_decay,
|
| 169 |
)
|
| 170 |
-
return ("
|
| 171 |
|
| 172 |
|
| 173 |
def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
|
@@ -175,13 +175,13 @@ def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
|
| 175 |
synth_df, message = generator.generate_synthetic_data(size)
|
| 176 |
if synth_df is not None:
|
| 177 |
_last_synth_df = synth_df.copy()
|
| 178 |
-
return synth_df, f"
|
| 179 |
else:
|
| 180 |
-
return None, f"
|
| 181 |
|
| 182 |
|
| 183 |
def download_csv_prepare() -> Optional[str]:
|
| 184 |
-
"""
|
| 185 |
global _last_synth_df
|
| 186 |
if _last_synth_df is None or _last_synth_df.empty:
|
| 187 |
return None
|
|
@@ -203,14 +203,14 @@ def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame
|
|
| 203 |
for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
|
| 204 |
row = i // n_cols + 1
|
| 205 |
col_idx = i % n_cols + 1
|
| 206 |
-
fig.add_trace(go.Histogram(x=original_df[col], name=f"
|
| 207 |
-
fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"
|
| 208 |
-
fig.update_layout(title="
|
| 209 |
return fig
|
| 210 |
|
| 211 |
|
| 212 |
def create_interface():
|
| 213 |
-
with gr.Blocks(title="
|
| 214 |
gr.Image(
|
| 215 |
value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
|
| 216 |
show_label=False,
|
|
@@ -225,131 +225,133 @@ def create_interface():
|
|
| 225 |
|
| 226 |
A Python toolkit for generating high-fidelity, privacy-safe synthetic data. This is a limited demo space intended to showcase the features of the [Synthetic Data SDK](https://github.com/mostly-ai/mostlyai).
|
| 227 |
|
| 228 |
-
**Demo Space Limitations
|
| 229 |
Training is supported up to **≤ {MAX_EPOCHS} epochs** and **≤ {MAX_TRAINING_MINUTES} minutes**.
|
| 230 |
-
|
| 231 |
)
|
| 232 |
|
| 233 |
-
with gr.Tab("
|
| 234 |
-
gr.Markdown("###
|
| 235 |
with gr.Row():
|
| 236 |
with gr.Column():
|
| 237 |
-
init_btn = gr.Button("
|
| 238 |
-
init_status = gr.Textbox(label="
|
| 239 |
with gr.Column():
|
| 240 |
gr.Markdown(
|
| 241 |
"""
|
| 242 |
-
**
|
| 243 |
-
1.
|
| 244 |
-
2.
|
| 245 |
-
3.
|
| 246 |
-
4.
|
| 247 |
"""
|
| 248 |
)
|
| 249 |
|
| 250 |
-
with gr.Tab("
|
| 251 |
-
gr.Markdown("###
|
| 252 |
gr.Markdown(
|
| 253 |
f"""
|
| 254 |
-
**
|
| 255 |
-
-
|
| 256 |
-
-
|
| 257 |
-
-
|
| 258 |
"""
|
| 259 |
)
|
| 260 |
|
| 261 |
-
file_upload = gr.File(label="
|
| 262 |
-
uploaded_data = gr.Dataframe(label="
|
| 263 |
-
memory_info = gr.Markdown(label="
|
| 264 |
|
| 265 |
with gr.Row():
|
| 266 |
with gr.Column(scale=1):
|
| 267 |
model_name = gr.Textbox(
|
| 268 |
-
value="
|
| 269 |
-
label="
|
| 270 |
-
placeholder="
|
| 271 |
-
info="
|
| 272 |
)
|
| 273 |
epochs = gr.Slider(
|
| 274 |
-
1, MAX_EPOCHS, value=MAX_EPOCHS, step=1, label=f"
|
| 275 |
-
info=f"
|
| 276 |
)
|
| 277 |
max_training_time = gr.Slider(
|
| 278 |
1, MAX_TRAINING_MINUTES, value=MAX_TRAINING_MINUTES, step=1,
|
| 279 |
-
label=f"
|
| 280 |
-
info=f"
|
| 281 |
)
|
| 282 |
batch_size = gr.Slider(
|
| 283 |
-
8, 1024, value=32, step=8, label="
|
| 284 |
-
info="
|
| 285 |
)
|
| 286 |
value_protection = gr.Checkbox(
|
| 287 |
-
label="
|
| 288 |
-
info="
|
| 289 |
value=False
|
| 290 |
)
|
| 291 |
rare_category_protection = gr.Checkbox(
|
| 292 |
-
label="
|
| 293 |
-
info="
|
| 294 |
value=False
|
| 295 |
)
|
| 296 |
with gr.Column(scale=1):
|
| 297 |
flexible_generation = gr.Checkbox(
|
| 298 |
-
label="
|
| 299 |
-
info="
|
| 300 |
value=True
|
| 301 |
)
|
| 302 |
model_size = gr.Dropdown(
|
| 303 |
choices=["SMALL", "MEDIUM", "LARGE"],
|
| 304 |
value="MEDIUM",
|
| 305 |
-
label="
|
| 306 |
-
info="
|
| 307 |
)
|
| 308 |
target_accuracy = gr.Slider(
|
| 309 |
-
0.50, 0.999, value=0.95, step=0.001, label="
|
| 310 |
-
info="
|
| 311 |
)
|
| 312 |
validation_split = gr.Slider(
|
| 313 |
-
0.05, 0.5, value=0.2, step=0.01, label="
|
| 314 |
-
info="
|
| 315 |
)
|
| 316 |
early_stopping_patience = gr.Slider(
|
| 317 |
-
0, 50, value=10, step=1, label="
|
| 318 |
-
info="
|
| 319 |
)
|
| 320 |
with gr.Column(scale=1):
|
| 321 |
learning_rate = gr.Number(
|
| 322 |
-
value=0.001, precision=6, label="
|
| 323 |
-
info="
|
| 324 |
)
|
| 325 |
dropout_rate = gr.Slider(
|
| 326 |
-
0.0, 0.6, value=0.1, step=0.01, label="
|
| 327 |
-
info="
|
| 328 |
)
|
| 329 |
weight_decay = gr.Number(
|
| 330 |
-
value=0.0001, precision=6, label="
|
| 331 |
-
info="
|
| 332 |
)
|
| 333 |
-
train_btn = gr.Button("
|
| 334 |
-
train_status = gr.Textbox(label="
|
| 335 |
|
| 336 |
-
with gr.Tab("
|
| 337 |
-
gr.Markdown("###
|
| 338 |
with gr.Row():
|
| 339 |
with gr.Column():
|
| 340 |
-
gen_size = gr.Slider(
|
| 341 |
-
|
| 342 |
-
|
|
|
|
|
|
|
| 343 |
with gr.Column():
|
| 344 |
-
gen_status = gr.Textbox(label="
|
| 345 |
|
| 346 |
-
synthetic_data = gr.Dataframe(label="
|
| 347 |
|
| 348 |
with gr.Row():
|
| 349 |
-
csv_download_btn = gr.Button("
|
| 350 |
with gr.Group(visible=False) as csv_group:
|
| 351 |
-
csv_file = gr.File(label="
|
| 352 |
-
comparison_plot = gr.Plot(label="
|
| 353 |
|
| 354 |
init_btn.click(initialize_sdk, outputs=[init_status])
|
| 355 |
|
|
@@ -383,7 +385,7 @@ def create_interface():
|
|
| 383 |
|
| 384 |
def process_uploaded_file(file):
|
| 385 |
if file is None:
|
| 386 |
-
return None, "
|
| 387 |
try:
|
| 388 |
df = pd.read_csv(file.name)
|
| 389 |
original_shape = df.shape
|
|
@@ -393,12 +395,15 @@ def create_interface():
|
|
| 393 |
df = df.iloc[:MAX_ROWS].copy()
|
| 394 |
trimmed_note = ""
|
| 395 |
if df.shape != original_shape:
|
| 396 |
-
trimmed_note =
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
| 398 |
mem_info = generator.estimate_memory_usage(df)
|
| 399 |
return df, success_msg, gr.update(value=mem_info, visible=True)
|
| 400 |
except Exception as e:
|
| 401 |
-
return None, f"
|
| 402 |
|
| 403 |
file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info])
|
| 404 |
|
|
@@ -406,5 +411,5 @@ def create_interface():
|
|
| 406 |
|
| 407 |
|
| 408 |
if __name__ == "__main__":
|
| 409 |
-
|
| 410 |
-
|
|
|
|
| 33 |
|
| 34 |
def initialize_mostly_ai(self) -> Tuple[bool, str]:
|
| 35 |
if not MOSTLY_AI_AVAILABLE:
|
| 36 |
+
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]."
|
| 37 |
try:
|
| 38 |
self.mostly = MostlyAI(local=True, local_port=8080)
|
| 39 |
+
return True, "Mostly AI SDK initialized successfully."
|
| 40 |
except Exception as e:
|
| 41 |
+
return False, f"Failed to initialize Mostly AI SDK: {str(e)}"
|
| 42 |
|
| 43 |
def train_generator(
|
| 44 |
self,
|
|
|
|
| 59 |
weight_decay: float = 0.0001,
|
| 60 |
) -> Tuple[bool, str]:
|
| 61 |
if not self.mostly:
|
| 62 |
+
return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
|
| 63 |
try:
|
| 64 |
self.original_data = data
|
| 65 |
train_config = {
|
|
|
|
| 86 |
]
|
| 87 |
}
|
| 88 |
self.generator = self.mostly.train(config=train_config)
|
| 89 |
+
return True, f"Training completed successfully. Model name: {name}"
|
| 90 |
except Exception as e:
|
| 91 |
+
return False, f"Training failed with error: {str(e)}"
|
| 92 |
|
| 93 |
def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
| 94 |
if not self.generator:
|
| 95 |
+
return None, "No trained generator available. Please train a model first."
|
| 96 |
try:
|
| 97 |
synthetic_data = self.mostly.generate(self.generator, size=size)
|
| 98 |
df = synthetic_data.data()
|
| 99 |
+
return df, f"Synthetic data generated successfully. {len(df)} records created."
|
| 100 |
except Exception as e:
|
| 101 |
+
return None, f"Synthetic data generation failed with error: {str(e)}"
|
| 102 |
|
| 103 |
def estimate_memory_usage(self, df: pd.DataFrame) -> str:
|
| 104 |
if df is None or df.empty:
|
| 105 |
+
return "No data available to analyze."
|
| 106 |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
|
| 107 |
rows, cols = len(df), len(df.columns)
|
| 108 |
estimated_training_mb = memory_mb * 4
|
| 109 |
+
status = "Good" if memory_mb < 100 else ("Large" if memory_mb < 500 else "Very Large")
|
| 110 |
return f"""
|
| 111 |
+
**Memory Usage Estimate**
|
| 112 |
+
- Data size: {memory_mb:.1f} MB
|
| 113 |
+
- Estimated training memory: {estimated_training_mb:.1f} MB
|
| 114 |
+
- Status: {status}
|
| 115 |
+
- Rows: {rows:,} | Columns: {cols}
|
| 116 |
""".strip()
|
| 117 |
|
| 118 |
|
|
|
|
| 123 |
|
| 124 |
def initialize_sdk() -> str:
|
| 125 |
ok, msg = generator.initialize_mostly_ai()
|
| 126 |
+
return ("Success: " if ok else "Error: ") + msg
|
| 127 |
|
| 128 |
|
| 129 |
def train_model(
|
|
|
|
| 144 |
weight_decay: float,
|
| 145 |
) -> str:
|
| 146 |
if data is None or data.empty:
|
| 147 |
+
return "Error: No data provided. Please upload or create sample data first."
|
| 148 |
|
| 149 |
# enforce backend caps regardless of ui inputs
|
| 150 |
epochs = min(int(epochs), MAX_EPOCHS)
|
|
|
|
| 167 |
dropout_rate=dropout_rate,
|
| 168 |
weight_decay=weight_decay,
|
| 169 |
)
|
| 170 |
+
return ("Success: " if ok else "Error: ") + msg
|
| 171 |
|
| 172 |
|
| 173 |
def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
|
|
|
|
| 175 |
synth_df, message = generator.generate_synthetic_data(size)
|
| 176 |
if synth_df is not None:
|
| 177 |
_last_synth_df = synth_df.copy()
|
| 178 |
+
return synth_df, f"Success: {message}"
|
| 179 |
else:
|
| 180 |
+
return None, f"Error: {message}"
|
| 181 |
|
| 182 |
|
| 183 |
def download_csv_prepare() -> Optional[str]:
|
| 184 |
+
"""Return a path to the latest synthetic CSV; used as output to gr.File."""
|
| 185 |
global _last_synth_df
|
| 186 |
if _last_synth_df is None or _last_synth_df.empty:
|
| 187 |
return None
|
|
|
|
| 203 |
for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
|
| 204 |
row = i // n_cols + 1
|
| 205 |
col_idx = i % n_cols + 1
|
| 206 |
+
fig.add_trace(go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
|
| 207 |
+
fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
|
| 208 |
+
fig.update_layout(title="Original vs. Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
|
| 209 |
return fig
|
| 210 |
|
| 211 |
|
| 212 |
def create_interface():
|
| 213 |
+
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
|
| 214 |
gr.Image(
|
| 215 |
value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
|
| 216 |
show_label=False,
|
|
|
|
| 225 |
|
| 226 |
A Python toolkit for generating high-fidelity, privacy-safe synthetic data. This is a limited demo space intended to showcase the features of the [Synthetic Data SDK](https://github.com/mostly-ai/mostlyai).
|
| 227 |
|
| 228 |
+
**Demo Space Limitations:** Datasets are supported up to **{MAX_ROWS:,} rows** and **{MAX_COLS} columns**.
|
| 229 |
Training is supported up to **≤ {MAX_EPOCHS} epochs** and **≤ {MAX_TRAINING_MINUTES} minutes**.
|
| 230 |
+
"""
|
| 231 |
)
|
| 232 |
|
| 233 |
+
with gr.Tab("Quick Start"):
|
| 234 |
+
gr.Markdown("### Initialize the SDK and Upload Your Data")
|
| 235 |
with gr.Row():
|
| 236 |
with gr.Column():
|
| 237 |
+
init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary")
|
| 238 |
+
init_status = gr.Textbox(label="Initialization Status", interactive=False)
|
| 239 |
with gr.Column():
|
| 240 |
gr.Markdown(
|
| 241 |
"""
|
| 242 |
+
**Next Steps**
|
| 243 |
+
1. Initialize the SDK.
|
| 244 |
+
2. Go to the “Upload Data and Train Model” tab to upload your CSV file.
|
| 245 |
+
3. Train a model on your data.
|
| 246 |
+
4. Generate synthetic data.
|
| 247 |
"""
|
| 248 |
)
|
| 249 |
|
| 250 |
+
with gr.Tab("Upload Data and Train Model"):
|
| 251 |
+
gr.Markdown("### Upload Your CSV File to Generate Synthetic Data")
|
| 252 |
gr.Markdown(
|
| 253 |
f"""
|
| 254 |
+
**File Requirements & Limits**
|
| 255 |
+
- Format: CSV with a header row.
|
| 256 |
+
- Size: Optimized for Hugging Face Spaces (2 vCPU, 16 GB RAM).
|
| 257 |
+
- This app will automatically trim to the first **{MAX_ROWS:,}** rows and first **{MAX_COLS}** columns.
|
| 258 |
"""
|
| 259 |
)
|
| 260 |
|
| 261 |
+
file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single")
|
| 262 |
+
uploaded_data = gr.Dataframe(label="Uploaded (Trimmed) Data", interactive=False)
|
| 263 |
+
memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
|
| 264 |
|
| 265 |
with gr.Row():
|
| 266 |
with gr.Column(scale=1):
|
| 267 |
model_name = gr.Textbox(
|
| 268 |
+
value="My Synthetic Model",
|
| 269 |
+
label="Generator Name",
|
| 270 |
+
placeholder="Enter a name for your generator.",
|
| 271 |
+
info="Appears in training runs and saved generators."
|
| 272 |
)
|
| 273 |
epochs = gr.Slider(
|
| 274 |
+
1, MAX_EPOCHS, value=MAX_EPOCHS, step=1, label=f"Training Epochs (≤ {MAX_EPOCHS})",
|
| 275 |
+
info=f"Maximum number of passes over the training data. Capped at {MAX_EPOCHS}."
|
| 276 |
)
|
| 277 |
max_training_time = gr.Slider(
|
| 278 |
1, MAX_TRAINING_MINUTES, value=MAX_TRAINING_MINUTES, step=1,
|
| 279 |
+
label=f"Maximum Training Time (minutes, ≤ {MAX_TRAINING_MINUTES})",
|
| 280 |
+
info=f"Upper bound in minutes; training stops if exceeded. Capped at {MAX_TRAINING_MINUTES}."
|
| 281 |
)
|
| 282 |
batch_size = gr.Slider(
|
| 283 |
+
8, 1024, value=32, step=8, label="Batch Size",
|
| 284 |
+
info="Number of rows per optimization step. Larger can speed up but requires more memory."
|
| 285 |
)
|
| 286 |
value_protection = gr.Checkbox(
|
| 287 |
+
label="Value Protection",
|
| 288 |
+
info="Adds protections to reduce memorization of unique or sensitive values.",
|
| 289 |
value=False
|
| 290 |
)
|
| 291 |
rare_category_protection = gr.Checkbox(
|
| 292 |
+
label="Rare Category Protection",
|
| 293 |
+
info="Prevents overfitting to infrequent categories to improve privacy and robustness.",
|
| 294 |
value=False
|
| 295 |
)
|
| 296 |
with gr.Column(scale=1):
|
| 297 |
flexible_generation = gr.Checkbox(
|
| 298 |
+
label="Flexible Generation",
|
| 299 |
+
info="Allows generation when inputs slightly differ from the training schema.",
|
| 300 |
value=True
|
| 301 |
)
|
| 302 |
model_size = gr.Dropdown(
|
| 303 |
choices=["SMALL", "MEDIUM", "LARGE"],
|
| 304 |
value="MEDIUM",
|
| 305 |
+
label="Model Size",
|
| 306 |
+
info="Sets model capacity. Larger can improve fidelity but uses more compute."
|
| 307 |
)
|
| 308 |
target_accuracy = gr.Slider(
|
| 309 |
+
0.50, 0.999, value=0.95, step=0.001, label="Target Accuracy",
|
| 310 |
+
info="Stop early when validation accuracy reaches this threshold."
|
| 311 |
)
|
| 312 |
validation_split = gr.Slider(
|
| 313 |
+
0.05, 0.5, value=0.2, step=0.01, label="Validation Split",
|
| 314 |
+
info="Fraction of the dataset held out for validation during training."
|
| 315 |
)
|
| 316 |
early_stopping_patience = gr.Slider(
|
| 317 |
+
0, 50, value=10, step=1, label="Early Stopping Patience (epochs)",
|
| 318 |
+
info="Stop if no validation improvement after this many epochs."
|
| 319 |
)
|
| 320 |
with gr.Column(scale=1):
|
| 321 |
learning_rate = gr.Number(
|
| 322 |
+
value=0.001, precision=6, label="Learning Rate",
|
| 323 |
+
info="Step size for the optimizer. Typical range: 1e-4 to 1e-2."
|
| 324 |
)
|
| 325 |
dropout_rate = gr.Slider(
|
| 326 |
+
0.0, 0.6, value=0.1, step=0.01, label="Dropout Rate",
|
| 327 |
+
info="Regularization to reduce overfitting by randomly dropping units."
|
| 328 |
)
|
| 329 |
weight_decay = gr.Number(
|
| 330 |
+
value=0.0001, precision=6, label="Weight Decay",
|
| 331 |
+
info="L2 regularization strength applied to model weights."
|
| 332 |
)
|
| 333 |
+
train_btn = gr.Button("Train Model", variant="primary")
|
| 334 |
+
train_status = gr.Textbox(label="Training Status", interactive=False)
|
| 335 |
|
| 336 |
+
with gr.Tab("Generate Data"):
|
| 337 |
+
gr.Markdown("### Generate Synthetic Data From Your Trained Model")
|
| 338 |
with gr.Row():
|
| 339 |
with gr.Column():
|
| 340 |
+
gen_size = gr.Slider(
|
| 341 |
+
10, 1000, value=100, step=10, label="Number of Records to Generate",
|
| 342 |
+
info="How many synthetic rows to create in the table."
|
| 343 |
+
)
|
| 344 |
+
generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
|
| 345 |
with gr.Column():
|
| 346 |
+
gen_status = gr.Textbox(label="Generation Status", interactive=False)
|
| 347 |
|
| 348 |
+
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
|
| 349 |
|
| 350 |
with gr.Row():
|
| 351 |
+
csv_download_btn = gr.Button("Download CSV", variant="secondary")
|
| 352 |
with gr.Group(visible=False) as csv_group:
|
| 353 |
+
csv_file = gr.File(label="Synthetic CSV", interactive=False)
|
| 354 |
+
comparison_plot = gr.Plot(label="Data Comparison")
|
| 355 |
|
| 356 |
init_btn.click(initialize_sdk, outputs=[init_status])
|
| 357 |
|
|
|
|
| 385 |
|
| 386 |
def process_uploaded_file(file):
|
| 387 |
if file is None:
|
| 388 |
+
return None, "No file uploaded.", gr.update(visible=False)
|
| 389 |
try:
|
| 390 |
df = pd.read_csv(file.name)
|
| 391 |
original_shape = df.shape
|
|
|
|
| 395 |
df = df.iloc[:MAX_ROWS].copy()
|
| 396 |
trimmed_note = ""
|
| 397 |
if df.shape != original_shape:
|
| 398 |
+
trimmed_note = (
|
| 399 |
+
f" (trimmed to {df.shape[0]:,} rows × {df.shape[1]} columns "
|
| 400 |
+
f"from {original_shape[0]:,} × {original_shape[1]})"
|
| 401 |
+
)
|
| 402 |
+
success_msg = f"File uploaded successfully.{trimmed_note}"
|
| 403 |
mem_info = generator.estimate_memory_usage(df)
|
| 404 |
return df, success_msg, gr.update(value=mem_info, visible=True)
|
| 405 |
except Exception as e:
|
| 406 |
+
return None, f"Error reading file: {str(e)}", gr.update(visible=False)
|
| 407 |
|
| 408 |
file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info])
|
| 409 |
|
|
|
|
| 411 |
|
| 412 |
|
| 413 |
if __name__ == "__main__":
|
| 414 |
+
sdk_demo = create_interface()
|
| 415 |
+
sdk_demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|