feat: refactor voice bootstrap logic and improve error handling in audio generation

This commit is contained in:
2026-03-12 23:47:52 +08:00
parent 233c4a9a98
commit 29799a8c7d

View File

@@ -469,71 +469,7 @@ async def parse_one_chapter(project_id: int, chapter_id: int, user: User, db) ->
async def _bootstrap_character_voices(segments, user, backend, backend_type: str, db: Session) -> None:
bootstrapped: set[int] = set()
for seg in segments:
char = crud.get_audiobook_character(db, seg.character_id)
if not char or not char.voice_design_id or char.voice_design_id in bootstrapped:
continue
bootstrapped.add(char.voice_design_id)
design = crud.get_voice_design(db, char.voice_design_id, user.id)
if not design:
continue
try:
if backend_type == "local" and not design.voice_cache_id:
from core.model_manager import ModelManager
from core.cache_manager import VoiceCacheManager
from utils.audio import process_ref_audio
import hashlib
ref_text = "你好,这是参考音频。"
ref_audio_bytes, _ = await backend.generate_voice_design({
"text": ref_text,
"language": "Auto",
"instruct": design.instruct or "",
"max_new_tokens": 512,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.9,
"repetition_penalty": 1.05,
})
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=ref_text,
)
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
cache_id = await cache_manager.set_cache(
user.id, ref_audio_hash, x_vector,
{"ref_text": ref_text, "instruct": design.instruct},
db
)
design.voice_cache_id = cache_id
db.commit()
logger.info(f"Bootstrapped local voice cache: design_id={design.id}, cache_id={cache_id}")
elif backend_type == "aliyun" and not design.aliyun_voice_id:
from core.tts_service import AliyunTTSBackend
if isinstance(backend, AliyunTTSBackend):
voice_id = await backend._create_voice_design(
instruct=design.instruct or "",
preview_text="你好,这是参考音频。"
)
design.aliyun_voice_id = voice_id
db.commit()
logger.info(f"Bootstrapped aliyun voice_id: design_id={design.id}, voice_id={voice_id}")
except Exception as e:
logger.error(f"Failed to bootstrap voice for design_id={design.id}: {e}", exc_info=True)
pass
async def generate_project(project_id: int, user: User, db: Session, chapter_index: Optional[int] = None, cancel_event: Optional[asyncio.Event] = None, force: bool = False) -> None:
@@ -570,24 +506,9 @@ async def generate_project(project_id: int, user: User, db: Session, chapter_ind
output_base = Path(settings.OUTPUT_DIR) / "audiobook" / str(project_id) / "segments"
output_base.mkdir(parents=True, exist_ok=True)
from core.tts_service import TTSServiceFactory
from core.security import decrypt_api_key
backend_type = user.user_preferences.get("default_backend", "aliyun") if user.user_preferences else "aliyun"
user_api_key = None
if backend_type == "aliyun":
from db.crud import get_system_setting
encrypted = get_system_setting(db, "aliyun_api_key")
if encrypted:
user_api_key = decrypt_api_key(encrypted)
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
await _bootstrap_character_voices(segments, user, backend, backend_type, db)
from core.tts_service import IndexTTS2Backend
for seg in segments:
# Check cancel event before each segment
if cancel_event and cancel_event.is_set():
logger.info(f"Generation cancelled for project {project_id}, stopping at segment {seg.id}")
break
@@ -608,75 +529,25 @@ async def generate_project(project_id: int, user: User, db: Session, chapter_ind
audio_filename = f"ch{seg.chapter_index:03d}_seg{seg.segment_index:04d}.wav"
audio_path = output_base / audio_filename
ref_audio_for_emo = design.ref_audio_path
if not ref_audio_for_emo:
ref_audio = design.ref_audio_path
if not ref_audio or not Path(ref_audio).exists():
preview_path = Path(settings.OUTPUT_DIR) / "audiobook" / str(project_id) / "previews" / f"char_{char.id}.wav"
if preview_path.exists():
ref_audio_for_emo = str(preview_path)
ref_audio = str(preview_path)
if not ref_audio or not Path(ref_audio).exists():
logger.error(f"No ref audio for char {char.id}, skipping segment {seg.id}")
crud.update_audiobook_segment_status(db, seg.id, "error")
continue
if seg.emo_text and ref_audio_for_emo and Path(ref_audio_for_emo).exists():
from core.tts_service import IndexTTS2Backend
indextts2 = IndexTTS2Backend()
audio_bytes = await indextts2.generate(
text=seg.text,
spk_audio_prompt=ref_audio_for_emo,
spk_audio_prompt=ref_audio,
output_path=str(audio_path),
emo_text=seg.emo_text,
emo_text=seg.emo_text or None,
emo_alpha=seg.emo_alpha if seg.emo_alpha is not None else 0.6,
)
elif backend_type == "aliyun":
if design.aliyun_voice_id:
audio_bytes, _ = await backend.generate_voice_design(
{"text": seg.text, "language": "zh"},
saved_voice_id=design.aliyun_voice_id
)
else:
audio_bytes, _ = await backend.generate_voice_design({
"text": seg.text,
"language": "zh",
"instruct": _get_gendered_instruct(char.gender, design.instruct),
})
else:
if design.voice_cache_id:
from core.cache_manager import VoiceCacheManager
cache_manager = await VoiceCacheManager.get_instance()
cache_result = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
x_vector = cache_result['data'] if cache_result else None
if x_vector:
audio_bytes, _ = await backend.generate_voice_clone(
{
"text": seg.text,
"language": "Auto",
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.9,
"repetition_penalty": 1.05,
},
x_vector=x_vector
)
else:
audio_bytes, _ = await backend.generate_voice_design({
"text": seg.text,
"language": "Auto",
"instruct": _get_gendered_instruct(char.gender, design.instruct),
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.9,
"repetition_penalty": 1.05,
})
else:
audio_bytes, _ = await backend.generate_voice_design({
"text": seg.text,
"language": "Auto",
"instruct": _get_gendered_instruct(char.gender, design.instruct),
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.9,
"repetition_penalty": 1.05,
})
with open(audio_path, "wb") as f:
f.write(audio_bytes)
@@ -725,18 +596,7 @@ async def generate_single_segment(segment_id: int, user: User, db: Session) -> N
crud.update_audiobook_segment_status(db, segment_id, "generating")
try:
from core.tts_service import TTSServiceFactory
from core.security import decrypt_api_key
backend_type = user.user_preferences.get("default_backend", "aliyun") if user.user_preferences else "aliyun"
user_api_key = None
if backend_type == "aliyun":
from db.crud import get_system_setting
encrypted = get_system_setting(db, "aliyun_api_key")
if encrypted:
user_api_key = decrypt_api_key(encrypted)
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
from core.tts_service import IndexTTS2Backend
char = crud.get_audiobook_character(db, seg.character_id)
if not char or not char.voice_design_id:
@@ -748,81 +608,28 @@ async def generate_single_segment(segment_id: int, user: User, db: Session) -> N
crud.update_audiobook_segment_status(db, segment_id, "error")
return
await _bootstrap_character_voices([seg], user, backend, backend_type, db)
db.refresh(design)
audio_filename = f"ch{seg.chapter_index:03d}_seg{seg.segment_index:04d}.wav"
audio_path = output_base / audio_filename
ref_audio_for_emo = design.ref_audio_path
if not ref_audio_for_emo:
ref_audio = design.ref_audio_path
if not ref_audio or not Path(ref_audio).exists():
preview_path = Path(settings.OUTPUT_DIR) / "audiobook" / str(seg.project_id) / "previews" / f"char_{char.id}.wav"
if preview_path.exists():
ref_audio_for_emo = str(preview_path)
ref_audio = str(preview_path)
if not ref_audio or not Path(ref_audio).exists():
logger.error(f"No ref audio for char {char.id}, skipping segment {segment_id}")
crud.update_audiobook_segment_status(db, segment_id, "error")
return
if seg.emo_text and ref_audio_for_emo and Path(ref_audio_for_emo).exists():
from core.tts_service import IndexTTS2Backend
indextts2 = IndexTTS2Backend()
audio_bytes = await indextts2.generate(
text=seg.text,
spk_audio_prompt=ref_audio_for_emo,
spk_audio_prompt=ref_audio,
output_path=str(audio_path),
emo_text=seg.emo_text,
emo_text=seg.emo_text or None,
emo_alpha=seg.emo_alpha if seg.emo_alpha is not None else 0.6,
)
elif backend_type == "aliyun":
if design.aliyun_voice_id:
audio_bytes, _ = await backend.generate_voice_design(
{"text": seg.text, "language": "zh"},
saved_voice_id=design.aliyun_voice_id
)
else:
audio_bytes, _ = await backend.generate_voice_design({
"text": seg.text,
"language": "zh",
"instruct": _get_gendered_instruct(char.gender, design.instruct),
})
else:
if design.voice_cache_id:
from core.cache_manager import VoiceCacheManager
cache_manager = await VoiceCacheManager.get_instance()
cache_result = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
x_vector = cache_result['data'] if cache_result else None
if x_vector:
audio_bytes, _ = await backend.generate_voice_clone(
{
"text": seg.text,
"language": "Auto",
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.9,
"repetition_penalty": 1.05,
},
x_vector=x_vector
)
else:
audio_bytes, _ = await backend.generate_voice_design({
"text": seg.text,
"language": "Auto",
"instruct": _get_gendered_instruct(char.gender, design.instruct),
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.9,
"repetition_penalty": 1.05,
})
else:
audio_bytes, _ = await backend.generate_voice_design({
"text": seg.text,
"language": "Auto",
"instruct": _get_gendered_instruct(char.gender, design.instruct),
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.9,
"repetition_penalty": 1.05,
})
with open(audio_path, "wb") as f:
f.write(audio_bytes)