limkang commited on
Commit
82a65d8
ยท
verified ยท
1 Parent(s): 085b201

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +108 -0
main.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import os
3
+ import io
4
+ import torch
5
+ import tempfile
6
+ from fastapi import FastAPI, HTTPException, Form, UploadFile, File
7
+ from fastapi.responses import StreamingResponse
8
+
9
+ # OpenVoice V2 ๊ด€๋ จ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
10
+ from openvoice import se_extractor
11
+ from openvoice.api import ToneColorConverter
12
+
13
+ # MeloTTS ๊ด€๋ จ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
14
+ from melo.api import TTS
15
+
16
+ # -------------------------------------------------------------------
17
+ # 1. FastAPI ์•ฑ ์ดˆ๊ธฐํ™” ๋ฐ ๋ชจ๋ธ ๋กœ๋“œ
18
+ # -------------------------------------------------------------------
19
+ app = FastAPI()
20
+
21
+ print("๐Ÿš€ Loading models...")
22
+ try:
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # OpenVoice ๋ชจ๋ธ ๋กœ๋“œ
26
+ # Tone Color Extractor: ์Œ์ƒ‰ ํŠน์ง•์„ ์ถ”์ถœํ•˜๋Š” ๋ชจ๋ธ
27
+ # Tone Color Converter: ์Œ์ƒ‰์„ ๋ณ€ํ™˜ํ•˜๋Š” ๋ชจ๋ธ
28
+ print("Loading OpenVoice V2 models...")
29
+ tone_color_converter = ToneColorConverter('checkpoints/converter', device=device)
30
+ print("โœ… OpenVoice V2 loaded.")
31
+
32
+ # Melotts ๋ชจ๋ธ ๋กœ๋“œ (ํ•œ๊ตญ์–ด ์ง€์›)
33
+ print("Loading Melotts model...")
34
+ melotts_model = TTS(language='KR', device=device)
35
+ speaker_ids = melotts_model.hps.data.spk2id
36
+ print("โœ… Melotts loaded.")
37
+
38
+ except Exception as ex:
39
+ print(f"โŒ Failed to load models. Error: {ex}")
40
+ tone_color_converter = None
41
+ melotts_model = None
42
+
43
+ # -------------------------------------------------------------------
44
+ # 2. API ์—”๋“œํฌ์ธํŠธ ์ƒ์„ฑ
45
+ # -------------------------------------------------------------------
46
+ @app.post("/generate-cloned-speech/")
47
+ async def generate_cloned_speech(
48
+ text: str = Form(...),
49
+ reference_audio: UploadFile = File(...)
50
+ ):
51
+ if not tone_color_converter or not melotts_model:
52
+ raise HTTPException(status_code=500, detail="Models are not loaded.")
53
+
54
+ # ์ž„์‹œ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๊ด€๋ฆฌํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€์ˆ˜
55
+ reference_path = None
56
+ source_path = None
57
+ save_path = None
58
+
59
+ try:
60
+ # 1. ์ฐธ์กฐ ์˜ค๋””์˜ค(๋ชฉ์†Œ๋ฆฌ ์ฃผ์ธ)๋ฅผ ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
61
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_file:
62
+ content = await reference_audio.read()
63
+ temp_ref_file.write(content)
64
+ reference_path = temp_ref_file.name
65
+
66
+ # 2. ์ฐธ์กฐ ์˜ค๋””์˜ค์—์„œ ์Œ์ƒ‰ ํŠน์ง•(Tone Color) ์ถ”์ถœ
67
+ target_se, audio_name = se_extractor.get_se(reference_path, tone_color_converter, target_dir='_outputs/form_clone', vad=True)
68
+
69
+ # 3. Melotts๋ฅผ ์‚ฌ์šฉํ•ด ํ…์ŠคํŠธ๋กœ ๊ธฐ๋ณธ(Source) ์Œ์„ฑ ์ƒ์„ฑ
70
+ # ์†๋„ ์กฐ์ ˆ ๊ฐ€๋Šฅ (speed)
71
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_src_file:
72
+ source_path = temp_src_file.name
73
+
74
+ melotts_model.tts_to_file(text, speaker_ids['KR'], source_path, speed=1.0)
75
+
76
+ # 4. OpenVoice๋ฅผ ์‚ฌ์šฉํ•ด ๊ธฐ๋ณธ ์Œ์„ฑ์— ์ถ”์ถœํ•œ ์Œ์ƒ‰์„ ์ž…ํž˜ (๋ณ€ํ™˜)
77
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_save_file:
78
+ save_path = temp_save_file.name
79
+
80
+ # ํ•ต์‹ฌ ๋ณ€ํ™˜ ๊ณผ์ •
81
+ tone_color_converter.convert(
82
+ audio_src_path=source_path,
83
+ src_se=None, # ์†Œ์Šค ์Œ์„ฑ์˜ ํŠน์ง•์€ ์‚ฌ์šฉ ์•ˆ ํ•จ
84
+ tgt_se=target_se, # ๋ชฉํ‘œ(์ฐธ์กฐ) ์Œ์„ฑ์˜ ํŠน์ง•์„ ์‚ฌ์šฉ
85
+ output_path=save_path,
86
+ message="@MyShell"
87
+ )
88
+
89
+ # 5. ์ƒ์„ฑ๋œ ํŒŒ์ผ์„ ์ฝ์–ด ์ŠคํŠธ๋ฆฌ๋ฐ์œผ๋กœ ๋ฐ˜ํ™˜
90
+ with open(save_path, 'rb') as f:
91
+ audio_data = f.read()
92
+
93
+ return StreamingResponse(
94
+ io.BytesIO(audio_data),
95
+ media_type="audio/wav",
96
+ headers={"Content-Disposition": "inline; filename=cloned_speech.wav"}
97
+ )
98
+
99
+ except Exception as e:
100
+ error_msg = f"Error during speech generation: {str(e)}"
101
+ print(f"โŒ {error_msg}")
102
+ raise HTTPException(status_code=500, detail=error_msg)
103
+
104
+ finally:
105
+ # 6. ์ž‘์—…์ด ๋๋‚˜๋ฉด ๋ชจ๋“  ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
106
+ for path in [reference_path, source_path, save_path]:
107
+ if path and os.path.exists(path):
108
+ os.remove(path)