feat: Implement character voice preview playback and regeneration, and add a turbo mode status indicator for audiobook projects.
This commit is contained in:
@@ -201,36 +201,86 @@ async def analyze_project(project_id: int, user: User, db: Session, turbo: bool
|
||||
crud.delete_audiobook_characters(db, project_id)
|
||||
|
||||
backend_type = user.user_preferences.get("default_backend", "aliyun") if user.user_preferences else "aliyun"
|
||||
|
||||
for char_data in characters_data:
|
||||
|
||||
async def _create_char_with_voice(char_data):
|
||||
name = char_data.get("name", "narrator")
|
||||
instruct = char_data.get("instruct", "")
|
||||
description = char_data.get("description", "")
|
||||
gender = char_data.get("gender") or ("未知" if name == "narrator" else None)
|
||||
|
||||
voice_design = crud.create_voice_design(
|
||||
db=db,
|
||||
user_id=user.id,
|
||||
name=f"[有声书] {project.title} - {name}",
|
||||
instruct=instruct,
|
||||
backend_type=backend_type,
|
||||
preview_text=description[:100] if description else None,
|
||||
)
|
||||
# Requires isolated DB queries since we're in an async concurrent block
|
||||
try:
|
||||
# We need an async wrapper or a local db session for concurrent sync DB pushes
|
||||
# Because core crud uses synchronous SQLalchemy, executing them in threadpool via asyncio.to_thread
|
||||
import asyncio
|
||||
|
||||
def db_ops():
|
||||
from core.database import SessionLocal
|
||||
local_db = SessionLocal()
|
||||
try:
|
||||
voice_design = crud.create_voice_design(
|
||||
db=local_db,
|
||||
user_id=user.id,
|
||||
name=f"[有声书] {project.title} - {name}",
|
||||
instruct=instruct,
|
||||
backend_type=backend_type,
|
||||
preview_text=description[:100] if description else None,
|
||||
)
|
||||
|
||||
crud.create_audiobook_character(
|
||||
db=db,
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
gender=gender,
|
||||
description=description,
|
||||
instruct=instruct,
|
||||
voice_design_id=voice_design.id,
|
||||
)
|
||||
crud.create_audiobook_character(
|
||||
db=local_db,
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
gender=gender,
|
||||
description=description,
|
||||
instruct=instruct,
|
||||
voice_design_id=voice_design.id,
|
||||
)
|
||||
finally:
|
||||
local_db.close()
|
||||
|
||||
await asyncio.to_thread(db_ops)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create char/voice for {name}: {e}")
|
||||
|
||||
import asyncio
|
||||
batch_tasks = [_create_char_with_voice(cd) for cd in characters_data]
|
||||
if batch_tasks:
|
||||
await asyncio.gather(*batch_tasks)
|
||||
|
||||
crud.update_audiobook_project_status(db, project_id, "characters_ready")
|
||||
ps.mark_done(key)
|
||||
logger.info(f"Project {project_id} character extraction complete: {len(characters_data)} characters")
|
||||
|
||||
# Kick off background preview generation
|
||||
import asyncio
|
||||
from core.database import SessionLocal
|
||||
|
||||
user_id = user.id
|
||||
|
||||
async def _generate_all_previews():
|
||||
async_db = SessionLocal()
|
||||
try:
|
||||
db_user = crud.get_user_by_id(async_db, user_id)
|
||||
characters = crud.list_audiobook_characters(async_db, project_id)
|
||||
|
||||
# Use a semaphore to limit concurrent TTS requests
|
||||
sem = asyncio.Semaphore(3)
|
||||
async def _gen(char_id: int):
|
||||
async with sem:
|
||||
try:
|
||||
await generate_character_preview(project_id, char_id, db_user, async_db)
|
||||
except Exception as e:
|
||||
logger.error(f"Background preview generation failed for char {char_id}: {e}")
|
||||
|
||||
tasks = [_gen(c.id) for c in characters]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
finally:
|
||||
async_db.close()
|
||||
|
||||
asyncio.create_task(_generate_all_previews())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Analysis failed for project {project_id}: {e}", exc_info=True)
|
||||
ps.append_line(key, f"\n[错误] {e}")
|
||||
@@ -587,6 +637,10 @@ async def parse_all_chapters(project_id: int, user: User, db: Session, statuses:
|
||||
max_concurrent = settings.AUDIOBOOK_PARSE_CONCURRENCY
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
logger.info(f"parse_all_chapters: project={project_id}, {len(pending)} chapters, concurrency={max_concurrent}")
|
||||
|
||||
ps = ProgressStore()
|
||||
key = f"project_{project_id}"
|
||||
ps.append_line(key, f"\n[状态] 开启章节并发解析,共 {len(pending)} 章待处理,最大并发: {max_concurrent}...\n")
|
||||
|
||||
async def parse_with_limit(chapter):
|
||||
if cancel_ev.is_set():
|
||||
@@ -605,6 +659,12 @@ async def parse_all_chapters(project_id: int, user: User, db: Session, statuses:
|
||||
|
||||
await asyncio.gather(*[parse_with_limit(ch) for ch in pending])
|
||||
_cancel_events.pop(project_id, None)
|
||||
|
||||
if cancel_ev.is_set():
|
||||
ps.append_line(key, f"\n[状态] 章节批量解析被用户取消\n")
|
||||
else:
|
||||
ps.append_line(key, f"\n[状态] 所有章节批量解析已完成\n")
|
||||
|
||||
logger.info(f"parse_all_chapters: project={project_id} {'cancelled' if cancel_ev.is_set() else 'complete'}")
|
||||
|
||||
|
||||
@@ -677,3 +737,150 @@ async def process_all(project_id: int, user: User, db: Session) -> None:
|
||||
|
||||
logger.info(f"process_all: project={project_id} complete")
|
||||
|
||||
|
||||
async def generate_character_preview(project_id: int, char_id: int, user: User, db: Session) -> None:
|
||||
"""Generate a short audio preview for a specific character."""
|
||||
project = crud.get_audiobook_project(db, project_id, user.id)
|
||||
if not project:
|
||||
raise ValueError("Project not found")
|
||||
|
||||
char = crud.get_audiobook_character(db, char_id)
|
||||
if not char or char.project_id != project_id:
|
||||
raise ValueError("Character not found or doesn't belong to this project")
|
||||
|
||||
if not char.voice_design_id:
|
||||
raise ValueError("Character has no associated voice design")
|
||||
|
||||
design = crud.get_voice_design(db, char.voice_design_id, user.id)
|
||||
if not design:
|
||||
raise ValueError("Voice design not found")
|
||||
|
||||
output_base = Path(settings.OUTPUT_DIR) / "audiobook" / str(project_id) / "previews"
|
||||
output_base.mkdir(parents=True, exist_ok=True)
|
||||
audio_path = output_base / f"char_{char_id}.wav"
|
||||
|
||||
preview_name = char.name
|
||||
if preview_name == "narrator":
|
||||
preview_name = "旁白"
|
||||
|
||||
preview_desc = ""
|
||||
if char.description:
|
||||
# Take a short snippet of description to make it sound natural
|
||||
preview_desc = "," + char.description[:30].replace('\n', ',')
|
||||
if not preview_desc.endswith('。') and not preview_desc.endswith('!'):
|
||||
preview_desc += "。"
|
||||
|
||||
preview_text = f"你好,我是{preview_name}{preview_desc}"
|
||||
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
backend_type = user.user_preferences.get("default_backend", "aliyun") if user.user_preferences else "aliyun"
|
||||
user_api_key = None
|
||||
if backend_type == "aliyun" and user.aliyun_api_key:
|
||||
user_api_key = decrypt_api_key(user.aliyun_api_key)
|
||||
|
||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
||||
|
||||
try:
|
||||
if backend_type == "local" and not design.voice_cache_id:
|
||||
logger.info(f"Local voice cache missing for char {char_id}. Bootstrapping now...")
|
||||
from core.model_manager import ModelManager
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from utils.audio import process_ref_audio
|
||||
import hashlib
|
||||
|
||||
ref_text = "你好,这是参考音频。"
|
||||
ref_audio_bytes, _ = await backend.generate_voice_design({
|
||||
"text": ref_text,
|
||||
"language": "Auto",
|
||||
"instruct": design.instruct or "",
|
||||
"max_new_tokens": 512,
|
||||
"temperature": 0.3,
|
||||
"top_k": 10,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.05,
|
||||
})
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=ref_text,
|
||||
)
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = hashlib.sha256(ref_audio_bytes).hexdigest()
|
||||
cache_id = await cache_manager.set_cache(
|
||||
user.id, ref_audio_hash, x_vector,
|
||||
{"ref_text": ref_text, "instruct": design.instruct},
|
||||
db
|
||||
)
|
||||
design.voice_cache_id = cache_id
|
||||
db.commit()
|
||||
logger.info(f"Bootstrapped local voice cache for preview: design_id={design.id}, cache_id={cache_id}")
|
||||
|
||||
if backend_type == "aliyun":
|
||||
if design.aliyun_voice_id:
|
||||
audio_bytes, _ = await backend.generate_voice_design(
|
||||
{"text": preview_text, "language": "zh"},
|
||||
saved_voice_id=design.aliyun_voice_id
|
||||
)
|
||||
else:
|
||||
audio_bytes, _ = await backend.generate_voice_design({
|
||||
"text": preview_text,
|
||||
"language": "zh",
|
||||
"instruct": design.instruct,
|
||||
})
|
||||
else:
|
||||
if design.voice_cache_id:
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
cache_result = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||
x_vector = cache_result['data'] if cache_result else None
|
||||
if x_vector:
|
||||
audio_bytes, _ = await backend.generate_voice_clone(
|
||||
{
|
||||
"text": preview_text,
|
||||
"language": "Auto",
|
||||
"max_new_tokens": 512,
|
||||
"temperature": 0.3,
|
||||
"top_k": 10,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.05,
|
||||
},
|
||||
x_vector=x_vector
|
||||
)
|
||||
else:
|
||||
audio_bytes, _ = await backend.generate_voice_design({
|
||||
"text": preview_text,
|
||||
"language": "Auto",
|
||||
"instruct": design.instruct,
|
||||
"max_new_tokens": 512,
|
||||
"temperature": 0.3,
|
||||
"top_k": 10,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.05,
|
||||
})
|
||||
else:
|
||||
audio_bytes, _ = await backend.generate_voice_design({
|
||||
"text": preview_text,
|
||||
"language": "Auto",
|
||||
"instruct": design.instruct,
|
||||
"max_new_tokens": 512,
|
||||
"temperature": 0.3,
|
||||
"top_k": 10,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.05,
|
||||
})
|
||||
|
||||
with open(audio_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
logger.info(f"Preview generated for char {char_id}: {audio_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate preview for char {char_id}: {e}")
|
||||
raise
|
||||
|
||||
@@ -134,13 +134,15 @@ class LLMService:
|
||||
if turbo and len(text_samples) > 1:
|
||||
logger.info(f"Extracting characters in turbo mode: {len(text_samples)} samples concurrent")
|
||||
|
||||
async def _extract_one(sample: str) -> list[Dict]:
|
||||
async def _extract_one(i: int, sample: str) -> list[Dict]:
|
||||
user_message = f"请分析以下小说文本并提取角色:\n\n{sample}"
|
||||
result = await self.stream_chat_json(system_prompt, user_message, None)
|
||||
if on_sample:
|
||||
on_sample(i, len(text_samples))
|
||||
return result.get("characters", [])
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_extract_one(s) for s in text_samples],
|
||||
*[_extract_one(i, s) for i, s in enumerate(text_samples)],
|
||||
return_exceptions=True,
|
||||
)
|
||||
raw_all: list[Dict] = []
|
||||
|
||||
@@ -32,6 +32,8 @@ class TTSBackend(ABC):
|
||||
class LocalTTSBackend(TTSBackend):
|
||||
def __init__(self):
|
||||
self.model_manager = None
|
||||
# Add a lock to prevent concurrent VRAM contention and CUDA errors on local GPU models
|
||||
self._gpu_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
from core.model_manager import ModelManager
|
||||
@@ -42,21 +44,22 @@ class LocalTTSBackend(TTSBackend):
|
||||
_, tts = await self.model_manager.get_current_model()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_custom_voice,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
speaker=params['speaker'],
|
||||
instruct=params.get('instruct', ''),
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
async with self._gpu_lock:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_custom_voice,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
speaker=params['speaker'],
|
||||
instruct=params.get('instruct', ''),
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
wavs, sample_rate = result if isinstance(result, tuple) else (result, 24000)
|
||||
@@ -68,20 +71,21 @@ class LocalTTSBackend(TTSBackend):
|
||||
_, tts = await self.model_manager.get_current_model()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_voice_design,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
instruct=params['instruct'],
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
async with self._gpu_lock:
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_voice_design,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
instruct=params['instruct'],
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
wavs, sample_rate = result if isinstance(result, tuple) else (result, 24000)
|
||||
@@ -96,37 +100,38 @@ class LocalTTSBackend(TTSBackend):
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if x_vector is None:
|
||||
if ref_audio_bytes is None:
|
||||
raise ValueError("Either ref_audio_bytes or x_vector must be provided")
|
||||
async with self._gpu_lock:
|
||||
if x_vector is None:
|
||||
if ref_audio_bytes is None:
|
||||
raise ValueError("Either ref_audio_bytes or x_vector must be provided")
|
||||
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
|
||||
x_vector = await loop.run_in_executor(
|
||||
x_vector = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.create_voice_clone_prompt,
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=params.get('ref_text', ''),
|
||||
x_vector_only_mode=False,
|
||||
)
|
||||
)
|
||||
|
||||
wavs, sample_rate = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.create_voice_clone_prompt,
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=params.get('ref_text', ''),
|
||||
x_vector_only_mode=False,
|
||||
tts.generate_voice_clone,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
voice_clone_prompt=x_vector,
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
)
|
||||
)
|
||||
|
||||
wavs, sample_rate = await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
tts.generate_voice_clone,
|
||||
text=params['text'],
|
||||
language=params['language'],
|
||||
voice_clone_prompt=x_vector,
|
||||
max_new_tokens=params['max_new_tokens'],
|
||||
temperature=params['temperature'],
|
||||
top_k=params['top_k'],
|
||||
top_p=params['top_p'],
|
||||
repetition_penalty=params['repetition_penalty'],
|
||||
)
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
if isinstance(audio_data, list):
|
||||
|
||||
Reference in New Issue
Block a user