diff --git a/qwen3-tts-backend/api/voice_designs.py b/qwen3-tts-backend/api/voice_designs.py index a5e885a..059c440 100644 --- a/qwen3-tts-backend/api/voice_designs.py +++ b/qwen3-tts-backend/api/voice_designs.py @@ -81,6 +81,95 @@ async def list_voice_designs( total = crud.count_voice_designs(db, current_user.id, backend_type) return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total) +@router.post("/prepare-and-create", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED) +@limiter.limit("10/minute") +async def prepare_and_create_voice_design( + request: Request, + data: VoiceDesignCreate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + from core.tts_service import TTSServiceFactory + from core.cache_manager import VoiceCacheManager + from utils.audio import process_ref_audio, extract_audio_features + from core.config import settings + from db.crud import can_user_use_local_model + from datetime import datetime + + if not can_user_use_local_model(current_user): + raise HTTPException(status_code=403, detail="Local model access required") + + try: + backend = await TTSServiceFactory.get_backend("local") + ref_text = data.preview_text or data.instruct[:100] + + ref_audio_bytes, _ = await backend.generate_voice_design({ + "text": ref_text, + "language": "Auto", + "instruct": data.instruct, + "max_new_tokens": 2048, + "temperature": 0.3, + "top_k": 10, + "top_p": 0.5, + "repetition_penalty": 1.05 + }) + + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + ref_filename = f"voice_design_new_{timestamp}.wav" + ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename + with open(ref_audio_path, 'wb') as f: + f.write(ref_audio_bytes) + + ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes) + + from core.model_manager import ModelManager + model_manager = await ModelManager.get_instance() + await model_manager.load_model("base") + _, tts = await model_manager.get_current_model() + if tts is None: + raise RuntimeError("Failed to load base model") + + x_vector = tts.create_voice_clone_prompt( + ref_audio=(ref_audio_array, ref_sr), + ref_text=ref_text, + x_vector_only_mode=True + ) + + cache_manager = await VoiceCacheManager.get_instance() + ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes) + features = extract_audio_features(ref_audio_array, ref_sr) + metadata = { + 'duration': features['duration'], + 'sample_rate': features['sample_rate'], + 'ref_text': ref_text, + 'x_vector_only_mode': True, + 'instruct': data.instruct + } + cache_id = await cache_manager.set_cache( + current_user.id, ref_audio_hash, x_vector, metadata, db + ) + + design = crud.create_voice_design( + db=db, + user_id=current_user.id, + name=data.name, + instruct=data.instruct, + backend_type="local", + meta_data=data.meta_data, + preview_text=data.preview_text, + voice_cache_id=cache_id, + ref_audio_path=str(ref_audio_path), + ref_text=ref_text, + ) + + logger.info(f"Voice design created with clone prompt: design_id={design.id}, cache_id={cache_id}") + return to_voice_design_response(design) + + except Exception as e: + logger.error(f"Failed to prepare and create voice design: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to prepare voice design") + + @router.post("/{design_id}/prepare-clone") @limiter.limit("10/minute") async def prepare_voice_clone_prompt( diff --git a/qwen3-tts-backend/db/crud.py b/qwen3-tts-backend/db/crud.py index 442aa2b..8768a0d 100644 --- a/qwen3-tts-backend/db/crud.py +++ b/qwen3-tts-backend/db/crud.py @@ -280,7 +280,10 @@ def create_voice_design( backend_type: str, aliyun_voice_id: Optional[str] = None, meta_data: Optional[Dict[str, Any]] = None, - preview_text: Optional[str] = None + preview_text: Optional[str] = None, + voice_cache_id: Optional[int] = None, + ref_audio_path: Optional[str] = None, + ref_text: Optional[str] = None, ) -> VoiceDesign: design = VoiceDesign( user_id=user_id, @@ -290,6 +293,9 @@ def create_voice_design( aliyun_voice_id=aliyun_voice_id, meta_data=meta_data, preview_text=preview_text, + voice_cache_id=voice_cache_id, + ref_audio_path=ref_audio_path, + ref_text=ref_text, created_at=datetime.utcnow(), last_used=datetime.utcnow() ) diff --git a/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx b/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx index 251cd4e..ab06424 100644 --- a/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx +++ b/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx @@ -145,28 +145,27 @@ const VoiceDesignForm = forwardRef((_props, ref) => { return } + const backend = preferences?.default_backend || 'local' + const text = watch('text') + const designData = { + name: saveDesignName, + instruct: instruct, + backend_type: backend, + preview_text: text || `${saveDesignName}的声音` + } + try { - const backend = preferences?.default_backend || 'local' - const text = watch('text') - const design = await voiceDesignApi.create({ - name: saveDesignName, - instruct: instruct, - backend_type: backend, - preview_text: text || `${saveDesignName}的声音` - }) - - toast.success(t('designSaved')) - if (backend === 'local') { setIsPreparing(true) try { - await voiceDesignApi.prepareClone(design.id) - toast.success(t('clonePrepared')) - } catch (error) { - toast.error(t('clonePrepareFailed')) + await voiceDesignApi.prepareAndCreate(designData) + toast.success(t('designSaved')) } finally { setIsPreparing(false) } + } else { + await voiceDesignApi.create(designData) + toast.success(t('designSaved')) } setShowSaveDialog(false) diff --git a/qwen3-tts-frontend/src/lib/api.ts b/qwen3-tts-frontend/src/lib/api.ts index aea6a1f..e11a2d2 100644 --- a/qwen3-tts-frontend/src/lib/api.ts +++ b/qwen3-tts-frontend/src/lib/api.ts @@ -432,6 +432,14 @@ export const voiceDesignApi = { return response.data }, + prepareAndCreate: async (data: VoiceDesignCreate): Promise => { + const response = await apiClient.post( + API_ENDPOINTS.VOICE_DESIGNS.PREPARE_AND_CREATE, + data + ) + return response.data + }, + prepareClone: async (id: number): Promise<{ message: string; cache_id: number; ref_text: string }> => { const response = await apiClient.post<{ message: string; cache_id: number; ref_text: string }>( API_ENDPOINTS.VOICE_DESIGNS.PREPARE_CLONE(id) diff --git a/qwen3-tts-frontend/src/lib/constants.ts b/qwen3-tts-frontend/src/lib/constants.ts index 105a14f..afe43e9 100644 --- a/qwen3-tts-frontend/src/lib/constants.ts +++ b/qwen3-tts-frontend/src/lib/constants.ts @@ -30,6 +30,7 @@ export const API_ENDPOINTS = { VOICE_DESIGNS: { LIST: '/voice-designs', CREATE: '/voice-designs', + PREPARE_AND_CREATE: '/voice-designs/prepare-and-create', PREPARE_CLONE: (id: number) => `/voice-designs/${id}/prepare-clone`, }, } as const