feat: Add voice design support for voice cloning and enhance cache management
This commit is contained in:
@@ -164,7 +164,8 @@ async def process_voice_clone_job(
|
||||
request_data: dict,
|
||||
ref_audio_path: str,
|
||||
backend_type: str,
|
||||
db_url: str
|
||||
db_url: str,
|
||||
use_voice_design: bool = False
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
@@ -194,13 +195,36 @@ async def process_voice_clone_job(
|
||||
ref_audio_data = f.read()
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
|
||||
voice_design_id = request_data.get('voice_design_id')
|
||||
if voice_design_id:
|
||||
from db.crud import get_voice_design
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
if not design or not design.voice_cache_id:
|
||||
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
|
||||
|
||||
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||
if not cached:
|
||||
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
|
||||
|
||||
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||
cache_metrics.record_hit(user_id)
|
||||
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
|
||||
|
||||
else:
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
if request_data.get('x_vector_only_mode', False) and backend_type == "local":
|
||||
x_vector = None
|
||||
cache_id = None
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
if voice_design_id:
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
x_vector = cached['data']
|
||||
cache_id = design.voice_cache_id
|
||||
elif request_data.get('use_cache', True):
|
||||
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
||||
if cached:
|
||||
x_vector = cached['data']
|
||||
@@ -248,6 +272,16 @@ async def process_voice_clone_job(
|
||||
|
||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
||||
|
||||
if voice_design_id and backend_type == "local":
|
||||
from db.crud import get_voice_design
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||
x_vector = cached['data']
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
|
||||
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
|
||||
else:
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -274,7 +308,7 @@ async def process_voice_clone_job(
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
if Path(ref_audio_path).exists():
|
||||
if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
|
||||
Path(ref_audio_path).unlink()
|
||||
db.close()
|
||||
|
||||
@@ -483,10 +517,11 @@ async def create_voice_clone_job(
|
||||
request: Request,
|
||||
text: str = Form(...),
|
||||
language: str = Form(default="Auto"),
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_audio: Optional[UploadFile] = File(default=None),
|
||||
ref_text: Optional[str] = Form(default=None),
|
||||
use_cache: bool = Form(default=True),
|
||||
x_vector_only_mode: bool = Form(default=False),
|
||||
voice_design_id: Optional[int] = Form(default=None),
|
||||
max_new_tokens: Optional[int] = Form(default=2048),
|
||||
temperature: Optional[float] = Form(default=0.9),
|
||||
top_k: Optional[int] = Form(default=50),
|
||||
@@ -498,7 +533,7 @@ async def create_voice_clone_job(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.security import decrypt_api_key
|
||||
from db.crud import get_user_preferences, can_user_use_local_model
|
||||
from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design
|
||||
|
||||
user_prefs = get_user_preferences(db, current_user.id)
|
||||
preferred_backend = user_prefs.get("default_backend", "aliyun")
|
||||
@@ -519,6 +554,10 @@ async def create_voice_clone_job(
|
||||
detail="Aliyun API key not configured. Please set your API key in Settings."
|
||||
)
|
||||
|
||||
ref_audio_data = None
|
||||
ref_audio_hash = None
|
||||
use_voice_design = False
|
||||
|
||||
try:
|
||||
validate_text_length(text)
|
||||
language = validate_language(language)
|
||||
@@ -531,12 +570,35 @@ async def create_voice_clone_job(
|
||||
'repetition_penalty': repetition_penalty
|
||||
})
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
|
||||
if voice_design_id:
|
||||
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||
if not design:
|
||||
raise ValueError("Voice design not found")
|
||||
|
||||
if design.backend_type != backend_type:
|
||||
raise ValueError(f"Voice design backend ({design.backend_type}) doesn't match request backend ({backend_type})")
|
||||
|
||||
if not design.voice_cache_id:
|
||||
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
|
||||
|
||||
use_voice_design = True
|
||||
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||
if not ref_text:
|
||||
ref_text = design.ref_text
|
||||
|
||||
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
|
||||
|
||||
else:
|
||||
if not ref_audio:
|
||||
raise ValueError("Either ref_audio or voice_design_id must be provided")
|
||||
|
||||
ref_audio_data = await ref_audio.read()
|
||||
|
||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
except ValueError as e:
|
||||
@@ -555,6 +617,7 @@ async def create_voice_clone_job(
|
||||
"ref_audio_hash": ref_audio_hash,
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
"voice_design_id": voice_design_id,
|
||||
**params
|
||||
}
|
||||
)
|
||||
@@ -562,6 +625,10 @@ async def create_voice_clone_job(
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
if use_voice_design:
|
||||
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||
tmp_audio_path = design.ref_audio_path
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
tmp_file.write(ref_audio_data)
|
||||
tmp_audio_path = tmp_file.name
|
||||
@@ -572,6 +639,7 @@ async def create_voice_clone_job(
|
||||
"ref_text": ref_text or "",
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
"voice_design_id": voice_design_id,
|
||||
**params
|
||||
}
|
||||
|
||||
@@ -582,7 +650,8 @@ async def create_voice_clone_job(
|
||||
request_data,
|
||||
tmp_audio_path,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL)
|
||||
str(settings.DATABASE_URL),
|
||||
use_voice_design
|
||||
)
|
||||
|
||||
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from pathlib import Path
|
||||
|
||||
from core.database import get_db
|
||||
from api.auth import get_current_user
|
||||
from db.models import User
|
||||
from db.models import User, Job, JobStatus
|
||||
from db import crud
|
||||
from schemas.voice_design import (
|
||||
VoiceDesignCreate,
|
||||
@@ -95,3 +96,116 @@ async def delete_voice_design(
|
||||
success = crud.delete_voice_design(db, design_id, current_user.id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Voice design not found")
|
||||
|
||||
@router.post("/{design_id}/prepare-clone")
|
||||
@limiter.limit("10/minute")
|
||||
async def prepare_voice_clone_prompt(
|
||||
request: Request,
|
||||
design_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from utils.audio import process_ref_audio, extract_audio_features
|
||||
from core.config import settings
|
||||
from db.crud import can_user_use_local_model
|
||||
from datetime import datetime
|
||||
|
||||
design = crud.get_voice_design(db, design_id, current_user.id)
|
||||
if not design:
|
||||
raise HTTPException(status_code=404, detail="Voice design not found")
|
||||
|
||||
if design.backend_type != "local":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Voice clone prompt preparation is only supported for local backend"
|
||||
)
|
||||
|
||||
if not can_user_use_local_model(current_user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Local model access required"
|
||||
)
|
||||
|
||||
if design.voice_cache_id:
|
||||
return {
|
||||
"message": "Voice clone prompt already exists",
|
||||
"cache_id": design.voice_cache_id
|
||||
}
|
||||
|
||||
try:
|
||||
backend = await TTSServiceFactory.get_backend("local")
|
||||
|
||||
ref_text = design.preview_text or design.instruct[:100]
|
||||
|
||||
logger.info(f"Generating reference audio for voice design {design_id}")
|
||||
ref_audio_bytes, sample_rate = await backend.generate_voice_design({
|
||||
"text": ref_text,
|
||||
"language": "Auto",
|
||||
"instruct": design.instruct,
|
||||
"max_new_tokens": 2048,
|
||||
"temperature": 0.3,
|
||||
"top_k": 10,
|
||||
"top_p": 0.5,
|
||||
"repetition_penalty": 1.05
|
||||
})
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
ref_filename = f"voice_design_{design_id}_{timestamp}.wav"
|
||||
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
|
||||
|
||||
with open(ref_audio_path, 'wb') as f:
|
||||
f.write(ref_audio_bytes)
|
||||
|
||||
logger.info(f"Extracting voice clone prompt from reference audio")
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=ref_text,
|
||||
x_vector_only_mode=True
|
||||
)
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
|
||||
|
||||
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||
metadata = {
|
||||
'duration': features['duration'],
|
||||
'sample_rate': features['sample_rate'],
|
||||
'ref_text': ref_text,
|
||||
'x_vector_only_mode': True,
|
||||
'voice_design_id': design_id,
|
||||
'instruct': design.instruct
|
||||
}
|
||||
|
||||
cache_id = await cache_manager.set_cache(
|
||||
current_user.id, ref_audio_hash, x_vector, metadata, db
|
||||
)
|
||||
|
||||
design.voice_cache_id = cache_id
|
||||
design.ref_audio_path = str(ref_audio_path)
|
||||
design.ref_text = ref_text
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Voice clone prompt prepared for design {design_id}, cache_id={cache_id}")
|
||||
|
||||
return {
|
||||
"message": "Voice clone prompt prepared successfully",
|
||||
"cache_id": cache_id,
|
||||
"ref_text": ref_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -72,6 +72,36 @@ class VoiceCacheManager:
|
||||
logger.error(f"Cache retrieval error: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def get_cache_by_id(self, cache_id: int, db: Session) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
cache_entry = db.query(VoiceCache).filter(VoiceCache.id == cache_id).first()
|
||||
if not cache_entry:
|
||||
logger.debug(f"Cache not found: id={cache_id}")
|
||||
return None
|
||||
|
||||
cache_file = Path(cache_entry.cache_path)
|
||||
if not cache_file.exists():
|
||||
logger.warning(f"Cache file missing: {cache_file}")
|
||||
return None
|
||||
|
||||
with open(cache_file, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
|
||||
cache_entry.last_accessed = datetime.utcnow()
|
||||
cache_entry.access_count += 1
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Cache loaded by id: cache_id={cache_id}, access_count={cache_entry.access_count}")
|
||||
return {
|
||||
'cache_id': cache_entry.id,
|
||||
'data': cache_data,
|
||||
'metadata': cache_entry.meta_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache retrieval by id error: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def set_cache(
|
||||
self,
|
||||
user_id: int,
|
||||
|
||||
@@ -83,14 +83,18 @@ class LocalTTSBackend(TTSBackend):
|
||||
audio_data = np.array(audio_data)
|
||||
return self._numpy_to_bytes(audio_data), 24000
|
||||
|
||||
async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes) -> Tuple[bytes, int]:
|
||||
async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes = None, x_vector=None) -> Tuple[bytes, int]:
|
||||
from utils.audio import process_ref_audio
|
||||
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
|
||||
await self.model_manager.load_model("base")
|
||||
_, tts = await self.model_manager.get_current_model()
|
||||
|
||||
if x_vector is None:
|
||||
if ref_audio_bytes is None:
|
||||
raise ValueError("Either ref_audio_bytes or x_vector must be provided")
|
||||
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=params.get('ref_text', ''),
|
||||
|
||||
@@ -90,6 +90,9 @@ class VoiceDesign(Base):
|
||||
aliyun_voice_id = Column(String(255), nullable=True)
|
||||
meta_data = Column(JSON, nullable=True)
|
||||
preview_text = Column(Text, nullable=True)
|
||||
ref_audio_path = Column(String(500), nullable=True)
|
||||
ref_text = Column(Text, nullable=True)
|
||||
voice_cache_id = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
last_used = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
use_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
Reference in New Issue
Block a user