Create handler.py
Browse filesAdd handler.py for Endpoint deployment
- handler.py +31 -0
handler.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from fish_speech.models.fish_speech import FishSpeech
|
| 3 |
+
from fish_speech.inference import infer
|
| 4 |
+
import io
|
| 5 |
+
import base64
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
|
| 8 |
+
# 加载模型
|
| 9 |
+
model = FishSpeech.from_pretrained('fishaudio/fish-speech-1.5')
|
| 10 |
+
|
| 11 |
+
def predict(inputs: dict):
|
| 12 |
+
text = inputs.get('inputs', 'Hello world')
|
| 13 |
+
# 支持 [singing] 标签
|
| 14 |
+
if "[singing]" in text.lower():
|
| 15 |
+
mode = "singing"
|
| 16 |
+
text = text.replace("[singing]", "")
|
| 17 |
+
else:
|
| 18 |
+
mode = "speech"
|
| 19 |
+
|
| 20 |
+
# 生成音频
|
| 21 |
+
audio = infer(model, text, mode=mode)
|
| 22 |
+
|
| 23 |
+
# 转 base64 WAV
|
| 24 |
+
buffer = io.BytesIO()
|
| 25 |
+
sf.write(buffer, audio.cpu().numpy(), 24000, format='WAV')
|
| 26 |
+
audio_b64 = base64.b64encode(buffer.getvalue()).decode()
|
| 27 |
+
|
| 28 |
+
return {"audio": audio_b64}
|
| 29 |
+
|
| 30 |
+
def query(payload):
|
| 31 |
+
return predict(payload)
|