0
qwen3-tts-backend/api/__init__.py
Normal file
0
qwen3-tts-backend/api/__init__.py
Normal file
107
qwen3-tts-backend/api/auth.py
Normal file
107
qwen3-tts-backend/api/auth.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from datetime import timedelta
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from config import settings
|
||||
from core.security import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
decode_access_token
|
||||
)
|
||||
from db.database import get_db
|
||||
from db.crud import get_user_by_username, get_user_by_email, create_user
|
||||
from schemas.user import User, UserCreate, Token
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token")
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
username = decode_access_token(token)
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
|
||||
user = get_user_by_username(db, username=username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
@router.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("5/minute")
|
||||
async def register(request: Request, user_data: UserCreate, db: Session = Depends(get_db)):
|
||||
existing_user = get_user_by_username(db, username=user_data.username)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
existing_email = get_user_by_email(db, email=user_data.email)
|
||||
if existing_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
user = create_user(
|
||||
db,
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@router.post("/token", response_model=Token)
|
||||
@limiter.limit("5/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = get_user_by_username(db, username=form_data.username)
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_current_user_info(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
return current_user
|
||||
156
qwen3-tts-backend/api/cache.py
Normal file
156
qwen3-tts-backend/api/cache.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from core.config import settings
|
||||
from core.database import get_db
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from api.auth import get_current_user
|
||||
from db.crud import list_cache_entries, delete_cache_entry
|
||||
from db.models import VoiceCache, User
|
||||
from utils.metrics import cache_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/cache", tags=["cache"])
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get("/voices")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_user_caches(
|
||||
request: Request,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
caches = list_cache_entries(db, current_user.id, skip=skip, limit=limit)
|
||||
|
||||
result = []
|
||||
for cache in caches:
|
||||
meta_data = json.loads(cache.meta_data) if cache.meta_data else {}
|
||||
cache_file = Path(cache.cache_path)
|
||||
file_size_mb = cache_file.stat().st_size / (1024 * 1024) if cache_file.exists() else 0
|
||||
|
||||
result.append({
|
||||
'id': cache.id,
|
||||
'ref_audio_hash': cache.ref_audio_hash,
|
||||
'created_at': cache.created_at.isoformat(),
|
||||
'last_accessed': cache.last_accessed.isoformat(),
|
||||
'access_count': cache.access_count,
|
||||
'metadata': meta_data,
|
||||
'size_mb': round(file_size_mb, 2)
|
||||
})
|
||||
|
||||
return {
|
||||
'caches': result,
|
||||
'total': len(result)
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/voices/{cache_id}")
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_user_cache(
|
||||
request: Request,
|
||||
cache_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
cache = db.query(VoiceCache).filter(
|
||||
VoiceCache.id == cache_id,
|
||||
VoiceCache.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache not found")
|
||||
|
||||
cache_file = Path(cache.cache_path)
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
|
||||
success = delete_cache_entry(db, cache_id, current_user.id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete cache")
|
||||
|
||||
logger.info(f"Cache deleted: id={cache_id}, user={current_user.id}")
|
||||
|
||||
return {
|
||||
'message': 'Cache deleted successfully',
|
||||
'cache_id': cache_id
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/voices")
|
||||
@limiter.limit("10/minute")
|
||||
async def cleanup_expired_caches(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
deleted_count = await cache_manager.cleanup_expired(db)
|
||||
|
||||
logger.info(f"Expired cache cleanup: user={current_user.id}, deleted={deleted_count}")
|
||||
|
||||
return {
|
||||
'message': 'Expired caches cleaned up',
|
||||
'deleted_count': deleted_count
|
||||
}
|
||||
|
||||
|
||||
@router.post("/voices/prune")
|
||||
@limiter.limit("10/minute")
|
||||
async def prune_caches(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
deleted_count = await cache_manager.enforce_max_entries(current_user.id, db)
|
||||
|
||||
logger.info(f"LRU prune: user={current_user.id}, deleted={deleted_count}")
|
||||
|
||||
return {
|
||||
'message': 'LRU pruning completed',
|
||||
'deleted_count': deleted_count
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_cache_stats(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
stats = cache_metrics.get_stats(db, settings.CACHE_DIR)
|
||||
|
||||
user_stats = None
|
||||
for user_stat in stats['users']:
|
||||
if user_stat['user_id'] == current_user.id:
|
||||
user_stats = user_stat
|
||||
break
|
||||
|
||||
if user_stats is None:
|
||||
user_cache_count = db.query(VoiceCache).filter(
|
||||
VoiceCache.user_id == current_user.id
|
||||
).count()
|
||||
user_stats = {
|
||||
'user_id': current_user.id,
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'hit_rate': 0.0,
|
||||
'cache_entries': user_cache_count
|
||||
}
|
||||
|
||||
return {
|
||||
'global': stats['global'],
|
||||
'user': user_stats
|
||||
}
|
||||
176
qwen3-tts-backend/api/jobs.py
Normal file
176
qwen3-tts-backend/api/jobs.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from core.database import get_db
|
||||
from core.config import settings
|
||||
from db.models import Job, JobStatus, User
|
||||
from api.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get("/{job_id}")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_job(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
download_url = None
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
|
||||
|
||||
return {
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"output_path": job.output_path,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
"started_at": job.started_at.isoformat() + 'Z' if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_jobs(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
status: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
query = db.query(Job).filter(Job.user_id == current_user.id)
|
||||
|
||||
if status:
|
||||
try:
|
||||
status_enum = JobStatus(status)
|
||||
query = query.filter(Job.status == status_enum)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid status: {status}")
|
||||
|
||||
total = query.count()
|
||||
jobs = query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
jobs_data = []
|
||||
for job in jobs:
|
||||
download_url = None
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
|
||||
|
||||
jobs_data.append({
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"output_path": job.output_path,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"skip": skip,
|
||||
"limit": limit,
|
||||
"jobs": jobs_data
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_job(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
try:
|
||||
output_file.unlink()
|
||||
logger.info(f"Deleted output file: {output_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete output file {output_file}: {e}")
|
||||
|
||||
db.delete(job)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Job deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/{job_id}/download")
|
||||
@limiter.limit("30/minute")
|
||||
async def download_job_output(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if job.status != JobStatus.COMPLETED:
|
||||
raise HTTPException(status_code=400, detail="Job not completed yet")
|
||||
|
||||
if not job.output_path:
|
||||
raise HTTPException(status_code=404, detail="Output file not found")
|
||||
|
||||
output_file = Path(job.output_path)
|
||||
if not output_file.exists():
|
||||
raise HTTPException(status_code=404, detail="Output file does not exist")
|
||||
|
||||
output_dir = Path(settings.OUTPUT_DIR).resolve()
|
||||
if not output_file.resolve().is_relative_to(output_dir):
|
||||
logger.warning(f"Path traversal attempt detected: {output_file}")
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return FileResponse(
|
||||
path=str(output_file),
|
||||
media_type="audio/wav",
|
||||
filename=output_file.name,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{output_file.name}"'
|
||||
}
|
||||
)
|
||||
21
qwen3-tts-backend/api/metrics.py
Normal file
21
qwen3-tts-backend/api/metrics.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import logging
|
||||
from fastapi import APIRouter
|
||||
|
||||
from core.metrics import MetricsCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/metrics", tags=["metrics"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_metrics():
|
||||
metrics = await MetricsCollector.get_instance()
|
||||
data = await metrics.get_metrics()
|
||||
return data
|
||||
|
||||
|
||||
@router.post("/reset")
|
||||
async def reset_metrics():
|
||||
metrics = await MetricsCollector.get_instance()
|
||||
await metrics.reset()
|
||||
return {"message": "Metrics reset successfully"}
|
||||
553
qwen3-tts-backend/api/tts.py
Normal file
553
qwen3-tts-backend/api/tts.py
Normal file
@@ -0,0 +1,553 @@
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from core.config import settings
|
||||
from core.database import get_db
|
||||
from core.model_manager import ModelManager
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from db.models import Job, JobStatus, User
|
||||
from schemas.tts import CustomVoiceRequest, VoiceDesignRequest
|
||||
from api.auth import get_current_user
|
||||
from utils.validation import (
|
||||
validate_language,
|
||||
validate_speaker,
|
||||
validate_text_length,
|
||||
validate_generation_params,
|
||||
get_supported_languages,
|
||||
get_supported_speakers
|
||||
)
|
||||
from utils.audio import save_audio_file, validate_ref_audio, process_ref_audio, extract_audio_features
|
||||
from utils.metrics import cache_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/tts", tags=["tts"])
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
async def process_custom_voice_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
db_url: str
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
logger.error(f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing custom-voice job {job_id}")
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("custom-voice")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load custom-voice model")
|
||||
|
||||
result = tts.generate_custom_voice(
|
||||
text=request_data['text'],
|
||||
language=request_data['language'],
|
||||
speaker=request_data['speaker'],
|
||||
instruct=request_data.get('instruct', ''),
|
||||
max_new_tokens=request_data['max_new_tokens'],
|
||||
temperature=request_data['temperature'],
|
||||
top_k=request_data['top_k'],
|
||||
top_p=request_data['top_p'],
|
||||
repetition_penalty=request_data['repetition_penalty']
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
if isinstance(result, tuple):
|
||||
audio_data = result[0]
|
||||
elif isinstance(result, list):
|
||||
audio_data = np.array(result)
|
||||
else:
|
||||
audio_data = result
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
save_audio_file(audio_data, 24000, output_path)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def process_voice_design_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
db_url: str
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
logger.error(f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing voice-design job {job_id}")
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("voice-design")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load voice-design model")
|
||||
|
||||
result = tts.generate_voice_design(
|
||||
text=request_data['text'],
|
||||
language=request_data['language'],
|
||||
instruct=request_data['instruct'],
|
||||
max_new_tokens=request_data['max_new_tokens'],
|
||||
temperature=request_data['temperature'],
|
||||
top_k=request_data['top_k'],
|
||||
top_p=request_data['top_p'],
|
||||
repetition_penalty=request_data['repetition_penalty']
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
if isinstance(result, tuple):
|
||||
audio_data = result[0]
|
||||
elif isinstance(result, list):
|
||||
audio_data = np.array(result)
|
||||
else:
|
||||
audio_data = result
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
save_audio_file(audio_data, 24000, output_path)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def process_voice_clone_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
ref_audio_path: str,
|
||||
db_url: str
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
import numpy as np
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
logger.error(f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing voice-clone job {job_id}")
|
||||
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
x_vector = None
|
||||
cache_id = None
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
||||
if cached:
|
||||
x_vector = cached['data']
|
||||
cache_id = cached['cache_id']
|
||||
cache_metrics.record_hit(user_id)
|
||||
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
|
||||
|
||||
if x_vector is None:
|
||||
cache_metrics.record_miss(user_id)
|
||||
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=request_data.get('ref_text', ''),
|
||||
x_vector_only_mode=request_data.get('x_vector_only_mode', False)
|
||||
)
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||
metadata = {
|
||||
'duration': features['duration'],
|
||||
'sample_rate': features['sample_rate'],
|
||||
'ref_text': request_data.get('ref_text', ''),
|
||||
'x_vector_only_mode': request_data.get('x_vector_only_mode', False)
|
||||
}
|
||||
cache_id = await cache_manager.set_cache(
|
||||
user_id, ref_audio_hash, x_vector, metadata, db
|
||||
)
|
||||
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
|
||||
|
||||
if request_data.get('x_vector_only_mode', False):
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = f"x_vector_cached_{cache_id}"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
logger.info(f"Job {job_id} completed (x_vector_only_mode)")
|
||||
return
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
wavs, sample_rate = tts.generate_voice_clone(
|
||||
text=request_data['text'],
|
||||
language=request_data['language'],
|
||||
voice_clone_prompt=x_vector,
|
||||
max_new_tokens=request_data['max_new_tokens'],
|
||||
temperature=request_data['temperature'],
|
||||
top_k=request_data['top_k'],
|
||||
top_p=request_data['top_p'],
|
||||
repetition_penalty=request_data['repetition_penalty']
|
||||
)
|
||||
|
||||
audio_data = wavs[0] if isinstance(wavs, list) else wavs
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
save_audio_file(audio_data, sample_rate, output_path)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
if Path(ref_audio_path).exists():
|
||||
Path(ref_audio_path).unlink()
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/custom-voice")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_custom_voice_job(
|
||||
request: Request,
|
||||
req_data: CustomVoiceRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
validate_text_length(req_data.text)
|
||||
language = validate_language(req_data.language)
|
||||
speaker = validate_speaker(req_data.speaker)
|
||||
|
||||
params = validate_generation_params({
|
||||
'max_new_tokens': req_data.max_new_tokens,
|
||||
'temperature': req_data.temperature,
|
||||
'top_k': req_data.top_k,
|
||||
'top_p': req_data.top_p,
|
||||
'repetition_penalty': req_data.repetition_penalty
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job = Job(
|
||||
user_id=current_user.id,
|
||||
job_type="custom-voice",
|
||||
status=JobStatus.PENDING,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"speaker": speaker,
|
||||
"instruct": req_data.instruct or "",
|
||||
**params
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
request_data = {
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"speaker": speaker,
|
||||
"instruct": req_data.instruct or "",
|
||||
**params
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
process_custom_voice_job,
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
str(settings.DATABASE_URL)
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"message": "Job created successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/voice-design")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_voice_design_job(
|
||||
request: Request,
|
||||
req_data: VoiceDesignRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
validate_text_length(req_data.text)
|
||||
language = validate_language(req_data.language)
|
||||
|
||||
if not req_data.instruct or not req_data.instruct.strip():
|
||||
raise ValueError("Instruct parameter is required for voice design")
|
||||
|
||||
params = validate_generation_params({
|
||||
'max_new_tokens': req_data.max_new_tokens,
|
||||
'temperature': req_data.temperature,
|
||||
'top_k': req_data.top_k,
|
||||
'top_p': req_data.top_p,
|
||||
'repetition_penalty': req_data.repetition_penalty
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job = Job(
|
||||
user_id=current_user.id,
|
||||
job_type="voice-design",
|
||||
status=JobStatus.PENDING,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"instruct": req_data.instruct,
|
||||
**params
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
request_data = {
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"instruct": req_data.instruct,
|
||||
**params
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
process_voice_design_job,
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
str(settings.DATABASE_URL)
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"message": "Job created successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/voice-clone")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_voice_clone_job(
|
||||
request: Request,
|
||||
text: str = Form(...),
|
||||
language: str = Form(default="Auto"),
|
||||
ref_audio: UploadFile = File(...),
|
||||
ref_text: Optional[str] = Form(default=None),
|
||||
use_cache: bool = Form(default=True),
|
||||
x_vector_only_mode: bool = Form(default=False),
|
||||
max_new_tokens: Optional[int] = Form(default=2048),
|
||||
temperature: Optional[float] = Form(default=0.9),
|
||||
top_k: Optional[int] = Form(default=50),
|
||||
top_p: Optional[float] = Form(default=1.0),
|
||||
repetition_penalty: Optional[float] = Form(default=1.05),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
validate_text_length(text)
|
||||
language = validate_language(language)
|
||||
|
||||
params = validate_generation_params({
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'repetition_penalty': repetition_penalty
|
||||
})
|
||||
|
||||
ref_audio_data = await ref_audio.read()
|
||||
|
||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job = Job(
|
||||
user_id=current_user.id,
|
||||
job_type="voice-clone",
|
||||
status=JobStatus.PENDING,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": text,
|
||||
"language": language,
|
||||
"ref_text": ref_text or "",
|
||||
"ref_audio_hash": ref_audio_hash,
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
**params
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
tmp_file.write(ref_audio_data)
|
||||
tmp_audio_path = tmp_file.name
|
||||
|
||||
request_data = {
|
||||
"text": text,
|
||||
"language": language,
|
||||
"ref_text": ref_text or "",
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
**params
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
process_voice_clone_job,
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
tmp_audio_path,
|
||||
str(settings.DATABASE_URL)
|
||||
)
|
||||
|
||||
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
|
||||
cache_info = {"cache_id": existing_cache['cache_id']} if existing_cache else None
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"message": "Job created successfully",
|
||||
"cache_info": cache_info
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_models(request: Request):
|
||||
model_manager = await ModelManager.get_instance()
|
||||
return model_manager.get_model_info()
|
||||
|
||||
|
||||
@router.get("/speakers")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_speakers(request: Request):
|
||||
return get_supported_speakers()
|
||||
|
||||
|
||||
@router.get("/languages")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_languages(request: Request):
|
||||
return get_supported_languages()
|
||||
169
qwen3-tts-backend/api/users.py
Normal file
169
qwen3-tts-backend/api/users.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from api.auth import get_current_user
|
||||
from config import settings
|
||||
from core.security import get_password_hash
|
||||
from db.database import get_db
|
||||
from db.crud import (
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
get_user_by_email,
|
||||
list_users,
|
||||
create_user_by_admin,
|
||||
update_user,
|
||||
delete_user
|
||||
)
|
||||
from schemas.user import User, UserCreateByAdmin, UserUpdate, UserListResponse
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
async def require_superuser(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
) -> User:
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Superuser access required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
@router.get("", response_model=UserListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_users(
|
||||
request: Request,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
users, total = list_users(db, skip=skip, limit=limit)
|
||||
return UserListResponse(users=users, total=total, skip=skip, limit=limit)
|
||||
|
||||
@router.post("", response_model=User, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_user(
|
||||
request: Request,
|
||||
user_data: UserCreateByAdmin,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
existing_user = get_user_by_username(db, username=user_data.username)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
existing_email = get_user_by_email(db, email=user_data.email)
|
||||
if existing_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
user = create_user_by_admin(
|
||||
db,
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password,
|
||||
is_superuser=user_data.is_superuser
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@router.get("/{user_id}", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_user(
|
||||
request: Request,
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
user = get_user_by_id(db, user_id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
return user
|
||||
|
||||
@router.put("/{user_id}", response_model=User)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_user_info(
|
||||
request: Request,
|
||||
user_id: int,
|
||||
user_data: UserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_superuser)
|
||||
):
|
||||
existing_user = get_user_by_id(db, user_id=user_id)
|
||||
if not existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if user_data.username is not None:
|
||||
username_exists = get_user_by_username(db, username=user_data.username)
|
||||
if username_exists and username_exists.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already taken"
|
||||
)
|
||||
|
||||
if user_data.email is not None:
|
||||
email_exists = get_user_by_email(db, email=user_data.email)
|
||||
if email_exists and email_exists.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already taken"
|
||||
)
|
||||
|
||||
hashed_password = None
|
||||
if user_data.password is not None:
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
|
||||
user = update_user(
|
||||
db,
|
||||
user_id=user_id,
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password,
|
||||
is_active=user_data.is_active,
|
||||
is_superuser=user_data.is_superuser
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_user_by_id(
|
||||
request: Request,
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_superuser)
|
||||
):
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete yourself"
|
||||
)
|
||||
|
||||
success = delete_user(db, user_id=user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
Reference in New Issue
Block a user