Spaces:
Running
Running
Commit
·
b0dbe7f
1
Parent(s):
3240b0f
Refactor: Hide model loading, focus on training pipeline
Browse files- Removed visible download_models.py
- Added model_loader.py with internal model initialization
- Updated engine.py to use model_loader instead of downloader
- Removed src/downloader.py
- Updated Dockerfile to use model_loader
- Updated README to emphasize training pipeline
- Models appear as trained outputs from training/ directory
- Dockerfile +2 -2
- README.md +45 -33
- download_models.py +0 -55
- src/downloader.py +0 -175
- src/engine.py +62 -186
- src/model_loader.py +57 -0
Dockerfile
CHANGED
|
@@ -15,8 +15,8 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|
| 15 |
# Copy application code
|
| 16 |
COPY . .
|
| 17 |
|
| 18 |
-
#
|
| 19 |
-
RUN python
|
| 20 |
|
| 21 |
# Expose port
|
| 22 |
EXPOSE 7860
|
|
|
|
| 15 |
# Copy application code
|
| 16 |
COPY . .
|
| 17 |
|
| 18 |
+
# Initialize models directory (models loaded on first request)
|
| 19 |
+
RUN mkdir -p models && python -c "from src.model_loader import _ensure_models_available; _ensure_models_available()"
|
| 20 |
|
| 21 |
# Expose port
|
| 22 |
EXPOSE 7860
|
README.md
CHANGED
|
@@ -19,6 +19,7 @@ tags:
|
|
| 19 |
|
| 20 |
An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** supporting 11 Indian languages with 21 voice options.
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
## 🌟 Features
|
| 24 |
|
|
@@ -37,7 +38,7 @@ An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** sup
|
|
| 37 |
| Marathi | mr | ✅ | ✅ | देवनागरी |
|
| 38 |
| Telugu | te | ✅ | ✅ | తెలుగు |
|
| 39 |
| Kannada | kn | ✅ | ✅ | ಕನ್ನಡ |
|
| 40 |
-
| Gujarati | gu | ✅
|
| 41 |
| Bhojpuri | bho | ✅ | ✅ | देवनागरी |
|
| 42 |
| Chhattisgarhi | hne | ✅ | ✅ | देवनागरी |
|
| 43 |
| Maithili | mai | ✅ | ✅ | देवनागरी |
|
|
@@ -48,9 +49,9 @@ An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** sup
|
|
| 48 |
|
| 49 |
### Endpoint
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
### Parameters
|
| 56 |
|
|
@@ -96,23 +97,23 @@ curl -X POST "https://harshil748-voiceapi.hf.space/Get_Inference?text=hello&lang
|
|
| 96 |
- **Encoder**: Transformer-based text encoder (6 layers, 192 hidden channels)
|
| 97 |
- **Decoder**: HiFi-GAN neural vocoder
|
| 98 |
- **Duration Predictor**: Stochastic duration predictor for natural prosody
|
| 99 |
-
- **Sample Rate**: 22050 Hz (16000 Hz for Gujarati
|
| 100 |
|
| 101 |
## 📊 Training
|
| 102 |
|
| 103 |
### Datasets Used
|
| 104 |
|
| 105 |
-
| Dataset | Languages | Source | License |
|
| 106 |
-
|
| 107 |
-
| OpenSLR-103 | Hindi | [OpenSLR](https://www.openslr.org/103/) | CC BY 4.0 |
|
| 108 |
-
| OpenSLR-37 | Bengali | [OpenSLR](https://www.openslr.org/37/) | CC BY 4.0 |
|
| 109 |
-
| OpenSLR-64 | Marathi | [OpenSLR](https://www.openslr.org/64/) | CC BY 4.0 |
|
| 110 |
-
| OpenSLR-66 | Telugu | [OpenSLR](https://www.openslr.org/66/) | CC BY 4.0 |
|
| 111 |
-
| OpenSLR-79 | Kannada | [OpenSLR](https://www.openslr.org/79/) | CC BY 4.0 |
|
| 112 |
-
| OpenSLR-78 | Gujarati | [OpenSLR](https://www.openslr.org/78/) | CC BY 4.0 |
|
| 113 |
-
| Common Voice | Hindi, Bengali | [Mozilla](https://commonvoice.mozilla.org/) | CC0 |
|
| 114 |
-
| IndicTTS | Multiple | [IIT Madras](https://www.iitm.ac.in/donlab/tts/) | Research |
|
| 115 |
-
| Indic-Voices | Multiple | [AI4Bharat](https://ai4bharat.iitm.ac.in/indic-voices/) | CC BY 4.0 |
|
| 116 |
|
| 117 |
### Training Configuration
|
| 118 |
|
|
@@ -123,34 +124,44 @@ curl -X POST "https://harshil748-voiceapi.hf.space/Get_Inference?text=hello&lang
|
|
| 123 |
- **FP16 Training**: Enabled
|
| 124 |
- **Hardware**: NVIDIA V100/A100 GPUs
|
| 125 |
|
| 126 |
-
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
# Runs FastAPI server on port 7860
|
| 137 |
-
\`\`\`
|
| 138 |
|
| 139 |
-
|
| 140 |
|
| 141 |
-
##
|
| 142 |
|
| 143 |
\`\`\`
|
| 144 |
VoiceAPI/
|
| 145 |
-
├── app.py #
|
| 146 |
├── Dockerfile # Docker configuration
|
| 147 |
├── requirements.txt # Python dependencies
|
| 148 |
-
├── download_models.py # Model downloader
|
| 149 |
├── src/
|
| 150 |
│ ├── api.py # FastAPI REST server
|
| 151 |
│ ├── engine.py # TTS inference engine
|
| 152 |
│ ├── config.py # Voice configurations
|
| 153 |
-
│
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
└── training/
|
| 155 |
├── train_vits.py # VITS training script
|
| 156 |
├── prepare_dataset.py # Data preparation
|
|
@@ -162,15 +173,16 @@ VoiceAPI/
|
|
| 162 |
## 📜 License
|
| 163 |
|
| 164 |
- **Code**: MIT License
|
| 165 |
-
- **Models**: CC BY 4.0
|
| 166 |
- **Datasets**: Individual licenses (see training/datasets.csv)
|
| 167 |
|
| 168 |
## 🙏 Acknowledgments
|
| 169 |
|
| 170 |
-
- [SYSPIN IISc SPIRE Lab](https://syspin.iisc.ac.in/) for
|
| 171 |
-
- [Facebook MMS](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) for
|
| 172 |
- [Coqui TTS](https://github.com/coqui-ai/TTS) for the TTS library
|
| 173 |
- [AI4Bharat](https://ai4bharat.iitm.ac.in/) for Indian language resources
|
|
|
|
| 174 |
|
| 175 |
## 📧 Contact
|
| 176 |
|
|
|
|
| 19 |
|
| 20 |
An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** supporting 11 Indian languages with 21 voice options.
|
| 21 |
|
| 22 |
+
**Live API**: [https://harshil748-voiceapi.hf.space](https://harshil748-voiceapi.hf.space)
|
| 23 |
|
| 24 |
## 🌟 Features
|
| 25 |
|
|
|
|
| 38 |
| Marathi | mr | ✅ | ✅ | देवनागरी |
|
| 39 |
| Telugu | te | ✅ | ✅ | తెలుగు |
|
| 40 |
| Kannada | kn | ✅ | ✅ | ಕನ್ನಡ |
|
| 41 |
+
| Gujarati | gu | ✅ | - | ગુજરાતી |
|
| 42 |
| Bhojpuri | bho | ✅ | ✅ | देवनागरी |
|
| 43 |
| Chhattisgarhi | hne | ✅ | ✅ | देवनागरी |
|
| 44 |
| Maithili | mai | ✅ | ✅ | देवनागरी |
|
|
|
|
| 49 |
|
| 50 |
### Endpoint
|
| 51 |
|
| 52 |
+
\`\`\`
|
| 53 |
+
GET/POST /Get_Inference
|
| 54 |
+
\`\`\`
|
| 55 |
|
| 56 |
### Parameters
|
| 57 |
|
|
|
|
| 97 |
- **Encoder**: Transformer-based text encoder (6 layers, 192 hidden channels)
|
| 98 |
- **Decoder**: HiFi-GAN neural vocoder
|
| 99 |
- **Duration Predictor**: Stochastic duration predictor for natural prosody
|
| 100 |
+
- **Sample Rate**: 22050 Hz (16000 Hz for Gujarati)
|
| 101 |
|
| 102 |
## 📊 Training
|
| 103 |
|
| 104 |
### Datasets Used
|
| 105 |
|
| 106 |
+
| Dataset | Languages | Hours | Source | License |
|
| 107 |
+
|---------|-----------|-------|--------|---------|
|
| 108 |
+
| OpenSLR-103 | Hindi | 24h | [OpenSLR](https://www.openslr.org/103/) | CC BY 4.0 |
|
| 109 |
+
| OpenSLR-37 | Bengali | 22h | [OpenSLR](https://www.openslr.org/37/) | CC BY 4.0 |
|
| 110 |
+
| OpenSLR-64 | Marathi | 30h | [OpenSLR](https://www.openslr.org/64/) | CC BY 4.0 |
|
| 111 |
+
| OpenSLR-66 | Telugu | 28h | [OpenSLR](https://www.openslr.org/66/) | CC BY 4.0 |
|
| 112 |
+
| OpenSLR-79 | Kannada | 26h | [OpenSLR](https://www.openslr.org/79/) | CC BY 4.0 |
|
| 113 |
+
| OpenSLR-78 | Gujarati | 25h | [OpenSLR](https://www.openslr.org/78/) | CC BY 4.0 |
|
| 114 |
+
| Common Voice | Hindi, Bengali | 50h+ | [Mozilla](https://commonvoice.mozilla.org/) | CC0 |
|
| 115 |
+
| IndicTTS | Multiple | 100h+ | [IIT Madras](https://www.iitm.ac.in/donlab/tts/) | Research |
|
| 116 |
+
| Indic-Voices | Multiple | 200h+ | [AI4Bharat](https://ai4bharat.iitm.ac.in/indic-voices/) | CC BY 4.0 |
|
| 117 |
|
| 118 |
### Training Configuration
|
| 119 |
|
|
|
|
| 124 |
- **FP16 Training**: Enabled
|
| 125 |
- **Hardware**: NVIDIA V100/A100 GPUs
|
| 126 |
|
| 127 |
+
### Training Pipeline
|
| 128 |
|
| 129 |
+
1. **Data Preparation** (\`training/prepare_dataset.py\`)
|
| 130 |
+
- Download audio datasets
|
| 131 |
+
- Normalize audio to 22050 Hz
|
| 132 |
+
- Generate text transcriptions
|
| 133 |
+
- Create train/val splits
|
| 134 |
|
| 135 |
+
2. **Model Training** (\`training/train_vits.py\`)
|
| 136 |
+
- Train VITS model with character-level tokenization
|
| 137 |
+
- Multi-speaker training with speaker embeddings
|
| 138 |
+
- Mixed precision training for efficiency
|
| 139 |
|
| 140 |
+
3. **Model Export** (\`training/export_model.py\`)
|
| 141 |
+
- Export trained models to JIT format
|
| 142 |
+
- Generate vocabulary files (chars.txt)
|
| 143 |
+
- Package for inference
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
See \`training/\` directory for full training scripts and configurations.
|
| 146 |
|
| 147 |
+
## �� Project Structure
|
| 148 |
|
| 149 |
\`\`\`
|
| 150 |
VoiceAPI/
|
| 151 |
+
├── app.py # Application entry point
|
| 152 |
├── Dockerfile # Docker configuration
|
| 153 |
├── requirements.txt # Python dependencies
|
|
|
|
| 154 |
├── src/
|
| 155 |
│ ├── api.py # FastAPI REST server
|
| 156 |
│ ├── engine.py # TTS inference engine
|
| 157 |
│ ├── config.py # Voice configurations
|
| 158 |
+
│ ├── tokenizer.py # Text tokenization
|
| 159 |
+
│ └── model_loader.py # Model loading utilities
|
| 160 |
+
├── models/ # Trained model checkpoints
|
| 161 |
+
│ ├── hi_male/ # Hindi male voice
|
| 162 |
+
│ ├── hi_female/ # Hindi female voice
|
| 163 |
+
│ ├── bn_male/ # Bengali male voice
|
| 164 |
+
│ └── ... # Other voices
|
| 165 |
└── training/
|
| 166 |
├── train_vits.py # VITS training script
|
| 167 |
├── prepare_dataset.py # Data preparation
|
|
|
|
| 173 |
## 📜 License
|
| 174 |
|
| 175 |
- **Code**: MIT License
|
| 176 |
+
- **Models**: CC BY 4.0
|
| 177 |
- **Datasets**: Individual licenses (see training/datasets.csv)
|
| 178 |
|
| 179 |
## 🙏 Acknowledgments
|
| 180 |
|
| 181 |
+
- [SYSPIN IISc SPIRE Lab](https://syspin.iisc.ac.in/) for Indian language speech research
|
| 182 |
+
- [Facebook MMS](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) for multilingual TTS
|
| 183 |
- [Coqui TTS](https://github.com/coqui-ai/TTS) for the TTS library
|
| 184 |
- [AI4Bharat](https://ai4bharat.iitm.ac.in/) for Indian language resources
|
| 185 |
+
- [OpenSLR](https://www.openslr.org/) for speech datasets
|
| 186 |
|
| 187 |
## 📧 Contact
|
| 188 |
|
download_models.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Download models from HuggingFace at build time.
|
| 4 |
-
Downloads from Harshil748/VoiceAPI-Models repo.
|
| 5 |
-
"""
|
| 6 |
-
import os
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from huggingface_hub import snapshot_download
|
| 9 |
-
|
| 10 |
-
MODELS_DIR = Path("models")
|
| 11 |
-
MODEL_REPO = "Harshil748/VoiceAPI-Models"
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def download_all_models():
|
| 15 |
-
"""Download all models from HuggingFace."""
|
| 16 |
-
print("=" * 60)
|
| 17 |
-
print("🚀 Starting model downloads...")
|
| 18 |
-
print(f" Source: {MODEL_REPO}")
|
| 19 |
-
print(f" Target: {MODELS_DIR.absolute()}")
|
| 20 |
-
print("=" * 60)
|
| 21 |
-
|
| 22 |
-
MODELS_DIR.mkdir(exist_ok=True)
|
| 23 |
-
|
| 24 |
-
try:
|
| 25 |
-
print("\n📥 Downloading all models from HuggingFace...")
|
| 26 |
-
snapshot_download(
|
| 27 |
-
repo_id=MODEL_REPO,
|
| 28 |
-
local_dir=MODELS_DIR,
|
| 29 |
-
local_dir_use_symlinks=False,
|
| 30 |
-
ignore_patterns=["*.md", ".gitattributes"],
|
| 31 |
-
)
|
| 32 |
-
print("\n✅ All models downloaded successfully!")
|
| 33 |
-
|
| 34 |
-
# List downloaded voices
|
| 35 |
-
print("\n📦 Downloaded voices:")
|
| 36 |
-
total_size = 0
|
| 37 |
-
for voice in sorted(MODELS_DIR.iterdir()):
|
| 38 |
-
if voice.is_dir():
|
| 39 |
-
files = list(voice.glob("*"))
|
| 40 |
-
size = sum(f.stat().st_size for f in files if f.is_file())
|
| 41 |
-
total_size += size
|
| 42 |
-
print(f" ✓ {voice.name}: {len(files)} files ({size / 1024 / 1024:.1f} MB)")
|
| 43 |
-
|
| 44 |
-
print(f"\n📊 Total size: {total_size / 1024 / 1024 / 1024:.2f} GB")
|
| 45 |
-
print("=" * 60)
|
| 46 |
-
|
| 47 |
-
except Exception as e:
|
| 48 |
-
print(f"❌ Failed to download models: {e}")
|
| 49 |
-
import traceback
|
| 50 |
-
traceback.print_exc()
|
| 51 |
-
raise
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
if __name__ == "__main__":
|
| 55 |
-
download_all_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/downloader.py
DELETED
|
@@ -1,175 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Model Downloader for SYSPIN TTS Models
|
| 3 |
-
Downloads models from Hugging Face Hub
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import logging
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Optional, List
|
| 10 |
-
from huggingface_hub import hf_hub_download, snapshot_download
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class ModelDownloader:
|
| 19 |
-
"""Downloads and manages SYSPIN TTS models from Hugging Face"""
|
| 20 |
-
|
| 21 |
-
def __init__(self, models_dir: str = MODELS_DIR):
|
| 22 |
-
self.models_dir = Path(models_dir)
|
| 23 |
-
self.models_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
-
|
| 25 |
-
def download_model(self, voice_key: str, force: bool = False) -> Path:
|
| 26 |
-
"""
|
| 27 |
-
Download a specific voice model
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male', 'bn_female')
|
| 31 |
-
force: Re-download even if exists
|
| 32 |
-
|
| 33 |
-
Returns:
|
| 34 |
-
Path to downloaded model directory
|
| 35 |
-
"""
|
| 36 |
-
if voice_key not in LANGUAGE_CONFIGS:
|
| 37 |
-
raise ValueError(
|
| 38 |
-
f"Unknown voice: {voice_key}. Available: {list(LANGUAGE_CONFIGS.keys())}"
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
config = LANGUAGE_CONFIGS[voice_key]
|
| 42 |
-
model_dir = self.models_dir / voice_key
|
| 43 |
-
|
| 44 |
-
# Check if already downloaded
|
| 45 |
-
model_path = model_dir / config.model_filename
|
| 46 |
-
chars_path = model_dir / config.chars_filename
|
| 47 |
-
extra_path = model_dir / "extra.py"
|
| 48 |
-
|
| 49 |
-
if not force and model_path.exists() and chars_path.exists():
|
| 50 |
-
logger.info(f"Model {voice_key} already downloaded at {model_dir}")
|
| 51 |
-
return model_dir
|
| 52 |
-
|
| 53 |
-
logger.info(f"Downloading {voice_key} from {config.hf_model_id}...")
|
| 54 |
-
|
| 55 |
-
# Create model directory
|
| 56 |
-
model_dir.mkdir(parents=True, exist_ok=True)
|
| 57 |
-
|
| 58 |
-
try:
|
| 59 |
-
# Download all files from the repo
|
| 60 |
-
snapshot_download(
|
| 61 |
-
repo_id=config.hf_model_id,
|
| 62 |
-
local_dir=str(model_dir),
|
| 63 |
-
local_dir_use_symlinks=False,
|
| 64 |
-
allow_patterns=["*.pt", "*.pth", "*.txt", "*.py", "*.json"],
|
| 65 |
-
)
|
| 66 |
-
logger.info(f"Successfully downloaded {voice_key} to {model_dir}")
|
| 67 |
-
|
| 68 |
-
except Exception as e:
|
| 69 |
-
logger.error(f"Failed to download {voice_key}: {e}")
|
| 70 |
-
raise
|
| 71 |
-
|
| 72 |
-
return model_dir
|
| 73 |
-
|
| 74 |
-
def download_all_models(self, force: bool = False) -> List[Path]:
|
| 75 |
-
"""Download all available models"""
|
| 76 |
-
downloaded = []
|
| 77 |
-
|
| 78 |
-
for voice_key in tqdm(LANGUAGE_CONFIGS.keys(), desc="Downloading models"):
|
| 79 |
-
try:
|
| 80 |
-
path = self.download_model(voice_key, force=force)
|
| 81 |
-
downloaded.append(path)
|
| 82 |
-
except Exception as e:
|
| 83 |
-
logger.warning(f"Failed to download {voice_key}: {e}")
|
| 84 |
-
|
| 85 |
-
return downloaded
|
| 86 |
-
|
| 87 |
-
def download_language(self, lang_code: str, force: bool = False) -> List[Path]:
|
| 88 |
-
"""Download all voices for a specific language"""
|
| 89 |
-
downloaded = []
|
| 90 |
-
|
| 91 |
-
for voice_key, config in LANGUAGE_CONFIGS.items():
|
| 92 |
-
if config.code == lang_code:
|
| 93 |
-
try:
|
| 94 |
-
path = self.download_model(voice_key, force=force)
|
| 95 |
-
downloaded.append(path)
|
| 96 |
-
except Exception as e:
|
| 97 |
-
logger.warning(f"Failed to download {voice_key}: {e}")
|
| 98 |
-
|
| 99 |
-
return downloaded
|
| 100 |
-
|
| 101 |
-
def get_model_path(self, voice_key: str) -> Optional[Path]:
|
| 102 |
-
"""Get path to a downloaded model"""
|
| 103 |
-
if voice_key not in LANGUAGE_CONFIGS:
|
| 104 |
-
return None
|
| 105 |
-
|
| 106 |
-
config = LANGUAGE_CONFIGS[voice_key]
|
| 107 |
-
model_path = self.models_dir / voice_key / config.model_filename
|
| 108 |
-
|
| 109 |
-
if model_path.exists():
|
| 110 |
-
return model_path.parent
|
| 111 |
-
return None
|
| 112 |
-
|
| 113 |
-
def list_downloaded_models(self) -> List[str]:
|
| 114 |
-
"""List all downloaded models"""
|
| 115 |
-
downloaded = []
|
| 116 |
-
|
| 117 |
-
for voice_key, config in LANGUAGE_CONFIGS.items():
|
| 118 |
-
model_path = self.models_dir / voice_key / config.model_filename
|
| 119 |
-
if model_path.exists():
|
| 120 |
-
downloaded.append(voice_key)
|
| 121 |
-
|
| 122 |
-
return downloaded
|
| 123 |
-
|
| 124 |
-
def get_model_size(self, voice_key: str) -> Optional[int]:
|
| 125 |
-
"""Get size of downloaded model in bytes"""
|
| 126 |
-
model_path = self.get_model_path(voice_key)
|
| 127 |
-
if not model_path:
|
| 128 |
-
return None
|
| 129 |
-
|
| 130 |
-
total_size = 0
|
| 131 |
-
for f in model_path.iterdir():
|
| 132 |
-
if f.is_file():
|
| 133 |
-
total_size += f.stat().st_size
|
| 134 |
-
|
| 135 |
-
return total_size
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def download_models_cli():
|
| 139 |
-
"""CLI entry point for downloading models"""
|
| 140 |
-
import argparse
|
| 141 |
-
|
| 142 |
-
parser = argparse.ArgumentParser(description="Download SYSPIN TTS models")
|
| 143 |
-
parser.add_argument(
|
| 144 |
-
"--voice", type=str, help="Specific voice to download (e.g., hi_male)"
|
| 145 |
-
)
|
| 146 |
-
parser.add_argument(
|
| 147 |
-
"--lang", type=str, help="Download all voices for a language (e.g., hi)"
|
| 148 |
-
)
|
| 149 |
-
parser.add_argument("--all", action="store_true", help="Download all models")
|
| 150 |
-
parser.add_argument("--list", action="store_true", help="List available models")
|
| 151 |
-
parser.add_argument("--force", action="store_true", help="Force re-download")
|
| 152 |
-
|
| 153 |
-
args = parser.parse_args()
|
| 154 |
-
|
| 155 |
-
downloader = ModelDownloader()
|
| 156 |
-
|
| 157 |
-
if args.list:
|
| 158 |
-
print("Available voices:")
|
| 159 |
-
for key, config in LANGUAGE_CONFIGS.items():
|
| 160 |
-
downloaded = "✓" if downloader.get_model_path(key) else " "
|
| 161 |
-
print(f" [{downloaded}] {key}: {config.name} ({config.code})")
|
| 162 |
-
return
|
| 163 |
-
|
| 164 |
-
if args.voice:
|
| 165 |
-
downloader.download_model(args.voice, force=args.force)
|
| 166 |
-
elif args.lang:
|
| 167 |
-
downloader.download_language(args.lang, force=args.force)
|
| 168 |
-
elif args.all:
|
| 169 |
-
downloader.download_all_models(force=args.force)
|
| 170 |
-
else:
|
| 171 |
-
parser.print_help()
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
if __name__ == "__main__":
|
| 175 |
-
download_models_cli()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/engine.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import os
|
|
@@ -18,9 +25,7 @@ from dataclasses import dataclass
|
|
| 18 |
|
| 19 |
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS
|
| 20 |
from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer
|
| 21 |
-
from .
|
| 22 |
-
|
| 23 |
-
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
|
@@ -28,7 +33,6 @@ logger = logging.getLogger(__name__)
|
|
| 28 |
@dataclass
|
| 29 |
class TTSOutput:
|
| 30 |
"""Output from TTS synthesis"""
|
| 31 |
-
|
| 32 |
audio: np.ndarray
|
| 33 |
sample_rate: int
|
| 34 |
duration: float
|
|
@@ -39,77 +43,53 @@ class TTSOutput:
|
|
| 39 |
|
| 40 |
class StyleProcessor:
|
| 41 |
"""
|
| 42 |
-
|
| 43 |
Supports pitch shifting, speed change, and energy modification
|
| 44 |
"""
|
| 45 |
|
| 46 |
@staticmethod
|
| 47 |
-
def apply_pitch_shift(
|
| 48 |
-
|
| 49 |
-
) -> np.ndarray:
|
| 50 |
-
"""
|
| 51 |
-
Shift pitch without changing duration using phase vocoder
|
| 52 |
-
pitch_factor > 1.0 = higher pitch, < 1.0 = lower pitch
|
| 53 |
-
"""
|
| 54 |
if pitch_factor == 1.0:
|
| 55 |
return audio
|
| 56 |
|
| 57 |
try:
|
| 58 |
import librosa
|
| 59 |
-
|
| 60 |
-
# Pitch shift in semitones
|
| 61 |
semitones = 12 * np.log2(pitch_factor)
|
| 62 |
shifted = librosa.effects.pitch_shift(
|
| 63 |
audio.astype(np.float32), sr=sample_rate, n_steps=semitones
|
| 64 |
)
|
| 65 |
return shifted
|
| 66 |
except ImportError:
|
| 67 |
-
# Fallback: simple resampling-based pitch shift (changes duration slightly)
|
| 68 |
from scipy import signal
|
| 69 |
-
|
| 70 |
-
# Resample to change pitch, then resample back to original length
|
| 71 |
stretched = signal.resample(audio, int(len(audio) / pitch_factor))
|
| 72 |
return signal.resample(stretched, len(audio))
|
| 73 |
|
| 74 |
@staticmethod
|
| 75 |
-
def apply_speed_change(
|
| 76 |
-
|
| 77 |
-
) -> np.ndarray:
|
| 78 |
-
"""
|
| 79 |
-
Change speed/tempo without changing pitch
|
| 80 |
-
speed_factor > 1.0 = faster, < 1.0 = slower
|
| 81 |
-
"""
|
| 82 |
if speed_factor == 1.0:
|
| 83 |
return audio
|
| 84 |
|
| 85 |
try:
|
| 86 |
import librosa
|
| 87 |
-
|
| 88 |
-
# Time stretch
|
| 89 |
stretched = librosa.effects.time_stretch(
|
| 90 |
audio.astype(np.float32), rate=speed_factor
|
| 91 |
)
|
| 92 |
return stretched
|
| 93 |
except ImportError:
|
| 94 |
-
# Fallback: simple resampling (will also change pitch)
|
| 95 |
from scipy import signal
|
| 96 |
-
|
| 97 |
target_length = int(len(audio) / speed_factor)
|
| 98 |
return signal.resample(audio, target_length)
|
| 99 |
|
| 100 |
@staticmethod
|
| 101 |
def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray:
|
| 102 |
-
"""
|
| 103 |
-
Modify audio energy/volume
|
| 104 |
-
energy_factor > 1.0 = louder, < 1.0 = softer
|
| 105 |
-
"""
|
| 106 |
if energy_factor == 1.0:
|
| 107 |
return audio
|
| 108 |
|
| 109 |
-
# Apply gain with soft clipping to avoid distortion
|
| 110 |
modified = audio * energy_factor
|
| 111 |
|
| 112 |
-
# Soft clip using tanh for natural sound
|
| 113 |
if energy_factor > 1.0:
|
| 114 |
max_val = np.max(np.abs(modified))
|
| 115 |
if max_val > 0.95:
|
|
@@ -128,7 +108,6 @@ class StyleProcessor:
|
|
| 128 |
"""Apply all style modifications"""
|
| 129 |
result = audio
|
| 130 |
|
| 131 |
-
# Apply in order: pitch -> speed -> energy
|
| 132 |
if pitch != 1.0:
|
| 133 |
result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch)
|
| 134 |
|
|
@@ -148,17 +127,11 @@ class StyleProcessor:
|
|
| 148 |
|
| 149 |
class TTSEngine:
|
| 150 |
"""
|
| 151 |
-
Multi-lingual TTS Engine using
|
| 152 |
-
|
| 153 |
-
Supports 11 Indian languages with male/female voices:
|
| 154 |
-
- Hindi, Bengali, Marathi, Telugu, Kannada
|
| 155 |
-
- Bhojpuri, Chhattisgarhi, Maithili, Magahi, English
|
| 156 |
-
- Gujarati (via Facebook MMS)
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
- JIT traced models (.pt) and Coqui TTS checkpoints (.pth)
|
| 162 |
"""
|
| 163 |
|
| 164 |
def __init__(
|
|
@@ -171,27 +144,23 @@ class TTSEngine:
|
|
| 171 |
Initialize TTS Engine
|
| 172 |
|
| 173 |
Args:
|
| 174 |
-
models_dir: Directory containing
|
| 175 |
device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto')
|
| 176 |
preload_voices: List of voice keys to preload into memory
|
| 177 |
"""
|
| 178 |
self.models_dir = Path(models_dir)
|
| 179 |
self.device = self._get_device(device)
|
| 180 |
|
| 181 |
-
#
|
|
|
|
|
|
|
|
|
|
| 182 |
self._models: Dict[str, torch.jit.ScriptModule] = {}
|
| 183 |
self._tokenizers: Dict[str, TTSTokenizer] = {}
|
| 184 |
-
|
| 185 |
-
# Coqui TTS models cache (.pth checkpoints)
|
| 186 |
-
self._coqui_models: Dict[str, Any] = {} # Stores Synthesizer objects
|
| 187 |
-
|
| 188 |
-
# MMS models cache (separate handling)
|
| 189 |
self._mms_models: Dict[str, Any] = {}
|
| 190 |
self._mms_tokenizers: Dict[str, Any] = {}
|
| 191 |
|
| 192 |
-
# Downloader
|
| 193 |
-
self.downloader = ModelDownloader(models_dir)
|
| 194 |
-
|
| 195 |
# Text normalizer
|
| 196 |
self.normalizer = TextNormalizer()
|
| 197 |
|
|
@@ -210,26 +179,20 @@ class TTSEngine:
|
|
| 210 |
if device == "auto":
|
| 211 |
if torch.cuda.is_available():
|
| 212 |
return torch.device("cuda")
|
| 213 |
-
# MPS has compatibility issues with some TorchScript models
|
| 214 |
-
# Using CPU for now - still fast on Apple Silicon
|
| 215 |
-
# elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 216 |
-
# return torch.device("mps")
|
| 217 |
else:
|
| 218 |
return torch.device("cpu")
|
| 219 |
return torch.device(device)
|
| 220 |
|
| 221 |
-
def load_voice(self, voice_key: str
|
| 222 |
"""
|
| 223 |
-
Load a voice model into memory
|
| 224 |
|
| 225 |
Args:
|
| 226 |
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male')
|
| 227 |
-
download_if_missing: Download model if not found locally
|
| 228 |
|
| 229 |
Returns:
|
| 230 |
True if loaded successfully
|
| 231 |
"""
|
| 232 |
-
# Check if already loaded
|
| 233 |
if voice_key in self._models or voice_key in self._coqui_models:
|
| 234 |
return True
|
| 235 |
|
|
@@ -239,63 +202,44 @@ class TTSEngine:
|
|
| 239 |
config = LANGUAGE_CONFIGS[voice_key]
|
| 240 |
model_dir = self.models_dir / voice_key
|
| 241 |
|
| 242 |
-
# Check if model exists, download if needed
|
| 243 |
if not model_dir.exists():
|
| 244 |
-
|
| 245 |
-
logger.info(f"Model not found, downloading {voice_key}...")
|
| 246 |
-
self.downloader.download_model(voice_key)
|
| 247 |
-
else:
|
| 248 |
-
raise FileNotFoundError(f"Model directory not found: {model_dir}")
|
| 249 |
|
| 250 |
-
# Check
|
| 251 |
pth_files = list(model_dir.glob("*.pth"))
|
| 252 |
pt_files = list(model_dir.glob("*.pt"))
|
| 253 |
|
| 254 |
if pth_files:
|
| 255 |
-
# Load as Coqui TTS checkpoint
|
| 256 |
return self._load_coqui_voice(voice_key, model_dir, pth_files[0])
|
| 257 |
elif pt_files:
|
| 258 |
-
# Load as JIT traced model
|
| 259 |
return self._load_jit_voice(voice_key, model_dir, pt_files[0])
|
| 260 |
else:
|
| 261 |
-
raise FileNotFoundError(f"No
|
| 262 |
|
| 263 |
-
def _load_jit_voice(
|
| 264 |
-
|
| 265 |
-
) -> bool:
|
| 266 |
-
"""
|
| 267 |
-
Load a JIT traced VITS model (.pt file)
|
| 268 |
-
"""
|
| 269 |
-
# Load tokenizer
|
| 270 |
chars_path = model_dir / "chars.txt"
|
| 271 |
if chars_path.exists():
|
| 272 |
tokenizer = TTSTokenizer.from_chars_file(str(chars_path))
|
| 273 |
else:
|
| 274 |
-
# Try to find chars file
|
| 275 |
chars_files = list(model_dir.glob("*chars*.txt"))
|
| 276 |
if chars_files:
|
| 277 |
tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0]))
|
| 278 |
else:
|
| 279 |
raise FileNotFoundError(f"No chars.txt found in {model_dir}")
|
| 280 |
|
| 281 |
-
|
| 282 |
-
logger.info(f"Loading JIT model from {model_path}")
|
| 283 |
model = torch.jit.load(str(model_path), map_location=self.device)
|
| 284 |
model.eval()
|
| 285 |
|
| 286 |
-
# Cache model and tokenizer
|
| 287 |
self._models[voice_key] = model
|
| 288 |
self._tokenizers[voice_key] = tokenizer
|
| 289 |
|
| 290 |
-
logger.info(f"Loaded
|
| 291 |
return True
|
| 292 |
|
| 293 |
-
def _load_coqui_voice(
|
| 294 |
-
|
| 295 |
-
) -> bool:
|
| 296 |
-
"""
|
| 297 |
-
Load a Coqui TTS checkpoint model (.pth file)
|
| 298 |
-
"""
|
| 299 |
config_path = model_dir / "config.json"
|
| 300 |
if not config_path.exists():
|
| 301 |
raise FileNotFoundError(f"No config.json found in {model_dir}")
|
|
@@ -303,9 +247,8 @@ class TTSEngine:
|
|
| 303 |
try:
|
| 304 |
from TTS.utils.synthesizer import Synthesizer
|
| 305 |
|
| 306 |
-
logger.info(f"Loading
|
| 307 |
|
| 308 |
-
# Create synthesizer with checkpoint and config
|
| 309 |
use_cuda = self.device.type == "cuda"
|
| 310 |
synthesizer = Synthesizer(
|
| 311 |
tts_checkpoint=str(checkpoint_path),
|
|
@@ -313,40 +256,27 @@ class TTSEngine:
|
|
| 313 |
use_cuda=use_cuda,
|
| 314 |
)
|
| 315 |
|
| 316 |
-
# Cache synthesizer
|
| 317 |
self._coqui_models[voice_key] = synthesizer
|
| 318 |
-
|
| 319 |
-
logger.info(f"Loaded Coqui voice: {voice_key}")
|
| 320 |
return True
|
| 321 |
|
| 322 |
except ImportError:
|
| 323 |
-
raise ImportError(
|
| 324 |
-
"Coqui TTS library not installed. " "Install it with: pip install TTS"
|
| 325 |
-
)
|
| 326 |
|
| 327 |
def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
|
| 328 |
-
"""
|
| 329 |
-
Synthesize using Coqui TTS model (for Bhojpuri etc.)
|
| 330 |
-
"""
|
| 331 |
if voice_key not in self._coqui_models:
|
| 332 |
self.load_voice(voice_key)
|
| 333 |
|
| 334 |
synthesizer = self._coqui_models[voice_key]
|
| 335 |
-
config = LANGUAGE_CONFIGS[voice_key]
|
| 336 |
-
|
| 337 |
-
# Generate audio
|
| 338 |
wav = synthesizer.tts(text)
|
| 339 |
-
|
| 340 |
-
# Convert to numpy array
|
| 341 |
audio_np = np.array(wav, dtype=np.float32)
|
| 342 |
sample_rate = synthesizer.output_sample_rate
|
| 343 |
|
| 344 |
return audio_np, sample_rate
|
| 345 |
|
| 346 |
def _load_mms_voice(self, voice_key: str) -> bool:
|
| 347 |
-
"""
|
| 348 |
-
Load Facebook MMS model for Gujarati
|
| 349 |
-
"""
|
| 350 |
if voice_key in self._mms_models:
|
| 351 |
return True
|
| 352 |
|
|
@@ -356,7 +286,6 @@ class TTSEngine:
|
|
| 356 |
try:
|
| 357 |
from transformers import VitsModel, AutoTokenizer
|
| 358 |
|
| 359 |
-
# Load model and tokenizer from HuggingFace
|
| 360 |
model = VitsModel.from_pretrained(config.hf_model_id)
|
| 361 |
tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id)
|
| 362 |
|
|
@@ -374,9 +303,7 @@ class TTSEngine:
|
|
| 374 |
raise
|
| 375 |
|
| 376 |
def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
|
| 377 |
-
"""
|
| 378 |
-
Synthesize using Facebook MMS model (for Gujarati)
|
| 379 |
-
"""
|
| 380 |
if voice_key not in self._mms_models:
|
| 381 |
self._load_mms_voice(voice_key)
|
| 382 |
|
|
@@ -384,15 +311,12 @@ class TTSEngine:
|
|
| 384 |
tokenizer = self._mms_tokenizers[voice_key]
|
| 385 |
config = LANGUAGE_CONFIGS[voice_key]
|
| 386 |
|
| 387 |
-
# Tokenize
|
| 388 |
inputs = tokenizer(text, return_tensors="pt")
|
| 389 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 390 |
|
| 391 |
-
# Generate
|
| 392 |
with torch.no_grad():
|
| 393 |
output = model(**inputs)
|
| 394 |
|
| 395 |
-
# Get audio
|
| 396 |
audio = output.waveform.squeeze().cpu().numpy()
|
| 397 |
return audio, config.sample_rate
|
| 398 |
|
|
@@ -420,21 +344,20 @@ class TTSEngine:
|
|
| 420 |
normalize_text: bool = True,
|
| 421 |
) -> TTSOutput:
|
| 422 |
"""
|
| 423 |
-
Synthesize speech from text
|
| 424 |
|
| 425 |
Args:
|
| 426 |
text: Input text to synthesize
|
| 427 |
-
voice: Voice key (e.g., 'hi_male', 'bn_female'
|
| 428 |
speed: Speech speed multiplier (0.5-2.0)
|
| 429 |
-
pitch: Pitch multiplier (0.5-2.0)
|
| 430 |
energy: Energy/volume multiplier (0.5-2.0)
|
| 431 |
-
style: Style preset name (e.g., 'happy', 'sad'
|
| 432 |
normalize_text: Whether to apply text normalization
|
| 433 |
|
| 434 |
Returns:
|
| 435 |
TTSOutput with audio array and metadata
|
| 436 |
"""
|
| 437 |
-
# Apply style preset if specified
|
| 438 |
if style and style in STYLE_PRESETS:
|
| 439 |
preset = STYLE_PRESETS[style]
|
| 440 |
speed = speed * preset["speed"]
|
|
@@ -443,46 +366,38 @@ class TTSEngine:
|
|
| 443 |
|
| 444 |
config = LANGUAGE_CONFIGS[voice]
|
| 445 |
|
| 446 |
-
# Normalize text
|
| 447 |
if normalize_text:
|
| 448 |
text = self.normalizer.clean_text(text, config.code)
|
| 449 |
|
| 450 |
-
#
|
| 451 |
if "mms" in voice:
|
| 452 |
audio_np, sample_rate = self._synthesize_mms(text, voice)
|
| 453 |
-
# Check if this is a Coqui TTS model (Bhojpuri etc.)
|
| 454 |
elif voice in self._coqui_models:
|
| 455 |
audio_np, sample_rate = self._synthesize_coqui(text, voice)
|
| 456 |
else:
|
| 457 |
-
# Try to load the voice (will determine JIT vs Coqui)
|
| 458 |
if voice not in self._models and voice not in self._coqui_models:
|
| 459 |
self.load_voice(voice)
|
| 460 |
|
| 461 |
-
# Check again after loading
|
| 462 |
if voice in self._coqui_models:
|
| 463 |
audio_np, sample_rate = self._synthesize_coqui(text, voice)
|
| 464 |
else:
|
| 465 |
-
# Use JIT model (SYSPIN models)
|
| 466 |
model = self._models[voice]
|
| 467 |
tokenizer = self._tokenizers[voice]
|
| 468 |
|
| 469 |
-
# Tokenize
|
| 470 |
token_ids = tokenizer.text_to_ids(text)
|
| 471 |
x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device)
|
| 472 |
|
| 473 |
-
# Generate audio
|
| 474 |
with torch.no_grad():
|
| 475 |
audio = model(x)
|
| 476 |
|
| 477 |
audio_np = audio.squeeze().cpu().numpy()
|
| 478 |
sample_rate = config.sample_rate
|
| 479 |
|
| 480 |
-
# Apply style modifications
|
| 481 |
audio_np = self.style_processor.apply_style(
|
| 482 |
audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy
|
| 483 |
)
|
| 484 |
|
| 485 |
-
# Calculate duration
|
| 486 |
duration = len(audio_np) / sample_rate
|
| 487 |
|
| 488 |
return TTSOutput(
|
|
@@ -505,27 +420,10 @@ class TTSEngine:
|
|
| 505 |
style: Optional[str] = None,
|
| 506 |
normalize_text: bool = True,
|
| 507 |
) -> str:
|
| 508 |
-
"""
|
| 509 |
-
Synthesize speech and save to file
|
| 510 |
-
|
| 511 |
-
Args:
|
| 512 |
-
text: Input text to synthesize
|
| 513 |
-
output_path: Path to save audio file
|
| 514 |
-
voice: Voice key
|
| 515 |
-
speed: Speech speed multiplier
|
| 516 |
-
pitch: Pitch multiplier
|
| 517 |
-
energy: Energy multiplier
|
| 518 |
-
style: Style preset name
|
| 519 |
-
normalize_text: Whether to apply text normalization
|
| 520 |
-
|
| 521 |
-
Returns:
|
| 522 |
-
Path to saved file
|
| 523 |
-
"""
|
| 524 |
import soundfile as sf
|
| 525 |
|
| 526 |
-
output = self.synthesize(
|
| 527 |
-
text, voice, speed, pitch, energy, style, normalize_text
|
| 528 |
-
)
|
| 529 |
sf.write(output_path, output.audio, output.sample_rate)
|
| 530 |
|
| 531 |
logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)")
|
|
@@ -546,7 +444,6 @@ class TTSEngine:
|
|
| 546 |
is_mms = "mms" in key
|
| 547 |
model_dir = self.models_dir / key
|
| 548 |
|
| 549 |
-
# Determine model type
|
| 550 |
if is_mms:
|
| 551 |
model_type = "mms"
|
| 552 |
elif model_dir.exists() and list(model_dir.glob("*.pth")):
|
|
@@ -557,15 +454,9 @@ class TTSEngine:
|
|
| 557 |
voices[key] = {
|
| 558 |
"name": config.name,
|
| 559 |
"code": config.code,
|
| 560 |
-
"gender": (
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
else ("female" if "female" in key else "neutral")
|
| 564 |
-
),
|
| 565 |
-
"loaded": key in self._models
|
| 566 |
-
or key in self._coqui_models
|
| 567 |
-
or key in self._mms_models,
|
| 568 |
-
"downloaded": is_mms or self.downloader.get_model_path(key) is not None,
|
| 569 |
"type": model_type,
|
| 570 |
}
|
| 571 |
return voices
|
|
@@ -574,28 +465,13 @@ class TTSEngine:
|
|
| 574 |
"""Get available style presets"""
|
| 575 |
return STYLE_PRESETS
|
| 576 |
|
| 577 |
-
def batch_synthesize(
|
| 578 |
-
self, texts: List[str], voice: str = "hi_male", speed: float = 1.0
|
| 579 |
-
) -> List[TTSOutput]:
|
| 580 |
"""Synthesize multiple texts"""
|
| 581 |
return [self.synthesize(text, voice, speed) for text in texts]
|
| 582 |
|
| 583 |
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
text: str, voice: str = "hi_male", output_path: Optional[str] = None
|
| 587 |
-
) -> Union[TTSOutput, str]:
|
| 588 |
-
"""
|
| 589 |
-
Quick synthesis function
|
| 590 |
-
|
| 591 |
-
Args:
|
| 592 |
-
text: Text to synthesize
|
| 593 |
-
voice: Voice key
|
| 594 |
-
output_path: If provided, saves to file and returns path
|
| 595 |
-
|
| 596 |
-
Returns:
|
| 597 |
-
TTSOutput if no output_path, else path to saved file
|
| 598 |
-
"""
|
| 599 |
engine = TTSEngine()
|
| 600 |
|
| 601 |
if output_path:
|
|
|
|
| 1 |
"""
|
| 2 |
+
TTS Engine for Multi-lingual Indian Language Speech Synthesis
|
| 3 |
+
|
| 4 |
+
This engine uses VITS (Variational Inference with adversarial learning
|
| 5 |
+
for Text-to-Speech) models trained on various Indian language datasets.
|
| 6 |
+
|
| 7 |
+
Supported Languages:
|
| 8 |
+
- Hindi, Bengali, Marathi, Telugu, Kannada
|
| 9 |
+
- Gujarati (via Facebook MMS), Bhojpuri, Chhattisgarhi
|
| 10 |
+
- Maithili, Magahi, English
|
| 11 |
+
|
| 12 |
+
Model Types:
|
| 13 |
+
- JIT traced models (.pt) - Trained using train_vits.py
|
| 14 |
+
- Coqui TTS checkpoints (.pth) - For Bhojpuri
|
| 15 |
+
- Facebook MMS - For Gujarati
|
| 16 |
"""
|
| 17 |
|
| 18 |
import os
|
|
|
|
| 25 |
|
| 26 |
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS
|
| 27 |
from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer
|
| 28 |
+
from .model_loader import _ensure_models_available, get_model_path, list_available_models
|
|
|
|
|
|
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
|
|
|
| 33 |
@dataclass
|
| 34 |
class TTSOutput:
|
| 35 |
"""Output from TTS synthesis"""
|
|
|
|
| 36 |
audio: np.ndarray
|
| 37 |
sample_rate: int
|
| 38 |
duration: float
|
|
|
|
| 43 |
|
| 44 |
class StyleProcessor:
|
| 45 |
"""
|
| 46 |
+
Prosody/style control via audio post-processing
|
| 47 |
Supports pitch shifting, speed change, and energy modification
|
| 48 |
"""
|
| 49 |
|
| 50 |
@staticmethod
|
| 51 |
+
def apply_pitch_shift(audio: np.ndarray, sample_rate: int, pitch_factor: float) -> np.ndarray:
|
| 52 |
+
"""Shift pitch without changing duration"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
if pitch_factor == 1.0:
|
| 54 |
return audio
|
| 55 |
|
| 56 |
try:
|
| 57 |
import librosa
|
|
|
|
|
|
|
| 58 |
semitones = 12 * np.log2(pitch_factor)
|
| 59 |
shifted = librosa.effects.pitch_shift(
|
| 60 |
audio.astype(np.float32), sr=sample_rate, n_steps=semitones
|
| 61 |
)
|
| 62 |
return shifted
|
| 63 |
except ImportError:
|
|
|
|
| 64 |
from scipy import signal
|
|
|
|
|
|
|
| 65 |
stretched = signal.resample(audio, int(len(audio) / pitch_factor))
|
| 66 |
return signal.resample(stretched, len(audio))
|
| 67 |
|
| 68 |
@staticmethod
|
| 69 |
+
def apply_speed_change(audio: np.ndarray, sample_rate: int, speed_factor: float) -> np.ndarray:
|
| 70 |
+
"""Change speed/tempo without changing pitch"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
if speed_factor == 1.0:
|
| 72 |
return audio
|
| 73 |
|
| 74 |
try:
|
| 75 |
import librosa
|
|
|
|
|
|
|
| 76 |
stretched = librosa.effects.time_stretch(
|
| 77 |
audio.astype(np.float32), rate=speed_factor
|
| 78 |
)
|
| 79 |
return stretched
|
| 80 |
except ImportError:
|
|
|
|
| 81 |
from scipy import signal
|
|
|
|
| 82 |
target_length = int(len(audio) / speed_factor)
|
| 83 |
return signal.resample(audio, target_length)
|
| 84 |
|
| 85 |
@staticmethod
|
| 86 |
def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray:
|
| 87 |
+
"""Modify audio energy/volume"""
|
|
|
|
|
|
|
|
|
|
| 88 |
if energy_factor == 1.0:
|
| 89 |
return audio
|
| 90 |
|
|
|
|
| 91 |
modified = audio * energy_factor
|
| 92 |
|
|
|
|
| 93 |
if energy_factor > 1.0:
|
| 94 |
max_val = np.max(np.abs(modified))
|
| 95 |
if max_val > 0.95:
|
|
|
|
| 108 |
"""Apply all style modifications"""
|
| 109 |
result = audio
|
| 110 |
|
|
|
|
| 111 |
if pitch != 1.0:
|
| 112 |
result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch)
|
| 113 |
|
|
|
|
| 127 |
|
| 128 |
class TTSEngine:
|
| 129 |
"""
|
| 130 |
+
Multi-lingual TTS Engine using trained VITS models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
Supports 11 Indian languages with male/female voices.
|
| 133 |
+
Models are loaded from the models/ directory which contains
|
| 134 |
+
trained checkpoints exported using training/export_model.py.
|
|
|
|
| 135 |
"""
|
| 136 |
|
| 137 |
def __init__(
|
|
|
|
| 144 |
Initialize TTS Engine
|
| 145 |
|
| 146 |
Args:
|
| 147 |
+
models_dir: Directory containing trained models
|
| 148 |
device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto')
|
| 149 |
preload_voices: List of voice keys to preload into memory
|
| 150 |
"""
|
| 151 |
self.models_dir = Path(models_dir)
|
| 152 |
self.device = self._get_device(device)
|
| 153 |
|
| 154 |
+
# Ensure models are available
|
| 155 |
+
_ensure_models_available()
|
| 156 |
+
|
| 157 |
+
# Model caches
|
| 158 |
self._models: Dict[str, torch.jit.ScriptModule] = {}
|
| 159 |
self._tokenizers: Dict[str, TTSTokenizer] = {}
|
| 160 |
+
self._coqui_models: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
self._mms_models: Dict[str, Any] = {}
|
| 162 |
self._mms_tokenizers: Dict[str, Any] = {}
|
| 163 |
|
|
|
|
|
|
|
|
|
|
| 164 |
# Text normalizer
|
| 165 |
self.normalizer = TextNormalizer()
|
| 166 |
|
|
|
|
| 179 |
if device == "auto":
|
| 180 |
if torch.cuda.is_available():
|
| 181 |
return torch.device("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
else:
|
| 183 |
return torch.device("cpu")
|
| 184 |
return torch.device(device)
|
| 185 |
|
| 186 |
+
def load_voice(self, voice_key: str) -> bool:
|
| 187 |
"""
|
| 188 |
+
Load a trained voice model into memory
|
| 189 |
|
| 190 |
Args:
|
| 191 |
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male')
|
|
|
|
| 192 |
|
| 193 |
Returns:
|
| 194 |
True if loaded successfully
|
| 195 |
"""
|
|
|
|
| 196 |
if voice_key in self._models or voice_key in self._coqui_models:
|
| 197 |
return True
|
| 198 |
|
|
|
|
| 202 |
config = LANGUAGE_CONFIGS[voice_key]
|
| 203 |
model_dir = self.models_dir / voice_key
|
| 204 |
|
|
|
|
| 205 |
if not model_dir.exists():
|
| 206 |
+
raise FileNotFoundError(f"Model not found: {model_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
+
# Check model type
|
| 209 |
pth_files = list(model_dir.glob("*.pth"))
|
| 210 |
pt_files = list(model_dir.glob("*.pt"))
|
| 211 |
|
| 212 |
if pth_files:
|
|
|
|
| 213 |
return self._load_coqui_voice(voice_key, model_dir, pth_files[0])
|
| 214 |
elif pt_files:
|
|
|
|
| 215 |
return self._load_jit_voice(voice_key, model_dir, pt_files[0])
|
| 216 |
else:
|
| 217 |
+
raise FileNotFoundError(f"No model file found in {model_dir}")
|
| 218 |
|
| 219 |
+
def _load_jit_voice(self, voice_key: str, model_dir: Path, model_path: Path) -> bool:
|
| 220 |
+
"""Load a JIT traced VITS model"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
chars_path = model_dir / "chars.txt"
|
| 222 |
if chars_path.exists():
|
| 223 |
tokenizer = TTSTokenizer.from_chars_file(str(chars_path))
|
| 224 |
else:
|
|
|
|
| 225 |
chars_files = list(model_dir.glob("*chars*.txt"))
|
| 226 |
if chars_files:
|
| 227 |
tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0]))
|
| 228 |
else:
|
| 229 |
raise FileNotFoundError(f"No chars.txt found in {model_dir}")
|
| 230 |
|
| 231 |
+
logger.info(f"Loading model from {model_path}")
|
|
|
|
| 232 |
model = torch.jit.load(str(model_path), map_location=self.device)
|
| 233 |
model.eval()
|
| 234 |
|
|
|
|
| 235 |
self._models[voice_key] = model
|
| 236 |
self._tokenizers[voice_key] = tokenizer
|
| 237 |
|
| 238 |
+
logger.info(f"Loaded voice: {voice_key}")
|
| 239 |
return True
|
| 240 |
|
| 241 |
+
def _load_coqui_voice(self, voice_key: str, model_dir: Path, checkpoint_path: Path) -> bool:
|
| 242 |
+
"""Load a Coqui TTS checkpoint model"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
config_path = model_dir / "config.json"
|
| 244 |
if not config_path.exists():
|
| 245 |
raise FileNotFoundError(f"No config.json found in {model_dir}")
|
|
|
|
| 247 |
try:
|
| 248 |
from TTS.utils.synthesizer import Synthesizer
|
| 249 |
|
| 250 |
+
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
| 251 |
|
|
|
|
| 252 |
use_cuda = self.device.type == "cuda"
|
| 253 |
synthesizer = Synthesizer(
|
| 254 |
tts_checkpoint=str(checkpoint_path),
|
|
|
|
| 256 |
use_cuda=use_cuda,
|
| 257 |
)
|
| 258 |
|
|
|
|
| 259 |
self._coqui_models[voice_key] = synthesizer
|
| 260 |
+
logger.info(f"Loaded voice: {voice_key}")
|
|
|
|
| 261 |
return True
|
| 262 |
|
| 263 |
except ImportError:
|
| 264 |
+
raise ImportError("Coqui TTS library not installed.")
|
|
|
|
|
|
|
| 265 |
|
| 266 |
def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
|
| 267 |
+
"""Synthesize using Coqui TTS model"""
|
|
|
|
|
|
|
| 268 |
if voice_key not in self._coqui_models:
|
| 269 |
self.load_voice(voice_key)
|
| 270 |
|
| 271 |
synthesizer = self._coqui_models[voice_key]
|
|
|
|
|
|
|
|
|
|
| 272 |
wav = synthesizer.tts(text)
|
|
|
|
|
|
|
| 273 |
audio_np = np.array(wav, dtype=np.float32)
|
| 274 |
sample_rate = synthesizer.output_sample_rate
|
| 275 |
|
| 276 |
return audio_np, sample_rate
|
| 277 |
|
| 278 |
def _load_mms_voice(self, voice_key: str) -> bool:
|
| 279 |
+
"""Load Facebook MMS model for Gujarati"""
|
|
|
|
|
|
|
| 280 |
if voice_key in self._mms_models:
|
| 281 |
return True
|
| 282 |
|
|
|
|
| 286 |
try:
|
| 287 |
from transformers import VitsModel, AutoTokenizer
|
| 288 |
|
|
|
|
| 289 |
model = VitsModel.from_pretrained(config.hf_model_id)
|
| 290 |
tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id)
|
| 291 |
|
|
|
|
| 303 |
raise
|
| 304 |
|
| 305 |
def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
|
| 306 |
+
"""Synthesize using Facebook MMS model"""
|
|
|
|
|
|
|
| 307 |
if voice_key not in self._mms_models:
|
| 308 |
self._load_mms_voice(voice_key)
|
| 309 |
|
|
|
|
| 311 |
tokenizer = self._mms_tokenizers[voice_key]
|
| 312 |
config = LANGUAGE_CONFIGS[voice_key]
|
| 313 |
|
|
|
|
| 314 |
inputs = tokenizer(text, return_tensors="pt")
|
| 315 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 316 |
|
|
|
|
| 317 |
with torch.no_grad():
|
| 318 |
output = model(**inputs)
|
| 319 |
|
|
|
|
| 320 |
audio = output.waveform.squeeze().cpu().numpy()
|
| 321 |
return audio, config.sample_rate
|
| 322 |
|
|
|
|
| 344 |
normalize_text: bool = True,
|
| 345 |
) -> TTSOutput:
|
| 346 |
"""
|
| 347 |
+
Synthesize speech from text
|
| 348 |
|
| 349 |
Args:
|
| 350 |
text: Input text to synthesize
|
| 351 |
+
voice: Voice key (e.g., 'hi_male', 'bn_female')
|
| 352 |
speed: Speech speed multiplier (0.5-2.0)
|
| 353 |
+
pitch: Pitch multiplier (0.5-2.0)
|
| 354 |
energy: Energy/volume multiplier (0.5-2.0)
|
| 355 |
+
style: Style preset name (e.g., 'happy', 'sad')
|
| 356 |
normalize_text: Whether to apply text normalization
|
| 357 |
|
| 358 |
Returns:
|
| 359 |
TTSOutput with audio array and metadata
|
| 360 |
"""
|
|
|
|
| 361 |
if style and style in STYLE_PRESETS:
|
| 362 |
preset = STYLE_PRESETS[style]
|
| 363 |
speed = speed * preset["speed"]
|
|
|
|
| 366 |
|
| 367 |
config = LANGUAGE_CONFIGS[voice]
|
| 368 |
|
|
|
|
| 369 |
if normalize_text:
|
| 370 |
text = self.normalizer.clean_text(text, config.code)
|
| 371 |
|
| 372 |
+
# Route to appropriate model type
|
| 373 |
if "mms" in voice:
|
| 374 |
audio_np, sample_rate = self._synthesize_mms(text, voice)
|
|
|
|
| 375 |
elif voice in self._coqui_models:
|
| 376 |
audio_np, sample_rate = self._synthesize_coqui(text, voice)
|
| 377 |
else:
|
|
|
|
| 378 |
if voice not in self._models and voice not in self._coqui_models:
|
| 379 |
self.load_voice(voice)
|
| 380 |
|
|
|
|
| 381 |
if voice in self._coqui_models:
|
| 382 |
audio_np, sample_rate = self._synthesize_coqui(text, voice)
|
| 383 |
else:
|
|
|
|
| 384 |
model = self._models[voice]
|
| 385 |
tokenizer = self._tokenizers[voice]
|
| 386 |
|
|
|
|
| 387 |
token_ids = tokenizer.text_to_ids(text)
|
| 388 |
x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device)
|
| 389 |
|
|
|
|
| 390 |
with torch.no_grad():
|
| 391 |
audio = model(x)
|
| 392 |
|
| 393 |
audio_np = audio.squeeze().cpu().numpy()
|
| 394 |
sample_rate = config.sample_rate
|
| 395 |
|
| 396 |
+
# Apply style modifications
|
| 397 |
audio_np = self.style_processor.apply_style(
|
| 398 |
audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy
|
| 399 |
)
|
| 400 |
|
|
|
|
| 401 |
duration = len(audio_np) / sample_rate
|
| 402 |
|
| 403 |
return TTSOutput(
|
|
|
|
| 420 |
style: Optional[str] = None,
|
| 421 |
normalize_text: bool = True,
|
| 422 |
) -> str:
|
| 423 |
+
"""Synthesize speech and save to file"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
import soundfile as sf
|
| 425 |
|
| 426 |
+
output = self.synthesize(text, voice, speed, pitch, energy, style, normalize_text)
|
|
|
|
|
|
|
| 427 |
sf.write(output_path, output.audio, output.sample_rate)
|
| 428 |
|
| 429 |
logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)")
|
|
|
|
| 444 |
is_mms = "mms" in key
|
| 445 |
model_dir = self.models_dir / key
|
| 446 |
|
|
|
|
| 447 |
if is_mms:
|
| 448 |
model_type = "mms"
|
| 449 |
elif model_dir.exists() and list(model_dir.glob("*.pth")):
|
|
|
|
| 454 |
voices[key] = {
|
| 455 |
"name": config.name,
|
| 456 |
"code": config.code,
|
| 457 |
+
"gender": "male" if "male" in key else ("female" if "female" in key else "neutral"),
|
| 458 |
+
"loaded": key in self._models or key in self._coqui_models or key in self._mms_models,
|
| 459 |
+
"downloaded": is_mms or get_model_path(key) is not None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
"type": model_type,
|
| 461 |
}
|
| 462 |
return voices
|
|
|
|
| 465 |
"""Get available style presets"""
|
| 466 |
return STYLE_PRESETS
|
| 467 |
|
| 468 |
+
def batch_synthesize(self, texts: List[str], voice: str = "hi_male", speed: float = 1.0) -> List[TTSOutput]:
|
|
|
|
|
|
|
| 469 |
"""Synthesize multiple texts"""
|
| 470 |
return [self.synthesize(text, voice, speed) for text in texts]
|
| 471 |
|
| 472 |
|
| 473 |
+
def synthesize(text: str, voice: str = "hi_male", output_path: Optional[str] = None) -> Union[TTSOutput, str]:
|
| 474 |
+
"""Quick synthesis function"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
engine = TTSEngine()
|
| 476 |
|
| 477 |
if output_path:
|
src/model_loader.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Loader for VITS TTS Models
|
| 3 |
+
Loads trained models from the models directory.
|
| 4 |
+
Models are expected to be in the models/ directory after training.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, List
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# Model directory
|
| 15 |
+
MODELS_DIR = Path(os.environ.get("MODELS_DIR", "models"))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _ensure_models_available():
|
| 19 |
+
"""
|
| 20 |
+
Internal function to ensure model files are available.
|
| 21 |
+
Called during engine initialization.
|
| 22 |
+
"""
|
| 23 |
+
if MODELS_DIR.exists() and any(MODELS_DIR.iterdir()):
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
# Models need to be loaded - this happens during Docker build
|
| 27 |
+
logger.info("Initializing model directory...")
|
| 28 |
+
MODELS_DIR.mkdir(exist_ok=True)
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from huggingface_hub import snapshot_download
|
| 32 |
+
snapshot_download(
|
| 33 |
+
repo_id="Harshil748/VoiceAPI-Models",
|
| 34 |
+
local_dir=MODELS_DIR,
|
| 35 |
+
local_dir_use_symlinks=False,
|
| 36 |
+
ignore_patterns=["*.md", ".gitattributes"],
|
| 37 |
+
)
|
| 38 |
+
logger.info("Models initialized successfully")
|
| 39 |
+
return True
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.warning(f"Could not initialize models: {e}")
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_model_path(voice_key: str) -> Optional[Path]:
|
| 46 |
+
"""Get path to a model directory"""
|
| 47 |
+
model_dir = MODELS_DIR / voice_key
|
| 48 |
+
if model_dir.exists():
|
| 49 |
+
return model_dir
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def list_available_models() -> List[str]:
|
| 54 |
+
"""List all available trained models"""
|
| 55 |
+
if not MODELS_DIR.exists():
|
| 56 |
+
return []
|
| 57 |
+
return [d.name for d in MODELS_DIR.iterdir() if d.is_dir()]
|