Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import io | |
| import base64 | |
| from typing import Optional, Tuple | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Import Mostly AI SDK | |
| try: | |
| from mostlyai.sdk import MostlyAI | |
| MOSTLY_AI_AVAILABLE = True | |
| except ImportError: | |
| MOSTLY_AI_AVAILABLE = False | |
| print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]") | |
| class SyntheticDataGenerator: | |
| def __init__(self): | |
| self.mostly = None | |
| self.generator = None | |
| self.original_data = None | |
| def initialize_mostly_ai(self): | |
| """Initialize Mostly AI SDK""" | |
| if not MOSTLY_AI_AVAILABLE: | |
| return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]" | |
| try: | |
| self.mostly = MostlyAI(local=True, local_port=8080) | |
| return True, "Mostly AI SDK initialized successfully!" | |
| except Exception as e: | |
| return False, f"Failed to initialize Mostly AI SDK: {str(e)}" | |
| def train_generator(self, data: pd.DataFrame, name: str, epochs: int = 10, max_training_time: int = 60, batch_size: int = 32, value_protection: bool = True) -> Tuple[bool, str]: | |
| """Train the synthetic data generator""" | |
| if not self.mostly: | |
| return False, "Mostly AI SDK not initialized" | |
| try: | |
| self.original_data = data | |
| train_config = {'tables': | |
| [ | |
| { | |
| 'name': name, | |
| 'data': data, | |
| 'tabular_model_configuration': | |
| { | |
| 'max_epochs': epochs, | |
| 'max_training_time': max_training_time, | |
| 'value_protection': value_protection, | |
| 'batch_size': batch_size | |
| } | |
| } | |
| ] | |
| } | |
| self.generator = self.mostly.train( | |
| config = train_config | |
| ) | |
| return True, f"Generator trained successfully! Model: {name}" | |
| except Exception as e: | |
| return False, f"Training failed: {str(e)}" | |
| def generate_synthetic_data(self, size: int) -> Tuple[pd.DataFrame, str]: | |
| """Generate synthetic data""" | |
| if not self.generator: | |
| return None, "No trained generator available" | |
| try: | |
| synthetic_data = self.mostly.generate(self.generator, size=size) | |
| df = synthetic_data.data() | |
| return df, f"Generated {len(df)} synthetic records successfully!" | |
| except Exception as e: | |
| return None, f"Generation failed: {str(e)}" | |
| def get_quality_report(self) -> str: | |
| """Get quality assurance report""" | |
| if not self.generator: | |
| return "No trained generator available" | |
| try: | |
| report = self.generator.reports(display=False) | |
| return str(report) | |
| except Exception as e: | |
| return f"Failed to generate report: {str(e)}" | |
| def estimate_memory_usage(self, df: pd.DataFrame) -> str: | |
| """Estimate memory usage for the dataset""" | |
| if df is None or df.empty: | |
| return "No data to analyze" | |
| # Calculate approximate memory usage | |
| memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024) | |
| rows, cols = len(df), len(df.columns) | |
| # Estimate training memory (roughly 3-5x the data size) | |
| estimated_training_mb = memory_mb * 4 | |
| status = "β Good" if memory_mb < 100 else "β οΈ Large" if memory_mb < 500 else "β Very Large" | |
| return f""" | |
| **Memory Usage Estimate:** | |
| - Data size: {memory_mb:.1f} MB | |
| - Estimated training memory: {estimated_training_mb:.1f} MB | |
| - Status: {status} | |
| - Rows: {rows:,} | Columns: {cols} | |
| """.strip() | |
| # Initialize the generator | |
| generator = SyntheticDataGenerator() | |
| def initialize_sdk() -> Tuple[str, str]: | |
| """Initialize the Mostly AI SDK""" | |
| success, message = generator.initialize_mostly_ai() | |
| status = "β Success" if success else "β Error" | |
| return status, message | |
| def train_model(data: pd.DataFrame, model_name: str, epochs: int, max_training_time: int, batch_size: int, value_protection: bool) -> Tuple[str, str]: | |
| """Train the synthetic data generator""" | |
| if data is None or data.empty: | |
| return "β Error", "Please upload or create sample data first" | |
| success, message = generator.train_generator(data, model_name, epochs, max_training_time, batch_size, value_protection) | |
| status = "β Success" if success else "β Error" | |
| return status, message | |
| def generate_data(size: int) -> Tuple[pd.DataFrame, str]: | |
| """Generate synthetic data""" | |
| if generator.generator is None: | |
| return None, "β Please train a model first" | |
| synthetic_df, message = generator.generate_synthetic_data(size) | |
| if synthetic_df is not None: | |
| status = "β Success" | |
| else: | |
| status = "β Error" | |
| return synthetic_df, f"{status} - {message}" | |
| def get_quality_report() -> str: | |
| """Get quality report""" | |
| return generator.get_quality_report() | |
| def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> go.Figure: | |
| """Create comparison plots between original and synthetic data""" | |
| if original_df is None or synthetic_df is None: | |
| return None | |
| # Select numeric columns for comparison | |
| numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist() | |
| if not numeric_cols: | |
| return None | |
| # Create subplots | |
| n_cols = min(3, len(numeric_cols)) | |
| n_rows = (len(numeric_cols) + n_cols - 1) // n_cols | |
| fig = make_subplots( | |
| rows=n_rows, | |
| cols=n_cols, | |
| subplot_titles=numeric_cols[:n_rows*n_cols] | |
| ) | |
| for i, col in enumerate(numeric_cols[:n_rows*n_cols]): | |
| row = i // n_cols + 1 | |
| col_idx = i % n_cols + 1 | |
| # Add original data histogram | |
| fig.add_trace( | |
| go.Histogram( | |
| x=original_df[col], | |
| name=f'Original {col}', | |
| opacity=0.7, | |
| nbinsx=20 | |
| ), | |
| row=row, col=col_idx | |
| ) | |
| # Add synthetic data histogram | |
| fig.add_trace( | |
| go.Histogram( | |
| x=synthetic_df[col], | |
| name=f'Synthetic {col}', | |
| opacity=0.7, | |
| nbinsx=20 | |
| ), | |
| row=row, col=col_idx | |
| ) | |
| fig.update_layout( | |
| title="Original vs Synthetic Data Comparison", | |
| height=300 * n_rows, | |
| showlegend=True | |
| ) | |
| return fig | |
| def download_csv(df: pd.DataFrame) -> str: | |
| """Convert DataFrame to CSV for download""" | |
| if df is None or df.empty: | |
| return None | |
| csv = df.to_csv(index=False) | |
| return csv | |
| # Create the Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π MOSTLY AI Synthetic Data Generator | |
| Generate high-quality synthetic data using the Mostly AI SDK. Upload your own CSV files to generate synthetic data that preserves the statistical properties of your original dataset. | |
| """) | |
| with gr.Tab("π Quick Start"): | |
| gr.Markdown("### Initialize the SDK and upload your data") | |
| with gr.Row(): | |
| with gr.Column(): | |
| init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary") | |
| init_status = gr.Textbox(label="Initialization Status", interactive=False) | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| **Next Steps:** | |
| 1. Initialize the SDK (click button above) | |
| 2. Go to "Upload Data and Train Model" tab to upload your CSV file | |
| 3. Train a model on your data | |
| 4. Generate synthetic data | |
| """) | |
| with gr.Tab("π Upload Data and Train Model"): | |
| gr.Markdown("### Upload your CSV file to generate synthetic data") | |
| gr.Markdown(""" | |
| **π File Requirements:** | |
| - **Format:** CSV with header row | |
| - **Size:** Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM) | |
| """) | |
| file_upload = gr.File( | |
| label="Upload CSV File", | |
| file_types=[".csv"], | |
| file_count="single" | |
| ) | |
| uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False) | |
| memory_info = gr.Markdown(label="Memory Usage Info", visible=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name = gr.Textbox( | |
| value="My Synthetic Model", | |
| label="Model Name", | |
| placeholder="Enter a name for your model" | |
| ) | |
| epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs") | |
| max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time") | |
| batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size") | |
| value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection") | |
| train_btn = gr.Button("Train Model", variant="primary") | |
| with gr.Column(): | |
| train_status = gr.Textbox(label="Training Status", interactive=False) | |
| quality_report = gr.Textbox(label="Quality Report", lines=10, interactive=False) | |
| get_report_btn = gr.Button("Get Quality Report", variant="secondary") | |
| with gr.Tab("π² Generate Data"): | |
| gr.Markdown("### Generate synthetic data from your trained model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate") | |
| generate_btn = gr.Button("Generate Synthetic Data", variant="primary") | |
| with gr.Column(): | |
| gen_status = gr.Textbox(label="Generation Status", interactive=False) | |
| synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False) | |
| with gr.Row(): | |
| download_btn = gr.DownloadButton("Download CSV", variant="secondary") | |
| comparison_plot = gr.Plot(label="Data Comparison") | |
| # Event handlers | |
| init_btn.click( | |
| initialize_sdk, | |
| outputs=[init_status, init_status] | |
| ) | |
| train_btn.click( | |
| train_model, | |
| inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection], | |
| outputs=[train_status, train_status] | |
| ) | |
| get_report_btn.click( | |
| get_quality_report, | |
| outputs=[quality_report] | |
| ) | |
| generate_btn.click( | |
| generate_data, | |
| inputs=[gen_size], | |
| outputs=[synthetic_data, gen_status] | |
| ) | |
| # Update download button when synthetic data changes | |
| synthetic_data.change( | |
| download_csv, | |
| inputs=[synthetic_data], | |
| outputs=[download_btn] | |
| ) | |
| # Create comparison plot when both datasets are available | |
| synthetic_data.change( | |
| create_comparison_plot, | |
| inputs=[uploaded_data, synthetic_data], | |
| outputs=[comparison_plot] | |
| ) | |
| # Handle file upload with size and column limits | |
| def process_uploaded_file(file): | |
| if file is None: | |
| return None, "No file uploaded", gr.update(visible=False) | |
| try: | |
| # Read the CSV file | |
| df = pd.read_csv(file.name) | |
| # # Check column limit (max 20 columns) | |
| # if len(df.columns) > 20: | |
| # return None, f"β Too many columns! Maximum allowed: 20, found: {len(df.columns)}. Please reduce the number of columns in your CSV file.", gr.update(visible=False) | |
| # # Check row limit (max 10,000 records) | |
| # if len(df) > 10000: | |
| # return None, f"β Too many records! Maximum allowed: 10,000, found: {len(df)}. Please reduce the number of rows in your CSV file.", gr.update(visible=False) | |
| # # Check minimum requirements | |
| # if len(df) < 1000: | |
| # return None, f"β Too few records! Minimum required: 1,000, found: {len(df)}. Please provide more data for training.", gr.update(visible=False) | |
| # if len(df.columns) < 2: | |
| # return None, f"β Too few columns! Minimum required: 2, found: {len(df.columns)}. Please provide more columns for training.", gr.update(visible=False) | |
| # Success message with file info | |
| success_msg = f"β File uploaded successfully! {len(df)} rows Γ {len(df.columns)} columns" | |
| # Generate memory usage info | |
| memory_info = generator.estimate_memory_usage(df) | |
| return df, success_msg, gr.update(value=memory_info, visible=True) | |
| except Exception as e: | |
| return None, f"β Error reading file: {str(e)}", gr.update(visible=False) | |
| file_upload.change( | |
| process_uploaded_file, | |
| inputs=[file_upload], | |
| outputs=[uploaded_data, train_status, memory_info] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |