282 lines
9.8 KiB
Python
282 lines
9.8 KiB
Python
import logging
|
|
import json
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
|
from sqlalchemy.orm import Session
|
|
from typing import Optional
|
|
from slowapi import Limiter
|
|
from slowapi.util import get_remote_address
|
|
from pathlib import Path
|
|
|
|
from core.database import get_db
|
|
from api.auth import get_current_user
|
|
from db.models import User, Job, JobStatus
|
|
from db import crud
|
|
from schemas.voice_design import (
|
|
VoiceDesignCreate,
|
|
VoiceDesignResponse,
|
|
VoiceDesignListResponse
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/voice-designs", tags=["voice-designs"])
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
|
|
|
|
def to_voice_design_response(design) -> VoiceDesignResponse:
|
|
meta_data = design.meta_data
|
|
if isinstance(meta_data, str):
|
|
try:
|
|
meta_data = json.loads(meta_data)
|
|
except Exception:
|
|
meta_data = None
|
|
return VoiceDesignResponse(
|
|
id=design.id,
|
|
user_id=design.user_id,
|
|
name=design.name,
|
|
instruct=design.instruct,
|
|
meta_data=meta_data,
|
|
preview_text=design.preview_text,
|
|
ref_audio_path=design.ref_audio_path,
|
|
created_at=design.created_at,
|
|
last_used=design.last_used,
|
|
use_count=design.use_count
|
|
)
|
|
|
|
@router.post("", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
|
|
@limiter.limit("30/minute")
|
|
async def save_voice_design(
|
|
request: Request,
|
|
data: VoiceDesignCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
try:
|
|
design = crud.create_voice_design(
|
|
db=db,
|
|
user_id=current_user.id,
|
|
name=data.name,
|
|
instruct=data.instruct,
|
|
meta_data=data.meta_data,
|
|
preview_text=data.preview_text
|
|
)
|
|
return to_voice_design_response(design)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save voice design: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to save voice design")
|
|
|
|
@router.get("", response_model=VoiceDesignListResponse)
|
|
@limiter.limit("30/minute")
|
|
async def list_voice_designs(
|
|
request: Request,
|
|
backend_type: Optional[str] = None,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
designs = crud.list_voice_designs(db, current_user.id, backend_type, skip, limit)
|
|
total = crud.count_voice_designs(db, current_user.id, backend_type)
|
|
return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total)
|
|
|
|
@router.post("/prepare-and-create", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
|
|
@limiter.limit("10/minute")
|
|
async def prepare_and_create_voice_design(
|
|
request: Request,
|
|
data: VoiceDesignCreate,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
from core.tts_service import TTSServiceFactory
|
|
from core.cache_manager import VoiceCacheManager
|
|
from utils.audio import process_ref_audio, extract_audio_features
|
|
from core.config import settings
|
|
from db.crud import can_user_use_local_model
|
|
from datetime import datetime
|
|
|
|
if not can_user_use_local_model(current_user):
|
|
raise HTTPException(status_code=403, detail="Local model access required")
|
|
|
|
try:
|
|
backend = await TTSServiceFactory.get_backend("local")
|
|
ref_text = data.preview_text or data.instruct[:100]
|
|
|
|
ref_audio_bytes, _ = await backend.generate_voice_design({
|
|
"text": ref_text,
|
|
"language": "Auto",
|
|
"instruct": data.instruct,
|
|
"max_new_tokens": 2048,
|
|
"temperature": 0.3,
|
|
"top_k": 10,
|
|
"top_p": 0.5,
|
|
"repetition_penalty": 1.05
|
|
})
|
|
|
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
ref_filename = f"voice_design_new_{timestamp}.wav"
|
|
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
|
|
with open(ref_audio_path, 'wb') as f:
|
|
f.write(ref_audio_bytes)
|
|
|
|
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
|
|
|
from core.model_manager import ModelManager
|
|
model_manager = await ModelManager.get_instance()
|
|
await model_manager.load_model("base")
|
|
_, tts = await model_manager.get_current_model()
|
|
if tts is None:
|
|
raise RuntimeError("Failed to load base model")
|
|
|
|
x_vector = tts.create_voice_clone_prompt(
|
|
ref_audio=(ref_audio_array, ref_sr),
|
|
ref_text=ref_text,
|
|
)
|
|
|
|
cache_manager = await VoiceCacheManager.get_instance()
|
|
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
|
|
features = extract_audio_features(ref_audio_array, ref_sr)
|
|
metadata = {
|
|
'duration': features['duration'],
|
|
'sample_rate': features['sample_rate'],
|
|
'ref_text': ref_text,
|
|
'instruct': data.instruct
|
|
}
|
|
cache_id = await cache_manager.set_cache(
|
|
current_user.id, ref_audio_hash, x_vector, metadata, db
|
|
)
|
|
|
|
design = crud.create_voice_design(
|
|
db=db,
|
|
user_id=current_user.id,
|
|
name=data.name,
|
|
instruct=data.instruct,
|
|
meta_data=data.meta_data,
|
|
preview_text=data.preview_text,
|
|
voice_cache_id=cache_id,
|
|
ref_audio_path=str(ref_audio_path),
|
|
ref_text=ref_text,
|
|
)
|
|
|
|
logger.info(f"Voice design created with clone prompt: design_id={design.id}, cache_id={cache_id}")
|
|
return to_voice_design_response(design)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to prepare and create voice design: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to prepare voice design")
|
|
|
|
|
|
@router.delete("/{design_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_voice_design(
|
|
design_id: int,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
deleted = crud.delete_voice_design(db, design_id, current_user.id)
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail="Voice design not found")
|
|
|
|
|
|
@router.post("/{design_id}/prepare-clone")
|
|
@limiter.limit("10/minute")
|
|
async def prepare_voice_clone_prompt(
|
|
request: Request,
|
|
design_id: int,
|
|
background_tasks: BackgroundTasks,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
from core.tts_service import TTSServiceFactory
|
|
from core.cache_manager import VoiceCacheManager
|
|
from utils.audio import process_ref_audio, extract_audio_features
|
|
from core.config import settings
|
|
from db.crud import can_user_use_local_model
|
|
from datetime import datetime
|
|
|
|
design = crud.get_voice_design(db, design_id, current_user.id)
|
|
if not design:
|
|
raise HTTPException(status_code=404, detail="Voice design not found")
|
|
|
|
if not can_user_use_local_model(current_user):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Local model access required"
|
|
)
|
|
|
|
if design.voice_cache_id:
|
|
return {
|
|
"message": "Voice clone prompt already exists",
|
|
"cache_id": design.voice_cache_id
|
|
}
|
|
|
|
try:
|
|
backend = await TTSServiceFactory.get_backend("local")
|
|
|
|
ref_text = design.preview_text or design.instruct[:100]
|
|
|
|
logger.info(f"Generating reference audio for voice design {design_id}")
|
|
ref_audio_bytes, sample_rate = await backend.generate_voice_design({
|
|
"text": ref_text,
|
|
"language": "Auto",
|
|
"instruct": design.instruct,
|
|
"max_new_tokens": 2048,
|
|
"temperature": 0.3,
|
|
"top_k": 10,
|
|
"top_p": 0.5,
|
|
"repetition_penalty": 1.05
|
|
})
|
|
|
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
ref_filename = f"voice_design_{design_id}_{timestamp}.wav"
|
|
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
|
|
|
|
with open(ref_audio_path, 'wb') as f:
|
|
f.write(ref_audio_bytes)
|
|
|
|
logger.info(f"Extracting voice clone prompt from reference audio")
|
|
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
|
|
|
from core.model_manager import ModelManager
|
|
model_manager = await ModelManager.get_instance()
|
|
await model_manager.load_model("base")
|
|
_, tts = await model_manager.get_current_model()
|
|
|
|
if tts is None:
|
|
raise RuntimeError("Failed to load base model")
|
|
|
|
x_vector = tts.create_voice_clone_prompt(
|
|
ref_audio=(ref_audio_array, ref_sr),
|
|
ref_text=ref_text,
|
|
)
|
|
|
|
cache_manager = await VoiceCacheManager.get_instance()
|
|
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
|
|
|
|
features = extract_audio_features(ref_audio_array, ref_sr)
|
|
metadata = {
|
|
'duration': features['duration'],
|
|
'sample_rate': features['sample_rate'],
|
|
'ref_text': ref_text,
|
|
'voice_design_id': design_id,
|
|
'instruct': design.instruct
|
|
}
|
|
|
|
cache_id = await cache_manager.set_cache(
|
|
current_user.id, ref_audio_hash, x_vector, metadata, db
|
|
)
|
|
|
|
design.voice_cache_id = cache_id
|
|
design.ref_audio_path = str(ref_audio_path)
|
|
design.ref_text = ref_text
|
|
db.commit()
|
|
|
|
logger.info(f"Voice clone prompt prepared for design {design_id}, cache_id={cache_id}")
|
|
|
|
return {
|
|
"message": "Voice clone prompt prepared successfully",
|
|
"cache_id": cache_id,
|
|
"ref_text": ref_text
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to prepare voice clone prompt")
|