feat: Add voice design support for voice cloning and enhance cache management
This commit is contained in:
@@ -164,7 +164,8 @@ async def process_voice_clone_job(
|
|||||||
request_data: dict,
|
request_data: dict,
|
||||||
ref_audio_path: str,
|
ref_audio_path: str,
|
||||||
backend_type: str,
|
backend_type: str,
|
||||||
db_url: str
|
db_url: str,
|
||||||
|
use_voice_design: bool = False
|
||||||
):
|
):
|
||||||
from core.database import SessionLocal
|
from core.database import SessionLocal
|
||||||
from core.tts_service import TTSServiceFactory
|
from core.tts_service import TTSServiceFactory
|
||||||
@@ -194,13 +195,36 @@ async def process_voice_clone_job(
|
|||||||
ref_audio_data = f.read()
|
ref_audio_data = f.read()
|
||||||
|
|
||||||
cache_manager = await VoiceCacheManager.get_instance()
|
cache_manager = await VoiceCacheManager.get_instance()
|
||||||
|
|
||||||
|
voice_design_id = request_data.get('voice_design_id')
|
||||||
|
if voice_design_id:
|
||||||
|
from db.crud import get_voice_design
|
||||||
|
design = get_voice_design(db, voice_design_id, user_id)
|
||||||
|
if not design or not design.voice_cache_id:
|
||||||
|
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
|
||||||
|
|
||||||
|
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||||
|
if not cached:
|
||||||
|
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
|
||||||
|
|
||||||
|
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||||
|
cache_metrics.record_hit(user_id)
|
||||||
|
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
with open(ref_audio_path, 'rb') as f:
|
||||||
|
ref_audio_data = f.read()
|
||||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||||
|
|
||||||
if request_data.get('x_vector_only_mode', False) and backend_type == "local":
|
if request_data.get('x_vector_only_mode', False) and backend_type == "local":
|
||||||
x_vector = None
|
x_vector = None
|
||||||
cache_id = None
|
cache_id = None
|
||||||
|
|
||||||
if request_data.get('use_cache', True):
|
if voice_design_id:
|
||||||
|
design = get_voice_design(db, voice_design_id, user_id)
|
||||||
|
x_vector = cached['data']
|
||||||
|
cache_id = design.voice_cache_id
|
||||||
|
elif request_data.get('use_cache', True):
|
||||||
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
||||||
if cached:
|
if cached:
|
||||||
x_vector = cached['data']
|
x_vector = cached['data']
|
||||||
@@ -248,6 +272,16 @@ async def process_voice_clone_job(
|
|||||||
|
|
||||||
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
|
||||||
|
|
||||||
|
if voice_design_id and backend_type == "local":
|
||||||
|
from db.crud import get_voice_design
|
||||||
|
design = get_voice_design(db, voice_design_id, user_id)
|
||||||
|
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||||
|
x_vector = cached['data']
|
||||||
|
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
|
||||||
|
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
|
||||||
|
else:
|
||||||
|
with open(ref_audio_path, 'rb') as f:
|
||||||
|
ref_audio_data = f.read()
|
||||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
|
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
|
||||||
|
|
||||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||||
@@ -274,7 +308,7 @@ async def process_voice_clone_job(
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if Path(ref_audio_path).exists():
|
if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
|
||||||
Path(ref_audio_path).unlink()
|
Path(ref_audio_path).unlink()
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -483,10 +517,11 @@ async def create_voice_clone_job(
|
|||||||
request: Request,
|
request: Request,
|
||||||
text: str = Form(...),
|
text: str = Form(...),
|
||||||
language: str = Form(default="Auto"),
|
language: str = Form(default="Auto"),
|
||||||
ref_audio: UploadFile = File(...),
|
ref_audio: Optional[UploadFile] = File(default=None),
|
||||||
ref_text: Optional[str] = Form(default=None),
|
ref_text: Optional[str] = Form(default=None),
|
||||||
use_cache: bool = Form(default=True),
|
use_cache: bool = Form(default=True),
|
||||||
x_vector_only_mode: bool = Form(default=False),
|
x_vector_only_mode: bool = Form(default=False),
|
||||||
|
voice_design_id: Optional[int] = Form(default=None),
|
||||||
max_new_tokens: Optional[int] = Form(default=2048),
|
max_new_tokens: Optional[int] = Form(default=2048),
|
||||||
temperature: Optional[float] = Form(default=0.9),
|
temperature: Optional[float] = Form(default=0.9),
|
||||||
top_k: Optional[int] = Form(default=50),
|
top_k: Optional[int] = Form(default=50),
|
||||||
@@ -498,7 +533,7 @@ async def create_voice_clone_job(
|
|||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
from core.security import decrypt_api_key
|
from core.security import decrypt_api_key
|
||||||
from db.crud import get_user_preferences, can_user_use_local_model
|
from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design
|
||||||
|
|
||||||
user_prefs = get_user_preferences(db, current_user.id)
|
user_prefs = get_user_preferences(db, current_user.id)
|
||||||
preferred_backend = user_prefs.get("default_backend", "aliyun")
|
preferred_backend = user_prefs.get("default_backend", "aliyun")
|
||||||
@@ -519,6 +554,10 @@ async def create_voice_clone_job(
|
|||||||
detail="Aliyun API key not configured. Please set your API key in Settings."
|
detail="Aliyun API key not configured. Please set your API key in Settings."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ref_audio_data = None
|
||||||
|
ref_audio_hash = None
|
||||||
|
use_voice_design = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
validate_text_length(text)
|
validate_text_length(text)
|
||||||
language = validate_language(language)
|
language = validate_language(language)
|
||||||
@@ -531,12 +570,35 @@ async def create_voice_clone_job(
|
|||||||
'repetition_penalty': repetition_penalty
|
'repetition_penalty': repetition_penalty
|
||||||
})
|
})
|
||||||
|
|
||||||
|
cache_manager = await VoiceCacheManager.get_instance()
|
||||||
|
|
||||||
|
if voice_design_id:
|
||||||
|
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||||
|
if not design:
|
||||||
|
raise ValueError("Voice design not found")
|
||||||
|
|
||||||
|
if design.backend_type != backend_type:
|
||||||
|
raise ValueError(f"Voice design backend ({design.backend_type}) doesn't match request backend ({backend_type})")
|
||||||
|
|
||||||
|
if not design.voice_cache_id:
|
||||||
|
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
|
||||||
|
|
||||||
|
use_voice_design = True
|
||||||
|
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||||
|
if not ref_text:
|
||||||
|
ref_text = design.ref_text
|
||||||
|
|
||||||
|
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not ref_audio:
|
||||||
|
raise ValueError("Either ref_audio or voice_design_id must be provided")
|
||||||
|
|
||||||
ref_audio_data = await ref_audio.read()
|
ref_audio_data = await ref_audio.read()
|
||||||
|
|
||||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||||
|
|
||||||
cache_manager = await VoiceCacheManager.get_instance()
|
|
||||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -555,6 +617,7 @@ async def create_voice_clone_job(
|
|||||||
"ref_audio_hash": ref_audio_hash,
|
"ref_audio_hash": ref_audio_hash,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"x_vector_only_mode": x_vector_only_mode,
|
"x_vector_only_mode": x_vector_only_mode,
|
||||||
|
"voice_design_id": voice_design_id,
|
||||||
**params
|
**params
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -562,6 +625,10 @@ async def create_voice_clone_job(
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(job)
|
db.refresh(job)
|
||||||
|
|
||||||
|
if use_voice_design:
|
||||||
|
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||||
|
tmp_audio_path = design.ref_audio_path
|
||||||
|
else:
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||||
tmp_file.write(ref_audio_data)
|
tmp_file.write(ref_audio_data)
|
||||||
tmp_audio_path = tmp_file.name
|
tmp_audio_path = tmp_file.name
|
||||||
@@ -572,6 +639,7 @@ async def create_voice_clone_job(
|
|||||||
"ref_text": ref_text or "",
|
"ref_text": ref_text or "",
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"x_vector_only_mode": x_vector_only_mode,
|
"x_vector_only_mode": x_vector_only_mode,
|
||||||
|
"voice_design_id": voice_design_id,
|
||||||
**params
|
**params
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -582,7 +650,8 @@ async def create_voice_clone_job(
|
|||||||
request_data,
|
request_data,
|
||||||
tmp_audio_path,
|
tmp_audio_path,
|
||||||
backend_type,
|
backend_type,
|
||||||
str(settings.DATABASE_URL)
|
str(settings.DATABASE_URL),
|
||||||
|
use_voice_design
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
|
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from core.database import get_db
|
from core.database import get_db
|
||||||
from api.auth import get_current_user
|
from api.auth import get_current_user
|
||||||
from db.models import User
|
from db.models import User, Job, JobStatus
|
||||||
from db import crud
|
from db import crud
|
||||||
from schemas.voice_design import (
|
from schemas.voice_design import (
|
||||||
VoiceDesignCreate,
|
VoiceDesignCreate,
|
||||||
@@ -95,3 +96,116 @@ async def delete_voice_design(
|
|||||||
success = crud.delete_voice_design(db, design_id, current_user.id)
|
success = crud.delete_voice_design(db, design_id, current_user.id)
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(status_code=404, detail="Voice design not found")
|
raise HTTPException(status_code=404, detail="Voice design not found")
|
||||||
|
|
||||||
|
@router.post("/{design_id}/prepare-clone")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def prepare_voice_clone_prompt(
|
||||||
|
request: Request,
|
||||||
|
design_id: int,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
from core.tts_service import TTSServiceFactory
|
||||||
|
from core.cache_manager import VoiceCacheManager
|
||||||
|
from utils.audio import process_ref_audio, extract_audio_features
|
||||||
|
from core.config import settings
|
||||||
|
from db.crud import can_user_use_local_model
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
design = crud.get_voice_design(db, design_id, current_user.id)
|
||||||
|
if not design:
|
||||||
|
raise HTTPException(status_code=404, detail="Voice design not found")
|
||||||
|
|
||||||
|
if design.backend_type != "local":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Voice clone prompt preparation is only supported for local backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not can_user_use_local_model(current_user):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Local model access required"
|
||||||
|
)
|
||||||
|
|
||||||
|
if design.voice_cache_id:
|
||||||
|
return {
|
||||||
|
"message": "Voice clone prompt already exists",
|
||||||
|
"cache_id": design.voice_cache_id
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
backend = await TTSServiceFactory.get_backend("local")
|
||||||
|
|
||||||
|
ref_text = design.preview_text or design.instruct[:100]
|
||||||
|
|
||||||
|
logger.info(f"Generating reference audio for voice design {design_id}")
|
||||||
|
ref_audio_bytes, sample_rate = await backend.generate_voice_design({
|
||||||
|
"text": ref_text,
|
||||||
|
"language": "Auto",
|
||||||
|
"instruct": design.instruct,
|
||||||
|
"max_new_tokens": 2048,
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_p": 0.5,
|
||||||
|
"repetition_penalty": 1.05
|
||||||
|
})
|
||||||
|
|
||||||
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||||
|
ref_filename = f"voice_design_{design_id}_{timestamp}.wav"
|
||||||
|
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
|
||||||
|
|
||||||
|
with open(ref_audio_path, 'wb') as f:
|
||||||
|
f.write(ref_audio_bytes)
|
||||||
|
|
||||||
|
logger.info(f"Extracting voice clone prompt from reference audio")
|
||||||
|
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||||
|
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
model_manager = await ModelManager.get_instance()
|
||||||
|
await model_manager.load_model("base")
|
||||||
|
_, tts = await model_manager.get_current_model()
|
||||||
|
|
||||||
|
if tts is None:
|
||||||
|
raise RuntimeError("Failed to load base model")
|
||||||
|
|
||||||
|
x_vector = tts.create_voice_clone_prompt(
|
||||||
|
ref_audio=(ref_audio_array, ref_sr),
|
||||||
|
ref_text=ref_text,
|
||||||
|
x_vector_only_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_manager = await VoiceCacheManager.get_instance()
|
||||||
|
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
|
||||||
|
|
||||||
|
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||||
|
metadata = {
|
||||||
|
'duration': features['duration'],
|
||||||
|
'sample_rate': features['sample_rate'],
|
||||||
|
'ref_text': ref_text,
|
||||||
|
'x_vector_only_mode': True,
|
||||||
|
'voice_design_id': design_id,
|
||||||
|
'instruct': design.instruct
|
||||||
|
}
|
||||||
|
|
||||||
|
cache_id = await cache_manager.set_cache(
|
||||||
|
current_user.id, ref_audio_hash, x_vector, metadata, db
|
||||||
|
)
|
||||||
|
|
||||||
|
design.voice_cache_id = cache_id
|
||||||
|
design.ref_audio_path = str(ref_audio_path)
|
||||||
|
design.ref_text = ref_text
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Voice clone prompt prepared for design {design_id}, cache_id={cache_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "Voice clone prompt prepared successfully",
|
||||||
|
"cache_id": cache_id,
|
||||||
|
"ref_text": ref_text
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|||||||
@@ -72,6 +72,36 @@ class VoiceCacheManager:
|
|||||||
logger.error(f"Cache retrieval error: {e}", exc_info=True)
|
logger.error(f"Cache retrieval error: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_cache_by_id(self, cache_id: int, db: Session) -> Optional[Dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
cache_entry = db.query(VoiceCache).filter(VoiceCache.id == cache_id).first()
|
||||||
|
if not cache_entry:
|
||||||
|
logger.debug(f"Cache not found: id={cache_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
cache_file = Path(cache_entry.cache_path)
|
||||||
|
if not cache_file.exists():
|
||||||
|
logger.warning(f"Cache file missing: {cache_file}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(cache_file, 'rb') as f:
|
||||||
|
cache_data = pickle.load(f)
|
||||||
|
|
||||||
|
cache_entry.last_accessed = datetime.utcnow()
|
||||||
|
cache_entry.access_count += 1
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Cache loaded by id: cache_id={cache_id}, access_count={cache_entry.access_count}")
|
||||||
|
return {
|
||||||
|
'cache_id': cache_entry.id,
|
||||||
|
'data': cache_data,
|
||||||
|
'metadata': cache_entry.meta_data
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache retrieval by id error: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
async def set_cache(
|
async def set_cache(
|
||||||
self,
|
self,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
|
|||||||
@@ -83,14 +83,18 @@ class LocalTTSBackend(TTSBackend):
|
|||||||
audio_data = np.array(audio_data)
|
audio_data = np.array(audio_data)
|
||||||
return self._numpy_to_bytes(audio_data), 24000
|
return self._numpy_to_bytes(audio_data), 24000
|
||||||
|
|
||||||
async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes) -> Tuple[bytes, int]:
|
async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes = None, x_vector=None) -> Tuple[bytes, int]:
|
||||||
from utils.audio import process_ref_audio
|
from utils.audio import process_ref_audio
|
||||||
|
|
||||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
|
||||||
|
|
||||||
await self.model_manager.load_model("base")
|
await self.model_manager.load_model("base")
|
||||||
_, tts = await self.model_manager.get_current_model()
|
_, tts = await self.model_manager.get_current_model()
|
||||||
|
|
||||||
|
if x_vector is None:
|
||||||
|
if ref_audio_bytes is None:
|
||||||
|
raise ValueError("Either ref_audio_bytes or x_vector must be provided")
|
||||||
|
|
||||||
|
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||||
|
|
||||||
x_vector = tts.create_voice_clone_prompt(
|
x_vector = tts.create_voice_clone_prompt(
|
||||||
ref_audio=(ref_audio_array, ref_sr),
|
ref_audio=(ref_audio_array, ref_sr),
|
||||||
ref_text=params.get('ref_text', ''),
|
ref_text=params.get('ref_text', ''),
|
||||||
|
|||||||
@@ -90,6 +90,9 @@ class VoiceDesign(Base):
|
|||||||
aliyun_voice_id = Column(String(255), nullable=True)
|
aliyun_voice_id = Column(String(255), nullable=True)
|
||||||
meta_data = Column(JSON, nullable=True)
|
meta_data = Column(JSON, nullable=True)
|
||||||
preview_text = Column(Text, nullable=True)
|
preview_text = Column(Text, nullable=True)
|
||||||
|
ref_audio_path = Column(String(500), nullable=True)
|
||||||
|
ref_text = Column(Text, nullable=True)
|
||||||
|
voice_cache_id = Column(Integer, nullable=True)
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||||
last_used = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
last_used = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||||
use_count = Column(Integer, default=0, nullable=False)
|
use_count = Column(Integer, default=0, nullable=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user