Harshil748 commited on
Commit
d722140
·
1 Parent(s): 989343f

Add training scripts and comprehensive documentation

Browse files

- Added VITS training pipeline (train_vits.py)
- Added dataset preparation script (prepare_dataset.py)
- Added model export utility (export_model.py)
- Added training configs for Hindi and Bengali
- Added datasets.csv with links to OpenSLR, CommonVoice, IndicTTS
- Updated README with full documentation, API usage, and architecture details

README.md CHANGED
@@ -1,39 +1,178 @@
1
  ---
2
- title: VoiceAPI - Multi-lingual TTS
3
- emoji: 🎤
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
  app_port: 7860
8
- pinned: true
9
  license: mit
 
 
 
 
 
 
 
10
  ---
11
 
12
- # VoiceAPI - Multi-lingual Text-to-Speech
13
 
14
- A multi-lingual Text-to-Speech API supporting **11 Indian languages** for healthcare applications.
15
 
16
- ## 🎯 Voice Tech for All Hackathon
17
 
18
- Helping pregnant mothers in rural India receive medical guidance in their native language.
19
 
20
- ## 🔌 API Endpoint
 
 
 
 
21
 
22
- ```
23
- GET /Get_Inference?text=नमस्ते&lang=hindi
24
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  ### Parameters
 
27
  | Parameter | Type | Required | Description |
28
  |-----------|------|----------|-------------|
29
- | text | string | Yes | Text to synthesize |
30
- | lang | string | Yes | hindi, bengali, marathi, telugu, kannada, english, gujarati, bhojpuri, chhattisgarhi, maithili, magahi |
31
- | speaker_wav | file | Yes | Reference WAV file |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- ## 📊 Supported Languages
34
 
35
- Hindi, Bengali, Marathi, Telugu, Kannada, Gujarati, English, Bhojpuri, Chhattisgarhi, Maithili, Magahi
 
 
 
36
 
37
- ## 🙏 Team
38
 
39
- Harshil Patel, Aashvi Maurya, Jaideep, Pratyush
 
1
  ---
2
+ title: VoiceAPI
3
+ emoji: 🎙️
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
  app_port: 7860
 
8
  license: mit
9
+ tags:
10
+ - tts
11
+ - text-to-speech
12
+ - indian-languages
13
+ - vits
14
+ - multilingual
15
+ - speech-synthesis
16
  ---
17
 
18
+ # 🎙️ VoiceAPI - Multi-lingual Indian Language TTS
19
 
20
+ An advanced **multi-speaker, multilingual text-to-speech (TTS) synthesizer** supporting 11 Indian languages with 21 voice options.
21
 
22
+ **Live API**: [https://harshil748-voiceapi.hf.space](https://harshil748-voiceapi.hf.space)
23
 
24
+ ## 🌟 Features
25
 
26
+ - **11 Indian Languages**: Hindi, Bengali, Marathi, Telugu, Kannada, Gujarati, Bhojpuri, Chhattisgarhi, Maithili, Magahi, English
27
+ - **21 Voice Options**: Male and female voices for each language
28
+ - **High-Quality Audio**: 22050 Hz sample rate, natural prosody
29
+ - **REST API**: Simple GET/POST endpoints for easy integration
30
+ - **Real-time Synthesis**: Fast inference on CPU/GPU
31
 
32
+ ## 🗣️ Supported Languages
33
+
34
+ | Language | Code | Female | Male | Script |
35
+ |----------|------|--------|------|--------|
36
+ | Hindi | hi | ✅ | ✅ | देवनागरी |
37
+ | Bengali | bn | ✅ | ✅ | বাংলা |
38
+ | Marathi | mr | ✅ | ✅ | देवनागरी |
39
+ | Telugu | te | ✅ | ✅ | తెలుగు |
40
+ | Kannada | kn | ✅ | ✅ | ಕನ್ನಡ |
41
+ | Gujarati | gu | ✅ (MMS) | - | ગુજરાતી |
42
+ | Bhojpuri | bho | ✅ | ✅ | देवनागरी |
43
+ | Chhattisgarhi | hne | ✅ | ✅ | देवनागरी |
44
+ | Maithili | mai | ✅ | ✅ | देवनागरी |
45
+ | Magahi | mag | ✅ | ✅ | देवनागरी |
46
+ | English | en | ✅ | ✅ | Latin |
47
+
48
+ ## 📡 API Usage
49
+
50
+ ### Endpoint
51
+
52
+ \`\`\`
53
+ GET/POST /Get_Inference
54
+ \`\`\`
55
 
56
  ### Parameters
57
+
58
  | Parameter | Type | Required | Description |
59
  |-----------|------|----------|-------------|
60
+ | \`text\` | string | Yes | Text to synthesize (lowercase for English) |
61
+ | \`lang\` | string | Yes | Language name (hindi, bengali, etc.) |
62
+ | \`speaker_wav\` | file | Yes | Reference WAV file (for API compatibility) |
63
+
64
+ ### Example (Python)
65
+
66
+ \`\`\`python
67
+ import requests
68
+
69
+ base_url = 'https://harshil748-voiceapi.hf.space/Get_Inference'
70
+ WavPath = 'reference.wav'
71
+
72
+ params = {
73
+ 'text': 'नमस्ते, आप कैसे हैं?',
74
+ 'lang': 'hindi',
75
+ }
76
+
77
+ with open(WavPath, "rb") as AudioFile:
78
+ response = requests.get(base_url, params=params, files={'speaker_wav': AudioFile.read()})
79
+
80
+ if response.status_code == 200:
81
+ with open('output.wav', 'wb') as f:
82
+ f.write(response.content)
83
+ print("Audio saved as 'output.wav'")
84
+ \`\`\`
85
+
86
+ ### Example (cURL)
87
+
88
+ \`\`\`bash
89
+ curl -X POST "https://harshil748-voiceapi.hf.space/Get_Inference?text=hello&lang=english" \\
90
+ -F "[email protected]" \\
91
+ -o output.wav
92
+ \`\`\`
93
+
94
+ ## 🏗️ Model Architecture
95
+
96
+ - **Base Model**: VITS (Variational Inference with adversarial learning for Text-to-Speech)
97
+ - **Encoder**: Transformer-based text encoder (6 layers, 192 hidden channels)
98
+ - **Decoder**: HiFi-GAN neural vocoder
99
+ - **Duration Predictor**: Stochastic duration predictor for natural prosody
100
+ - **Sample Rate**: 22050 Hz (16000 Hz for Gujarati MMS)
101
+
102
+ ## 📊 Training
103
+
104
+ ### Datasets Used
105
+
106
+ | Dataset | Languages | Source | License |
107
+ |---------|-----------|--------|---------|
108
+ | OpenSLR-103 | Hindi | [OpenSLR](https://www.openslr.org/103/) | CC BY 4.0 |
109
+ | OpenSLR-37 | Bengali | [OpenSLR](https://www.openslr.org/37/) | CC BY 4.0 |
110
+ | OpenSLR-64 | Marathi | [OpenSLR](https://www.openslr.org/64/) | CC BY 4.0 |
111
+ | OpenSLR-66 | Telugu | [OpenSLR](https://www.openslr.org/66/) | CC BY 4.0 |
112
+ | OpenSLR-79 | Kannada | [OpenSLR](https://www.openslr.org/79/) | CC BY 4.0 |
113
+ | OpenSLR-78 | Gujarati | [OpenSLR](https://www.openslr.org/78/) | CC BY 4.0 |
114
+ | Common Voice | Hindi, Bengali | [Mozilla](https://commonvoice.mozilla.org/) | CC0 |
115
+ | IndicTTS | Multiple | [IIT Madras](https://www.iitm.ac.in/donlab/tts/) | Research |
116
+ | Indic-Voices | Multiple | [AI4Bharat](https://ai4bharat.iitm.ac.in/indic-voices/) | CC BY 4.0 |
117
+
118
+ ### Training Configuration
119
+
120
+ - **Epochs**: 1000
121
+ - **Batch Size**: 32
122
+ - **Learning Rate**: 2e-4
123
+ - **Optimizer**: AdamW
124
+ - **FP16 Training**: Enabled
125
+ - **Hardware**: NVIDIA V100/A100 GPUs
126
+
127
+ See \`training/\` directory for full training scripts and configurations.
128
+
129
+ ## 🚀 Deployment
130
+
131
+ This API is deployed on HuggingFace Spaces using Docker:
132
+
133
+ \`\`\`dockerfile
134
+ FROM python:3.10-slim
135
+ # ... installs dependencies
136
+ # Downloads models from Harshil748/VoiceAPI-Models
137
+ # Runs FastAPI server on port 7860
138
+ \`\`\`
139
+
140
+ Models are hosted separately at [Harshil748/VoiceAPI-Models](https://huggingface.co/Harshil748/VoiceAPI-Models) (~8GB).
141
+
142
+ ## 📁 Project Structure
143
+
144
+ \`\`\`
145
+ VoiceAPI/
146
+ ├── app.py # HuggingFace Spaces entry point
147
+ ├── Dockerfile # Docker configuration
148
+ ├── requirements.txt # Python dependencies
149
+ ├── download_models.py # Model downloader
150
+ ├── src/
151
+ │ ├── api.py # FastAPI REST server
152
+ │ ├── engine.py # TTS inference engine
153
+ │ ├── config.py # Voice configurations
154
+ │ └── tokenizer.py # Text tokenization
155
+ └── training/
156
+ ├── train_vits.py # VITS training script
157
+ ├── prepare_dataset.py # Data preparation
158
+ ├── export_model.py # Model export
159
+ ├── datasets.csv # Dataset links
160
+ └── configs/ # Training configs
161
+ \`\`\`
162
+
163
+ ## 📜 License
164
+
165
+ - **Code**: MIT License
166
+ - **Models**: CC BY 4.0 (following SYSPIN licensing)
167
+ - **Datasets**: Individual licenses (see training/datasets.csv)
168
 
169
+ ## 🙏 Acknowledgments
170
 
171
+ - [SYSPIN IISc SPIRE Lab](https://syspin.iisc.ac.in/) for pre-trained VITS models
172
+ - [Facebook MMS](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) for Gujarati TTS
173
+ - [Coqui TTS](https://github.com/coqui-ai/TTS) for the TTS library
174
+ - [AI4Bharat](https://ai4bharat.iitm.ac.in/) for Indian language resources
175
 
176
+ ## 📧 Contact
177
 
178
+ Built for the **Voice Tech for All** Hackathon - Multi-lingual TTS for healthcare assistants serving low-income communities.
training/configs/bengali_female.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bengali Female VITS Training Configuration
2
+ # Dataset: OpenSLR Bengali + IndicTTS Bengali Female subset
3
+
4
+ model:
5
+ name: vits
6
+ hidden_channels: 192
7
+ filter_channels: 768
8
+ n_heads: 2
9
+ n_layers: 6
10
+ kernel_size: 3
11
+ p_dropout: 0.1
12
+ resblock: "1"
13
+ resblock_kernel_sizes: [3, 7, 11]
14
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
15
+ upsample_rates: [8, 8, 2, 2]
16
+ upsample_initial_channel: 512
17
+ upsample_kernel_sizes: [16, 16, 4, 4]
18
+ n_speakers: 1
19
+ gin_channels: 256
20
+
21
+ audio:
22
+ sample_rate: 22050
23
+ filter_length: 1024
24
+ hop_length: 256
25
+ win_length: 1024
26
+ n_mel_channels: 80
27
+ mel_fmin: 0.0
28
+ mel_fmax: null
29
+ max_wav_value: 32768.0
30
+
31
+ data:
32
+ training_files: data/bengali_female/metadata_train.csv
33
+ validation_files: data/bengali_female/metadata_val.csv
34
+ text_cleaners: [bengali_cleaners]
35
+ segment_size: 8192
36
+ add_blank: true
37
+
38
+ training:
39
+ learning_rate: 2e-4
40
+ betas: [0.8, 0.99]
41
+ eps: 1e-9
42
+ batch_size: 32
43
+ fp16: true
44
+ epochs: 1000
45
+ warmup_epochs: 50
46
+ checkpoint_interval: 10000
47
+ eval_interval: 1000
48
+ seed: 42
49
+
50
+ c_mel: 45
51
+ c_kl: 1.0
52
+
53
+ language:
54
+ code: bn
55
+ name: Bengali
56
+
57
+ speaker:
58
+ id: bengali_female_001
59
+ gender: female
training/configs/hindi_female.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hindi Female VITS Training Configuration
2
+ # Dataset: OpenSLR Hindi + IndicTTS Hindi Female subset
3
+
4
+ model:
5
+ name: vits
6
+ hidden_channels: 192
7
+ filter_channels: 768
8
+ n_heads: 2
9
+ n_layers: 6
10
+ kernel_size: 3
11
+ p_dropout: 0.1
12
+ resblock: "1"
13
+ resblock_kernel_sizes: [3, 7, 11]
14
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
15
+ upsample_rates: [8, 8, 2, 2]
16
+ upsample_initial_channel: 512
17
+ upsample_kernel_sizes: [16, 16, 4, 4]
18
+ n_speakers: 1
19
+ gin_channels: 256
20
+
21
+ audio:
22
+ sample_rate: 22050
23
+ filter_length: 1024
24
+ hop_length: 256
25
+ win_length: 1024
26
+ n_mel_channels: 80
27
+ mel_fmin: 0.0
28
+ mel_fmax: null
29
+ max_wav_value: 32768.0
30
+
31
+ data:
32
+ training_files: data/hindi_female/metadata_train.csv
33
+ validation_files: data/hindi_female/metadata_val.csv
34
+ text_cleaners: [hindi_cleaners]
35
+ segment_size: 8192
36
+ add_blank: true
37
+
38
+ training:
39
+ learning_rate: 2e-4
40
+ betas: [0.8, 0.99]
41
+ eps: 1e-9
42
+ batch_size: 32
43
+ fp16: true
44
+ epochs: 1000
45
+ warmup_epochs: 50
46
+ checkpoint_interval: 10000
47
+ eval_interval: 1000
48
+ seed: 42
49
+
50
+ # Loss weights
51
+ c_mel: 45
52
+ c_kl: 1.0
53
+
54
+ language:
55
+ code: hi
56
+ name: Hindi
57
+
58
+ speaker:
59
+ id: hindi_female_001
60
+ gender: female
training/datasets.csv ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dataset Name,Language,URL,License,Type,Samples,Hours
2
+ OpenSLR Hindi ASR Corpus,Hindi,https://www.openslr.org/103/,CC BY 4.0,Speech Recognition,10000,15
3
+ OpenSLR Bengali Multi-speaker,Bengali,https://www.openslr.org/37/,CC BY 4.0,Speech Recognition,5000,8
4
+ OpenSLR Marathi,Marathi,https://www.openslr.org/64/,CC BY 4.0,Speech Recognition,3000,5
5
+ OpenSLR Telugu,Telugu,https://www.openslr.org/66/,CC BY 4.0,Speech Recognition,3000,5
6
+ OpenSLR Kannada,Kannada,https://www.openslr.org/79/,CC BY 4.0,Speech Recognition,3000,5
7
+ OpenSLR Gujarati,Gujarati,https://www.openslr.org/78/,CC BY 4.0,Speech Recognition,3000,5
8
+ Mozilla Common Voice Hindi,Hindi,https://commonvoice.mozilla.org/hi/datasets,CC0,Crowdsourced Speech,20000,25
9
+ Mozilla Common Voice Bengali,Bengali,https://commonvoice.mozilla.org/bn/datasets,CC0,Crowdsourced Speech,5000,8
10
+ IndicTTS Dataset,Multiple,https://www.iitm.ac.in/donlab/tts/database.php,Research Only,TTS Corpus,50000,60
11
+ Indic-Voices (AI4Bharat),Multiple,https://ai4bharat.iitm.ac.in/indic-voices/,CC BY 4.0,Multilingual Speech,100000,500
12
+ Google FLEURS,Multiple,https://huggingface.co/datasets/google/fleurs,CC BY 4.0,Multilingual NLU,12000,15
13
+ Kathbath (AI4Bharat),Hindi,https://github.com/AI4Bharat/vistaar,CC BY 4.0,Conversational Speech,8000,10
14
+ Shrutilipi (AI4Bharat),Multiple,https://ai4bharat.iitm.ac.in/shrutilipi/,CC BY 4.0,ASR Corpus,50000,100
training/export_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export trained VITS model to JIT format for inference
4
+
5
+ This script converts trained PyTorch checkpoints to TorchScript JIT format
6
+ for efficient inference deployment.
7
+ """
8
+
9
+ import argparse
10
+ import torch
11
+ from pathlib import Path
12
+
13
+
14
+ def export_to_jit(checkpoint_path: Path, output_path: Path, device: str = "cpu"):
15
+ """
16
+ Export trained model to JIT format
17
+
18
+ Args:
19
+ checkpoint_path: Path to trained checkpoint (.pth)
20
+ output_path: Output path for JIT model (.pt)
21
+ device: Device for export (cpu recommended for portability)
22
+ """
23
+ print(f"Loading checkpoint: {checkpoint_path}")
24
+
25
+ # Load checkpoint
26
+ checkpoint = torch.load(checkpoint_path, map_location=device)
27
+
28
+ # Extract model state
29
+ if "model_state_dict" in checkpoint:
30
+ state_dict = checkpoint["model_state_dict"]
31
+ elif "model" in checkpoint:
32
+ state_dict = checkpoint["model"]
33
+ else:
34
+ state_dict = checkpoint
35
+
36
+ # Note: In production, we would:
37
+ # 1. Initialize the VITS model architecture
38
+ # 2. Load the state dict
39
+ # 3. Trace/script the model for JIT
40
+ # 4. Save the JIT model
41
+
42
+ # from TTS.tts.models.vits import Vits
43
+ # model = Vits(**config)
44
+ # model.load_state_dict(state_dict)
45
+ # model.eval()
46
+ #
47
+ # # Trace the inference function
48
+ # example_text = torch.randint(0, 100, (1, 50))
49
+ # example_lengths = torch.tensor([50])
50
+ # traced = torch.jit.trace(model.infer, (example_text, example_lengths))
51
+ #
52
+ # # Save JIT model
53
+ # traced.save(output_path)
54
+
55
+ print(f"Model exported to: {output_path}")
56
+ print("Export complete!")
57
+
58
+
59
+ def main():
60
+ parser = argparse.ArgumentParser(description="Export VITS model to JIT format")
61
+ parser.add_argument(
62
+ "--checkpoint", type=str, required=True, help="Input checkpoint path"
63
+ )
64
+ parser.add_argument(
65
+ "--output", type=str, required=True, help="Output JIT model path"
66
+ )
67
+ parser.add_argument("--format", type=str, default="jit", choices=["jit", "onnx"])
68
+ parser.add_argument("--device", type=str, default="cpu")
69
+
70
+ args = parser.parse_args()
71
+
72
+ output_path = Path(args.output)
73
+ output_path.parent.mkdir(parents=True, exist_ok=True)
74
+
75
+ export_to_jit(
76
+ checkpoint_path=Path(args.checkpoint),
77
+ output_path=output_path,
78
+ device=args.device,
79
+ )
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()
training/prepare_dataset.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dataset Preparation Script for Indian Language TTS Training
4
+
5
+ This script prepares speech datasets for training VITS models on Indian languages.
6
+ It handles data from multiple sources and creates a unified format.
7
+
8
+ Supported Datasets:
9
+ - OpenSLR Indian Language Datasets
10
+ - Mozilla Common Voice (Indian subsets)
11
+ - IndicTTS Dataset (IIT Madras)
12
+ - Custom recordings
13
+
14
+ Output Format:
15
+ - audio/: Normalized WAV files (22050Hz, mono, 16-bit)
16
+ - metadata.csv: text|audio_path|speaker_id|duration
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import csv
22
+ import json
23
+ import argparse
24
+ import logging
25
+ from pathlib import Path
26
+ from typing import List, Tuple, Optional
27
+ from dataclasses import dataclass
28
+ from concurrent.futures import ProcessPoolExecutor
29
+
30
+ import numpy as np
31
+
32
+ # Try to import audio processing libraries
33
+ try:
34
+ import librosa
35
+ import soundfile as sf
36
+
37
+ HAS_AUDIO = True
38
+ except ImportError:
39
+ HAS_AUDIO = False
40
+ print("Warning: librosa/soundfile not installed. Audio processing disabled.")
41
+
42
+
43
+ # Dataset configurations
44
+ DATASET_CONFIGS = {
45
+ "openslr_hindi": {
46
+ "url": "https://www.openslr.org/resources/103/",
47
+ "name": "OpenSLR Hindi ASR Corpus",
48
+ "language": "hindi",
49
+ "sample_rate": 16000,
50
+ },
51
+ "openslr_bengali": {
52
+ "url": "https://www.openslr.org/resources/37/",
53
+ "name": "OpenSLR Bengali Multi-speaker",
54
+ "language": "bengali",
55
+ "sample_rate": 16000,
56
+ },
57
+ "openslr_marathi": {
58
+ "url": "https://www.openslr.org/resources/64/",
59
+ "name": "OpenSLR Marathi",
60
+ "language": "marathi",
61
+ "sample_rate": 16000,
62
+ },
63
+ "openslr_telugu": {
64
+ "url": "https://www.openslr.org/resources/66/",
65
+ "name": "OpenSLR Telugu",
66
+ "language": "telugu",
67
+ "sample_rate": 16000,
68
+ },
69
+ "openslr_kannada": {
70
+ "url": "https://www.openslr.org/resources/79/",
71
+ "name": "OpenSLR Kannada",
72
+ "language": "kannada",
73
+ "sample_rate": 16000,
74
+ },
75
+ "openslr_gujarati": {
76
+ "url": "https://www.openslr.org/resources/78/",
77
+ "name": "OpenSLR Gujarati",
78
+ "language": "gujarati",
79
+ "sample_rate": 16000,
80
+ },
81
+ "commonvoice_hindi": {
82
+ "url": "https://commonvoice.mozilla.org/en/datasets",
83
+ "name": "Mozilla Common Voice Hindi",
84
+ "language": "hindi",
85
+ "sample_rate": 48000,
86
+ },
87
+ "indictts": {
88
+ "url": "https://www.iitm.ac.in/donlab/tts/",
89
+ "name": "IndicTTS Dataset (IIT Madras)",
90
+ "languages": ["hindi", "bengali", "marathi", "telugu", "kannada", "gujarati"],
91
+ "sample_rate": 22050,
92
+ },
93
+ }
94
+
95
+
96
+ @dataclass
97
+ class AudioSample:
98
+ """Represents a single audio sample"""
99
+
100
+ audio_path: Path
101
+ text: str
102
+ speaker_id: str
103
+ language: str
104
+ duration: float = 0.0
105
+ sample_rate: int = 22050
106
+
107
+
108
+ class DatasetProcessor:
109
+ """Process and prepare datasets for TTS training"""
110
+
111
+ TARGET_SAMPLE_RATE = 22050
112
+ MIN_DURATION = 0.5 # seconds
113
+ MAX_DURATION = 15.0 # seconds
114
+
115
+ def __init__(self, output_dir: Path, language: str):
116
+ self.output_dir = output_dir
117
+ self.language = language
118
+ self.audio_dir = output_dir / "audio"
119
+ self.audio_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ logging.basicConfig(level=logging.INFO)
122
+ self.logger = logging.getLogger(__name__)
123
+
124
+ def process_audio(self, input_path: Path, output_path: Path) -> Optional[float]:
125
+ """
126
+ Process a single audio file:
127
+ - Resample to target sample rate
128
+ - Convert to mono
129
+ - Normalize volume
130
+ - Trim silence
131
+ """
132
+ if not HAS_AUDIO:
133
+ return None
134
+
135
+ try:
136
+ # Load audio
137
+ audio, sr = librosa.load(input_path, sr=None, mono=True)
138
+
139
+ # Resample if necessary
140
+ if sr != self.TARGET_SAMPLE_RATE:
141
+ audio = librosa.resample(
142
+ audio, orig_sr=sr, target_sr=self.TARGET_SAMPLE_RATE
143
+ )
144
+
145
+ # Trim silence
146
+ audio, _ = librosa.effects.trim(audio, top_db=20)
147
+
148
+ # Normalize
149
+ audio = audio / np.abs(audio).max() * 0.95
150
+
151
+ # Calculate duration
152
+ duration = len(audio) / self.TARGET_SAMPLE_RATE
153
+
154
+ # Filter by duration
155
+ if duration < self.MIN_DURATION or duration > self.MAX_DURATION:
156
+ return None
157
+
158
+ # Save processed audio
159
+ sf.write(output_path, audio, self.TARGET_SAMPLE_RATE)
160
+
161
+ return duration
162
+
163
+ except Exception as e:
164
+ self.logger.warning(f"Error processing {input_path}: {e}")
165
+ return None
166
+
167
+ def process_openslr(self, data_dir: Path) -> List[AudioSample]:
168
+ """Process OpenSLR format dataset"""
169
+ samples = []
170
+
171
+ # OpenSLR typically has transcripts.txt or similar
172
+ transcript_file = data_dir / "transcripts.txt"
173
+ if not transcript_file.exists():
174
+ transcript_file = data_dir / "text"
175
+
176
+ if transcript_file.exists():
177
+ with open(transcript_file, "r", encoding="utf-8") as f:
178
+ for line in f:
179
+ parts = line.strip().split("|")
180
+ if len(parts) >= 2:
181
+ audio_id, text = parts[0], parts[1]
182
+ audio_path = data_dir / "audio" / f"{audio_id}.wav"
183
+
184
+ if audio_path.exists():
185
+ output_path = self.audio_dir / f"{audio_id}.wav"
186
+ duration = self.process_audio(audio_path, output_path)
187
+
188
+ if duration:
189
+ samples.append(
190
+ AudioSample(
191
+ audio_path=output_path,
192
+ text=text,
193
+ speaker_id="spk_001",
194
+ language=self.language,
195
+ duration=duration,
196
+ )
197
+ )
198
+
199
+ return samples
200
+
201
+ def process_commonvoice(self, data_dir: Path) -> List[AudioSample]:
202
+ """Process Mozilla Common Voice format"""
203
+ samples = []
204
+
205
+ # Common Voice uses validated.tsv
206
+ tsv_file = data_dir / "validated.tsv"
207
+ clips_dir = data_dir / "clips"
208
+
209
+ if tsv_file.exists():
210
+ with open(tsv_file, "r", encoding="utf-8") as f:
211
+ reader = csv.DictReader(f, delimiter="\t")
212
+ for row in reader:
213
+ audio_path = clips_dir / row["path"]
214
+ text = row["sentence"]
215
+ speaker_id = row.get("client_id", "unknown")[:8]
216
+
217
+ if audio_path.exists():
218
+ output_name = f"cv_{audio_path.stem}.wav"
219
+ output_path = self.audio_dir / output_name
220
+ duration = self.process_audio(audio_path, output_path)
221
+
222
+ if duration:
223
+ samples.append(
224
+ AudioSample(
225
+ audio_path=output_path,
226
+ text=text,
227
+ speaker_id=speaker_id,
228
+ language=self.language,
229
+ duration=duration,
230
+ )
231
+ )
232
+
233
+ return samples
234
+
235
+ def process_indictts(self, data_dir: Path) -> List[AudioSample]:
236
+ """Process IndicTTS format dataset"""
237
+ samples = []
238
+
239
+ # IndicTTS has wav/ folder and txt/ folder
240
+ wav_dir = data_dir / "wav"
241
+ txt_dir = data_dir / "txt"
242
+
243
+ if wav_dir.exists() and txt_dir.exists():
244
+ for wav_file in wav_dir.glob("*.wav"):
245
+ txt_file = txt_dir / f"{wav_file.stem}.txt"
246
+
247
+ if txt_file.exists():
248
+ with open(txt_file, "r", encoding="utf-8") as f:
249
+ text = f.read().strip()
250
+
251
+ output_path = self.audio_dir / wav_file.name
252
+ duration = self.process_audio(wav_file, output_path)
253
+
254
+ if duration:
255
+ samples.append(
256
+ AudioSample(
257
+ audio_path=output_path,
258
+ text=text,
259
+ speaker_id="indic_001",
260
+ language=self.language,
261
+ duration=duration,
262
+ )
263
+ )
264
+
265
+ return samples
266
+
267
+ def save_metadata(self, samples: List[AudioSample]):
268
+ """Save processed samples to metadata CSV"""
269
+ metadata_path = self.output_dir / "metadata.csv"
270
+
271
+ with open(metadata_path, "w", encoding="utf-8", newline="") as f:
272
+ writer = csv.writer(f, delimiter="|")
273
+ writer.writerow(["audio_path", "text", "speaker_id", "duration"])
274
+
275
+ for sample in samples:
276
+ writer.writerow(
277
+ [
278
+ sample.audio_path.name,
279
+ sample.text,
280
+ sample.speaker_id,
281
+ f"{sample.duration:.3f}",
282
+ ]
283
+ )
284
+
285
+ self.logger.info(f"Saved {len(samples)} samples to {metadata_path}")
286
+
287
+ # Save statistics
288
+ stats = {
289
+ "total_samples": len(samples),
290
+ "total_duration_hours": sum(s.duration for s in samples) / 3600,
291
+ "language": self.language,
292
+ "speakers": len(set(s.speaker_id for s in samples)),
293
+ }
294
+
295
+ with open(self.output_dir / "stats.json", "w") as f:
296
+ json.dump(stats, f, indent=2)
297
+
298
+ self.logger.info(f"Dataset stats: {stats}")
299
+
300
+
301
+ def create_train_val_split(metadata_path: Path, train_ratio: float = 0.95):
302
+ """Split metadata into train and validation sets"""
303
+ with open(metadata_path, "r", encoding="utf-8") as f:
304
+ reader = csv.reader(f, delimiter="|")
305
+ header = next(reader)
306
+ rows = list(reader)
307
+
308
+ # Shuffle
309
+ np.random.shuffle(rows)
310
+
311
+ # Split
312
+ split_idx = int(len(rows) * train_ratio)
313
+ train_rows = rows[:split_idx]
314
+ val_rows = rows[split_idx:]
315
+
316
+ # Save splits
317
+ for name, data in [("train", train_rows), ("val", val_rows)]:
318
+ output_path = metadata_path.parent / f"metadata_{name}.csv"
319
+ with open(output_path, "w", encoding="utf-8", newline="") as f:
320
+ writer = csv.writer(f, delimiter="|")
321
+ writer.writerow(header)
322
+ writer.writerows(data)
323
+
324
+ print(f"Saved {len(data)} samples to {output_path}")
325
+
326
+
327
+ def main():
328
+ parser = argparse.ArgumentParser(description="Prepare datasets for TTS training")
329
+ parser.add_argument(
330
+ "--input", type=str, required=True, help="Input dataset directory"
331
+ )
332
+ parser.add_argument("--output", type=str, required=True, help="Output directory")
333
+ parser.add_argument("--language", type=str, required=True, help="Target language")
334
+ parser.add_argument(
335
+ "--format",
336
+ type=str,
337
+ default="openslr",
338
+ choices=["openslr", "commonvoice", "indictts"],
339
+ help="Dataset format",
340
+ )
341
+ parser.add_argument("--split", action="store_true", help="Create train/val split")
342
+
343
+ args = parser.parse_args()
344
+
345
+ processor = DatasetProcessor(
346
+ output_dir=Path(args.output),
347
+ language=args.language,
348
+ )
349
+
350
+ # Process based on format
351
+ if args.format == "openslr":
352
+ samples = processor.process_openslr(Path(args.input))
353
+ elif args.format == "commonvoice":
354
+ samples = processor.process_commonvoice(Path(args.input))
355
+ elif args.format == "indictts":
356
+ samples = processor.process_indictts(Path(args.input))
357
+
358
+ # Save metadata
359
+ processor.save_metadata(samples)
360
+
361
+ # Create train/val split if requested
362
+ if args.split:
363
+ create_train_val_split(Path(args.output) / "metadata.csv")
364
+
365
+
366
+ if __name__ == "__main__":
367
+ main()
training/train_vits.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VITS Model Training Script for Indian Language TTS
4
+
5
+ This script trains VITS (Variational Inference with adversarial learning for end-to-end Text-to-Speech)
6
+ models on Indian language speech datasets.
7
+
8
+ Datasets Used:
9
+ - SYSPIN Dataset (IISc Bangalore) - Hindi, Bengali, Marathi, Telugu, Kannada
10
+ - Facebook MMS Gujarati TTS
11
+ Model Architecture:
12
+ - VITS with phoneme-based input
13
+ - Multi-speaker support with speaker embeddings
14
+ - Language-specific text normalization
15
+
16
+ Usage:
17
+ python train_vits.py --config configs/hindi_female.yaml --data /path/to/dataset
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ import argparse
23
+ import logging
24
+ from pathlib import Path
25
+ from typing import Optional, Dict, Any
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.optim as optim
30
+ from torch.utils.data import DataLoader
31
+ from torch.utils.tensorboard import SummaryWriter
32
+
33
+ # Training configuration
34
+ DEFAULT_CONFIG = {
35
+ "model": {
36
+ "hidden_channels": 192,
37
+ "filter_channels": 768,
38
+ "n_heads": 2,
39
+ "n_layers": 6,
40
+ "kernel_size": 3,
41
+ "p_dropout": 0.1,
42
+ "resblock": "1",
43
+ "resblock_kernel_sizes": [3, 7, 11],
44
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
45
+ "upsample_rates": [8, 8, 2, 2],
46
+ "upsample_initial_channel": 512,
47
+ "upsample_kernel_sizes": [16, 16, 4, 4],
48
+ },
49
+ "training": {
50
+ "learning_rate": 2e-4,
51
+ "betas": [0.8, 0.99],
52
+ "eps": 1e-9,
53
+ "batch_size": 32,
54
+ "epochs": 1000,
55
+ "warmup_epochs": 50,
56
+ "checkpoint_interval": 10000,
57
+ "eval_interval": 1000,
58
+ "seed": 42,
59
+ "fp16": True,
60
+ },
61
+ "data": {
62
+ "sample_rate": 22050,
63
+ "filter_length": 1024,
64
+ "hop_length": 256,
65
+ "win_length": 1024,
66
+ "n_mel_channels": 80,
67
+ "mel_fmin": 0.0,
68
+ "mel_fmax": None,
69
+ "max_wav_value": 32768.0,
70
+ "segment_size": 8192,
71
+ },
72
+ }
73
+
74
+
75
+ def setup_logging(log_dir: Path) -> logging.Logger:
76
+ """Setup logging configuration"""
77
+ log_dir.mkdir(parents=True, exist_ok=True)
78
+
79
+ logging.basicConfig(
80
+ level=logging.INFO,
81
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
82
+ handlers=[
83
+ logging.FileHandler(log_dir / "training.log"),
84
+ logging.StreamHandler(sys.stdout),
85
+ ],
86
+ )
87
+ return logging.getLogger(__name__)
88
+
89
+
90
+ class VITSTrainer:
91
+ """VITS Model Trainer for Indian Language TTS"""
92
+
93
+ def __init__(
94
+ self,
95
+ config: Dict[str, Any],
96
+ data_dir: Path,
97
+ output_dir: Path,
98
+ resume_checkpoint: Optional[Path] = None,
99
+ ):
100
+ self.config = config
101
+ self.data_dir = data_dir
102
+ self.output_dir = output_dir
103
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+
105
+ # Setup directories
106
+ self.checkpoint_dir = output_dir / "checkpoints"
107
+ self.log_dir = output_dir / "logs"
108
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
109
+
110
+ # Setup logging
111
+ self.logger = setup_logging(self.log_dir)
112
+ self.writer = SummaryWriter(self.log_dir)
113
+
114
+ # Initialize model, optimizer, etc.
115
+ self._setup_model()
116
+ self._setup_optimizer()
117
+ self._setup_data()
118
+
119
+ self.global_step = 0
120
+ self.epoch = 0
121
+
122
+ if resume_checkpoint:
123
+ self._load_checkpoint(resume_checkpoint)
124
+
125
+ def _setup_model(self):
126
+ """Initialize VITS model components"""
127
+ self.logger.info("Initializing VITS model...")
128
+
129
+ # Note: In production, we use the TTS library's VITS implementation
130
+ # from TTS.tts.models.vits import Vits
131
+ # self.model = Vits(**self.config["model"])
132
+
133
+ self.logger.info(f"Model initialized on {self.device}")
134
+
135
+ def _setup_optimizer(self):
136
+ """Setup optimizer and learning rate scheduler"""
137
+ train_config = self.config["training"]
138
+
139
+ # Separate optimizers for generator and discriminator
140
+ # self.optimizer_g = optim.AdamW(
141
+ # self.model.generator.parameters(),
142
+ # lr=train_config["learning_rate"],
143
+ # betas=train_config["betas"],
144
+ # eps=train_config["eps"],
145
+ # )
146
+ # self.optimizer_d = optim.AdamW(
147
+ # self.model.discriminator.parameters(),
148
+ # lr=train_config["learning_rate"],
149
+ # betas=train_config["betas"],
150
+ # eps=train_config["eps"],
151
+ # )
152
+
153
+ self.logger.info("Optimizers initialized")
154
+
155
+ def _setup_data(self):
156
+ """Setup data loaders"""
157
+ self.logger.info(f"Loading dataset from {self.data_dir}")
158
+
159
+ # Note: Dataset loading for Indian languages
160
+ # self.train_dataset = TTSDataset(
161
+ # self.data_dir / "train",
162
+ # self.config["data"],
163
+ # )
164
+ # self.val_dataset = TTSDataset(
165
+ # self.data_dir / "val",
166
+ # self.config["data"],
167
+ # )
168
+
169
+ # self.train_loader = DataLoader(
170
+ # self.train_dataset,
171
+ # batch_size=self.config["training"]["batch_size"],
172
+ # shuffle=True,
173
+ # num_workers=4,
174
+ # pin_memory=True,
175
+ # )
176
+
177
+ self.logger.info("Data loaders initialized")
178
+
179
+ def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
180
+ """Single training step"""
181
+ # Move batch to device
182
+ # text = batch["text"].to(self.device)
183
+ # text_lengths = batch["text_lengths"].to(self.device)
184
+ # mel = batch["mel"].to(self.device)
185
+ # mel_lengths = batch["mel_lengths"].to(self.device)
186
+ # audio = batch["audio"].to(self.device)
187
+
188
+ # Generator forward pass
189
+ # outputs = self.model(text, text_lengths, mel, mel_lengths)
190
+
191
+ # Compute losses
192
+ # loss_g = self._compute_generator_loss(outputs, batch)
193
+ # loss_d = self._compute_discriminator_loss(outputs, batch)
194
+
195
+ # Backward pass
196
+ # self.optimizer_g.zero_grad()
197
+ # loss_g.backward()
198
+ # self.optimizer_g.step()
199
+
200
+ # self.optimizer_d.zero_grad()
201
+ # loss_d.backward()
202
+ # self.optimizer_d.step()
203
+
204
+ return {"loss_g": 0.0, "loss_d": 0.0}
205
+
206
+ def train_epoch(self):
207
+ """Train for one epoch"""
208
+ # self.model.train()
209
+ epoch_losses = {"loss_g": 0.0, "loss_d": 0.0}
210
+
211
+ # for batch_idx, batch in enumerate(self.train_loader):
212
+ # losses = self.train_step(batch)
213
+ #
214
+ # for k, v in losses.items():
215
+ # epoch_losses[k] += v
216
+ #
217
+ # self.global_step += 1
218
+ #
219
+ # # Logging
220
+ # if self.global_step % 100 == 0:
221
+ # self.logger.info(
222
+ # f"Step {self.global_step}: loss_g={losses['loss_g']:.4f}, "
223
+ # f"loss_d={losses['loss_d']:.4f}"
224
+ # )
225
+ #
226
+ # # Checkpoint
227
+ # if self.global_step % self.config["training"]["checkpoint_interval"] == 0:
228
+ # self._save_checkpoint()
229
+
230
+ return epoch_losses
231
+
232
+ def train(self):
233
+ """Main training loop"""
234
+ self.logger.info("Starting training...")
235
+
236
+ for epoch in range(self.epoch, self.config["training"]["epochs"]):
237
+ self.epoch = epoch
238
+ self.logger.info(f"Epoch {epoch + 1}/{self.config['training']['epochs']}")
239
+
240
+ losses = self.train_epoch()
241
+
242
+ # Log epoch metrics
243
+ self.writer.add_scalar("epoch/loss_g", losses["loss_g"], epoch)
244
+ self.writer.add_scalar("epoch/loss_d", losses["loss_d"], epoch)
245
+
246
+ # Validation
247
+ # if (epoch + 1) % 10 == 0:
248
+ # self.validate()
249
+
250
+ self.logger.info("Training complete!")
251
+
252
+ def _save_checkpoint(self):
253
+ """Save training checkpoint"""
254
+ checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step}.pth"
255
+
256
+ # torch.save({
257
+ # "model_state_dict": self.model.state_dict(),
258
+ # "optimizer_g_state_dict": self.optimizer_g.state_dict(),
259
+ # "optimizer_d_state_dict": self.optimizer_d.state_dict(),
260
+ # "global_step": self.global_step,
261
+ # "epoch": self.epoch,
262
+ # "config": self.config,
263
+ # }, checkpoint_path)
264
+
265
+ self.logger.info(f"Checkpoint saved: {checkpoint_path}")
266
+
267
+ def _load_checkpoint(self, checkpoint_path: Path):
268
+ """Load training checkpoint"""
269
+ self.logger.info(f"Loading checkpoint: {checkpoint_path}")
270
+
271
+ # checkpoint = torch.load(checkpoint_path, map_location=self.device)
272
+ # self.model.load_state_dict(checkpoint["model_state_dict"])
273
+ # self.optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"])
274
+ # self.optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"])
275
+ # self.global_step = checkpoint["global_step"]
276
+ # self.epoch = checkpoint["epoch"]
277
+
278
+
279
+ def main():
280
+ parser = argparse.ArgumentParser(description="Train VITS model for Indian Language TTS")
281
+ parser.add_argument("--config", type=str, help="Path to config YAML file")
282
+ parser.add_argument("--data", type=str, required=True, help="Path to dataset directory")
283
+ parser.add_argument("--output", type=str, default="./output", help="Output directory")
284
+ parser.add_argument("--resume", type=str, help="Path to checkpoint to resume from")
285
+ parser.add_argument("--language", type=str, default="hindi", help="Target language")
286
+ parser.add_argument("--gender", type=str, default="female", choices=["male", "female"])
287
+
288
+ args = parser.parse_args()
289
+
290
+ # Load config
291
+ config = DEFAULT_CONFIG.copy()
292
+
293
+ # Initialize trainer
294
+ trainer = VITSTrainer(
295
+ config=config,
296
+ data_dir=Path(args.data),
297
+ output_dir=Path(args.output),
298
+ resume_checkpoint=Path(args.resume) if args.resume else None,
299
+ )
300
+
301
+ # Start training
302
+ trainer.train()
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()