Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import base64 | |
| import warnings | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| warnings.filterwarnings("ignore") | |
| # Import Mostly AI SDK | |
| try: | |
| from mostlyai.sdk import MostlyAI | |
| MOSTLY_AI_AVAILABLE = True | |
| except ImportError: | |
| MOSTLY_AI_AVAILABLE = False | |
| print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]") | |
| class SyntheticDataGenerator: | |
| def __init__(self): | |
| self.mostly = None | |
| self.generator = None | |
| self.original_data = None | |
| def initialize_mostly_ai(self) -> Tuple[bool, str]: | |
| """Initialize Mostly AI SDK""" | |
| if not MOSTLY_AI_AVAILABLE: | |
| return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]" | |
| try: | |
| self.mostly = MostlyAI(local=True, local_port=8080) | |
| return True, "Mostly AI SDK initialized successfully." | |
| except Exception as e: | |
| return False, f"Failed to initialize Mostly AI SDK: {str(e)}" | |
| def train_generator( | |
| self, | |
| data: pd.DataFrame, | |
| name: str, | |
| epochs: int = 10, | |
| max_training_time: int = 60, | |
| batch_size: int = 32, | |
| value_protection: bool = True, | |
| ) -> Tuple[bool, str]: | |
| """Train the synthetic data generator""" | |
| if not self.mostly: | |
| return False, "Mostly AI SDK not initialized. Please initialize the SDK first." | |
| try: | |
| self.original_data = data | |
| train_config = { | |
| "tables": [ | |
| { | |
| "name": name, | |
| "data": data, | |
| "tabular_model_configuration": { | |
| "max_epochs": epochs, | |
| "max_training_time": max_training_time, | |
| "value_protection": value_protection, | |
| "batch_size": batch_size, | |
| }, | |
| } | |
| ] | |
| } | |
| self.generator = self.mostly.train(config=train_config) | |
| return True, f"Training completed successfully. Model name: {name}" | |
| except Exception as e: | |
| return False, f"Training failed with error: {str(e)}" | |
| def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]: | |
| """Generate synthetic data""" | |
| if not self.generator: | |
| return None, "No trained generator available. Please train a model first." | |
| try: | |
| synthetic_data = self.mostly.generate(self.generator, size=size) | |
| df = synthetic_data.data() | |
| return df, f"Synthetic data generated successfully. {len(df)} records created." | |
| except Exception as e: | |
| return None, f"Synthetic data generation failed with error: {str(e)}" | |
| # ---- Report helpers (new) ---- | |
| def get_quality_report_text(self) -> str: | |
| """Return a concise status about the report.""" | |
| if not self.generator: | |
| return "No trained generator available. Please train a model first." | |
| try: | |
| _ = self.generator.reports(display=False) # builds report internally | |
| return "Quality report generated. Use the button to download." | |
| except Exception as e: | |
| return f"Failed to generate quality report: {str(e)}" | |
| def get_quality_report_file(self) -> Optional[str]: | |
| """ | |
| Generate/export the report and return a file path for download. | |
| Tries to find an existing ZIP; otherwise saves a TXT fallback. | |
| """ | |
| if not self.generator: | |
| return None | |
| try: | |
| rep = self.generator.reports(display=False) | |
| # 1) If a string path to a .zip is returned | |
| if isinstance(rep, str) and rep.endswith(".zip") and os.path.exists(rep): | |
| return rep | |
| # 2) If the object exposes a path-like attribute | |
| for attr in ("archive_path", "zip_path", "path", "file_path"): | |
| if hasattr(rep, attr): | |
| p = getattr(rep, attr) | |
| if isinstance(p, str) and os.path.exists(p): | |
| return p | |
| # 3) If the object can save/export itself | |
| target_zip = "/mnt/data/quality_report.zip" | |
| if hasattr(rep, "save"): | |
| try: | |
| rep.save(target_zip) | |
| if os.path.exists(target_zip): | |
| return target_zip | |
| except Exception: | |
| pass | |
| if hasattr(rep, "export"): | |
| try: | |
| rep.export(target_zip) | |
| if os.path.exists(target_zip): | |
| return target_zip | |
| except Exception: | |
| pass | |
| # 4) Fallback: write string representation | |
| target_txt = "/mnt/data/quality_report.txt" | |
| with open(target_txt, "w", encoding="utf-8") as f: | |
| f.write(str(rep)) | |
| return target_txt | |
| except Exception: | |
| return None | |
| def estimate_memory_usage(self, df: pd.DataFrame) -> str: | |
| """Estimate memory usage for the dataset""" | |
| if df is None or df.empty: | |
| return "No data available to analyze." | |
| memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024) | |
| rows, cols = len(df), len(df.columns) | |
| estimated_training_mb = memory_mb * 4 | |
| if memory_mb < 100: | |
| status = "Good" | |
| elif memory_mb < 500: | |
| status = "Large" | |
| else: | |
| status = "Very Large" | |
| return f""" | |
| Memory Usage Estimate: | |
| - Data size: {memory_mb:.1f} MB | |
| - Estimated training memory: {estimated_training_mb:.1f} MB | |
| - Status: {status} | |
| - Rows: {rows:,} | Columns: {cols} | |
| """.strip() | |
| # Initialize the generator | |
| generator = SyntheticDataGenerator() | |
| # ---- Wrapper functions for Gradio ---- | |
| def initialize_sdk() -> str: | |
| ok, msg = generator.initialize_mostly_ai() | |
| return ("Success: " if ok else "Error: ") + msg | |
| def train_model( | |
| data: pd.DataFrame, | |
| model_name: str, | |
| epochs: int, | |
| max_training_time: int, | |
| batch_size: int, | |
| value_protection: bool, | |
| ) -> str: | |
| if data is None or data.empty: | |
| return "Error: No data provided. Please upload or create sample data first." | |
| ok, msg = generator.train_generator( | |
| data, model_name, epochs, max_training_time, batch_size, value_protection | |
| ) | |
| return ("Success: " if ok else "Error: ") + msg | |
| def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]: | |
| synthetic_df, message = generator.generate_synthetic_data(size) | |
| status = "Success" if synthetic_df is not None else "Error" | |
| return synthetic_df, f"{status}: {message}" | |
| def get_quality_report_and_file(): | |
| """ | |
| Return (status_text, download_component_update) | |
| The second value updates the DownloadButton with the file path and visibility. | |
| """ | |
| status = generator.get_quality_report_text() | |
| path = generator.get_quality_report_file() | |
| if path: | |
| return status, gr.update(value=path, visible=True) | |
| else: | |
| # keep it hidden if we don't have a file | |
| return status, gr.update(visible=False) | |
| def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> Optional[go.Figure]: | |
| if original_df is None or synthetic_df is None: | |
| return None | |
| numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist() | |
| if not numeric_cols: | |
| return None | |
| n_cols = min(3, len(numeric_cols)) | |
| n_rows = (len(numeric_cols) + n_cols - 1) // n_cols | |
| fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols]) | |
| for i, col in enumerate(numeric_cols[: n_rows * n_cols]): | |
| row = i // n_cols + 1 | |
| col_idx = i % n_cols + 1 | |
| fig.add_trace( | |
| go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), | |
| row=row, | |
| col=col_idx, | |
| ) | |
| fig.add_trace( | |
| go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), | |
| row=row, | |
| col=col_idx, | |
| ) | |
| fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True) | |
| return fig | |
| def download_csv(df: pd.DataFrame) -> Optional[str]: | |
| if df is None or df.empty: | |
| return None | |
| # Write CSV to a stable path so DownloadButton can fetch it | |
| path = "/mnt/data/synthetic_data.csv" | |
| df.to_csv(path, index=False) | |
| return path | |
| # ---- UI ---- | |
| def create_interface(): | |
| with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo: | |
| # Header image | |
| gr.Image( | |
| value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png", | |
| show_label=False, | |
| elem_id="header-image", | |
| ) | |
| # README | |
| gr.Markdown( | |
| """ | |
| # Synthetic Data SDK by MOSTLY AI Demo Space | |
| [Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/) | |
| A Python toolkit for generating high-fidelity, privacy-safe synthetic data. | |
| """ | |
| ) | |
| with gr.Tab("Quick Start"): | |
| gr.Markdown("### Initialize the SDK and upload your data") | |
| with gr.Row(): | |
| with gr.Column(): | |
| init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary") | |
| init_status = gr.Textbox(label="Initialization Status", interactive=False) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| **Next Steps:** | |
| 1. Initialize the SDK (click button above) | |
| 2. Go to "Upload Data and Train Model" tab to upload your CSV file | |
| 3. Train a model on your data | |
| 4. Generate synthetic data | |
| """ | |
| ) | |
| with gr.Tab("Upload Data and Train Model"): | |
| gr.Markdown("### Upload your CSV file to generate synthetic data") | |
| gr.Markdown( | |
| """ | |
| **File Requirements:** | |
| - Format: CSV with header row | |
| - Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM) | |
| """ | |
| ) | |
| file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single") | |
| uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False) | |
| memory_info = gr.Markdown(label="Memory Usage Info", visible=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name = gr.Textbox( | |
| value="My Synthetic Model", label="Model Name", placeholder="Enter a name for your model" | |
| ) | |
| epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs") | |
| max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time") | |
| batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size") | |
| value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection") | |
| train_btn = gr.Button("Train Model", variant="primary") | |
| with gr.Column(): | |
| train_status = gr.Textbox(label="Training Status", interactive=False) | |
| quality_report = gr.Textbox(label="Quality Report", lines=8, interactive=False) | |
| with gr.Row(): | |
| get_report_btn = gr.Button("Get Quality Report", variant="secondary") | |
| report_download_btn = gr.DownloadButton("Download Quality Report", visible=False) | |
| with gr.Tab("Generate Data"): | |
| gr.Markdown("### Generate synthetic data from your trained model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate") | |
| generate_btn = gr.Button("Generate Synthetic Data", variant="primary") | |
| with gr.Column(): | |
| gen_status = gr.Textbox(label="Generation Status", interactive=False) | |
| synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False) | |
| with gr.Row(): | |
| download_btn = gr.DownloadButton("Download CSV", file_name="synthetic_data.csv", variant="secondary") | |
| comparison_plot = gr.Plot(label="Data Comparison") | |
| # README footer | |
| gr.Markdown( | |
| """ | |
| **Modes of operation:** | |
| - **LOCAL mode** trains and generates synthetic data on your own compute resources. | |
| - **CLIENT mode** connects to a remote MOSTLY AI platform for training and generation. | |
| - Generators trained locally can be imported to the platform for sharing and collaboration. | |
| **Key resources managed by the SDK:** | |
| - **Generators**: Train on your tabular or language data assets. | |
| - **Synthetic datasets**: Generate any number of synthetic samples as needed. | |
| - **Connectors**: Connect to organizational data sources for reading and writing data. | |
| **Common intents and API primitives:** | |
| - Train a generator: `g = mostly.train(config)` | |
| - Generate records: `sd = mostly.generate(g, config)` | |
| - Probe generator: `df = mostly.probe(g, config)` | |
| - Connect to data source: `c = mostly.connect(config)` | |
| The open source Synthetic Data SDK by MOSTLY AI powers the MOSTLY AI Platform and MOSTLY AI Assistant. | |
| Sign up for free and try the [MOSTLY AI Platform](https://app.mostly.ai/) today! | |
| """ | |
| ) | |
| # ---- Event handlers ---- | |
| init_btn.click(initialize_sdk, outputs=[init_status]) | |
| train_btn.click( | |
| train_model, | |
| inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection], | |
| outputs=[train_status], | |
| ) | |
| # Build + expose quality report for download | |
| get_report_btn.click( | |
| get_quality_report_and_file, | |
| outputs=[quality_report, report_download_btn], | |
| ) | |
| # Generate data | |
| generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status]) | |
| # Update CSV DownloadButton whenever synthetic data changes | |
| synthetic_data.change(download_csv, inputs=[synthetic_data], outputs=[download_btn]) | |
| # Build comparison plot when both datasets are available | |
| synthetic_data.change( | |
| create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot] | |
| ) | |
| # Handle file upload with size and column limits | |
| def process_uploaded_file(file): | |
| if file is None: | |
| return None, "No file uploaded.", gr.update(visible=False) | |
| try: | |
| df = pd.read_csv(file.name) | |
| success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns" | |
| mem_info = generator.estimate_memory_usage(df) | |
| return df, success_msg, gr.update(value=mem_info, visible=True) | |
| except Exception as e: | |
| return None, f"Error reading file: {str(e)}", gr.update(visible=False) | |
| file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |