feat: Integrate IndexTTS2 model and update related schemas and frontend components
This commit is contained in:
@@ -573,6 +573,16 @@ async def generate_project(project_id: int, user: User, db: Session, chapter_ind
|
||||
"language": "zh",
|
||||
"instruct": _get_gendered_instruct(char.gender, design.instruct),
|
||||
})
|
||||
elif char.use_indextts2 and design.ref_audio_path and Path(design.ref_audio_path).exists():
|
||||
from core.tts_service import IndexTTS2Backend
|
||||
indextts2 = IndexTTS2Backend()
|
||||
audio_bytes = await indextts2.generate(
|
||||
text=seg.text,
|
||||
spk_audio_prompt=design.ref_audio_path,
|
||||
output_path=str(audio_path),
|
||||
emo_text=char.instruct or None,
|
||||
emo_alpha=0.6,
|
||||
)
|
||||
else:
|
||||
if design.voice_cache_id:
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
|
||||
@@ -121,3 +121,40 @@ class ModelManager:
|
||||
}
|
||||
for name, path in self.MODEL_PATHS.items()
|
||||
}
|
||||
|
||||
|
||||
class IndexTTS2ModelManager:
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.tts = None
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
async def get_model(self):
|
||||
if self.tts is None:
|
||||
await self._load()
|
||||
return self.tts
|
||||
|
||||
async def _load(self):
|
||||
from indextts.infer_indextts2 import IndexTTS2
|
||||
from pathlib import Path
|
||||
model_dir = Path(settings.MODEL_BASE_PATH) / "IndexTTS2"
|
||||
loop = asyncio.get_event_loop()
|
||||
self.tts = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: IndexTTS2(
|
||||
cfg_path=str(model_dir / "config.yaml"),
|
||||
model_dir=str(model_dir),
|
||||
is_fp16=False,
|
||||
use_cuda_kernel=False,
|
||||
use_deepspeed=False,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -427,6 +427,73 @@ class AliyunTTSBackend(TTSBackend):
|
||||
return local_speaker
|
||||
|
||||
|
||||
class IndexTTS2Backend:
|
||||
_gpu_lock = asyncio.Lock()
|
||||
|
||||
# Emotion keyword → index mapping
|
||||
# Order: [happy, angry, sad, fear, hate, low, surprise, neutral]
|
||||
_EMO_KEYWORDS = [
|
||||
['喜', '开心', '快乐', '高兴', '欢乐', '愉快', 'happy', '热情', '兴奋', '愉悦', '激动'],
|
||||
['怒', '愤怒', '生气', '恼', 'angry', '气愤', '愤慨'],
|
||||
['哀', '悲伤', '难过', '忧郁', '伤心', '悲', 'sad', '感慨', '沉重', '沉痛', '哭'],
|
||||
['惧', '恐惧', '害怕', '恐', 'fear', '担心', '紧张'],
|
||||
['厌恶', '厌', 'hate', '讨厌', '反感'],
|
||||
['低落', '沮丧', '消沉', 'low', '抑郁', '颓废'],
|
||||
['惊喜', '惊讶', '意外', 'surprise', '惊', '吃惊', '震惊'],
|
||||
['自然', '平静', '中性', '平和', 'neutral', '平淡', '冷静', '稳定'],
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _emo_text_to_vector(emo_text: str) -> Optional[list]:
|
||||
text = emo_text.lower()
|
||||
matched = []
|
||||
for idx, words in enumerate(IndexTTS2Backend._EMO_KEYWORDS):
|
||||
for word in words:
|
||||
if word in text:
|
||||
matched.append(idx)
|
||||
break
|
||||
if not matched:
|
||||
return None
|
||||
vec = [0.0] * 8
|
||||
score = 0.8 if len(matched) == 1 else 0.5
|
||||
for idx in matched:
|
||||
vec[idx] = score
|
||||
return vec
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
spk_audio_prompt: str,
|
||||
output_path: str,
|
||||
emo_text: str = None,
|
||||
emo_alpha: float = 0.6,
|
||||
) -> bytes:
|
||||
from core.model_manager import IndexTTS2ModelManager
|
||||
manager = await IndexTTS2ModelManager.get_instance()
|
||||
tts = await manager.get_model()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
emo_vector = None
|
||||
if emo_text and len(emo_text.strip()) > 0:
|
||||
emo_vector = self._emo_text_to_vector(emo_text)
|
||||
logger.info(f"IndexTTS2 emo_text={repr(emo_text)} → emo_vector={emo_vector}")
|
||||
|
||||
async with IndexTTS2Backend._gpu_lock:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.infer,
|
||||
spk_audio_prompt=spk_audio_prompt,
|
||||
text=text,
|
||||
output_path=output_path,
|
||||
emo_vector=emo_vector,
|
||||
emo_alpha=1.0,
|
||||
)
|
||||
)
|
||||
with open(output_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class TTSServiceFactory:
|
||||
_local_backend: Optional[LocalTTSBackend] = None
|
||||
_aliyun_backend: Optional[AliyunTTSBackend] = None
|
||||
|
||||
Reference in New Issue
Block a user