feat: Add voice design support for voice cloning and enhance cache management

This commit is contained in:
2026-02-04 17:52:24 +08:00
parent 13820e38c7
commit 9e5d12c9fb
5 changed files with 247 additions and 27 deletions

View File

@@ -164,7 +164,8 @@ async def process_voice_clone_job(
request_data: dict, request_data: dict,
ref_audio_path: str, ref_audio_path: str,
backend_type: str, backend_type: str,
db_url: str db_url: str,
use_voice_design: bool = False
): ):
from core.database import SessionLocal from core.database import SessionLocal
from core.tts_service import TTSServiceFactory from core.tts_service import TTSServiceFactory
@@ -194,13 +195,36 @@ async def process_voice_clone_job(
ref_audio_data = f.read() ref_audio_data = f.read()
cache_manager = await VoiceCacheManager.get_instance() cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
voice_design_id = request_data.get('voice_design_id')
if voice_design_id:
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
if not design or not design.voice_cache_id:
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
if not cached:
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
ref_audio_hash = f"voice_design_{voice_design_id}"
cache_metrics.record_hit(user_id)
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
if request_data.get('x_vector_only_mode', False) and backend_type == "local": if request_data.get('x_vector_only_mode', False) and backend_type == "local":
x_vector = None x_vector = None
cache_id = None cache_id = None
if request_data.get('use_cache', True): if voice_design_id:
design = get_voice_design(db, voice_design_id, user_id)
x_vector = cached['data']
cache_id = design.voice_cache_id
elif request_data.get('use_cache', True):
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db) cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
if cached: if cached:
x_vector = cached['data'] x_vector = cached['data']
@@ -248,7 +272,17 @@ async def process_voice_clone_job(
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key) backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data) if voice_design_id and backend_type == "local":
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
x_vector = cached['data']
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav" filename = f"{user_id}_{job_id}_{timestamp}.wav"
@@ -274,7 +308,7 @@ async def process_voice_clone_job(
db.commit() db.commit()
finally: finally:
if Path(ref_audio_path).exists(): if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
Path(ref_audio_path).unlink() Path(ref_audio_path).unlink()
db.close() db.close()
@@ -483,10 +517,11 @@ async def create_voice_clone_job(
request: Request, request: Request,
text: str = Form(...), text: str = Form(...),
language: str = Form(default="Auto"), language: str = Form(default="Auto"),
ref_audio: UploadFile = File(...), ref_audio: Optional[UploadFile] = File(default=None),
ref_text: Optional[str] = Form(default=None), ref_text: Optional[str] = Form(default=None),
use_cache: bool = Form(default=True), use_cache: bool = Form(default=True),
x_vector_only_mode: bool = Form(default=False), x_vector_only_mode: bool = Form(default=False),
voice_design_id: Optional[int] = Form(default=None),
max_new_tokens: Optional[int] = Form(default=2048), max_new_tokens: Optional[int] = Form(default=2048),
temperature: Optional[float] = Form(default=0.9), temperature: Optional[float] = Form(default=0.9),
top_k: Optional[int] = Form(default=50), top_k: Optional[int] = Form(default=50),
@@ -498,7 +533,7 @@ async def create_voice_clone_job(
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
from core.security import decrypt_api_key from core.security import decrypt_api_key
from db.crud import get_user_preferences, can_user_use_local_model from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design
user_prefs = get_user_preferences(db, current_user.id) user_prefs = get_user_preferences(db, current_user.id)
preferred_backend = user_prefs.get("default_backend", "aliyun") preferred_backend = user_prefs.get("default_backend", "aliyun")
@@ -519,6 +554,10 @@ async def create_voice_clone_job(
detail="Aliyun API key not configured. Please set your API key in Settings." detail="Aliyun API key not configured. Please set your API key in Settings."
) )
ref_audio_data = None
ref_audio_hash = None
use_voice_design = False
try: try:
validate_text_length(text) validate_text_length(text)
language = validate_language(language) language = validate_language(language)
@@ -531,13 +570,36 @@ async def create_voice_clone_job(
'repetition_penalty': repetition_penalty 'repetition_penalty': repetition_penalty
}) })
ref_audio_data = await ref_audio.read()
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
cache_manager = await VoiceCacheManager.get_instance() cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
if voice_design_id:
design = get_voice_design(db, voice_design_id, current_user.id)
if not design:
raise ValueError("Voice design not found")
if design.backend_type != backend_type:
raise ValueError(f"Voice design backend ({design.backend_type}) doesn't match request backend ({backend_type})")
if not design.voice_cache_id:
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
use_voice_design = True
ref_audio_hash = f"voice_design_{voice_design_id}"
if not ref_text:
ref_text = design.ref_text
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
else:
if not ref_audio:
raise ValueError("Either ref_audio or voice_design_id must be provided")
ref_audio_data = await ref_audio.read()
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@@ -555,6 +617,7 @@ async def create_voice_clone_job(
"ref_audio_hash": ref_audio_hash, "ref_audio_hash": ref_audio_hash,
"use_cache": use_cache, "use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode, "x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params **params
} }
) )
@@ -562,9 +625,13 @@ async def create_voice_clone_job(
db.commit() db.commit()
db.refresh(job) db.refresh(job)
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: if use_voice_design:
tmp_file.write(ref_audio_data) design = get_voice_design(db, voice_design_id, current_user.id)
tmp_audio_path = tmp_file.name tmp_audio_path = design.ref_audio_path
else:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(ref_audio_data)
tmp_audio_path = tmp_file.name
request_data = { request_data = {
"text": text, "text": text,
@@ -572,6 +639,7 @@ async def create_voice_clone_job(
"ref_text": ref_text or "", "ref_text": ref_text or "",
"use_cache": use_cache, "use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode, "x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params **params
} }
@@ -582,7 +650,8 @@ async def create_voice_clone_job(
request_data, request_data,
tmp_audio_path, tmp_audio_path,
backend_type, backend_type,
str(settings.DATABASE_URL) str(settings.DATABASE_URL),
use_voice_design
) )
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db) existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)

