feat: Add voice design support for voice cloning and enhance cache management

This commit is contained in:
2026-02-04 17:52:24 +08:00
parent 13820e38c7
commit 9e5d12c9fb
5 changed files with 247 additions and 27 deletions

View File

@@ -164,7 +164,8 @@ async def process_voice_clone_job(
request_data: dict,
ref_audio_path: str,
backend_type: str,
db_url: str
db_url: str,
use_voice_design: bool = False
):
from core.database import SessionLocal
from core.tts_service import TTSServiceFactory
@@ -194,13 +195,36 @@ async def process_voice_clone_job(
ref_audio_data = f.read()
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
voice_design_id = request_data.get('voice_design_id')
if voice_design_id:
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
if not design or not design.voice_cache_id:
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
if not cached:
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
ref_audio_hash = f"voice_design_{voice_design_id}"
cache_metrics.record_hit(user_id)
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
if request_data.get('x_vector_only_mode', False) and backend_type == "local":
x_vector = None
cache_id = None
if request_data.get('use_cache', True):
if voice_design_id:
design = get_voice_design(db, voice_design_id, user_id)
x_vector = cached['data']
cache_id = design.voice_cache_id
elif request_data.get('use_cache', True):
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
if cached:
x_vector = cached['data']
@@ -248,7 +272,17 @@ async def process_voice_clone_job(
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
if voice_design_id and backend_type == "local":
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
x_vector = cached['data']
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
@@ -274,7 +308,7 @@ async def process_voice_clone_job(
db.commit()
finally:
if Path(ref_audio_path).exists():
if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
Path(ref_audio_path).unlink()
db.close()
@@ -483,10 +517,11 @@ async def create_voice_clone_job(
request: Request,
text: str = Form(...),
language: str = Form(default="Auto"),
ref_audio: UploadFile = File(...),
ref_audio: Optional[UploadFile] = File(default=None),
ref_text: Optional[str] = Form(default=None),
use_cache: bool = Form(default=True),
x_vector_only_mode: bool = Form(default=False),
voice_design_id: Optional[int] = Form(default=None),
max_new_tokens: Optional[int] = Form(default=2048),
temperature: Optional[float] = Form(default=0.9),
top_k: Optional[int] = Form(default=50),
@@ -498,7 +533,7 @@ async def create_voice_clone_job(
db: Session = Depends(get_db)
):
from core.security import decrypt_api_key
from db.crud import get_user_preferences, can_user_use_local_model
from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design
user_prefs = get_user_preferences(db, current_user.id)
preferred_backend = user_prefs.get("default_backend", "aliyun")
@@ -519,6 +554,10 @@ async def create_voice_clone_job(
detail="Aliyun API key not configured. Please set your API key in Settings."
)
ref_audio_data = None
ref_audio_hash = None
use_voice_design = False
try:
validate_text_length(text)
language = validate_language(language)
@@ -531,13 +570,36 @@ async def create_voice_clone_job(
'repetition_penalty': repetition_penalty
})
ref_audio_data = await ref_audio.read()
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
if voice_design_id:
design = get_voice_design(db, voice_design_id, current_user.id)
if not design:
raise ValueError("Voice design not found")
if design.backend_type != backend_type:
raise ValueError(f"Voice design backend ({design.backend_type}) doesn't match request backend ({backend_type})")
if not design.voice_cache_id:
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
use_voice_design = True
ref_audio_hash = f"voice_design_{voice_design_id}"
if not ref_text:
ref_text = design.ref_text
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
else:
if not ref_audio:
raise ValueError("Either ref_audio or voice_design_id must be provided")
ref_audio_data = await ref_audio.read()
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -555,6 +617,7 @@ async def create_voice_clone_job(
"ref_audio_hash": ref_audio_hash,
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params
}
)
@@ -562,9 +625,13 @@ async def create_voice_clone_job(
db.commit()
db.refresh(job)
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(ref_audio_data)
tmp_audio_path = tmp_file.name
if use_voice_design:
design = get_voice_design(db, voice_design_id, current_user.id)
tmp_audio_path = design.ref_audio_path
else:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(ref_audio_data)
tmp_audio_path = tmp_file.name
request_data = {
"text": text,
@@ -572,6 +639,7 @@ async def create_voice_clone_job(
"ref_text": ref_text or "",
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params
}
@@ -582,7 +650,8 @@ async def create_voice_clone_job(
request_data,
tmp_audio_path,
backend_type,
str(settings.DATABASE_URL)
str(settings.DATABASE_URL),
use_voice_design
)
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)