feat: Add voice design support for voice cloning and enhance cache management
This commit is contained in:
@@ -164,7 +164,8 @@ async def process_voice_clone_job(
|
||||
request_data: dict,
|
||||
ref_audio_path: str,
|
||||
backend_type: str,
|
||||
db_url: str
|
||||
db_url: str,
|
||||
use_voice_design: bool = False
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
@@ -194,13 +195,36 @@ async def process_voice_clone_job(
|
||||
ref_audio_data = f.read()
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
voice_design_id = request_data.get('voice_design_id')
|
||||
if voice_design_id:
|
||||
from db.crud import get_voice_design
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
if not design or not design.voice_cache_id:
|
||||
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
|
||||
|
||||
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||
if not cached:
|
||||
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
|
||||
|
||||
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||
cache_metrics.record_hit(user_id)
|
||||
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
|
||||
|
||||
else:
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
if request_data.get('x_vector_only_mode', False) and backend_type == "local":
|
||||
x_vector = None
|
||||
cache_id = None
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
if voice_design_id:
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
x_vector = cached['data']
|
||||
cache_id = design.voice_cache_id
|
||||
elif request_data.get('use_cache', True):
|
||||
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
||||
if cached:
|
||||
x_vector = cached['data']
|
||||
@@ -248,7 +272,17 @@ async def process_voice_clone_job(
|
||||
|
||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
||||
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
|
||||
if voice_design_id and backend_type == "local":
|
||||
from db.crud import get_voice_design
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||
x_vector = cached['data']
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
|
||||
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
|
||||
else:
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
@@ -274,7 +308,7 @@ async def process_voice_clone_job(
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
if Path(ref_audio_path).exists():
|
||||
if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
|
||||
Path(ref_audio_path).unlink()
|
||||
db.close()
|
||||
|
||||
@@ -483,10 +517,11 @@ async def create_voice_clone_job(
|
||||
request: Request,
|
||||
text: str = Form(...),
|
||||
language: str = Form(default="Auto"),
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_audio: Optional[UploadFile] = File(default=None),
|
||||
ref_text: Optional[str] = Form(default=None),
|
||||
use_cache: bool = Form(default=True),
|
||||
x_vector_only_mode: bool = Form(default=False),
|
||||
voice_design_id: Optional[int] = Form(default=None),
|
||||
max_new_tokens: Optional[int] = Form(default=2048),
|
||||
temperature: Optional[float] = Form(default=0.9),
|
||||
top_k: Optional[int] = Form(default=50),
|
||||
@@ -498,7 +533,7 @@ async def create_voice_clone_job(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.security import decrypt_api_key
|
||||
from db.crud import get_user_preferences, can_user_use_local_model
|
||||
from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design
|
||||
|
||||
user_prefs = get_user_preferences(db, current_user.id)
|
||||
preferred_backend = user_prefs.get("default_backend", "aliyun")
|
||||
@@ -519,6 +554,10 @@ async def create_voice_clone_job(
|
||||
detail="Aliyun API key not configured. Please set your API key in Settings."
|
||||
)
|
||||
|
||||
ref_audio_data = None
|
||||
ref_audio_hash = None
|
||||
use_voice_design = False
|
||||
|
||||
try:
|
||||
validate_text_length(text)
|
||||
language = validate_language(language)
|
||||
@@ -531,13 +570,36 @@ async def create_voice_clone_job(
|
||||
'repetition_penalty': repetition_penalty
|
||||
})
|
||||
|
||||
ref_audio_data = await ref_audio.read()
|
||||
|
||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
if voice_design_id:
|
||||
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||
if not design:
|
||||
raise ValueError("Voice design not found")
|
||||
|
||||
if design.backend_type != backend_type:
|
||||
raise ValueError(f"Voice design backend ({design.backend_type}) doesn't match request backend ({backend_type})")
|
||||
|
||||
if not design.voice_cache_id:
|
||||
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
|
||||
|
||||
use_voice_design = True
|
||||
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||
if not ref_text:
|
||||
ref_text = design.ref_text
|
||||
|
||||
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
|
||||
|
||||
else:
|
||||
if not ref_audio:
|
||||
raise ValueError("Either ref_audio or voice_design_id must be provided")
|
||||
|
||||
ref_audio_data = await ref_audio.read()
|
||||
|
||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -555,6 +617,7 @@ async def create_voice_clone_job(
|
||||
"ref_audio_hash": ref_audio_hash,
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
"voice_design_id": voice_design_id,
|
||||
**params
|
||||
}
|
||||
)
|
||||
@@ -562,9 +625,13 @@ async def create_voice_clone_job(
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
tmp_file.write(ref_audio_data)
|
||||
tmp_audio_path = tmp_file.name
|
||||
if use_voice_design:
|
||||
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||
tmp_audio_path = design.ref_audio_path
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
tmp_file.write(ref_audio_data)
|
||||
tmp_audio_path = tmp_file.name
|
||||
|
||||
request_data = {
|
||||
"text": text,
|
||||
@@ -572,6 +639,7 @@ async def create_voice_clone_job(
|
||||
"ref_text": ref_text or "",
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
"voice_design_id": voice_design_id,
|
||||
**params
|
||||
}
|
||||
|
||||
@@ -582,7 +650,8 @@ async def create_voice_clone_job(
|
||||
request_data,
|
||||
tmp_audio_path,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL)
|
||||
str(settings.DATABASE_URL),
|
||||
use_voice_design
|
||||
)
|
||||
|
||||
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
|
||||
|
||||
Reference in New Issue
Block a user