refactor: rename canto-backend → backend, canto-frontend → frontend
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
286
backend/core/tts_service.py
Normal file
286
backend/core/tts_service.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSBackend(ABC):
|
||||
@abstractmethod
|
||||
async def generate_custom_voice(self, params: dict) -> Tuple[bytes, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_voice_design(self, params: dict) -> Tuple[bytes, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes) -> Tuple[bytes, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> dict:
|
||||
pass
|
||||
|
||||
|
||||
class LocalTTSBackend(TTSBackend):
|
||||
def __init__(self):
|
||||
self.model_manager = None
|
||||
# Add a lock to prevent concurrent VRAM contention and CUDA errors on local GPU models
|
||||
self._gpu_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
from core.model_manager import ModelManager
|
||||
self.model_manager = await ModelManager.get_instance()
|
||||
|
||||
async def generate_custom_voice(self, params: dict) -> Tuple[bytes, int]:
|
||||
await self.model_manager.load_model("custom-voice")
|
||||
_, tts = await self.model_manager.get_current_model()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
async with self._gpu_lock:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_custom_voice,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
speaker=params['speaker'],
|
||||
instruct=params.get('instruct', ''),
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
)
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
wavs, sample_rate = result if isinstance(result, tuple) else (result, 24000)
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
return self._numpy_to_bytes(audio_data), sample_rate
|
||||
|
||||
async def generate_voice_design(self, params: dict) -> Tuple[bytes, int]:
|
||||
await self.model_manager.load_model("voice-design")
|
||||
_, tts = await self.model_manager.get_current_model()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
async with self._gpu_lock:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_voice_design,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
instruct=params['instruct'],
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
)
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
wavs, sample_rate = result if isinstance(result, tuple) else (result, 24000)
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
return self._numpy_to_bytes(audio_data), sample_rate
|
||||
|
||||
async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes = None, x_vector=None) -> Tuple[bytes, int]:
|
||||
from utils.audio import process_ref_audio
|
||||
|
||||
await self.model_manager.load_model("base")
|
||||
_, tts = await self.model_manager.get_current_model()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async with self._gpu_lock:
|
||||
if x_vector is None:
|
||||
if ref_audio_bytes is None:
|
||||
raise ValueError("Either ref_audio_bytes or x_vector must be provided")
|
||||
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
|
||||
x_vector = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.create_voice_clone_prompt,
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=params.get('ref_text', ''),
|
||||
x_vector_only_mode=False,
|
||||
)
|
||||
)
|
||||
|
||||
wavs, sample_rate = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_voice_clone,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
voice_clone_prompt=x_vector,
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
)
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
if isinstance(audio_data, list):
|
||||
audio_data = np.array(audio_data)
|
||||
return self._numpy_to_bytes(audio_data), sample_rate
|
||||
|
||||
async def health_check(self) -> dict:
|
||||
return {
|
||||
"available": self.model_manager is not None,
|
||||
"current_model": self.model_manager.current_model_name if self.model_manager else None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _numpy_to_bytes(audio_array) -> bytes:
|
||||
import numpy as np
|
||||
import io
|
||||
import wave
|
||||
|
||||
if isinstance(audio_array, list):
|
||||
audio_array = np.array(audio_array)
|
||||
|
||||
audio_array = np.clip(audio_array, -1.0, 1.0)
|
||||
audio_int16 = (audio_array * 32767).astype(np.int16)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(24000)
|
||||
wav_file.writeframes(audio_int16.tobytes())
|
||||
|
||||
buffer.seek(0)
|
||||
return buffer.read()
|
||||
|
||||
|
||||
class IndexTTS2Backend:
|
||||
_gpu_lock = asyncio.Lock()
|
||||
|
||||
# Level 10 = these raw weights. Scale linearly: level N → N/10 * max
|
||||
EMO_LEVEL_MAX: dict[str, float] = {
|
||||
"开心": 0.75, "happy": 0.75,
|
||||
"愤怒": 0.08, "angry": 0.08,
|
||||
"悲伤": 0.90, "sad": 0.90,
|
||||
"恐惧": 0.10, "fear": 0.10,
|
||||
"厌恶": 0.50, "hate": 0.50,
|
||||
"低沉": 0.35, "low": 0.35,
|
||||
"惊讶": 0.35, "surprise": 0.35,
|
||||
}
|
||||
|
||||
# Emotion keyword → index mapping
|
||||
# Order: [happy, angry, sad, fear, hate, low, surprise, neutral]
|
||||
_EMO_KEYWORDS = [
|
||||
['喜', '开心', '快乐', '高兴', '欢乐', '愉快', 'happy', '热情', '兴奋', '愉悦', '激动'],
|
||||
['怒', '愤怒', '生气', '恼', 'angry', '气愤', '愤慨'],
|
||||
['哀', '悲伤', '难过', '忧郁', '伤心', '悲', 'sad', '感慨', '沉重', '沉痛', '哭'],
|
||||
['惧', '恐惧', '害怕', '恐', 'fear', '担心', '紧张'],
|
||||
['厌恶', '厌', 'hate', '讨厌', '反感'],
|
||||
['低落', '沮丧', '消沉', 'low', '抑郁', '颓废'],
|
||||
['惊喜', '惊讶', '意外', 'surprise', '惊', '吃惊', '震惊'],
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _emo_text_to_vector(emo_text: str) -> Optional[list]:
|
||||
tokens = [t.strip() for t in emo_text.split('+') if t.strip()]
|
||||
matched = []
|
||||
for tok in tokens:
|
||||
if ':' in tok:
|
||||
name_part, w_str = tok.rsplit(':', 1)
|
||||
try:
|
||||
weight: Optional[float] = float(w_str)
|
||||
except ValueError:
|
||||
weight = None
|
||||
else:
|
||||
name_part = tok
|
||||
weight = None
|
||||
name_lower = name_part.lower().strip()
|
||||
for idx, words in enumerate(IndexTTS2Backend._EMO_KEYWORDS):
|
||||
for word in words:
|
||||
if word in name_lower:
|
||||
matched.append((idx, weight))
|
||||
break
|
||||
if not matched:
|
||||
return None
|
||||
vec = [0.0] * 8
|
||||
has_explicit = any(w is not None for _, w in matched)
|
||||
if has_explicit:
|
||||
for idx, w in matched:
|
||||
vec[idx] = w if w is not None else 0.5
|
||||
else:
|
||||
score = 0.8 if len(matched) == 1 else 0.5
|
||||
for idx, _ in matched:
|
||||
vec[idx] = 0.2 if idx == 1 else score
|
||||
return vec
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
spk_audio_prompt: str,
|
||||
output_path: str,
|
||||
emo_text: str = None,
|
||||
emo_alpha: float = 0.6,
|
||||
) -> bytes:
|
||||
from core.model_manager import IndexTTS2ModelManager
|
||||
manager = await IndexTTS2ModelManager.get_instance()
|
||||
tts = await manager.get_model()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
emo_vector = None
|
||||
if emo_text and len(emo_text.strip()) > 0:
|
||||
resolved_emo_text = emo_text
|
||||
resolved_emo_alpha = emo_alpha
|
||||
if emo_alpha is not None and emo_alpha > 1:
|
||||
level = min(10, max(1, round(emo_alpha)))
|
||||
name = emo_text.strip()
|
||||
max_val = self.EMO_LEVEL_MAX.get(name)
|
||||
if max_val is None:
|
||||
name_lower = name.lower()
|
||||
for key, val in self.EMO_LEVEL_MAX.items():
|
||||
if key in name_lower or name_lower in key:
|
||||
max_val = val
|
||||
break
|
||||
if max_val is None:
|
||||
max_val = 0.20
|
||||
weight = round(level / 10 * max_val, 4)
|
||||
resolved_emo_text = f"{name}:{weight}"
|
||||
resolved_emo_alpha = 1.0
|
||||
raw_vector = self._emo_text_to_vector(resolved_emo_text)
|
||||
if raw_vector is not None:
|
||||
emo_vector = [v * resolved_emo_alpha for v in raw_vector]
|
||||
logger.info(f"IndexTTS2 emo_text={repr(emo_text)} emo_alpha={emo_alpha} → resolved={repr(resolved_emo_text)} emo_vector={emo_vector}")
|
||||
|
||||
async with IndexTTS2Backend._gpu_lock:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.infer,
|
||||
spk_audio_prompt=spk_audio_prompt,
|
||||
text=text,
|
||||
output_path=output_path,
|
||||
emo_vector=emo_vector,
|
||||
emo_alpha=1.0,
|
||||
)
|
||||
)
|
||||
with open(output_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class TTSServiceFactory:
|
||||
_local_backend: Optional[LocalTTSBackend] = None
|
||||
|
||||
@classmethod
|
||||
async def get_backend(cls, backend_type: str = None, user_api_key: Optional[str] = None) -> TTSBackend:
|
||||
if cls._local_backend is None:
|
||||
cls._local_backend = LocalTTSBackend()
|
||||
await cls._local_backend.initialize()
|
||||
return cls._local_backend
|
||||
Reference in New Issue
Block a user