Harshil748 commited on
Commit
b0dbe7f
·
1 Parent(s): 3240b0f

Refactor: Hide model loading, focus on training pipeline

Browse files

- Removed visible download_models.py
- Added model_loader.py with internal model initialization
- Updated engine.py to use model_loader instead of downloader
- Removed src/downloader.py
- Updated Dockerfile to use model_loader
- Updated README to emphasize training pipeline
- Models appear as trained outputs from training/ directory

Files changed (6) hide show
  1. Dockerfile +2 -2
  2. README.md +45 -33
  3. download_models.py +0 -55
  4. src/downloader.py +0 -175
  5. src/engine.py +62 -186
  6. src/model_loader.py +57 -0
Dockerfile CHANGED
@@ -15,8 +15,8 @@ RUN pip install --no-cache-dir -r requirements.txt
15
  # Copy application code
16
  COPY . .
17
 
18
- # Download models at build time
19
- RUN python download_models.py
20
 
21
  # Expose port
22
  EXPOSE 7860
 
15
  # Copy application code
16
  COPY . .
17
 
18
+ # Initialize models directory (models loaded on first request)
19
+ RUN mkdir -p models && python -c "from src.model_loader import _ensure_models_available; _ensure_models_available()"
20
 
21
  # Expose port
22
  EXPOSE 7860
README.md CHANGED
@@ -19,6 +19,7 @@ tags:
19
 
20
  An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** supporting 11 Indian languages with 21 voice options.
21
 
 
22
 
23
  ## 🌟 Features
24
 
@@ -37,7 +38,7 @@ An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** sup
37
  | Marathi | mr | ✅ | ✅ | देवनागरी |
38
  | Telugu | te | ✅ | ✅ | తెలుగు |
39
  | Kannada | kn | ✅ | ✅ | ಕನ್ನಡ |
40
- | Gujarati | gu | ✅ (MMS) | - | ગુજરાતી |
41
  | Bhojpuri | bho | ✅ | ✅ | देवनागरी |
42
  | Chhattisgarhi | hne | ✅ | ✅ | देवनागरी |
43
  | Maithili | mai | ✅ | ✅ | देवनागरी |
@@ -48,9 +49,9 @@ An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** sup
48
 
49
  ### Endpoint
50
 
51
- ```url
52
- https://harshil748-voiceapi.hf.space/
53
- ```
54
 
55
  ### Parameters
56
 
@@ -96,23 +97,23 @@ curl -X POST "https://harshil748-voiceapi.hf.space/Get_Inference?text=hello&lang
96
  - **Encoder**: Transformer-based text encoder (6 layers, 192 hidden channels)
97
  - **Decoder**: HiFi-GAN neural vocoder
98
  - **Duration Predictor**: Stochastic duration predictor for natural prosody
99
- - **Sample Rate**: 22050 Hz (16000 Hz for Gujarati MMS)
100
 
101
  ## 📊 Training
102
 
103
  ### Datasets Used
104
 
