feat: refactor voice bootstrap logic and improve error handling in audio generation
This commit is contained in:
@@ -469,71 +469,7 @@ async def parse_one_chapter(project_id: int, chapter_id: int, user: User, db) ->
|
|||||||
|
|
||||||
|
|
||||||
async def _bootstrap_character_voices(segments, user, backend, backend_type: str, db: Session) -> None:
|
async def _bootstrap_character_voices(segments, user, backend, backend_type: str, db: Session) -> None:
|
||||||
bootstrapped: set[int] = set()
|
pass
|
||||||
|
|
||||||
for seg in segments:
|
|
||||||
char = crud.get_audiobook_character(db, seg.character_id)
|
|
||||||
if not char or not char.voice_design_id or char.voice_design_id in bootstrapped:
|
|
||||||
continue
|
|
||||||
bootstrapped.add(char.voice_design_id)
|
|
||||||
|
|
||||||
design = crud.get_voice_design(db, char.voice_design_id, user.id)
|
|
||||||
if not design:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
if backend_type == "local" and not design.voice_cache_id:
|
|
||||||
from core.model_manager import ModelManager
|
|
||||||
from core.cache_manager import VoiceCacheManager
|
|
||||||
from utils.audio import process_ref_audio
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
ref_text = "你好,这是参考音频。"
|
|
||||||
ref_audio_bytes, _ = await backend.generate_voice_design({
|
|
||||||
"text": ref_text,
|
|
||||||
"language": "Auto",
|
|
||||||
"instruct": design.instruct or "",
|
|
||||||
"max_new_tokens": 512,
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.05,
|
|
||||||
})
|
|
||||||
|
|
||||||
model_manager = await ModelManager.get_instance()
|
|
||||||
await model_manager.load_model("base")
|
|
||||||
_, tts = await model_manager.get_current_model()
|
|
||||||
|
|
||||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
|
||||||
x_vector = tts.create_voice_clone_prompt(
|
|
||||||
ref_audio=(ref_audio_array, ref_sr),
|
|
||||||
ref_text=ref_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_manager = await VoiceCacheManager.get_instance()
|
|
||||||
ref_audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
|
|
||||||
cache_id = await cache_manager.set_cache(
|
|
||||||
user.id, ref_audio_hash, x_vector,
|
|
||||||
{"ref_text": ref_text, "instruct": design.instruct},
|
|
||||||
db
|
|
||||||
)
|
|
||||||
design.voice_cache_id = cache_id
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Bootstrapped local voice cache: design_id={design.id}, cache_id={cache_id}")
|
|
||||||
|
|
||||||
elif backend_type == "aliyun" and not design.aliyun_voice_id:
|
|
||||||
from core.tts_service import AliyunTTSBackend
|
|
||||||
if isinstance(backend, AliyunTTSBackend):
|
|
||||||
voice_id = await backend._create_voice_design(
|
|
||||||
instruct=design.instruct or "",
|
|
||||||
preview_text="你好,这是参考音频。"
|
|
||||||
)
|
|
||||||
design.aliyun_voice_id = voice_id
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Bootstrapped aliyun voice_id: design_id={design.id}, voice_id={voice_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to bootstrap voice for design_id={design.id}: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_project(project_id: int, user: User, db: Session, chapter_index: Optional[int] = None, cancel_event: Optional[asyncio.Event] = None, force: bool = False) -> None:
|
async def generate_project(project_id: int, user: User, db: Session, chapter_index: Optional[int] = None, cancel_event: Optional[asyncio.Event] = None, force: bool = False) -> None:
|
||||||
@@ -570,24 +506,9 @@ async def generate_project(project_id: int, user: User, db: Session, chapter_ind
|
|||||||
output_base = Path(settings.OUTPUT_DIR) / "audiobook" / str(project_id) / "segments"
|
output_base = Path(settings.OUTPUT_DIR) / "audiobook" / str(project_id) / "segments"
|
||||||
output_base.mkdir(parents=True, exist_ok=True)
|
output_base.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
from core.tts_service import TTSServiceFactory
|
from core.tts_service import IndexTTS2Backend
|
||||||
from core.security import decrypt_api_key
|
|
||||||
|
|
||||||
backend_type = user.user_preferences.get("default_backend", "aliyun") if user.user_preferences else "aliyun"
|
|
||||||
|
|
||||||
user_api_key = None
|
|
||||||
if backend_type == "aliyun":
|
|
||||||
from db.crud import get_system_setting
|
|
||||||
encrypted = get_system_setting(db, "aliyun_api_key")
|
|
||||||
if encrypted:
|
|
||||||
user_api_key = decrypt_api_key(encrypted)
|
|
||||||
|
|
||||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
|
||||||
|
|
||||||
await _bootstrap_character_voices(segments, user, backend, backend_type, db)
|
|
||||||
|
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
# Check cancel event before each segment
|
|
||||||
if cancel_event and cancel_event.is_set():
|
if cancel_event and cancel_event.is_set():
|
||||||
logger.info(f"Generation cancelled for project {project_id}, stopping at segment {seg.id}")
|
logger.info(f"Generation cancelled for project {project_id}, stopping at segment {seg.id}")
|
||||||
break
|
break
|
||||||
@@ -608,75 +529,25 @@ async def generate_project(project_id: int, user: User, db: Session, chapter_ind
|
|||||||
audio_filename = f"ch{seg.chapter_index:03d}_seg{seg.segment_index:04d}.wav"
|
audio_filename = f"ch{seg.chapter_index:03d}_seg{seg.segment_index:04d}.wav"
|
||||||
audio_path = output_base / audio_filename
|
audio_path = output_base / audio_filename
|
||||||
|
|
||||||
ref_audio_for_emo = design.ref_audio_path
|
ref_audio = design.ref_audio_path
|
||||||
if not ref_audio_for_emo:
|
if not ref_audio or not Path(ref_audio).exists():
|
||||||
preview_path = Path(settings.OUTPUT_DIR) / "audiobook" / str(project_id) / "previews" / f"char_{char.id}.wav"
|
preview_path = Path(settings.OUTPUT_DIR) / "audiobook" / str(project_id) / "previews" / f"char_{char.id}.wav"
|
||||||
if preview_path.exists():
|
if preview_path.exists():
|
||||||
ref_audio_for_emo = str(preview_path)
|
ref_audio = str(preview_path)
|
||||||
|
|
||||||
|
if not ref_audio or not Path(ref_audio).exists():
|
||||||
|
logger.error(f"No ref audio for char {char.id}, skipping segment {seg.id}")
|
||||||
|
crud.update_audiobook_segment_status(db, seg.id, "error")
|
||||||
|
continue
|
||||||
|
|
||||||
if seg.emo_text and ref_audio_for_emo and Path(ref_audio_for_emo).exists():
|
|
||||||
from core.tts_service import IndexTTS2Backend
|
|
||||||
indextts2 = IndexTTS2Backend()
|
indextts2 = IndexTTS2Backend()
|
||||||
audio_bytes = await indextts2.generate(
|
audio_bytes = await indextts2.generate(
|
||||||
text=seg.text,
|
text=seg.text,
|
||||||
spk_audio_prompt=ref_audio_for_emo,
|
spk_audio_prompt=ref_audio,
|
||||||
output_path=str(audio_path),
|
output_path=str(audio_path),
|
||||||
emo_text=seg.emo_text,
|
emo_text=seg.emo_text or None,
|
||||||
emo_alpha=seg.emo_alpha if seg.emo_alpha is not None else 0.6,
|
emo_alpha=seg.emo_alpha if seg.emo_alpha is not None else 0.6,
|
||||||
)
|
)
|
||||||
elif backend_type == "aliyun":
|
|
||||||
if design.aliyun_voice_id:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design(
|
|
||||||
{"text": seg.text, "language": "zh"},
|
|
||||||
saved_voice_id=design.aliyun_voice_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design({
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "zh",
|
|
||||||
"instruct": _get_gendered_instruct(char.gender, design.instruct),
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
if design.voice_cache_id:
|
|
||||||
from core.cache_manager import VoiceCacheManager
|
|
||||||
cache_manager = await VoiceCacheManager.get_instance()
|
|
||||||
cache_result = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
|
||||||
x_vector = cache_result['data'] if cache_result else None
|
|
||||||
if x_vector:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_clone(
|
|
||||||
{
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "Auto",
|
|
||||||
"max_new_tokens": 2048,
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.05,
|
|
||||||
},
|
|
||||||
x_vector=x_vector
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design({
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "Auto",
|
|
||||||
"instruct": _get_gendered_instruct(char.gender, design.instruct),
|
|
||||||
"max_new_tokens": 2048,
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.05,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design({
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "Auto",
|
|
||||||
"instruct": _get_gendered_instruct(char.gender, design.instruct),
|
|
||||||
"max_new_tokens": 2048,
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.05,
|
|
||||||
})
|
|
||||||
|
|
||||||
with open(audio_path, "wb") as f:
|
with open(audio_path, "wb") as f:
|
||||||
f.write(audio_bytes)
|
f.write(audio_bytes)
|
||||||
@@ -725,18 +596,7 @@ async def generate_single_segment(segment_id: int, user: User, db: Session) -> N
|
|||||||
|
|
||||||
crud.update_audiobook_segment_status(db, segment_id, "generating")
|
crud.update_audiobook_segment_status(db, segment_id, "generating")
|
||||||
try:
|
try:
|
||||||
from core.tts_service import TTSServiceFactory
|
from core.tts_service import IndexTTS2Backend
|
||||||
from core.security import decrypt_api_key
|
|
||||||
|
|
||||||
backend_type = user.user_preferences.get("default_backend", "aliyun") if user.user_preferences else "aliyun"
|
|
||||||
user_api_key = None
|
|
||||||
if backend_type == "aliyun":
|
|
||||||
from db.crud import get_system_setting
|
|
||||||
encrypted = get_system_setting(db, "aliyun_api_key")
|
|
||||||
if encrypted:
|
|
||||||
user_api_key = decrypt_api_key(encrypted)
|
|
||||||
|
|
||||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
|
||||||
|
|
||||||
char = crud.get_audiobook_character(db, seg.character_id)
|
char = crud.get_audiobook_character(db, seg.character_id)
|
||||||
if not char or not char.voice_design_id:
|
if not char or not char.voice_design_id:
|
||||||
@@ -748,81 +608,28 @@ async def generate_single_segment(segment_id: int, user: User, db: Session) -> N
|
|||||||
crud.update_audiobook_segment_status(db, segment_id, "error")
|
crud.update_audiobook_segment_status(db, segment_id, "error")
|
||||||
return
|
return
|
||||||
|
|
||||||
await _bootstrap_character_voices([seg], user, backend, backend_type, db)
|
|
||||||
db.refresh(design)
|
|
||||||
|
|
||||||
audio_filename = f"ch{seg.chapter_index:03d}_seg{seg.segment_index:04d}.wav"
|
audio_filename = f"ch{seg.chapter_index:03d}_seg{seg.segment_index:04d}.wav"
|
||||||
audio_path = output_base / audio_filename
|
audio_path = output_base / audio_filename
|
||||||
|
|
||||||
ref_audio_for_emo = design.ref_audio_path
|
ref_audio = design.ref_audio_path
|
||||||
if not ref_audio_for_emo:
|
if not ref_audio or not Path(ref_audio).exists():
|
||||||
preview_path = Path(settings.OUTPUT_DIR) / "audiobook" / str(seg.project_id) / "previews" / f"char_{char.id}.wav"
|
preview_path = Path(settings.OUTPUT_DIR) / "audiobook" / str(seg.project_id) / "previews" / f"char_{char.id}.wav"
|
||||||
if preview_path.exists():
|
if preview_path.exists():
|
||||||
ref_audio_for_emo = str(preview_path)
|
ref_audio = str(preview_path)
|
||||||
|
|
||||||
|
if not ref_audio or not Path(ref_audio).exists():
|
||||||
|
logger.error(f"No ref audio for char {char.id}, skipping segment {segment_id}")
|
||||||
|
crud.update_audiobook_segment_status(db, segment_id, "error")
|
||||||
|
return
|
||||||
|
|
||||||
if seg.emo_text and ref_audio_for_emo and Path(ref_audio_for_emo).exists():
|
|
||||||
from core.tts_service import IndexTTS2Backend
|
|
||||||
indextts2 = IndexTTS2Backend()
|
indextts2 = IndexTTS2Backend()
|
||||||
audio_bytes = await indextts2.generate(
|
audio_bytes = await indextts2.generate(
|
||||||
text=seg.text,
|
text=seg.text,
|
||||||
spk_audio_prompt=ref_audio_for_emo,
|
spk_audio_prompt=ref_audio,
|
||||||
output_path=str(audio_path),
|
output_path=str(audio_path),
|
||||||
emo_text=seg.emo_text,
|
emo_text=seg.emo_text or None,
|
||||||
emo_alpha=seg.emo_alpha if seg.emo_alpha is not None else 0.6,
|
emo_alpha=seg.emo_alpha if seg.emo_alpha is not None else 0.6,
|
||||||
)
|
)
|
||||||
elif backend_type == "aliyun":
|
|
||||||
if design.aliyun_voice_id:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design(
|
|
||||||
{"text": seg.text, "language": "zh"},
|
|
||||||
saved_voice_id=design.aliyun_voice_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design({
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "zh",
|
|
||||||
"instruct": _get_gendered_instruct(char.gender, design.instruct),
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
if design.voice_cache_id:
|
|
||||||
from core.cache_manager import VoiceCacheManager
|
|
||||||
cache_manager = await VoiceCacheManager.get_instance()
|
|
||||||
cache_result = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
|
||||||
x_vector = cache_result['data'] if cache_result else None
|
|
||||||
if x_vector:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_clone(
|
|
||||||
{
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "Auto",
|
|
||||||
"max_new_tokens": 2048,
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.05,
|
|
||||||
},
|
|
||||||
x_vector=x_vector
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design({
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "Auto",
|
|
||||||
"instruct": _get_gendered_instruct(char.gender, design.instruct),
|
|
||||||
"max_new_tokens": 2048,
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.05,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
audio_bytes, _ = await backend.generate_voice_design({
|
|
||||||
"text": seg.text,
|
|
||||||
"language": "Auto",
|
|
||||||
"instruct": _get_gendered_instruct(char.gender, design.instruct),
|
|
||||||
"max_new_tokens": 2048,
|
|
||||||
"temperature": 0.3,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.05,
|
|
||||||
})
|
|
||||||
|
|
||||||
with open(audio_path, "wb") as f:
|
with open(audio_path, "wb") as f:
|
||||||
f.write(audio_bytes)
|
f.write(audio_bytes)
|
||||||
|
|||||||
Reference in New Issue
Block a user