diff --git a/qwen3-tts-backend/api/tts.py b/qwen3-tts-backend/api/tts.py index 6313ac1..768f0ba 100644 --- a/qwen3-tts-backend/api/tts.py +++ b/qwen3-tts-backend/api/tts.py @@ -164,7 +164,8 @@ async def process_voice_clone_job( request_data: dict, ref_audio_path: str, backend_type: str, - db_url: str + db_url: str, + use_voice_design: bool = False ): from core.database import SessionLocal from core.tts_service import TTSServiceFactory @@ -194,13 +195,36 @@ async def process_voice_clone_job( ref_audio_data = f.read() cache_manager = await VoiceCacheManager.get_instance() - ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data) + + voice_design_id = request_data.get('voice_design_id') + if voice_design_id: + from db.crud import get_voice_design + design = get_voice_design(db, voice_design_id, user_id) + if not design or not design.voice_cache_id: + raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt") + + cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db) + if not cached: + raise RuntimeError(f"Cache {design.voice_cache_id} not found") + + ref_audio_hash = f"voice_design_{voice_design_id}" + cache_metrics.record_hit(user_id) + logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}") + + else: + with open(ref_audio_path, 'rb') as f: + ref_audio_data = f.read() + ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data) if request_data.get('x_vector_only_mode', False) and backend_type == "local": x_vector = None cache_id = None - if request_data.get('use_cache', True): + if voice_design_id: + design = get_voice_design(db, voice_design_id, user_id) + x_vector = cached['data'] + cache_id = design.voice_cache_id + elif request_data.get('use_cache', True): cached = await cache_manager.get_cache(user_id, ref_audio_hash, db) if cached: x_vector = cached['data'] @@ -248,7 +272,17 @@ async def process_voice_clone_job( backend = await TTSServiceFactory.get_backend(backend_type, user_api_key) - audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data) + if voice_design_id and backend_type == "local": + from db.crud import get_voice_design + design = get_voice_design(db, voice_design_id, user_id) + cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db) + x_vector = cached['data'] + audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector) + logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}") + else: + with open(ref_audio_path, 'rb') as f: + ref_audio_data = f.read() + audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data) timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") filename = f"{user_id}_{job_id}_{timestamp}.wav" @@ -274,7 +308,7 @@ async def process_voice_clone_job( db.commit() finally: - if Path(ref_audio_path).exists(): + if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists(): Path(ref_audio_path).unlink() db.close() @@ -483,10 +517,11 @@ async def create_voice_clone_job( request: Request, text: str = Form(...), language: str = Form(default="Auto"), - ref_audio: UploadFile = File(...), + ref_audio: Optional[UploadFile] = File(default=None), ref_text: Optional[str] = Form(default=None), use_cache: bool = Form(default=True), x_vector_only_mode: bool = Form(default=False), + voice_design_id: Optional[int] = Form(default=None), max_new_tokens: Optional[int] = Form(default=2048), temperature: Optional[float] = Form(default=0.9), top_k: Optional[int] = Form(default=50), @@ -498,7 +533,7 @@ async def create_voice_clone_job( db: Session = Depends(get_db) ): from core.security import decrypt_api_key - from db.crud import get_user_preferences, can_user_use_local_model + from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design user_prefs = get_user_preferences(db, current_user.id) preferred_backend = user_prefs.get("default_backend", "aliyun") @@ -519,6 +554,10 @@ async def create_voice_clone_job( detail="Aliyun API key not configured. Please set your API key in Settings." ) + ref_audio_data = None + ref_audio_hash = None + use_voice_design = False + try: validate_text_length(text) language = validate_language(language) @@ -531,13 +570,36 @@ async def create_voice_clone_job( 'repetition_penalty': repetition_penalty }) - ref_audio_data = await ref_audio.read() - - if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB): - raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB") - cache_manager = await VoiceCacheManager.get_instance() - ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data) + + if voice_design_id: + design = get_voice_design(db, voice_design_id, current_user.id) + if not design: + raise ValueError("Voice design not found") + + if design.backend_type != backend_type: + raise ValueError(f"Voice design backend ({design.backend_type}) doesn't match request backend ({backend_type})") + + if not design.voice_cache_id: + raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first") + + use_voice_design = True + ref_audio_hash = f"voice_design_{voice_design_id}" + if not ref_text: + ref_text = design.ref_text + + logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}") + + else: + if not ref_audio: + raise ValueError("Either ref_audio or voice_design_id must be provided") + + ref_audio_data = await ref_audio.read() + + if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB): + raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB") + + ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -555,6 +617,7 @@ async def create_voice_clone_job( "ref_audio_hash": ref_audio_hash, "use_cache": use_cache, "x_vector_only_mode": x_vector_only_mode, + "voice_design_id": voice_design_id, **params } ) @@ -562,9 +625,13 @@ async def create_voice_clone_job( db.commit() db.refresh(job) - with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: - tmp_file.write(ref_audio_data) - tmp_audio_path = tmp_file.name + if use_voice_design: + design = get_voice_design(db, voice_design_id, current_user.id) + tmp_audio_path = design.ref_audio_path + else: + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: + tmp_file.write(ref_audio_data) + tmp_audio_path = tmp_file.name request_data = { "text": text, @@ -572,6 +639,7 @@ async def create_voice_clone_job( "ref_text": ref_text or "", "use_cache": use_cache, "x_vector_only_mode": x_vector_only_mode, + "voice_design_id": voice_design_id, **params } @@ -582,7 +650,8 @@ async def create_voice_clone_job( request_data, tmp_audio_path, backend_type, - str(settings.DATABASE_URL) + str(settings.DATABASE_URL), + use_voice_design ) existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db) diff --git a/qwen3-tts-backend/api/voice_designs.py b/qwen3-tts-backend/api/voice_designs.py index a4c4a08..edd0c84 100644 --- a/qwen3-tts-backend/api/voice_designs.py +++ b/qwen3-tts-backend/api/voice_designs.py @@ -1,13 +1,14 @@ import logging -from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks from sqlalchemy.orm import Session from typing import Optional from slowapi import Limiter from slowapi.util import get_remote_address +from pathlib import Path from core.database import get_db from api.auth import get_current_user -from db.models import User +from db.models import User, Job, JobStatus from db import crud from schemas.voice_design import ( VoiceDesignCreate, @@ -95,3 +96,116 @@ async def delete_voice_design( success = crud.delete_voice_design(db, design_id, current_user.id) if not success: raise HTTPException(status_code=404, detail="Voice design not found") + +@router.post("/{design_id}/prepare-clone") +@limiter.limit("10/minute") +async def prepare_voice_clone_prompt( + request: Request, + design_id: int, + background_tasks: BackgroundTasks, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + from core.tts_service import TTSServiceFactory + from core.cache_manager import VoiceCacheManager + from utils.audio import process_ref_audio, extract_audio_features + from core.config import settings + from db.crud import can_user_use_local_model + from datetime import datetime + + design = crud.get_voice_design(db, design_id, current_user.id) + if not design: + raise HTTPException(status_code=404, detail="Voice design not found") + + if design.backend_type != "local": + raise HTTPException( + status_code=400, + detail="Voice clone prompt preparation is only supported for local backend" + ) + + if not can_user_use_local_model(current_user): + raise HTTPException( + status_code=403, + detail="Local model access required" + ) + + if design.voice_cache_id: + return { + "message": "Voice clone prompt already exists", + "cache_id": design.voice_cache_id + } + + try: + backend = await TTSServiceFactory.get_backend("local") + + ref_text = design.preview_text or design.instruct[:100] + + logger.info(f"Generating reference audio for voice design {design_id}") + ref_audio_bytes, sample_rate = await backend.generate_voice_design({ + "text": ref_text, + "language": "Auto", + "instruct": design.instruct, + "max_new_tokens": 2048, + "temperature": 0.3, + "top_k": 10, + "top_p": 0.5, + "repetition_penalty": 1.05 + }) + + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + ref_filename = f"voice_design_{design_id}_{timestamp}.wav" + ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename + + with open(ref_audio_path, 'wb') as f: + f.write(ref_audio_bytes) + + logger.info(f"Extracting voice clone prompt from reference audio") + ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes) + + from core.model_manager import ModelManager + model_manager = await ModelManager.get_instance() + await model_manager.load_model("base") + _, tts = await model_manager.get_current_model() + + if tts is None: + raise RuntimeError("Failed to load base model") + + x_vector = tts.create_voice_clone_prompt( + ref_audio=(ref_audio_array, ref_sr), + ref_text=ref_text, + x_vector_only_mode=True + ) + + cache_manager = await VoiceCacheManager.get_instance() + ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes) + + features = extract_audio_features(ref_audio_array, ref_sr) + metadata = { + 'duration': features['duration'], + 'sample_rate': features['sample_rate'], + 'ref_text': ref_text, + 'x_vector_only_mode': True, + 'voice_design_id': design_id, + 'instruct': design.instruct + } + + cache_id = await cache_manager.set_cache( + current_user.id, ref_audio_hash, x_vector, metadata, db + ) + + design.voice_cache_id = cache_id + design.ref_audio_path = str(ref_audio_path) + design.ref_text = ref_text + db.commit() + + logger.info(f"Voice clone prompt prepared for design {design_id}, cache_id={cache_id}") + + return { + "message": "Voice clone prompt prepared successfully", + "cache_id": cache_id, + "ref_text": ref_text + } + + except Exception as e: + logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/qwen3-tts-backend/core/cache_manager.py b/qwen3-tts-backend/core/cache_manager.py index f9f6a01..ab48001 100644 --- a/qwen3-tts-backend/core/cache_manager.py +++ b/qwen3-tts-backend/core/cache_manager.py @@ -72,6 +72,36 @@ class VoiceCacheManager: logger.error(f"Cache retrieval error: {e}", exc_info=True) return None + async def get_cache_by_id(self, cache_id: int, db: Session) -> Optional[Dict[str, Any]]: + try: + cache_entry = db.query(VoiceCache).filter(VoiceCache.id == cache_id).first() + if not cache_entry: + logger.debug(f"Cache not found: id={cache_id}") + return None + + cache_file = Path(cache_entry.cache_path) + if not cache_file.exists(): + logger.warning(f"Cache file missing: {cache_file}") + return None + + with open(cache_file, 'rb') as f: + cache_data = pickle.load(f) + + cache_entry.last_accessed = datetime.utcnow() + cache_entry.access_count += 1 + db.commit() + + logger.info(f"Cache loaded by id: cache_id={cache_id}, access_count={cache_entry.access_count}") + return { + 'cache_id': cache_entry.id, + 'data': cache_data, + 'metadata': cache_entry.meta_data + } + + except Exception as e: + logger.error(f"Cache retrieval by id error: {e}", exc_info=True) + return None + async def set_cache( self, user_id: int, diff --git a/qwen3-tts-backend/core/tts_service.py b/qwen3-tts-backend/core/tts_service.py index be2db9f..a410858 100644 --- a/qwen3-tts-backend/core/tts_service.py +++ b/qwen3-tts-backend/core/tts_service.py @@ -83,19 +83,23 @@ class LocalTTSBackend(TTSBackend): audio_data = np.array(audio_data) return self._numpy_to_bytes(audio_data), 24000 - async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes) -> Tuple[bytes, int]: + async def generate_voice_clone(self, params: dict, ref_audio_bytes: bytes = None, x_vector=None) -> Tuple[bytes, int]: from utils.audio import process_ref_audio - ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes) - await self.model_manager.load_model("base") _, tts = await self.model_manager.get_current_model() - x_vector = tts.create_voice_clone_prompt( - ref_audio=(ref_audio_array, ref_sr), - ref_text=params.get('ref_text', ''), - x_vector_only_mode=False - ) + if x_vector is None: + if ref_audio_bytes is None: + raise ValueError("Either ref_audio_bytes or x_vector must be provided") + + ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes) + + x_vector = tts.create_voice_clone_prompt( + ref_audio=(ref_audio_array, ref_sr), + ref_text=params.get('ref_text', ''), + x_vector_only_mode=False + ) wavs, sample_rate = tts.generate_voice_clone( text=params['text'], diff --git a/qwen3-tts-backend/db/models.py b/qwen3-tts-backend/db/models.py index d5a79c1..913ad20 100644 --- a/qwen3-tts-backend/db/models.py +++ b/qwen3-tts-backend/db/models.py @@ -90,6 +90,9 @@ class VoiceDesign(Base): aliyun_voice_id = Column(String(255), nullable=True) meta_data = Column(JSON, nullable=True) preview_text = Column(Text, nullable=True) + ref_audio_path = Column(String(500), nullable=True) + ref_text = Column(Text, nullable=True) + voice_cache_id = Column(Integer, nullable=True) created_at = Column(DateTime, default=datetime.utcnow, nullable=False) last_used = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) use_count = Column(Integer, default=0, nullable=False)