View File

@@ -1,13 +1,14 @@
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import Optional from typing import Optional
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from pathlib import Path
from core.database import get_db from core.database import get_db
from api.auth import get_current_user from api.auth import get_current_user
from db.models import User from db.models import User, Job, JobStatus
from db import crud from db import crud
from schemas.voice_design import ( from schemas.voice_design import (
VoiceDesignCreate, VoiceDesignCreate,
@@ -95,3 +96,116 @@ async def delete_voice_design(
success = crud.delete_voice_design(db, design_id, current_user.id) success = crud.delete_voice_design(db, design_id, current_user.id)
if not success: if not success:
raise HTTPException(status_code=404, detail="Voice design not found") raise HTTPException(status_code=404, detail="Voice design not found")
@router.post("/{design_id}/prepare-clone")
@limiter.limit("10/minute")
async def prepare_voice_clone_prompt(
request: Request,
design_id: int,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from core.tts_service import TTSServiceFactory
from core.cache_manager import VoiceCacheManager
from utils.audio import process_ref_audio, extract_audio_features
from core.config import settings
from db.crud import can_user_use_local_model
from datetime import datetime
design = crud.get_voice_design(db, design_id, current_user.id)
if not design:
raise HTTPException(status_code=404, detail="Voice design not found")
if design.backend_type != "local":
raise HTTPException(
status_code=400,
detail="Voice clone prompt preparation is only supported for local backend"
)
if not can_user_use_local_model(current_user):
raise HTTPException(
status_code=403,
detail="Local model access required"
)
if design.voice_cache_id:
return {
"message": "Voice clone prompt already exists",
"cache_id": design.voice_cache_id
}
try:
backend = await TTSServiceFactory.get_backend("local")
ref_text = design.preview_text or design.instruct[:100]
logger.info(f"Generating reference audio for voice design {design_id}")
ref_audio_bytes, sample_rate = await backend.generate_voice_design({
"text": ref_text,
"language": "Auto",
"instruct": design.instruct,
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.5,
"repetition_penalty": 1.05
})
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
ref_filename = f"voice_design_{design_id}_{timestamp}.wav"
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
with open(ref_audio_path, 'wb') as f:
f.write(ref_audio_bytes)
logger.info(f"Extracting voice clone prompt from reference audio")
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
from core.model_manager import ModelManager
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=ref_text,
x_vector_only_mode=True
)
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
features = extract_audio_features(ref_audio_array, ref_sr)
metadata = {
'duration': features['duration'],
'sample_rate': features['sample_rate'],
'ref_text': ref_text,
'x_vector_only_mode': True,
'voice_design_id': design_id,
'instruct': design.instruct
}
cache_id = await cache_manager.set_cache(
current_user.id, ref_audio_hash, x_vector, metadata, db
)
design.voice_cache_id = cache_id
design.ref_audio_path = str(ref_audio_path)
design.ref_text = ref_text
db.commit()
logger.info(f"Voice clone prompt prepared for design {design_id}, cache_id={cache_id}")
return {
"message": "Voice clone prompt prepared successfully",
"cache_id": cache_id,
"ref_text": ref_text
}
except Exception as e:
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -72,6 +72,36 @@ class VoiceCacheManager:
logger.error(f"Cache retrieval error: {e}", exc_info=True) logger.error(f"Cache retrieval error: {e}", exc_info=True)
return None return None
async def get_cache_by_id(self, cache_id: int, db: Session) -> Optional[Dict[str, Any]]:
try:
cache_entry = db.query(VoiceCache).filter(VoiceCache.id == cache_id).first()
if not cache_entry:
logger.debug(f"Cache not found: id={cache_id}")
return None
cache_file = Path(cache_entry.cache_path)
if not cache_file.exists():
logger.warning(f"Cache file missing: {cache_file}")
return None
with open(cache_file, 'rb') as f:
cache_data = pickle.load(f)
cache_entry.last_accessed = datetime.utcnow()
cache_entry.access_count += 1
db.commit()
logger.info(f"Cache loaded by id: cache_id={cache_id}, access_count={cache_entry.access_count}")
return {
'cache_id': cache_entry.id,
'data': cache_data,
'metadata': cache_entry.meta_data
}
except Exception as e:
logger.error(f"Cache retrieval by id error: {e}", exc_info=True)
return None
async def set_cache( async def set_cache(
self, self,
user_id: int, user_id: int,

View File

@@ -83,19 +83,23 @@ class LocalTTSBackend(TTSBackend):
audio_data = np.array(audio_data) audio_data = np.array(audio_data)
return self._numpy_to_bytes(audio_data), 24000 return self._numpy_to_bytes(audio_data), 24000
async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes) -> Tuple[bytes, int]: async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes = None, x_vector=None) -> Tuple[bytes, int]:
from utils.audio import process_ref_audio from utils.audio import process_ref_audio
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
await self.model_manager.load_model("base") await self.model_manager.load_model("base")
_, tts = await self.model_manager.get_current_model() _, tts = await self.model_manager.get_current_model()
x_vector = tts.create_voice_clone_prompt( if x_vector is None:
ref_audio=(ref_audio_array, ref_sr), if ref_audio_bytes is None:
ref_text=params.get('ref_text', ''), raise ValueError("Either ref_audio_bytes or x_vector must be provided")
x_vector_only_mode=False
) ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=params.get('ref_text', ''),
x_vector_only_mode=False
)
wavs, sample_rate = tts.generate_voice_clone( wavs, sample_rate = tts.generate_voice_clone(
text=params['text'], text=params['text'],

View File

@@ -90,6 +90,9 @@ class VoiceDesign(Base):
aliyun_voice_id = Column(String(255), nullable=True) aliyun_voice_id = Column(String(255), nullable=True)
meta_data = Column(JSON, nullable=True) meta_data = Column(JSON, nullable=True)
preview_text = Column(Text, nullable=True) preview_text = Column(Text, nullable=True)
ref_audio_path = Column(String(500), nullable=True)
ref_text = Column(Text, nullable=True)
voice_cache_id = Column(Integer, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
last_used = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) last_used = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
use_count = Column(Integer, default=0, nullable=False) use_count = Column(Integer, default=0, nullable=False)