Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gemini TTS API Server | |
| A FastAPI-based REST API for Google's Gemini Text-to-Speech service | |
| with concurrent request handling and audio format conversion. | |
| """ | |
| import os | |
| import asyncio | |
| import json | |
| import base64 | |
| import uuid | |
| from datetime import datetime | |
| from typing import Optional, List, Dict, Any | |
| from io import BytesIO | |
| import aiohttp | |
| import aiofiles | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from pydub import AudioSegment | |
| import uvicorn | |
| # Pydantic models for request/response | |
| class VoiceConfig(BaseModel): | |
| voice_name: str = Field(default="Zephyr", description="Voice name (e.g., Zephyr, Puck)") | |
| class SpeakerConfig(BaseModel): | |
| speaker: str = Field(description="Speaker identifier") | |
| voice_config: VoiceConfig | |
| class TTSRequest(BaseModel): | |
| text: str = Field(description="Text to convert to speech") | |
| speakers: Optional[List[SpeakerConfig]] = Field( | |
| default=None, | |
| description="Multi-speaker configuration (optional)" | |
| ) | |
| voice_name: Optional[str] = Field( | |
| default="Zephyr", | |
| description="Single voice name (used if speakers not provided)" | |
| ) | |
| output_format: str = Field(default="wav", description="Output format: wav or mp3") | |
| speed_factor: float = Field(default=1.0, description="Speed adjustment factor") | |
| temperature: float = Field(default=1.0, description="Generation temperature") | |
| class TTSResponse(BaseModel): | |
| task_id: str | |
| status: str | |
| message: str | |
| audio_url: Optional[str] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| class TaskStatus(BaseModel): | |
| task_id: str | |
| status: str | |
| progress: Optional[str] = None | |
| error: Optional[str] = None | |
| result: Optional[Dict[str, Any]] = None | |
| # Global task storage (in production, use Redis or database) | |
| tasks: Dict[str, Dict[str, Any]] = {} | |
| # FastAPI app initialization | |
| app = FastAPI( | |
| title="Gemini TTS API", | |
| description="Text-to-Speech API using Google's Gemini model with concurrent request handling", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure appropriately for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Configuration | |
| def get_api_keys(): | |
| """Get API keys from environment variables""" | |
| # Support multiple formats for API keys | |
| api_keys = [] | |
| # Single API key (backward compatibility) | |
| single_key = os.getenv('GEMINI_API_KEY') | |
| if single_key: | |
| api_keys.append(single_key.strip()) | |
| # Multiple API keys (comma-separated) | |
| multi_keys = os.getenv('GEMINI_API_KEYS') | |
| if multi_keys: | |
| keys = [key.strip() for key in multi_keys.split(',') if key.strip()] | |
| api_keys.extend(keys) | |
| # Individual API keys (GEMINI_API_KEY_1, GEMINI_API_KEY_2, etc.) | |
| i = 1 | |
| while True: | |
| key = os.getenv(f'GEMINI_API_KEY_{i}') | |
| if not key: | |
| break | |
| api_keys.append(key.strip()) | |
| i += 1 | |
| # Remove duplicates while preserving order | |
| seen = set() | |
| unique_keys = [] | |
| for key in api_keys: | |
| if key not in seen: | |
| seen.add(key) | |
| unique_keys.append(key) | |
| return unique_keys | |
| GEMINI_API_KEYS = get_api_keys() | |
| MODEL_ID = "gemini-2.5-flash-preview-tts" | |
| GENERATE_CONTENT_API = "streamGenerateContent" | |
| OUTPUT_DIR = "/tmp/audio_files" | |
| MAX_CONCURRENT_REQUESTS = 10 | |
| RATE_LIMIT_RETRY_DELAY = 60 # seconds to wait after rate limit | |
| MAX_RETRIES_PER_KEY = 2 | |
| # API key management | |
| class APIKeyManager: | |
| def __init__(self, api_keys: List[str]): | |
| self.api_keys = api_keys | |
| self.current_key_index = 0 | |
| self.key_stats = {key: {"requests": 0, "failures": 0, "last_rate_limit": None} for key in api_keys} | |
| self.lock = asyncio.Lock() | |
| async def get_next_key(self) -> Optional[str]: | |
| """Get the next available API key""" | |
| async with self.lock: | |
| if not self.api_keys: | |
| return None | |
| # Try to find a key that's not rate limited | |
| for _ in range(len(self.api_keys)): | |
| key = self.api_keys[self.current_key_index] | |
| stats = self.key_stats[key] | |
| # Check if this key is currently rate limited | |
| if stats["last_rate_limit"]: | |
| time_since_limit = datetime.now().timestamp() - stats["last_rate_limit"] | |
| if time_since_limit < RATE_LIMIT_RETRY_DELAY: | |
| # Still rate limited, try next key | |
| self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) | |
| continue | |
| else: | |
| # Rate limit period has passed, reset | |
| stats["last_rate_limit"] = None | |
| # This key is available | |
| stats["requests"] += 1 | |
| return key | |
| # All keys are rate limited, return the one with oldest rate limit | |
| oldest_key = min( | |
| self.api_keys, | |
| key=lambda k: self.key_stats[k]["last_rate_limit"] or 0 | |
| ) | |
| return oldest_key | |
| async def mark_rate_limited(self, api_key: str): | |
| """Mark an API key as rate limited""" | |
| async with self.lock: | |
| if api_key in self.key_stats: | |
| self.key_stats[api_key]["last_rate_limit"] = datetime.now().timestamp() | |
| self.key_stats[api_key]["failures"] += 1 | |
| async def mark_success(self, api_key: str): | |
| """Mark an API key as successful (reset failure count)""" | |
| async with self.lock: | |
| if api_key in self.key_stats: | |
| self.key_stats[api_key]["failures"] = max(0, self.key_stats[api_key]["failures"] - 1) | |
| def get_stats(self) -> dict: | |
| """Get statistics for all API keys""" | |
| return { | |
| "total_keys": len(self.api_keys), | |
| "key_stats": self.key_stats.copy() | |
| } | |
| # Initialize API key manager | |
| api_key_manager = APIKeyManager(GEMINI_API_KEYS) | |
| # Ensure output directory exists | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Semaphore to limit concurrent requests | |
| semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) | |
| async def convert_and_adjust_audio( | |
| audio_data: bytes, | |
| output_format: str = "wav", | |
| speed_factor: float = 1.0 | |
| ) -> tuple[bytes, str]: | |
| """ | |
| Convert PCM audio data to specified format and adjust speed asynchronously | |
| """ | |
| def _convert(): | |
| # Create AudioSegment from raw PCM data | |
| audio = AudioSegment( | |
| data=audio_data, | |
| sample_width=2, # 16-bit = 2 bytes | |
| frame_rate=24000, # 24kHz | |
| channels=1 # mono | |
| ) | |
| # Adjust speed by changing frame rate | |
| if speed_factor != 1.0: | |
| new_frame_rate = int(audio.frame_rate * speed_factor) | |
| audio_speed_adjusted = audio._spawn( | |
| audio.raw_data, | |
| overrides={"frame_rate": new_frame_rate} | |
| ) | |
| audio_speed_adjusted = audio_speed_adjusted.set_frame_rate(audio.frame_rate) | |
| else: | |
| audio_speed_adjusted = audio | |
| # Export to desired format | |
| buffer = BytesIO() | |
| if output_format.lower() == "mp3": | |
| audio_speed_adjusted.export(buffer, format="mp3", bitrate="128k") | |
| return buffer.getvalue(), "mp3" | |
| else: | |
| audio_speed_adjusted.export(buffer, format="wav") | |
| return buffer.getvalue(), "wav" | |
| # Run audio processing in thread pool to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(None, _convert) | |
| async def generate_tts_audio( | |
| task_id: str, | |
| text: str, | |
| speakers: Optional[List[SpeakerConfig]] = None, | |
| voice_name: str = "Zephyr", | |
| output_format: str = "wav", | |
| speed_factor: float = 1.0, | |
| temperature: float = 1.0 | |
| ): | |
| """ | |
| Generate TTS audio using Gemini API with multiple API key support and rate limit handling | |
| """ | |
| async with semaphore: # Limit concurrent requests | |
| try: | |
| # Update task status | |
| tasks[task_id]["status"] = "processing" | |
| tasks[task_id]["progress"] = "Preparing request" | |
| # Prepare request data | |
| if speakers: | |
| # Multi-speaker configuration | |
| speech_config = { | |
| "multi_speaker_voice_config": { | |
| "speaker_voice_configs": [ | |
| { | |
| "speaker": speaker.speaker, | |
| "voice_config": { | |
| "prebuilt_voice_config": { | |
| "voice_name": speaker.voice_config.voice_name | |
| } | |
| } | |
| } | |
| for speaker in speakers | |
| ] | |
| } | |
| } | |
| else: | |
| # Single voice configuration | |
| speech_config = { | |
| "voice_config": { | |
| "prebuilt_voice_config": { | |
| "voice_name": voice_name | |
| } | |
| } | |
| } | |
| request_data = { | |
| "contents": [ | |
| { | |
| "role": "user", | |
| "parts": [{"text": text}] | |
| } | |
| ], | |
| "generationConfig": { | |
| "responseModalities": ["audio"], | |
| "temperature": temperature, | |
| "speech_config": speech_config | |
| } | |
| } | |
| # API endpoint | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/{MODEL_ID}:{GENERATE_CONTENT_API}" | |
| tasks[task_id]["progress"] = "Calling Gemini API" | |
| # Try multiple API keys with rate limit handling | |
| last_error = None | |
| attempts = 0 | |
| max_total_attempts = len(GEMINI_API_KEYS) * MAX_RETRIES_PER_KEY if GEMINI_API_KEYS else 1 | |
| while attempts < max_total_attempts: | |
| current_api_key = await api_key_manager.get_next_key() | |
| if not current_api_key: | |
| raise HTTPException(status_code=500, detail="No API keys available") | |
| attempts += 1 | |
| tasks[task_id]["progress"] = f"Attempting API call (attempt {attempts}/{max_total_attempts})" | |
| try: | |
| # Make async API request | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| url, | |
| headers={"Content-Type": "application/json"}, | |
| params={"key": current_api_key}, | |
| json=request_data, | |
| timeout=aiohttp.ClientTimeout(total=120) # 2 minute timeout | |
| ) as response: | |
| # Handle different HTTP status codes | |
| if response.status == 200: | |
| # Success! Mark key as successful and proceed | |
| await api_key_manager.mark_success(current_api_key) | |
| response_data = await response.json() | |
| break | |
| elif response.status == 429: # Rate limit exceeded | |
| error_text = await response.text() | |
| await api_key_manager.mark_rate_limited(current_api_key) | |
| last_error = f"Rate limit exceeded for API key: {error_text}" | |
| print(f"Rate limit hit for key ending in ...{current_api_key[-4:]}, trying next key") | |
| continue | |
| elif response.status in [403, 401]: # Auth errors | |
| error_text = await response.text() | |
| await api_key_manager.mark_rate_limited(current_api_key) # Temporarily disable this key | |
| last_error = f"Authentication error: {error_text}" | |
| print(f"Auth error for key ending in ...{current_api_key[-4:]}: {error_text}") | |
| continue | |
| else: # Other HTTP errors | |
| error_text = await response.text() | |
| last_error = f"HTTP {response.status}: {error_text}" | |
| # Don't mark as rate limited for other errors, but still try next key | |
| continue | |
| except asyncio.TimeoutError: | |
| last_error = "Request timeout" | |
| continue | |
| except aiohttp.ClientError as e: | |
| last_error = f"Client error: {str(e)}" | |
| continue | |
| except Exception as e: | |
| last_error = f"Unexpected error: {str(e)}" | |
| continue | |
| else: | |
| # All attempts failed | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"All API keys exhausted. Last error: {last_error}" | |
| ) | |
| tasks[task_id]["progress"] = "Processing audio data" | |
| # Extract audio data | |
| if response_data and len(response_data) > 0: | |
| candidates = response_data[0].get("candidates", []) | |
| if not candidates: | |
| raise HTTPException(status_code=500, detail="No candidates in response") | |
| parts = candidates[0].get("content", {}).get("parts", []) | |
| audio_data_b64 = None | |
| for part in parts: | |
| if "inlineData" in part: | |
| audio_data_b64 = part["inlineData"].get("data", "") | |
| break | |
| if not audio_data_b64: | |
| raise HTTPException(status_code=500, detail="No audio data found in response") | |
| # Decode base64 audio data | |
| audio_data = base64.b64decode(audio_data_b64) | |
| tasks[task_id]["progress"] = "Converting audio format" | |
| # Convert and adjust audio | |
| converted_audio, file_ext = await convert_and_adjust_audio( | |
| audio_data, output_format, speed_factor | |
| ) | |
| # Generate filename | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"gemini_audio_{task_id}_{timestamp}.{file_ext}" | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| # Save audio file | |
| async with aiofiles.open(filepath, "wb") as f: | |
| await f.write(converted_audio) | |
| # Update task with results | |
| tasks[task_id].update({ | |
| "status": "completed", | |
| "progress": "Completed", | |
| "result": { | |
| "filename": filename, | |
| "filepath": filepath, | |
| "format": output_format.upper(), | |
| "speed_factor": speed_factor, | |
| "original_size": len(audio_data), | |
| "converted_size": len(converted_audio), | |
| "audio_url": f"/audio/{filename}" | |
| } | |
| }) | |
| except Exception as e: | |
| tasks[task_id].update({ | |
| "status": "failed", | |
| "error": str(e) | |
| }) | |
| # API Endpoints | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Gemini TTS API Server", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "POST /tts": "Generate TTS audio", | |
| "GET /status/{task_id}": "Get task status", | |
| "GET /audio/{filename}": "Download audio file", | |
| "GET /tasks": "List all tasks" | |
| } | |
| } | |
| async def create_tts_task( | |
| request: TTSRequest, | |
| background_tasks: BackgroundTasks | |
| ): | |
| """ | |
| Create a new TTS generation task | |
| """ | |
| if not GEMINI_API_KEYS: | |
| raise HTTPException(status_code=500, detail="No GEMINI_API_KEYs configured. Please set GEMINI_API_KEY, GEMINI_API_KEYS, or GEMINI_API_KEY_1, GEMINI_API_KEY_2, etc.") | |
| # Generate unique task ID | |
| task_id = str(uuid.uuid4()) | |
| # Initialize task | |
| tasks[task_id] = { | |
| "task_id": task_id, | |
| "status": "queued", | |
| "created_at": datetime.now().isoformat(), | |
| "request": request.dict() | |
| } | |
| # Start background task | |
| background_tasks.add_task( | |
| generate_tts_audio, | |
| task_id, | |
| request.text, | |
| request.speakers, | |
| request.voice_name, | |
| request.output_format, | |
| request.speed_factor, | |
| request.temperature | |
| ) | |
| return TTSResponse( | |
| task_id=task_id, | |
| status="queued", | |
| message="TTS generation task created successfully" | |
| ) | |
| async def get_task_status(task_id: str): | |
| """ | |
| Get the status of a TTS generation task | |
| """ | |
| if task_id not in tasks: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| task = tasks[task_id] | |
| return TaskStatus( | |
| task_id=task_id, | |
| status=task["status"], | |
| progress=task.get("progress"), | |
| error=task.get("error"), | |
| result=task.get("result") | |
| ) | |
| async def download_audio(filename: str): | |
| """ | |
| Download generated audio file | |
| """ | |
| filepath = os.path.join(OUTPUT_DIR, filename) | |
| if not os.path.exists(filepath): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| return FileResponse( | |
| filepath, | |
| media_type="application/octet-stream", | |
| filename=filename | |
| ) | |
| async def list_tasks(): | |
| """ | |
| List all tasks with their current status | |
| """ | |
| return {"tasks": list(tasks.values())} | |
| async def delete_task(task_id: str): | |
| """ | |
| Delete a task and its associated audio file | |
| """ | |
| if task_id not in tasks: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| task = tasks[task_id] | |
| # Delete audio file if it exists | |
| if task.get("result") and task["result"].get("filepath"): | |
| filepath = task["result"]["filepath"] | |
| if os.path.exists(filepath): | |
| os.remove(filepath) | |
| # Remove task from memory | |
| del tasks[task_id] | |
| return {"message": "Task deleted successfully"} | |
| async def health_check(): | |
| """ | |
| Health check endpoint with API key status | |
| """ | |
| api_stats = api_key_manager.get_stats() | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.now().isoformat(), | |
| "active_tasks": len([t for t in tasks.values() if t["status"] in ["queued", "processing"]]), | |
| "total_tasks": len(tasks), | |
| "api_keys": { | |
| "total_configured": api_stats["total_keys"], | |
| "available_keys": len([ | |
| key for key, stats in api_stats["key_stats"].items() | |
| if not stats["last_rate_limit"] or | |
| (datetime.now().timestamp() - stats["last_rate_limit"]) > RATE_LIMIT_RETRY_DELAY | |
| ]), | |
| "rate_limited_keys": len([ | |
| key for key, stats in api_stats["key_stats"].items() | |
| if stats["last_rate_limit"] and | |
| (datetime.now().timestamp() - stats["last_rate_limit"]) <= RATE_LIMIT_RETRY_DELAY | |
| ]) | |
| } | |
| } | |
| async def get_api_key_stats(): | |
| """ | |
| Get detailed statistics for all API keys | |
| """ | |
| stats = api_key_manager.get_stats() | |
| # Mask API keys for security (show only last 4 characters) | |
| masked_stats = {} | |
| for key, data in stats["key_stats"].items(): | |
| masked_key = f"***{key[-4:]}" if len(key) > 4 else "***" | |
| masked_stats[masked_key] = { | |
| **data, | |
| "is_rate_limited": ( | |
| data["last_rate_limit"] and | |
| (datetime.now().timestamp() - data["last_rate_limit"]) <= RATE_LIMIT_RETRY_DELAY | |
| ) if data["last_rate_limit"] else False, | |
| "time_until_available": max(0, RATE_LIMIT_RETRY_DELAY - ( | |
| datetime.now().timestamp() - data["last_rate_limit"] | |
| )) if data["last_rate_limit"] else 0 | |
| } | |
| return { | |
| "total_keys": stats["total_keys"], | |
| "key_statistics": masked_stats, | |
| "rate_limit_settings": { | |
| "retry_delay_seconds": RATE_LIMIT_RETRY_DELAY, | |
| "max_retries_per_key": MAX_RETRIES_PER_KEY | |
| } | |
| } | |
| # Cleanup old files periodically (you might want to implement this with a proper scheduler) | |
| async def cleanup_old_files(): | |
| """ | |
| Clean up old audio files and completed tasks | |
| """ | |
| # This is a simple implementation - consider using APScheduler for production | |
| pass | |
| if __name__ == "__main__": | |
| # Configuration for running the server | |
| uvicorn.run( | |
| "gemini_tts_api:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=False, # Set to False in production | |
| workers=4 # Use multiple workers in production with proper task storage | |
| ) |