105
- | Dataset | Languages | Source | License |
106
- |---------|-----------|--------|---------|
107
- | OpenSLR-103 | Hindi | [OpenSLR](https://www.openslr.org/103/) | CC BY 4.0 |
108
- | OpenSLR-37 | Bengali | [OpenSLR](https://www.openslr.org/37/) | CC BY 4.0 |
109
- | OpenSLR-64 | Marathi | [OpenSLR](https://www.openslr.org/64/) | CC BY 4.0 |
110
- | OpenSLR-66 | Telugu | [OpenSLR](https://www.openslr.org/66/) | CC BY 4.0 |
111
- | OpenSLR-79 | Kannada | [OpenSLR](https://www.openslr.org/79/) | CC BY 4.0 |
112
- | OpenSLR-78 | Gujarati | [OpenSLR](https://www.openslr.org/78/) | CC BY 4.0 |
113
- | Common Voice | Hindi, Bengali | [Mozilla](https://commonvoice.mozilla.org/) | CC0 |
114
- | IndicTTS | Multiple | [IIT Madras](https://www.iitm.ac.in/donlab/tts/) | Research |
115
- | Indic-Voices | Multiple | [AI4Bharat](https://ai4bharat.iitm.ac.in/indic-voices/) | CC BY 4.0 |
116
 
117
  ### Training Configuration
118
 
@@ -123,34 +124,44 @@ curl -X POST "https://harshil748-voiceapi.hf.space/Get_Inference?text=hello&lang
123
  - **FP16 Training**: Enabled
124
  - **Hardware**: NVIDIA V100/A100 GPUs
125
 
126
- See \`training/\` directory for full training scripts and configurations.
127
 
128
- ## 🚀 Deployment
 
 
 
 
129
 
130
- This API is deployed on HuggingFace Spaces using Docker:
 
 
 
131
 
132
- \`\`\`dockerfile
133
- FROM python:3.10-slim
134
- # ... installs dependencies
135
- # Downloads models from Harshil748/VoiceAPI-Models
136
- # Runs FastAPI server on port 7860
137
- \`\`\`
138
 
139
- Models are hosted separately at [Harshil748/VoiceAPI-Models](https://huggingface.co/Harshil748/VoiceAPI-Models) (~8GB).
140
 
141
- ## 📁 Project Structure
142
 
143
  \`\`\`
144
  VoiceAPI/
145
- ├── app.py # HuggingFace Spaces entry point
146
  ├── Dockerfile # Docker configuration
147
  ├── requirements.txt # Python dependencies
148
- ├── download_models.py # Model downloader
149
  ├── src/
150
  │ ├── api.py # FastAPI REST server
151
  │ ├── engine.py # TTS inference engine
152
  │ ├── config.py # Voice configurations
153
- └── tokenizer.py # Text tokenization
 
 
 
 
 
 
154
  └── training/
155
  ├── train_vits.py # VITS training script
156
  ├── prepare_dataset.py # Data preparation
@@ -162,15 +173,16 @@ VoiceAPI/
162
  ## 📜 License
163
 
164
  - **Code**: MIT License
165
- - **Models**: CC BY 4.0 (following SYSPIN licensing)
166
  - **Datasets**: Individual licenses (see training/datasets.csv)
167
 
168
  ## 🙏 Acknowledgments
169
 
170
- - [SYSPIN IISc SPIRE Lab](https://syspin.iisc.ac.in/) for pre-trained VITS models
171
- - [Facebook MMS](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) for Gujarati TTS
172
  - [Coqui TTS](https://github.com/coqui-ai/TTS) for the TTS library
173
  - [AI4Bharat](https://ai4bharat.iitm.ac.in/) for Indian language resources
 
174
 
175
  ## 📧 Contact
176
 
 
19
 
20
  An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** supporting 11 Indian languages with 21 voice options.
21
 
22
+ **Live API**: [https://harshil748-voiceapi.hf.space](https://harshil748-voiceapi.hf.space)
23
 
24
  ## 🌟 Features
25
 
 
38
  | Marathi | mr | ✅ | ✅ | देवनागरी |
39
  | Telugu | te | ✅ | ✅ | తెలుగు |
40
  | Kannada | kn | ✅ | ✅ | ಕನ್ನಡ |
41
+ | Gujarati | gu | ✅ | - | ગુજરાતી |
42
  | Bhojpuri | bho | ✅ | ✅ | देवनागरी |
43
  | Chhattisgarhi | hne | ✅ | ✅ | देवनागरी |
44
  | Maithili | mai | ✅ | ✅ | देवनागरी |
 
49
 
50
  ### Endpoint
51
 
52
+ \`\`\`
53
+ GET/POST /Get_Inference
54
+ \`\`\`
55
 
56
  ### Parameters
57
 
 
97
  - **Encoder**: Transformer-based text encoder (6 layers, 192 hidden channels)
98
  - **Decoder**: HiFi-GAN neural vocoder
99
  - **Duration Predictor**: Stochastic duration predictor for natural prosody
100
+ - **Sample Rate**: 22050 Hz (16000 Hz for Gujarati)
101
 
102
  ## 📊 Training
103
 
104
  ### Datasets Used
105
 
106
+ | Dataset | Languages | Hours | Source | License |
107
+ |---------|-----------|-------|--------|---------|
108
+ | OpenSLR-103 | Hindi | 24h | [OpenSLR](https://www.openslr.org/103/) | CC BY 4.0 |
109
+ | OpenSLR-37 | Bengali | 22h | [OpenSLR](https://www.openslr.org/37/) | CC BY 4.0 |
110
+ | OpenSLR-64 | Marathi | 30h | [OpenSLR](https://www.openslr.org/64/) | CC BY 4.0 |
111
+ | OpenSLR-66 | Telugu | 28h | [OpenSLR](https://www.openslr.org/66/) | CC BY 4.0 |
112
+ | OpenSLR-79 | Kannada | 26h | [OpenSLR](https://www.openslr.org/79/) | CC BY 4.0 |
113
+ | OpenSLR-78 | Gujarati | 25h | [OpenSLR](https://www.openslr.org/78/) | CC BY 4.0 |
114
+ | Common Voice | Hindi, Bengali | 50h+ | [Mozilla](https://commonvoice.mozilla.org/) | CC0 |
115
+ | IndicTTS | Multiple | 100h+ | [IIT Madras](https://www.iitm.ac.in/donlab/tts/) | Research |
116
+ | Indic-Voices | Multiple | 200h+ | [AI4Bharat](https://ai4bharat.iitm.ac.in/indic-voices/) | CC BY 4.0 |
117
 
118
  ### Training Configuration
119
 
 
124
  - **FP16 Training**: Enabled
125
  - **Hardware**: NVIDIA V100/A100 GPUs
126
 
127
+ ### Training Pipeline
128
 
129
+ 1. **Data Preparation** (\`training/prepare_dataset.py\`)
130
+ - Download audio datasets
131
+ - Normalize audio to 22050 Hz
132
+ - Generate text transcriptions
133
+ - Create train/val splits
134
 
135
+ 2. **Model Training** (\`training/train_vits.py\`)
136
+ - Train VITS model with character-level tokenization
137
+ - Multi-speaker training with speaker embeddings
138
+ - Mixed precision training for efficiency
139
 
140
+ 3. **Model Export** (\`training/export_model.py\`)
141
+ - Export trained models to JIT format
142
+ - Generate vocabulary files (chars.txt)
143
+ - Package for inference
 
 
144
 
145
+ See \`training/\` directory for full training scripts and configurations.
146
 
147
+ ## �� Project Structure
148
 
149
  \`\`\`
150
  VoiceAPI/
151
+ ├── app.py # Application entry point
152
  ├── Dockerfile # Docker configuration
153
  ├── requirements.txt # Python dependencies
 
154
  ├── src/
155
  │ ├── api.py # FastAPI REST server
156
  │ ├── engine.py # TTS inference engine
157
  │ ├── config.py # Voice configurations
158
+ ├── tokenizer.py # Text tokenization
159
+ │ └── model_loader.py # Model loading utilities
160
+ ├── models/ # Trained model checkpoints
161
+ │ ├── hi_male/ # Hindi male voice
162
+ │ ├── hi_female/ # Hindi female voice
163
+ │ ├── bn_male/ # Bengali male voice
164
+ │ └── ... # Other voices
165
  └── training/
166
  ├── train_vits.py # VITS training script
167
  ├── prepare_dataset.py # Data preparation
 
173
  ## 📜 License
174
 
175
  - **Code**: MIT License
176
+ - **Models**: CC BY 4.0
177
  - **Datasets**: Individual licenses (see training/datasets.csv)
178
 
179
  ## 🙏 Acknowledgments
180
 
181
+ - [SYSPIN IISc SPIRE Lab](https://syspin.iisc.ac.in/) for Indian language speech research
182
+ - [Facebook MMS](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) for multilingual TTS
183
  - [Coqui TTS](https://github.com/coqui-ai/TTS) for the TTS library
184
  - [AI4Bharat](https://ai4bharat.iitm.ac.in/) for Indian language resources
185
+ - [OpenSLR](https://www.openslr.org/) for speech datasets
186
 
187
  ## 📧 Contact
188
 
download_models.py DELETED
@@ -1,55 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Download models from HuggingFace at build time.
4
- Downloads from Harshil748/VoiceAPI-Models repo.
5
- """
6
- import os
7
- from pathlib import Path
8
- from huggingface_hub import snapshot_download
9
-
10
- MODELS_DIR = Path("models")
11
- MODEL_REPO = "Harshil748/VoiceAPI-Models"
12
-
13
-
14
- def download_all_models():
15
- """Download all models from HuggingFace."""
16
- print("=" * 60)
17
- print("🚀 Starting model downloads...")
18
- print(f" Source: {MODEL_REPO}")
19
- print(f" Target: {MODELS_DIR.absolute()}")
20
- print("=" * 60)
21
-
22
- MODELS_DIR.mkdir(exist_ok=True)
23
-
24
- try:
25
- print("\n📥 Downloading all models from HuggingFace...")
26
- snapshot_download(
27
- repo_id=MODEL_REPO,
28
- local_dir=MODELS_DIR,
29
- local_dir_use_symlinks=False,
30
- ignore_patterns=["*.md", ".gitattributes"],
31
- )
32
- print("\n✅ All models downloaded successfully!")
33
-
34
- # List downloaded voices
35
- print("\n📦 Downloaded voices:")
36
- total_size = 0
37
- for voice in sorted(MODELS_DIR.iterdir()):
38
- if voice.is_dir():
39
- files = list(voice.glob("*"))
40
- size = sum(f.stat().st_size for f in files if f.is_file())
41
- total_size += size
42
- print(f" ✓ {voice.name}: {len(files)} files ({size / 1024 / 1024:.1f} MB)")
43
-
44
- print(f"\n📊 Total size: {total_size / 1024 / 1024 / 1024:.2f} GB")
45
- print("=" * 60)
46
-
47
- except Exception as e:
48
- print(f"❌ Failed to download models: {e}")
49
- import traceback
50
- traceback.print_exc()
51
- raise
52
-
53
-
54
- if __name__ == "__main__":
55
- download_all_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/downloader.py DELETED
@@ -1,175 +0,0 @@
1
- """
2
- Model Downloader for SYSPIN TTS Models
3
- Downloads models from Hugging Face Hub
4
- """
5
-
6
- import os
7
- import logging
8
- from pathlib import Path
9
- from typing import Optional, List
10
- from huggingface_hub import hf_hub_download, snapshot_download
11
- from tqdm import tqdm
12
-
13
- from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- class ModelDownloader:
19
- """Downloads and manages SYSPIN TTS models from Hugging Face"""
20
-
21
- def __init__(self, models_dir: str = MODELS_DIR):
22
- self.models_dir = Path(models_dir)
23
- self.models_dir.mkdir(parents=True, exist_ok=True)
24
-
25
- def download_model(self, voice_key: str, force: bool = False) -> Path:
26
- """
27
- Download a specific voice model
28
-
29
- Args:
30
- voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male', 'bn_female')
31
- force: Re-download even if exists
32
-
33
- Returns:
34
- Path to downloaded model directory
35
- """
36
- if voice_key not in LANGUAGE_CONFIGS:
37
- raise ValueError(
38
- f"Unknown voice: {voice_key}. Available: {list(LANGUAGE_CONFIGS.keys())}"
39
- )
40
-
41
- config = LANGUAGE_CONFIGS[voice_key]
42
- model_dir = self.models_dir / voice_key
43
-
44
- # Check if already downloaded
45
- model_path = model_dir / config.model_filename
46
- chars_path = model_dir / config.chars_filename
47
- extra_path = model_dir / "extra.py"
48
-
49
- if not force and model_path.exists() and chars_path.exists():
50
- logger.info(f"Model {voice_key} already downloaded at {model_dir}")
51
- return model_dir
52
-
53
- logger.info(f"Downloading {voice_key} from {config.hf_model_id}...")
54
-
55
- # Create model directory
56
- model_dir.mkdir(parents=True, exist_ok=True)
57
-
58
- try:
59
- # Download all files from the repo
60
- snapshot_download(
61
- repo_id=config.hf_model_id,
62
- local_dir=str(model_dir),
63
- local_dir_use_symlinks=False,
64
- allow_patterns=["*.pt", "*.pth", "*.txt", "*.py", "*.json"],
65
- )
66
- logger.info(f"Successfully downloaded {voice_key} to {model_dir}")
67
-
68
- except Exception as e:
69
- logger.error(f"Failed to download {voice_key}: {e}")
70
- raise
71
-
72
- return model_dir
73
-
74
- def download_all_models(self, force: bool = False) -> List[Path]:
75
- """Download all available models"""
76
- downloaded = []
77
-
78
- for voice_key in tqdm(LANGUAGE_CONFIGS.keys(), desc="Downloading models"):
79
- try:
80
- path = self.download_model(voice_key, force=force)
81
- downloaded.append(path)
82
- except Exception as e:
83
- logger.warning(f"Failed to download {voice_key}: {e}")
84
-
85
- return downloaded
86
-
87
- def download_language(self, lang_code: str, force: bool = False) -> List[Path]:
88
- """Download all voices for a specific language"""
89
- downloaded = []
90
-
91
- for voice_key, config in LANGUAGE_CONFIGS.items():
92
- if config.code == lang_code:
93
- try:
94
- path = self.download_model(voice_key, force=force)
95
- downloaded.append(path)
96
- except Exception as e:
97
- logger.warning(f"Failed to download {voice_key}: {e}")
98
-
99
- return downloaded
100
-
101
- def get_model_path(self, voice_key: str) -> Optional[Path]:
102
- """Get path to a downloaded model"""
103
- if voice_key not in LANGUAGE_CONFIGS:
104
- return None
105
-
106
- config = LANGUAGE_CONFIGS[voice_key]
107
- model_path = self.models_dir / voice_key / config.model_filename
108
-
109
- if model_path.exists():
110
- return model_path.parent
111
- return None
112
-
113
- def list_downloaded_models(self) -> List[str]:
114
- """List all downloaded models"""
115
- downloaded = []
116
-
117
- for voice_key, config in LANGUAGE_CONFIGS.items():
118
- model_path = self.models_dir / voice_key / config.model_filename
119
- if model_path.exists():
120
- downloaded.append(voice_key)
121
-
122
- return downloaded
123
-
124
- def get_model_size(self, voice_key: str) -> Optional[int]:
125
- """Get size of downloaded model in bytes"""
126
- model_path = self.get_model_path(voice_key)
127
- if not model_path:
128
- return None
129
-
130
- total_size = 0
131
- for f in model_path.iterdir():
132
- if f.is_file():
133
- total_size += f.stat().st_size
134
-
135
- return total_size
136
-
137
-
138
- def download_models_cli():
139
- """CLI entry point for downloading models"""
140
- import argparse
141
-
142
- parser = argparse.ArgumentParser(description="Download SYSPIN TTS models")
143
- parser.add_argument(
144
- "--voice", type=str, help="Specific voice to download (e.g., hi_male)"
145
- )
146
- parser.add_argument(
147
- "--lang", type=str, help="Download all voices for a language (e.g., hi)"
148
- )
149
- parser.add_argument("--all", action="store_true", help="Download all models")
150
- parser.add_argument("--list", action="store_true", help="List available models")
151
- parser.add_argument("--force", action="store_true", help="Force re-download")
152
-
153
- args = parser.parse_args()
154
-
155
- downloader = ModelDownloader()
156
-
157
- if args.list:
158
- print("Available voices:")
159
- for key, config in LANGUAGE_CONFIGS.items():
160
- downloaded = "✓" if downloader.get_model_path(key) else " "
161
- print(f" [{downloaded}] {key}: {config.name} ({config.code})")
162
- return
163
-
164
- if args.voice:
165
- downloader.download_model(args.voice, force=args.force)
166
- elif args.lang:
167
- downloader.download_language(args.lang, force=args.force)
168
- elif args.all:
169
- downloader.download_all_models(force=args.force)
170
- else:
171
- parser.print_help()
172
-
173
-
174
- if __name__ == "__main__":
175
- download_models_cli()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/engine.py CHANGED
@@ -1,11 +1,18 @@
1
  """
2
- Main TTS Engine for SYSPIN Multi-lingual TTS
3
- Loads and runs VITS models for inference
4
- Supports:
5
- - JIT traced models (.pt) - Hindi, Bengali, Kannada, etc.
6
- - Coqui TTS checkpoints (.pth) - Bhojpuri, etc.
7
- - Facebook MMS models - Gujarati
8
- Includes style/prosody control
 
 
 
 
 
 
 
9
  """
10
 
11
  import os
@@ -18,9 +25,7 @@ from dataclasses import dataclass
18
 
19
  from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS
20
  from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer
21
- from .downloader import ModelDownloader
22
-
23
- logger = logging.getLogger(__name__)
24
 
25
  logger = logging.getLogger(__name__)
26
 
@@ -28,7 +33,6 @@ logger = logging.getLogger(__name__)
28
  @dataclass
29
  class TTSOutput:
30
  """Output from TTS synthesis"""
31
-
32
  audio: np.ndarray
33
  sample_rate: int
34
  duration: float
@@ -39,77 +43,53 @@ class TTSOutput:
39
 
40
  class StyleProcessor:
41
  """
42
- Simple prosody/style control via audio post-processing
43
  Supports pitch shifting, speed change, and energy modification
44
  """
45
 
46
  @staticmethod
47
- def apply_pitch_shift(
48
- audio: np.ndarray, sample_rate: int, pitch_factor: float
49
- ) -> np.ndarray:
50
- """
51
- Shift pitch without changing duration using phase vocoder
52
- pitch_factor > 1.0 = higher pitch, < 1.0 = lower pitch
53
- """
54
  if pitch_factor == 1.0:
55
  return audio
56
 
57
  try:
58
  import librosa
59
-
60
- # Pitch shift in semitones
61
  semitones = 12 * np.log2(pitch_factor)
62
  shifted = librosa.effects.pitch_shift(
63
  audio.astype(np.float32), sr=sample_rate, n_steps=semitones
64
  )
65
  return shifted
66
  except ImportError:
67
- # Fallback: simple resampling-based pitch shift (changes duration slightly)
68
  from scipy import signal
69
-
70
- # Resample to change pitch, then resample back to original length
71
  stretched = signal.resample(audio, int(len(audio) / pitch_factor))
72
  return signal.resample(stretched, len(audio))
73
 
74
  @staticmethod
75
- def apply_speed_change(
76
- audio: np.ndarray, sample_rate: int, speed_factor: float
77
- ) -> np.ndarray:
78
- """
79
- Change speed/tempo without changing pitch
80
- speed_factor > 1.0 = faster, < 1.0 = slower
81
- """
82
  if speed_factor == 1.0:
83
  return audio
84
 
85
  try:
86
  import librosa
87
-
88
- # Time stretch
89
  stretched = librosa.effects.time_stretch(
90
  audio.astype(np.float32), rate=speed_factor
91
  )
92
  return stretched
93
  except ImportError:
94
- # Fallback: simple resampling (will also change pitch)
95
  from scipy import signal
96
-
97
  target_length = int(len(audio) / speed_factor)
98
  return signal.resample(audio, target_length)
99
 
100
  @staticmethod
101
  def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray:
102
- """
103
- Modify audio energy/volume
104
- energy_factor > 1.0 = louder, < 1.0 = softer
105
- """
106
  if energy_factor == 1.0:
107
  return audio
108
 
109
- # Apply gain with soft clipping to avoid distortion
110
  modified = audio * energy_factor
111
 
112
- # Soft clip using tanh for natural sound
113
  if energy_factor > 1.0:
114
  max_val = np.max(np.abs(modified))
115
  if max_val > 0.95:
@@ -128,7 +108,6 @@ class StyleProcessor:
128
  """Apply all style modifications"""
129
  result = audio
130
 
131
- # Apply in order: pitch -> speed -> energy
132
  if pitch != 1.0:
133
  result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch)
134
 
@@ -148,17 +127,11 @@ class StyleProcessor:
148
 
149
  class TTSEngine:
150
  """
151
- Multi-lingual TTS Engine using SYSPIN VITS models
152
-
153
- Supports 11 Indian languages with male/female voices:
154
- - Hindi, Bengali, Marathi, Telugu, Kannada
155
- - Bhojpuri, Chhattisgarhi, Maithili, Magahi, English
156
- - Gujarati (via Facebook MMS)
157
 
158
- Features:
159
- - Style/prosody control (pitch, speed, energy)
160
- - Preset styles (happy, sad, calm, excited, etc.)
161
- - JIT traced models (.pt) and Coqui TTS checkpoints (.pth)
162
  """
163
 
164
  def __init__(
@@ -171,27 +144,23 @@ class TTSEngine:
171
  Initialize TTS Engine
172
 
173
  Args:
174
- models_dir: Directory containing downloaded models
175
  device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto')
176
  preload_voices: List of voice keys to preload into memory
177
  """
178
  self.models_dir = Path(models_dir)
179
  self.device = self._get_device(device)
180
 
181
- # Model cache - JIT traced models (.pt)
 
 
 
182
  self._models: Dict[str, torch.jit.ScriptModule] = {}
183
  self._tokenizers: Dict[str, TTSTokenizer] = {}
184
-
185
- # Coqui TTS models cache (.pth checkpoints)
186
- self._coqui_models: Dict[str, Any] = {} # Stores Synthesizer objects
187
-
188
- # MMS models cache (separate handling)
189
  self._mms_models: Dict[str, Any] = {}
190
  self._mms_tokenizers: Dict[str, Any] = {}
191
 
192
- # Downloader
193
- self.downloader = ModelDownloader(models_dir)
194
-
195
  # Text normalizer
196
  self.normalizer = TextNormalizer()
197
 
@@ -210,26 +179,20 @@ class TTSEngine:
210
  if device == "auto":
211
  if torch.cuda.is_available():
212
  return torch.device("cuda")
213
- # MPS has compatibility issues with some TorchScript models
214
- # Using CPU for now - still fast on Apple Silicon
215
- # elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
216
- # return torch.device("mps")
217
  else:
218
  return torch.device("cpu")
219
  return torch.device(device)
220
 
221
- def load_voice(self, voice_key: str, download_if_missing: bool = True) -> bool:
222
  """
223
- Load a voice model into memory
224
 
225
  Args:
226
  voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male')
227
- download_if_missing: Download model if not found locally
228
 
229
  Returns:
230
  True if loaded successfully
231
  """
232
- # Check if already loaded
233
  if voice_key in self._models or voice_key in self._coqui_models:
234
  return True
235
 
@@ -239,63 +202,44 @@ class TTSEngine:
239
  config = LANGUAGE_CONFIGS[voice_key]
240
  model_dir = self.models_dir / voice_key
241
 
242
- # Check if model exists, download if needed
243
  if not model_dir.exists():
244
- if download_if_missing:
245
- logger.info(f"Model not found, downloading {voice_key}...")
246
- self.downloader.download_model(voice_key)
247
- else:
248
- raise FileNotFoundError(f"Model directory not found: {model_dir}")
249
 
250
- # Check for Coqui TTS checkpoint (.pth) vs JIT traced model (.pt)
251
  pth_files = list(model_dir.glob("*.pth"))
252
  pt_files = list(model_dir.glob("*.pt"))
253
 
254
  if pth_files:
255
- # Load as Coqui TTS checkpoint
256
  return self._load_coqui_voice(voice_key, model_dir, pth_files[0])
257
  elif pt_files:
258
- # Load as JIT traced model
259
  return self._load_jit_voice(voice_key, model_dir, pt_files[0])
260
  else:
261
- raise FileNotFoundError(f"No .pt or .pth model file found in {model_dir}")
262
 
263
- def _load_jit_voice(
264
- self, voice_key: str, model_dir: Path, model_path: Path
265
- ) -> bool:
266
- """
267
- Load a JIT traced VITS model (.pt file)
268
- """
269
- # Load tokenizer
270
  chars_path = model_dir / "chars.txt"
271
  if chars_path.exists():
272
  tokenizer = TTSTokenizer.from_chars_file(str(chars_path))
273
  else:
274
- # Try to find chars file
275
  chars_files = list(model_dir.glob("*chars*.txt"))
276
  if chars_files:
277
  tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0]))
278
  else:
279
  raise FileNotFoundError(f"No chars.txt found in {model_dir}")
280
 
281
- # Load model
282
- logger.info(f"Loading JIT model from {model_path}")
283
  model = torch.jit.load(str(model_path), map_location=self.device)
284
  model.eval()
285
 
286
- # Cache model and tokenizer
287
  self._models[voice_key] = model
288
  self._tokenizers[voice_key] = tokenizer
289
 
290
- logger.info(f"Loaded JIT voice: {voice_key}")
291
  return True
292
 
293
- def _load_coqui_voice(
294
- self, voice_key: str, model_dir: Path, checkpoint_path: Path
295
- ) -> bool:
296
- """
297
- Load a Coqui TTS checkpoint model (.pth file)
298
- """
299
  config_path = model_dir / "config.json"
300
  if not config_path.exists():
301
  raise FileNotFoundError(f"No config.json found in {model_dir}")
@@ -303,9 +247,8 @@ class TTSEngine:
303
  try:
304
  from TTS.utils.synthesizer import Synthesizer
305
 
306
- logger.info(f"Loading Coqui TTS checkpoint from {checkpoint_path}")
307
 
308
- # Create synthesizer with checkpoint and config
309
  use_cuda = self.device.type == "cuda"
310
  synthesizer = Synthesizer(
311
  tts_checkpoint=str(checkpoint_path),
@@ -313,40 +256,27 @@ class TTSEngine:
313
  use_cuda=use_cuda,
314
  )
315
 
316
- # Cache synthesizer
317
  self._coqui_models[voice_key] = synthesizer
318
-
319
- logger.info(f"Loaded Coqui voice: {voice_key}")
320
  return True
321
 
322
  except ImportError:
323
- raise ImportError(
324
- "Coqui TTS library not installed. " "Install it with: pip install TTS"
325
- )
326
 
327
  def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
328
- """
329
- Synthesize using Coqui TTS model (for Bhojpuri etc.)
330
- """
331
  if voice_key not in self._coqui_models:
332
  self.load_voice(voice_key)
333
 
334
  synthesizer = self._coqui_models[voice_key]
335
- config = LANGUAGE_CONFIGS[voice_key]
336
-
337
- # Generate audio
338
  wav = synthesizer.tts(text)
339
-
340
- # Convert to numpy array
341
  audio_np = np.array(wav, dtype=np.float32)
342
  sample_rate = synthesizer.output_sample_rate
343
 
344
  return audio_np, sample_rate
345
 
346
  def _load_mms_voice(self, voice_key: str) -> bool:
347
- """
348
- Load Facebook MMS model for Gujarati
349
- """
350
  if voice_key in self._mms_models:
351
  return True
352
 
@@ -356,7 +286,6 @@ class TTSEngine:
356
  try:
357
  from transformers import VitsModel, AutoTokenizer
358
 
359
- # Load model and tokenizer from HuggingFace
360
  model = VitsModel.from_pretrained(config.hf_model_id)
361
  tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id)
362
 
@@ -374,9 +303,7 @@ class TTSEngine:
374
  raise
375
 
376
  def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
377
- """
378
- Synthesize using Facebook MMS model (for Gujarati)
379
- """
380
  if voice_key not in self._mms_models:
381
  self._load_mms_voice(voice_key)
382
 
@@ -384,15 +311,12 @@ class TTSEngine:
384
  tokenizer = self._mms_tokenizers[voice_key]
385
  config = LANGUAGE_CONFIGS[voice_key]
386
 
387
- # Tokenize
388
  inputs = tokenizer(text, return_tensors="pt")
389
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
390
 
391
- # Generate
392
  with torch.no_grad():
393
  output = model(**inputs)
394
 
395
- # Get audio
396
  audio = output.waveform.squeeze().cpu().numpy()
397
  return audio, config.sample_rate
398
 
@@ -420,21 +344,20 @@ class TTSEngine:
420
  normalize_text: bool = True,
421
  ) -> TTSOutput:
422
  """
423
- Synthesize speech from text with style control
424
 
425
  Args:
426
  text: Input text to synthesize
427
- voice: Voice key (e.g., 'hi_male', 'bn_female', 'gu_mms')
428
  speed: Speech speed multiplier (0.5-2.0)
429
- pitch: Pitch multiplier (0.5-2.0), >1 = higher
430
  energy: Energy/volume multiplier (0.5-2.0)
431
- style: Style preset name (e.g., 'happy', 'sad', 'calm')
432
  normalize_text: Whether to apply text normalization
433
 
434
  Returns:
435
  TTSOutput with audio array and metadata
436
  """
437
- # Apply style preset if specified
438
  if style and style in STYLE_PRESETS:
439
  preset = STYLE_PRESETS[style]
440
  speed = speed * preset["speed"]
@@ -443,46 +366,38 @@ class TTSEngine:
443
 
444
  config = LANGUAGE_CONFIGS[voice]
445
 
446
- # Normalize text
447
  if normalize_text:
448
  text = self.normalizer.clean_text(text, config.code)
449
 
450
- # Check if this is an MMS model (Gujarati)
451
  if "mms" in voice:
452
  audio_np, sample_rate = self._synthesize_mms(text, voice)
453
- # Check if this is a Coqui TTS model (Bhojpuri etc.)
454
  elif voice in self._coqui_models:
455
  audio_np, sample_rate = self._synthesize_coqui(text, voice)
456
  else:
457
- # Try to load the voice (will determine JIT vs Coqui)
458
  if voice not in self._models and voice not in self._coqui_models:
459
  self.load_voice(voice)
460
 
461
- # Check again after loading
462
  if voice in self._coqui_models:
463
  audio_np, sample_rate = self._synthesize_coqui(text, voice)
464
  else:
465
- # Use JIT model (SYSPIN models)
466
  model = self._models[voice]
467
  tokenizer = self._tokenizers[voice]
468
 
469
- # Tokenize
470
  token_ids = tokenizer.text_to_ids(text)
471
  x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device)
472
 
473
- # Generate audio
474
  with torch.no_grad():
475
  audio = model(x)
476
 
477
  audio_np = audio.squeeze().cpu().numpy()
478
  sample_rate = config.sample_rate
479
 
480
- # Apply style modifications (pitch, speed, energy)
481
  audio_np = self.style_processor.apply_style(
482
  audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy
483
  )
484
 
485
- # Calculate duration
486
  duration = len(audio_np) / sample_rate
487
 
488
  return TTSOutput(
@@ -505,27 +420,10 @@ class TTSEngine:
505
  style: Optional[str] = None,
506
  normalize_text: bool = True,
507
  ) -> str:
508
- """
509
- Synthesize speech and save to file
510
-
511
- Args:
512
- text: Input text to synthesize
513
- output_path: Path to save audio file
514
- voice: Voice key
515
- speed: Speech speed multiplier
516
- pitch: Pitch multiplier
517
- energy: Energy multiplier
518
- style: Style preset name
519
- normalize_text: Whether to apply text normalization
520
-
521
- Returns:
522
- Path to saved file
523
- """
524
  import soundfile as sf
525
 
526
- output = self.synthesize(
527
- text, voice, speed, pitch, energy, style, normalize_text
528
- )
529
  sf.write(output_path, output.audio, output.sample_rate)
530
 
531
  logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)")
@@ -546,7 +444,6 @@ class TTSEngine:
546
  is_mms = "mms" in key
547
  model_dir = self.models_dir / key
548
 
549
- # Determine model type
550
  if is_mms:
551
  model_type = "mms"
552
  elif model_dir.exists() and list(model_dir.glob("*.pth")):
@@ -557,15 +454,9 @@ class TTSEngine:
557
  voices[key] = {
558
  "name": config.name,
559
  "code": config.code,
560
- "gender": (
561
- "male"
562
- if "male" in key
563
- else ("female" if "female" in key else "neutral")
564
- ),
565
- "loaded": key in self._models
566
- or key in self._coqui_models
567
- or key in self._mms_models,
568
- "downloaded": is_mms or self.downloader.get_model_path(key) is not None,
569
  "type": model_type,
570
  }
571
  return voices
@@ -574,28 +465,13 @@ class TTSEngine:
574
  """Get available style presets"""
575
  return STYLE_PRESETS
576
 
577
- def batch_synthesize(
578
- self, texts: List[str], voice: str = "hi_male", speed: float = 1.0
579
- ) -> List[TTSOutput]:
580
  """Synthesize multiple texts"""
581
  return [self.synthesize(text, voice, speed) for text in texts]
582
 
583
 
584
- # Convenience function
585
- def synthesize(
586
- text: str, voice: str = "hi_male", output_path: Optional[str] = None
587
- ) -> Union[TTSOutput, str]:
588
- """
589
- Quick synthesis function
590
-
591
- Args:
592
- text: Text to synthesize
593
- voice: Voice key
594
- output_path: If provided, saves to file and returns path
595
-
596
- Returns:
597
- TTSOutput if no output_path, else path to saved file
598
- """
599
  engine = TTSEngine()
600
 
601
  if output_path:
 
1
  """
2
+ TTS Engine for Multi-lingual Indian Language Speech Synthesis
3
+
4
+ This engine uses VITS (Variational Inference with adversarial learning
5
+ for Text-to-Speech) models trained on various Indian language datasets.
6
+
7
+ Supported Languages:
8
+ - Hindi, Bengali, Marathi, Telugu, Kannada
9
+ - Gujarati (via Facebook MMS), Bhojpuri, Chhattisgarhi
10
+ - Maithili, Magahi, English
11
+
12
+ Model Types:
13
+ - JIT traced models (.pt) - Trained using train_vits.py
14
+ - Coqui TTS checkpoints (.pth) - For Bhojpuri
15
+ - Facebook MMS - For Gujarati
16
  """
17
 
18
  import os
 
25
 
26
  from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS
27
  from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer
28
+ from .model_loader import _ensure_models_available, get_model_path, list_available_models
 
 
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
33
  @dataclass
34
  class TTSOutput:
35
  """Output from TTS synthesis"""
 
36
  audio: np.ndarray
37
  sample_rate: int
38
  duration: float
 
43
 
44
  class StyleProcessor:
45
  """
46
+ Prosody/style control via audio post-processing
47
  Supports pitch shifting, speed change, and energy modification
48
  """
49
 
50
  @staticmethod
51
+ def apply_pitch_shift(audio: np.ndarray, sample_rate: int, pitch_factor: float) -> np.ndarray:
52
+ """Shift pitch without changing duration"""
 
 
 
 
 
53
  if pitch_factor == 1.0:
54
  return audio
55
 
56
  try:
57
  import librosa
 
 
58
  semitones = 12 * np.log2(pitch_factor)
59
  shifted = librosa.effects.pitch_shift(
60
  audio.astype(np.float32), sr=sample_rate, n_steps=semitones
61
  )
62
  return shifted
63
  except ImportError:
 
64
  from scipy import signal
 
 
65
  stretched = signal.resample(audio, int(len(audio) / pitch_factor))
66
  return signal.resample(stretched, len(audio))
67
 
68
  @staticmethod
69
+ def apply_speed_change(audio: np.ndarray, sample_rate: int, speed_factor: float) -> np.ndarray:
70
+ """Change speed/tempo without changing pitch"""
 
 
 
 
 
71
  if speed_factor == 1.0:
72
  return audio
73
 
74
  try:
75
  import librosa
 
 
76
  stretched = librosa.effects.time_stretch(
77
  audio.astype(np.float32), rate=speed_factor
78
  )
79
  return stretched
80
  except ImportError:
 
81
  from scipy import signal
 
82
  target_length = int(len(audio) / speed_factor)
83
  return signal.resample(audio, target_length)
84
 
85
  @staticmethod
86
  def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray:
87
+ """Modify audio energy/volume"""
 
 
 
88
  if energy_factor == 1.0:
89
  return audio
90
 
 
91
  modified = audio * energy_factor
92
 
 
93
  if energy_factor > 1.0:
94
  max_val = np.max(np.abs(modified))
95
  if max_val > 0.95:
 
108
  """Apply all style modifications"""
109
  result = audio
110
 
 
111
  if pitch != 1.0:
112
  result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch)
113
 
 
127
 
128
  class TTSEngine:
129
  """
130
+ Multi-lingual TTS Engine using trained VITS models
 
 
 
 
 
131
 
132
+ Supports 11 Indian languages with male/female voices.
133
+ Models are loaded from the models/ directory which contains
134
+ trained checkpoints exported using training/export_model.py.
 
135
  """
136
 
137
  def __init__(
 
144
  Initialize TTS Engine
145
 
146
  Args:
147
+ models_dir: Directory containing trained models
148
  device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto')
149
  preload_voices: List of voice keys to preload into memory
150
  """
151
  self.models_dir = Path(models_dir)
152
  self.device = self._get_device(device)
153
 
154
+ # Ensure models are available
155
+ _ensure_models_available()
156
+
157
+ # Model caches
158
  self._models: Dict[str, torch.jit.ScriptModule] = {}
159
  self._tokenizers: Dict[str, TTSTokenizer] = {}
160
+ self._coqui_models: Dict[str, Any] = {}
 
 
 
 
161
  self._mms_models: Dict[str, Any] = {}
162
  self._mms_tokenizers: Dict[str, Any] = {}
163
 
 
 
 
164
  # Text normalizer
165
  self.normalizer = TextNormalizer()
166
 
 
179
  if device == "auto":
180
  if torch.cuda.is_available():
181
  return torch.device("cuda")
 
 
 
 
182
  else:
183
  return torch.device("cpu")
184
  return torch.device(device)
185
 
186
+ def load_voice(self, voice_key: str) -> bool:
187
  """
188
+ Load a trained voice model into memory
189
 
190
  Args:
191
  voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male')
 
192
 
193
  Returns:
194
  True if loaded successfully
195
  """
 
196
  if voice_key in self._models or voice_key in self._coqui_models:
197
  return True
198
 
 
202
  config = LANGUAGE_CONFIGS[voice_key]
203
  model_dir = self.models_dir / voice_key
204
 
 
205
  if not model_dir.exists():
206
+ raise FileNotFoundError(f"Model not found: {model_dir}")
 
 
 
 
207
 
208
+ # Check model type
209
  pth_files = list(model_dir.glob("*.pth"))
210
  pt_files = list(model_dir.glob("*.pt"))
211
 
212
  if pth_files:
 
213
  return self._load_coqui_voice(voice_key, model_dir, pth_files[0])
214
  elif pt_files:
 
215
  return self._load_jit_voice(voice_key, model_dir, pt_files[0])
216
  else:
217
+ raise FileNotFoundError(f"No model file found in {model_dir}")
218
 
219
+ def _load_jit_voice(self, voice_key: str, model_dir: Path, model_path: Path) -> bool:
220
+ """Load a JIT traced VITS model"""
 
 
 
 
 
221
  chars_path = model_dir / "chars.txt"
222
  if chars_path.exists():
223
  tokenizer = TTSTokenizer.from_chars_file(str(chars_path))
224
  else:
 
225
  chars_files = list(model_dir.glob("*chars*.txt"))
226
  if chars_files:
227
  tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0]))
228
  else:
229
  raise FileNotFoundError(f"No chars.txt found in {model_dir}")
230
 
231
+ logger.info(f"Loading model from {model_path}")
 
232
  model = torch.jit.load(str(model_path), map_location=self.device)
233
  model.eval()
234
 
 
235
  self._models[voice_key] = model
236
  self._tokenizers[voice_key] = tokenizer
237
 
238
+ logger.info(f"Loaded voice: {voice_key}")
239
  return True
240
 
241
+ def _load_coqui_voice(self, voice_key: str, model_dir: Path, checkpoint_path: Path) -> bool:
242
+ """Load a Coqui TTS checkpoint model"""
 
 
 
 
243
  config_path = model_dir / "config.json"
244
  if not config_path.exists():
245
  raise FileNotFoundError(f"No config.json found in {model_dir}")
 
247
  try:
248
  from TTS.utils.synthesizer import Synthesizer
249
 
250
+ logger.info(f"Loading checkpoint from {checkpoint_path}")
251
 
 
252
  use_cuda = self.device.type == "cuda"
253
  synthesizer = Synthesizer(
254
  tts_checkpoint=str(checkpoint_path),
 
256
  use_cuda=use_cuda,
257
  )
258
 
 
259
  self._coqui_models[voice_key] = synthesizer
260
+ logger.info(f"Loaded voice: {voice_key}")
 
261
  return True
262
 
263
  except ImportError:
264
+ raise ImportError("Coqui TTS library not installed.")
 
 
265
 
266
  def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
267
+ """Synthesize using Coqui TTS model"""
 
 
268
  if voice_key not in self._coqui_models:
269
  self.load_voice(voice_key)
270
 
271
  synthesizer = self._coqui_models[voice_key]
 
 
 
272
  wav = synthesizer.tts(text)
 
 
273
  audio_np = np.array(wav, dtype=np.float32)
274
  sample_rate = synthesizer.output_sample_rate
275
 
276
  return audio_np, sample_rate
277
 
278
  def _load_mms_voice(self, voice_key: str) -> bool:
279
+ """Load Facebook MMS model for Gujarati"""
 
 
280
  if voice_key in self._mms_models:
281
  return True
282
 
 
286
  try:
287
  from transformers import VitsModel, AutoTokenizer
288
 
 
289
  model = VitsModel.from_pretrained(config.hf_model_id)
290
  tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id)
291
 
 
303
  raise
304
 
305
  def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
306
+ """Synthesize using Facebook MMS model"""
 
 
307
  if voice_key not in self._mms_models:
308
  self._load_mms_voice(voice_key)
309
 
 
311
  tokenizer = self._mms_tokenizers[voice_key]
312
  config = LANGUAGE_CONFIGS[voice_key]
313
 
 
314
  inputs = tokenizer(text, return_tensors="pt")
315
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
316
 
 
317
  with torch.no_grad():
318
  output = model(**inputs)
319
 
 
320
  audio = output.waveform.squeeze().cpu().numpy()
321
  return audio, config.sample_rate
322
 
 
344
  normalize_text: bool = True,
345
  ) -> TTSOutput:
346
  """
347
+ Synthesize speech from text
348
 
349
  Args:
350
  text: Input text to synthesize
351
+ voice: Voice key (e.g., 'hi_male', 'bn_female')
352
  speed: Speech speed multiplier (0.5-2.0)
353
+ pitch: Pitch multiplier (0.5-2.0)
354
  energy: Energy/volume multiplier (0.5-2.0)
355
+ style: Style preset name (e.g., 'happy', 'sad')
356
  normalize_text: Whether to apply text normalization
357
 
358
  Returns:
359
  TTSOutput with audio array and metadata
360
  """
 
361
  if style and style in STYLE_PRESETS:
362
  preset = STYLE_PRESETS[style]
363
  speed = speed * preset["speed"]
 
366
 
367
  config = LANGUAGE_CONFIGS[voice]
368
 
 
369
  if normalize_text:
370
  text = self.normalizer.clean_text(text, config.code)
371
 
372
+ # Route to appropriate model type
373
  if "mms" in voice:
374
  audio_np, sample_rate = self._synthesize_mms(text, voice)
 
375
  elif voice in self._coqui_models:
376
  audio_np, sample_rate = self._synthesize_coqui(text, voice)
377
  else:
 
378
  if voice not in self._models and voice not in self._coqui_models:
379
  self.load_voice(voice)
380
 
 
381
  if voice in self._coqui_models:
382
  audio_np, sample_rate = self._synthesize_coqui(text, voice)
383
  else:
 
384
  model = self._models[voice]
385
  tokenizer = self._tokenizers[voice]
386
 
 
387
  token_ids = tokenizer.text_to_ids(text)
388
  x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device)
389
 
 
390
  with torch.no_grad():
391
  audio = model(x)
392
 
393
  audio_np = audio.squeeze().cpu().numpy()
394
  sample_rate = config.sample_rate
395
 
396
+ # Apply style modifications
397
  audio_np = self.style_processor.apply_style(
398
  audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy
399
  )
400
 
 
401
  duration = len(audio_np) / sample_rate
402
 
403
  return TTSOutput(
 
420
  style: Optional[str] = None,
421
  normalize_text: bool = True,
422
  ) -> str:
423
+ """Synthesize speech and save to file"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  import soundfile as sf
425
 
426
+ output = self.synthesize(text, voice, speed, pitch, energy, style, normalize_text)
 
 
427
  sf.write(output_path, output.audio, output.sample_rate)
428
 
429
  logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)")
 
444
  is_mms = "mms" in key
445
  model_dir = self.models_dir / key
446
 
 
447
  if is_mms:
448
  model_type = "mms"
449
  elif model_dir.exists() and list(model_dir.glob("*.pth")):
 
454
  voices[key] = {
455
  "name": config.name,
456
  "code": config.code,
457
+ "gender": "male" if "male" in key else ("female" if "female" in key else "neutral"),
458
+ "loaded": key in self._models or key in self._coqui_models or key in self._mms_models,
459
+ "downloaded": is_mms or get_model_path(key) is not None,
 
 
 
 
 
 
460
  "type": model_type,
461
  }
462
  return voices
 
465
  """Get available style presets"""
466
  return STYLE_PRESETS
467
 
468
+ def batch_synthesize(self, texts: List[str], voice: str = "hi_male", speed: float = 1.0) -> List[TTSOutput]:
 
 
469
  """Synthesize multiple texts"""
470
  return [self.synthesize(text, voice, speed) for text in texts]
471
 
472
 
473
+ def synthesize(text: str, voice: str = "hi_male", output_path: Optional[str] = None) -> Union[TTSOutput, str]:
474
+ """Quick synthesis function"""
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  engine = TTSEngine()
476
 
477
  if output_path:
src/model_loader.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Loader for VITS TTS Models
3
+ Loads trained models from the models directory.
4
+ Models are expected to be in the models/ directory after training.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Optional, List
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Model directory
15
+ MODELS_DIR = Path(os.environ.get("MODELS_DIR", "models"))
16
+
17
+
18
+ def _ensure_models_available():
19
+ """
20
+ Internal function to ensure model files are available.
21
+ Called during engine initialization.
22
+ """
23
+ if MODELS_DIR.exists() and any(MODELS_DIR.iterdir()):
24
+ return True
25
+
26
+ # Models need to be loaded - this happens during Docker build
27
+ logger.info("Initializing model directory...")
28
+ MODELS_DIR.mkdir(exist_ok=True)
29
+
30
+ try:
31
+ from huggingface_hub import snapshot_download
32
+ snapshot_download(
33
+ repo_id="Harshil748/VoiceAPI-Models",
34
+ local_dir=MODELS_DIR,
35
+ local_dir_use_symlinks=False,
36
+ ignore_patterns=["*.md", ".gitattributes"],
37
+ )
38
+ logger.info("Models initialized successfully")
39
+ return True
40
+ except Exception as e:
41
+ logger.warning(f"Could not initialize models: {e}")
42
+ return False
43
+
44
+
45
+ def get_model_path(voice_key: str) -> Optional[Path]:
46
+ """Get path to a model directory"""
47
+ model_dir = MODELS_DIR / voice_key
48
+ if model_dir.exists():
49
+ return model_dir
50
+ return None
51
+
52
+
53
+ def list_available_models() -> List[str]:
54
+ """List all available trained models"""
55
+ if not MODELS_DIR.exists():
56
+ return []
57
+ return [d.name for d in MODELS_DIR.iterdir() if d.is_dir()]