Files
Canto/qwen3-tts-backend/api/tts.py

818 lines
27 KiB
Python

import logging
import tempfile
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request, status
from sqlalchemy.orm import Session
from typing import Optional
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.config import settings
from core.database import get_db
from core.model_manager import ModelManager
from core.cache_manager import VoiceCacheManager
from db.models import Job, JobStatus, User
from schemas.tts import CustomVoiceRequest, VoiceDesignRequest, IndexTTS2FromDesignRequest
from api.auth import get_current_user
from utils.validation import (
validate_language,
validate_speaker,
validate_text_length,
validate_generation_params,
get_supported_languages,
get_supported_speakers
)
from utils.audio import save_audio_file, validate_ref_audio, process_ref_audio, extract_audio_features
from utils.metrics import cache_metrics
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tts", tags=["tts"])
limiter = Limiter(key_func=get_remote_address)
async def read_upload_with_size_limit(upload_file: UploadFile, max_size_bytes: int) -> bytes:
chunks = []
total = 0
while True:
chunk = await upload_file.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
if total > max_size_bytes:
raise ValueError(f"Audio file exceeds {max_size_bytes // (1024 * 1024)}MB limit")
chunks.append(chunk)
return b"".join(chunks)
async def process_custom_voice_job(
job_id: int,
user_id: int,
request_data: dict,
backend_type: str,
db_url: str
):
from core.database import SessionLocal
from core.tts_service import TTSServiceFactory
from core.security import decrypt_api_key
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing custom-voice job {job_id} with backend {backend_type}")
user_api_key = None
if backend_type == "aliyun":
from db.crud import get_system_setting
encrypted = get_system_setting(db, "aliyun_api_key")
if encrypted:
user_api_key = decrypt_api_key(encrypted)
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
audio_bytes, sample_rate = await backend.generate_custom_voice(request_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
with open(output_path, 'wb') as f:
f.write(audio_bytes)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_design_job(
job_id: int,
user_id: int,
request_data: dict,
backend_type: str,
db_url: str,
saved_voice_id: Optional[str] = None
):
from core.database import SessionLocal
from core.tts_service import TTSServiceFactory
from core.security import decrypt_api_key
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-design job {job_id} with backend {backend_type}")
user_api_key = None
if backend_type == "aliyun":
from db.crud import get_system_setting
encrypted = get_system_setting(db, "aliyun_api_key")
if encrypted:
user_api_key = decrypt_api_key(encrypted)
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
if backend_type == "aliyun" and saved_voice_id:
audio_bytes, sample_rate = await backend.generate_voice_design(request_data, saved_voice_id)
else:
audio_bytes, sample_rate = await backend.generate_voice_design(request_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
with open(output_path, 'wb') as f:
f.write(audio_bytes)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_clone_job(
job_id: int,
user_id: int,
request_data: dict,
ref_audio_path: str,
backend_type: str,
db_url: str,
use_voice_design: bool = False
):
from core.database import SessionLocal
from core.tts_service import TTSServiceFactory
import numpy as np
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-clone job {job_id} with backend {backend_type}")
from core.security import decrypt_api_key
user_api_key = None
if backend_type == "aliyun":
from db.crud import get_system_setting
encrypted = get_system_setting(db, "aliyun_api_key")
if encrypted:
user_api_key = decrypt_api_key(encrypted)
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
cache_manager = await VoiceCacheManager.get_instance()
voice_design_id = request_data.get('voice_design_id')
if voice_design_id:
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
if not design or not design.voice_cache_id:
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
if not cached:
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
ref_audio_hash = f"voice_design_{voice_design_id}"
cache_metrics.record_hit(user_id)
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
if request_data.get('x_vector_only_mode', False) and backend_type == "local":
x_vector = None
cache_id = None
if voice_design_id:
design = get_voice_design(db, voice_design_id, user_id)
x_vector = cached['data']
cache_id = design.voice_cache_id
elif request_data.get('use_cache', True):
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
if cached:
x_vector = cached['data']
cache_id = cached['cache_id']
cache_metrics.record_hit(user_id)
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
if x_vector is None:
cache_metrics.record_miss(user_id)
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=request_data.get('ref_text', ''),
x_vector_only_mode=True
)
if request_data.get('use_cache', True):
features = extract_audio_features(ref_audio_array, ref_sr)
metadata = {
'duration': features['duration'],
'sample_rate': features['sample_rate'],
'ref_text': request_data.get('ref_text', ''),
'x_vector_only_mode': True
}
cache_id = await cache_manager.set_cache(
user_id, ref_audio_hash, x_vector, metadata, db
)
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
job.status = JobStatus.COMPLETED
job.output_path = f"x_vector_cached_{cache_id}"
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed (x_vector_only_mode)")
return
backend = await TTSServiceFactory.get_backend(backend_type, user_api_key)
if voice_design_id and backend_type == "local":
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
x_vector = cached['data']
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
with open(output_path, 'wb') as f:
f.write(audio_bytes)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
Path(ref_audio_path).unlink()
db.close()
@router.post("/custom-voice")
@limiter.limit("10/minute")
async def create_custom_voice_job(
request: Request,
req_data: CustomVoiceRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from core.security import decrypt_api_key
from db.crud import get_user_preferences, can_user_use_local_model
user_prefs = get_user_preferences(db, current_user.id)
preferred_backend = user_prefs.get("default_backend", "aliyun")
can_use_local = can_user_use_local_model(current_user)
backend_type = req_data.backend if hasattr(req_data, 'backend') and req_data.backend else preferred_backend
if backend_type == "local" and not can_use_local:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Local model is not available. Please contact administrator."
)
if backend_type == "aliyun":
from db.crud import get_system_setting
if not get_system_setting(db, "aliyun_api_key"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Aliyun API key not configured. Please contact administrator."
)
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
speaker = validate_speaker(req_data.speaker, backend_type)
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="custom-voice",
status=JobStatus.PENDING,
backend_type=backend_type,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
background_tasks.add_task(
process_custom_voice_job,
job.id,
current_user.id,
request_data,
backend_type,
str(settings.DATABASE_URL)
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-design")
@limiter.limit("10/minute")
async def create_voice_design_job(
request: Request,
req_data: VoiceDesignRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from core.security import decrypt_api_key
from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design, update_voice_design_usage
user_prefs = get_user_preferences(db, current_user.id)
preferred_backend = user_prefs.get("default_backend", "aliyun")
can_use_local = can_user_use_local_model(current_user)
backend_type = req_data.backend if hasattr(req_data, 'backend') and req_data.backend else preferred_backend
saved_voice_id = None
if req_data.saved_design_id:
saved_design = get_voice_design(db, req_data.saved_design_id, current_user.id)
if not saved_design:
raise HTTPException(status_code=404, detail="Saved voice design not found")
if saved_design.backend_type != backend_type:
raise HTTPException(
status_code=400,
detail=f"Saved design backend ({saved_design.backend_type}) doesn't match current backend ({backend_type})"
)
req_data.instruct = saved_design.instruct
saved_voice_id = saved_design.aliyun_voice_id
update_voice_design_usage(db, req_data.saved_design_id, current_user.id)
if backend_type == "local" and not can_use_local:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Local model is not available. Please contact administrator."
)
if backend_type == "aliyun":
from db.crud import get_system_setting
if not get_system_setting(db, "aliyun_api_key"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Aliyun API key not configured. Please contact administrator."
)
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
if not req_data.saved_design_id:
if not req_data.instruct or not req_data.instruct.strip():
raise ValueError("Instruct parameter is required when saved_design_id is not provided")
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-design",
status=JobStatus.PENDING,
backend_type=backend_type,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
background_tasks.add_task(
process_voice_design_job,
job.id,
current_user.id,
request_data,
backend_type,
str(settings.DATABASE_URL),
saved_voice_id
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-clone")
@limiter.limit("10/minute")
async def create_voice_clone_job(
request: Request,
text: str = Form(...),
language: str = Form(default="Auto"),
ref_audio: Optional[UploadFile] = File(default=None),
ref_text: Optional[str] = Form(default=None),
use_cache: bool = Form(default=True),
x_vector_only_mode: bool = Form(default=False),
voice_design_id: Optional[int] = Form(default=None),
max_new_tokens: Optional[int] = Form(default=2048),
temperature: Optional[float] = Form(default=0.9),
top_k: Optional[int] = Form(default=50),
top_p: Optional[float] = Form(default=1.0),
repetition_penalty: Optional[float] = Form(default=1.05),
backend: Optional[str] = Form(default=None),
background_tasks: BackgroundTasks = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from core.security import decrypt_api_key
from db.crud import get_user_preferences, can_user_use_local_model, get_voice_design
user_prefs = get_user_preferences(db, current_user.id)
preferred_backend = user_prefs.get("default_backend", "aliyun")
can_use_local = can_user_use_local_model(current_user)
backend_type = backend if backend else preferred_backend
if backend_type == "local" and not can_use_local:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Local model is not available. Please contact administrator."
)
if backend_type == "aliyun":
from db.crud import get_system_setting
if not get_system_setting(db, "aliyun_api_key"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Aliyun API key not configured. Please contact administrator."
)
ref_audio_data = None
ref_audio_hash = None
use_voice_design = False
try:
validate_text_length(text)
language = validate_language(language)
params = validate_generation_params({
'max_new_tokens': max_new_tokens,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': repetition_penalty
})
cache_manager = await VoiceCacheManager.get_instance()
if voice_design_id:
design = get_voice_design(db, voice_design_id, current_user.id)
if not design:
raise ValueError("Voice design not found")
if design.backend_type != backend_type:
raise ValueError(f"Voice design backend ({design.backend_type}) doesn't match request backend ({backend_type})")
if not design.voice_cache_id:
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
use_voice_design = True
ref_audio_hash = f"voice_design_{voice_design_id}"
if not ref_text:
ref_text = design.ref_text
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
else:
if not ref_audio:
raise ValueError("Either ref_audio or voice_design_id must be provided")
max_audio_size_bytes = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024
ref_audio_data = await read_upload_with_size_limit(ref_audio, max_audio_size_bytes)
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-clone",
status=JobStatus.PENDING,
backend_type=backend_type,
input_data="",
input_params={
"text": text,
"language": language,
"ref_text": ref_text or "",
"ref_audio_hash": ref_audio_hash,
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
if use_voice_design:
design = get_voice_design(db, voice_design_id, current_user.id)
tmp_audio_path = design.ref_audio_path
else:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(ref_audio_data)
tmp_audio_path = tmp_file.name
request_data = {
"text": text,
"language": language,
"ref_text": ref_text or "",
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params
}
background_tasks.add_task(
process_voice_clone_job,
job.id,
current_user.id,
request_data,
tmp_audio_path,
backend_type,
str(settings.DATABASE_URL),
use_voice_design
)
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
cache_info = {"cache_id": existing_cache['cache_id']} if existing_cache else None
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully",
"cache_info": cache_info
}
async def process_indextts2_job(
job_id: int,
user_id: int,
voice_design_id: int,
text: str,
emo_text: Optional[str],
emo_alpha: float,
):
from core.database import SessionLocal
from core.tts_service import IndexTTS2Backend
from db.crud import get_voice_design
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
design = get_voice_design(db, voice_design_id, user_id)
if not design or not design.ref_audio_path:
raise RuntimeError("Voice design has no ref_audio_path")
from pathlib import Path as _Path
if not _Path(design.ref_audio_path).exists():
raise RuntimeError(f"ref_audio_path does not exist: {design.ref_audio_path}")
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = str(_Path(settings.OUTPUT_DIR) / filename)
_Path(settings.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
backend = IndexTTS2Backend()
audio_bytes = await backend.generate(
text=text,
spk_audio_prompt=design.ref_audio_path,
output_path=output_path,
emo_text=emo_text,
emo_alpha=emo_alpha,
)
job.status = JobStatus.COMPLETED
job.output_path = output_path
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
logger.error(f"IndexTTS2 job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
@router.post("/indextts2-from-design")
@limiter.limit("10/minute")
async def create_indextts2_from_design_job(
request: Request,
req_data: IndexTTS2FromDesignRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
from db.crud import get_voice_design
design = get_voice_design(db, req_data.voice_design_id, current_user.id)
if not design:
raise HTTPException(status_code=404, detail="Voice design not found")
if not design.ref_audio_path:
raise HTTPException(status_code=400, detail="Voice design has no ref_audio_path")
from pathlib import Path as _Path
if not _Path(design.ref_audio_path).exists():
raise HTTPException(status_code=400, detail="ref_audio_path file does not exist")
job = Job(
user_id=current_user.id,
job_type="indextts2",
status=JobStatus.PENDING,
backend_type="local",
input_data="",
input_params={
"text": req_data.text,
"voice_design_id": req_data.voice_design_id,
"emo_text": req_data.emo_text,
"emo_alpha": req_data.emo_alpha,
}
)
db.add(job)
db.commit()
db.refresh(job)
background_tasks.add_task(
process_indextts2_job,
job.id,
current_user.id,
req_data.voice_design_id,
req_data.text,
req_data.emo_text,
req_data.emo_alpha,
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.get("/speakers")
@limiter.limit("30/minute")
async def list_speakers(request: Request, backend: Optional[str] = "local"):
return get_supported_speakers(backend)
@router.get("/languages")
@limiter.limit("30/minute")
async def list_languages(request: Request):
return get_supported_languages()