Spaces:
Sleeping
Sleeping
| import os | |
| import warnings | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| warnings.filterwarnings("ignore") | |
| # limits for this demo space | |
| MAX_ROWS = 5_000 | |
| MAX_COLS = 10 | |
| MAX_EPOCHS = 6 | |
| MAX_TRAINING_MINUTES = 20 | |
| # try to import mostly ai sdk, but allow the app to load without it | |
| try: | |
| from mostlyai.sdk import MostlyAI | |
| MOSTLY_AI_AVAILABLE = True | |
| except ImportError: | |
| MOSTLY_AI_AVAILABLE = False | |
| print("warning: mostly ai sdk not available. please install with: pip install mostlyai[local]") | |
| class SyntheticDataGenerator: | |
| def __init__(self): | |
| self.mostly = None | |
| self.generator = None | |
| self.original_data = None | |
| def initialize_mostly_ai(self) -> Tuple[bool, str]: | |
| if not MOSTLY_AI_AVAILABLE: | |
| return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]." | |
| try: | |
| self.mostly = MostlyAI(local=True, local_port=8080) | |
| return True, "Mostly AI SDK initialized successfully." | |
| except Exception as e: | |
| return False, f"Failed to initialize Mostly AI SDK: {str(e)}" | |
| def train_generator( | |
| self, | |
| data: pd.DataFrame, | |
| name: str, | |
| epochs: int = 10, | |
| max_training_time: int = 30, | |
| batch_size: int = 32, | |
| value_protection: bool = True, | |
| rare_category_protection: bool = False, | |
| flexible_generation: bool = False, | |
| model_size: str = "MEDIUM", | |
| target_accuracy: float = 0.95, | |
| validation_split: float = 0.2, | |
| learning_rate: float = 0.001, | |
| early_stopping_patience: int = 10, | |
| dropout_rate: float = 0.1, | |
| weight_decay: float = 0.0001, | |
| ) -> Tuple[bool, str]: | |
| if not self.mostly: | |
| return False, "Mostly AI SDK not initialized. Please initialize the SDK first." | |
| try: | |
| self.original_data = data | |
| train_config = { | |
| "tables": [ | |
| { | |
| "name": name, | |
| "data": data, | |
| "tabular_model_configuration": { | |
| "max_epochs": epochs, | |
| "max_training_time": max_training_time, | |
| "value_protection": value_protection, | |
| "batch_size": batch_size, | |
| "rare_category_protection": rare_category_protection, | |
| "flexible_generation": flexible_generation, | |
| "model_size": model_size, | |
| "target_accuracy": target_accuracy, | |
| "validation_split": validation_split, | |
| "learning_rate": learning_rate, | |
| "early_stopping_patience": early_stopping_patience, | |
| "dropout_rate": dropout_rate, | |
| "weight_decay": weight_decay, | |
| }, | |
| } | |
| ] | |
| } | |
| self.generator = self.mostly.train(config=train_config) | |
| return True, f"Training completed successfully. Model name: {name}" | |
| except Exception as e: | |
| return False, f"Training failed with error: {str(e)}" | |
| def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]: | |
| if not self.generator: | |
| return None, "No trained generator available. Please train a model first." | |
| try: | |
| synthetic_data = self.mostly.generate(self.generator, size=size) | |
| df = synthetic_data.data() | |
| return df, f"Synthetic data generated successfully. {len(df)} records created." | |
| except Exception as e: | |
| return None, f"Synthetic data generation failed with error: {str(e)}" | |
| def estimate_memory_usage(self, df: pd.DataFrame) -> str: | |
| if df is None or df.empty: | |
| return "No data available to analyze." | |
| memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024) | |
| rows, cols = len(df), len(df.columns) | |
| estimated_training_mb = memory_mb * 4 | |
| status = "Good" if memory_mb < 100 else ("Large" if memory_mb < 500 else "Very Large") | |
| return f""" | |
| **Memory Usage Estimate** | |
| - Data size: {memory_mb:.1f} MB | |
| - Estimated training memory: {estimated_training_mb:.1f} MB | |
| - Status: {status} | |
| - Rows: {rows:,} | Columns: {cols} | |
| """.strip() | |
| # keep global state for generator and last dataset | |
| generator = SyntheticDataGenerator() | |
| _last_synth_df: Optional[pd.DataFrame] = None | |
| def initialize_sdk() -> str: | |
| ok, msg = generator.initialize_mostly_ai() | |
| return ("Success: " if ok else "Error: ") + msg | |
| def train_model( | |
| data: pd.DataFrame, | |
| model_name: str, | |
| epochs: int, | |
| max_training_time: int, | |
| batch_size: int, | |
| value_protection: bool, | |
| rare_category_protection: bool, | |
| flexible_generation: bool, | |
| model_size: str, | |
| target_accuracy: float, | |
| validation_split: float, | |
| learning_rate: float, | |
| early_stopping_patience: int, | |
| dropout_rate: float, | |
| weight_decay: float, | |
| ) -> str: | |
| if data is None or data.empty: | |
| return "Error: No data provided. Please upload or create sample data first." | |
| # enforce backend caps regardless of ui inputs | |
| epochs = min(int(epochs), MAX_EPOCHS) | |
| max_training_time = min(int(max_training_time), MAX_TRAINING_MINUTES) | |
| ok, msg = generator.train_generator( | |
| data=data, | |
| name=model_name, | |
| epochs=epochs, | |
| max_training_time=max_training_time, | |
| batch_size=batch_size, | |
| value_protection=value_protection, | |
| rare_category_protection=rare_category_protection, | |
| flexible_generation=flexible_generation, | |
| model_size=model_size, | |
| target_accuracy=target_accuracy, | |
| validation_split=validation_split, | |
| learning_rate=learning_rate, | |
| early_stopping_patience=early_stopping_patience, | |
| dropout_rate=dropout_rate, | |
| weight_decay=weight_decay, | |
| ) | |
| return ("Success: " if ok else "Error: ") + msg | |
| def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]: | |
| global _last_synth_df | |
| synth_df, message = generator.generate_synthetic_data(size) | |
| if synth_df is not None: | |
| _last_synth_df = synth_df.copy() | |
| return synth_df, f"Success: {message}" | |
| else: | |
| return None, f"Error: {message}" | |
| def download_csv_prepare() -> Optional[str]: | |
| """Return a path to the latest synthetic CSV; used as output to gr.File.""" | |
| global _last_synth_df | |
| if _last_synth_df is None or _last_synth_df.empty: | |
| return None | |
| os.makedirs("/tmp", exist_ok=True) | |
| path = "/tmp/synthetic_data.csv" | |
| _last_synth_df.to_csv(path, index=False) | |
| return path | |
| def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame): | |
| if original_df is None or synthetic_df is None: | |
| return None | |
| numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist() | |
| if not numeric_cols: | |
| return None | |
| n_cols = min(3, len(numeric_cols)) | |
| n_rows = (len(numeric_cols) + n_cols - 1) // n_cols | |
| fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols]) | |
| for i, col in enumerate(numeric_cols[: n_rows * n_cols]): | |
| row = i // n_cols + 1 | |
| col_idx = i % n_cols + 1 | |
| fig.add_trace(go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx) | |
| fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx) | |
| fig.update_layout(title="Original vs. Synthetic Data Comparison", height=300 * n_rows, showlegend=True) | |
| return fig | |
| def create_interface(): | |
| with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Image( | |
| value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png", | |
| show_label=False, | |
| elem_id="header-image", | |
| ) | |
| gr.Markdown( | |
| f""" | |
| # Synthetic Data SDK by MOSTLY AI – Demo Space | |
| [Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/) | |
| A Python toolkit for generating high-fidelity, privacy-safe synthetic data. This is a limited demo space intended to showcase the features of the [Synthetic Data SDK](https://github.com/mostly-ai/mostlyai). | |
| **Demo Space Limitations:** Datasets are supported up to **{MAX_ROWS:,} rows** and **{MAX_COLS} columns**. | |
| Training is supported up to **≤ {MAX_EPOCHS} epochs** and **≤ {MAX_TRAINING_MINUTES} minutes**. | |
| """ | |
| ) | |
| with gr.Tab("Quick Start"): | |
| gr.Markdown("### Initialize the SDK and Upload Your Data") | |
| with gr.Row(): | |
| with gr.Column(): | |
| init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary") | |
| init_status = gr.Textbox(label="Initialization Status", interactive=False) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Next Steps** | |
| 1. Initialize the SDK. | |
| 2. Go to the “Upload Data and Train Model” tab to upload your CSV file. | |
| 3. Train a model on your data. | |
| 4. Generate synthetic data. | |
| """ | |
| ) | |
| with gr.Tab("Upload Data and Train Model"): | |
| gr.Markdown("### Upload Your CSV File to Generate Synthetic Data") | |
| gr.Markdown( | |
| f""" | |
| **File Requirements & Limits** | |
| - Format: CSV with a header row. | |
| - Size: Optimized for Hugging Face Spaces (2 vCPU, 16 GB RAM). | |
| - This app will automatically trim to the first **{MAX_ROWS:,}** rows and first **{MAX_COLS}** columns. | |
| """ | |
| ) | |
| file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single") | |
| uploaded_data = gr.Dataframe(label="Uploaded (Trimmed) Data", interactive=False) | |
| memory_info = gr.Markdown(label="Memory Usage Info", visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_name = gr.Textbox( | |
| value="My Synthetic Model", | |
| label="Generator Name", | |
| placeholder="Enter a name for your generator.", | |
| info="Appears in training runs and saved generators." | |
| ) | |
| epochs = gr.Slider( | |
| 1, MAX_EPOCHS, value=MAX_EPOCHS, step=1, label=f"Training Epochs (≤ {MAX_EPOCHS})", | |
| info=f"Maximum number of passes over the training data. Capped at {MAX_EPOCHS}." | |
| ) | |
| max_training_time = gr.Slider( | |
| 1, MAX_TRAINING_MINUTES, value=MAX_TRAINING_MINUTES, step=1, | |
| label=f"Maximum Training Time (minutes, ≤ {MAX_TRAINING_MINUTES})", | |
| info=f"Upper bound in minutes; training stops if exceeded. Capped at {MAX_TRAINING_MINUTES}." | |
| ) | |
| batch_size = gr.Slider( | |
| 8, 1024, value=32, step=8, label="Batch Size", | |
| info="Number of rows per optimization step. Larger can speed up but requires more memory." | |
| ) | |
| value_protection = gr.Checkbox( | |
| label="Value Protection", | |
| info="Adds protections to reduce memorization of unique or sensitive values.", | |
| value=False | |
| ) | |
| rare_category_protection = gr.Checkbox( | |
| label="Rare Category Protection", | |
| info="Prevents overfitting to infrequent categories to improve privacy and robustness.", | |
| value=False | |
| ) | |
| with gr.Column(scale=1): | |
| flexible_generation = gr.Checkbox( | |
| label="Flexible Generation", | |
| info="Allows generation when inputs slightly differ from the training schema.", | |
| value=True | |
| ) | |
| model_size = gr.Dropdown( | |
| choices=["SMALL", "MEDIUM", "LARGE"], | |
| value="MEDIUM", | |
| label="Model Size", | |
| info="Sets model capacity. Larger can improve fidelity but uses more compute." | |
| ) | |
| target_accuracy = gr.Slider( | |
| 0.50, 0.999, value=0.95, step=0.001, label="Target Accuracy", | |
| info="Stop early when validation accuracy reaches this threshold." | |
| ) | |
| validation_split = gr.Slider( | |
| 0.05, 0.5, value=0.2, step=0.01, label="Validation Split", | |
| info="Fraction of the dataset held out for validation during training." | |
| ) | |
| early_stopping_patience = gr.Slider( | |
| 0, 50, value=10, step=1, label="Early Stopping Patience (epochs)", | |
| info="Stop if no validation improvement after this many epochs." | |
| ) | |
| with gr.Column(scale=1): | |
| learning_rate = gr.Number( | |
| value=0.001, precision=6, label="Learning Rate", | |
| info="Step size for the optimizer. Typical range: 1e-4 to 1e-2." | |
| ) | |
| dropout_rate = gr.Slider( | |
| 0.0, 0.6, value=0.1, step=0.01, label="Dropout Rate", | |
| info="Regularization to reduce overfitting by randomly dropping units." | |
| ) | |
| weight_decay = gr.Number( | |
| value=0.0001, precision=6, label="Weight Decay", | |
| info="L2 regularization strength applied to model weights." | |
| ) | |
| train_btn = gr.Button("Train Model", variant="primary") | |
| train_status = gr.Textbox(label="Training Status", interactive=False) | |
| with gr.Tab("Generate Data"): | |
| gr.Markdown("### Generate Synthetic Data From Your Trained Model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gen_size = gr.Slider( | |
| 10, 1000, value=100, step=10, label="Number of Records to Generate", | |
| info="How many synthetic rows to create in the table." | |
| ) | |
| generate_btn = gr.Button("Generate Synthetic Data", variant="primary") | |
| with gr.Column(): | |
| gen_status = gr.Textbox(label="Generation Status", interactive=False) | |
| synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False) | |
| with gr.Row(): | |
| csv_download_btn = gr.Button("Download CSV", variant="secondary") | |
| with gr.Group(visible=False) as csv_group: | |
| csv_file = gr.File(label="Synthetic CSV", interactive=False) | |
| comparison_plot = gr.Plot(label="Data Comparison") | |
| init_btn.click(initialize_sdk, outputs=[init_status]) | |
| train_btn.click( | |
| train_model, | |
| inputs=[ | |
| uploaded_data, model_name, | |
| epochs, max_training_time, batch_size, | |
| value_protection, rare_category_protection, flexible_generation, | |
| model_size, target_accuracy, validation_split, | |
| learning_rate, early_stopping_patience, dropout_rate, weight_decay | |
| ], | |
| outputs=[train_status], | |
| ) | |
| generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status]) | |
| synthetic_data.change(create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot]) | |
| def _prepare_csv_for_download(): | |
| path = download_csv_prepare() | |
| if path: | |
| return path, gr.update(visible=True) | |
| else: | |
| return None, gr.update(visible=False) | |
| csv_download_btn.click( | |
| _prepare_csv_for_download, | |
| outputs=[csv_file, csv_group], | |
| ) | |
| def process_uploaded_file(file): | |
| if file is None: | |
| return None, "No file uploaded.", gr.update(visible=False) | |
| try: | |
| df = pd.read_csv(file.name) | |
| original_shape = df.shape | |
| if df.shape[1] > MAX_COLS: | |
| df = df.iloc[:, :MAX_COLS].copy() | |
| if df.shape[0] > MAX_ROWS: | |
| df = df.iloc[:MAX_ROWS].copy() | |
| trimmed_note = "" | |
| if df.shape != original_shape: | |
| trimmed_note = ( | |
| f" (trimmed to {df.shape[0]:,} rows × {df.shape[1]} columns " | |
| f"from {original_shape[0]:,} × {original_shape[1]})" | |
| ) | |
| success_msg = f"File uploaded successfully.{trimmed_note}" | |
| mem_info = generator.estimate_memory_usage(df) | |
| return df, success_msg, gr.update(value=mem_info, visible=True) | |
| except Exception as e: | |
| return None, f"Error reading file: {str(e)}", gr.update(visible=False) | |
| file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info]) | |
| return demo | |
| if __name__ == "__main__": | |
| sdk_demo = create_interface() | |
| sdk_demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |