init commit

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-26 15:34:31 +08:00
commit 80513a3258
141 changed files with 24966 additions and 0 deletions

View File

View File

@@ -0,0 +1,107 @@
from datetime import timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from config import settings
from core.security import (
get_password_hash,
verify_password,
create_access_token,
decode_access_token
)
from db.database import get_db
from db.crud import get_user_by_username, get_user_by_email, create_user
from schemas.user import User, UserCreate, Token
router = APIRouter(prefix="/auth", tags=["authentication"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token")
limiter = Limiter(key_func=get_remote_address)
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
db: Session = Depends(get_db)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
username = decode_access_token(token)
if username is None:
raise credentials_exception
user = get_user_by_username(db, username=username)
if user is None:
raise credentials_exception
return user
@router.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
@limiter.limit("5/minute")
async def register(request: Request, user_data: UserCreate, db: Session = Depends(get_db)):
existing_user = get_user_by_username(db, username=user_data.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
)
existing_email = get_user_by_email(db, email=user_data.email)
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
hashed_password = get_password_hash(user_data.password)
user = create_user(
db,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password
)
return user
@router.post("/token", response_model=Token)
@limiter.limit("5/minute")
async def login(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_db)
):
user = get_user_by_username(db, username=form_data.username)
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/me", response_model=User)
@limiter.limit("30/minute")
async def get_current_user_info(
request: Request,
current_user: Annotated[User, Depends(get_current_user)]
):
return current_user

View File

@@ -0,0 +1,156 @@
import logging
import json
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.config import settings
from core.database import get_db
from core.cache_manager import VoiceCacheManager
from api.auth import get_current_user
from db.crud import list_cache_entries, delete_cache_entry
from db.models import VoiceCache, User
from utils.metrics import cache_metrics
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/cache", tags=["cache"])
limiter = Limiter(key_func=get_remote_address)
@router.get("/voices")
@limiter.limit("30/minute")
async def list_user_caches(
request: Request,
skip: int = 0,
limit: int = 100,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
caches = list_cache_entries(db, current_user.id, skip=skip, limit=limit)
result = []
for cache in caches:
meta_data = json.loads(cache.meta_data) if cache.meta_data else {}
cache_file = Path(cache.cache_path)
file_size_mb = cache_file.stat().st_size / (1024 * 1024) if cache_file.exists() else 0
result.append({
'id': cache.id,
'ref_audio_hash': cache.ref_audio_hash,
'created_at': cache.created_at.isoformat(),
'last_accessed': cache.last_accessed.isoformat(),
'access_count': cache.access_count,
'metadata': meta_data,
'size_mb': round(file_size_mb, 2)
})
return {
'caches': result,
'total': len(result)
}
@router.delete("/voices/{cache_id}")
@limiter.limit("30/minute")
async def delete_user_cache(
request: Request,
cache_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
cache = db.query(VoiceCache).filter(
VoiceCache.id == cache_id,
VoiceCache.user_id == current_user.id
).first()
if not cache:
raise HTTPException(status_code=404, detail="Cache not found")
cache_file = Path(cache.cache_path)
if cache_file.exists():
cache_file.unlink()
success = delete_cache_entry(db, cache_id, current_user.id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete cache")
logger.info(f"Cache deleted: id={cache_id}, user={current_user.id}")
return {
'message': 'Cache deleted successfully',
'cache_id': cache_id
}
@router.delete("/voices")
@limiter.limit("10/minute")
async def cleanup_expired_caches(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
cache_manager = await VoiceCacheManager.get_instance()
deleted_count = await cache_manager.cleanup_expired(db)
logger.info(f"Expired cache cleanup: user={current_user.id}, deleted={deleted_count}")
return {
'message': 'Expired caches cleaned up',
'deleted_count': deleted_count
}
@router.post("/voices/prune")
@limiter.limit("10/minute")
async def prune_caches(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
cache_manager = await VoiceCacheManager.get_instance()
deleted_count = await cache_manager.enforce_max_entries(current_user.id, db)
logger.info(f"LRU prune: user={current_user.id}, deleted={deleted_count}")
return {
'message': 'LRU pruning completed',
'deleted_count': deleted_count
}
@router.get("/stats")
@limiter.limit("30/minute")
async def get_cache_stats(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
stats = cache_metrics.get_stats(db, settings.CACHE_DIR)
user_stats = None
for user_stat in stats['users']:
if user_stat['user_id'] == current_user.id:
user_stats = user_stat
break
if user_stats is None:
user_cache_count = db.query(VoiceCache).filter(
VoiceCache.user_id == current_user.id
).count()
user_stats = {
'user_id': current_user.id,
'hits': 0,
'misses': 0,
'hit_rate': 0.0,
'cache_entries': user_cache_count
}
return {
'global': stats['global'],
'user': user_stats
}

View File

@@ -0,0 +1,176 @@
import logging
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.database import get_db
from core.config import settings
from db.models import Job, JobStatus, User
from api.auth import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/jobs", tags=["jobs"])
limiter = Limiter(key_func=get_remote_address)
@router.get("/{job_id}")
@limiter.limit("30/minute")
async def get_job(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
download_url = None
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
return {
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"output_path": job.output_path,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
"started_at": job.started_at.isoformat() + 'Z' if job.started_at else None,
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
}
@router.get("")
@limiter.limit("30/minute")
async def list_jobs(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
status: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
query = db.query(Job).filter(Job.user_id == current_user.id)
if status:
try:
status_enum = JobStatus(status)
query = query.filter(Job.status == status_enum)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid status: {status}")
total = query.count()
jobs = query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
jobs_data = []
for job in jobs:
download_url = None
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
jobs_data.append({
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"output_path": job.output_path,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
})
return {
"total": total,
"skip": skip,
"limit": limit,
"jobs": jobs_data
}
@router.delete("/{job_id}")
@limiter.limit("30/minute")
async def delete_job(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
if job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
try:
output_file.unlink()
logger.info(f"Deleted output file: {output_file}")
except Exception as e:
logger.error(f"Failed to delete output file {output_file}: {e}")
db.delete(job)
db.commit()
return {"message": "Job deleted successfully"}
@router.get("/{job_id}/download")
@limiter.limit("30/minute")
async def download_job_output(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
if job.status != JobStatus.COMPLETED:
raise HTTPException(status_code=400, detail="Job not completed yet")
if not job.output_path:
raise HTTPException(status_code=404, detail="Output file not found")
output_file = Path(job.output_path)
if not output_file.exists():
raise HTTPException(status_code=404, detail="Output file does not exist")
output_dir = Path(settings.OUTPUT_DIR).resolve()
if not output_file.resolve().is_relative_to(output_dir):
logger.warning(f"Path traversal attempt detected: {output_file}")
raise HTTPException(status_code=403, detail="Access denied")
return FileResponse(
path=str(output_file),
media_type="audio/wav",
filename=output_file.name,
headers={
"Content-Disposition": f'attachment; filename="{output_file.name}"'
}
)

View File

@@ -0,0 +1,21 @@
import logging
from fastapi import APIRouter
from core.metrics import MetricsCollector
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/metrics", tags=["metrics"])
@router.get("")
async def get_metrics():
metrics = await MetricsCollector.get_instance()
data = await metrics.get_metrics()
return data
@router.post("/reset")
async def reset_metrics():
metrics = await MetricsCollector.get_instance()
await metrics.reset()
return {"message": "Metrics reset successfully"}

View File

@@ -0,0 +1,553 @@
import logging
import tempfile
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
from sqlalchemy.orm import Session
from typing import Optional
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.config import settings
from core.database import get_db
from core.model_manager import ModelManager
from core.cache_manager import VoiceCacheManager
from db.models import Job, JobStatus, User
from schemas.tts import CustomVoiceRequest, VoiceDesignRequest
from api.auth import get_current_user
from utils.validation import (
validate_language,
validate_speaker,
validate_text_length,
validate_generation_params,
get_supported_languages,
get_supported_speakers
)
from utils.audio import save_audio_file, validate_ref_audio, process_ref_audio, extract_audio_features
from utils.metrics import cache_metrics
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tts", tags=["tts"])
limiter = Limiter(key_func=get_remote_address)
async def process_custom_voice_job(
job_id: int,
user_id: int,
request_data: dict,
db_url: str
):
from core.database import SessionLocal
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing custom-voice job {job_id}")
model_manager = await ModelManager.get_instance()
await model_manager.load_model("custom-voice")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load custom-voice model")
result = tts.generate_custom_voice(
text=request_data['text'],
language=request_data['language'],
speaker=request_data['speaker'],
instruct=request_data.get('instruct', ''),
max_new_tokens=request_data['max_new_tokens'],
temperature=request_data['temperature'],
top_k=request_data['top_k'],
top_p=request_data['top_p'],
repetition_penalty=request_data['repetition_penalty']
)
import numpy as np
if isinstance(result, tuple):
audio_data = result[0]
elif isinstance(result, list):
audio_data = np.array(result)
else:
audio_data = result
from pathlib import Path
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
save_audio_file(audio_data, 24000, output_path)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_design_job(
job_id: int,
user_id: int,
request_data: dict,
db_url: str
):
from core.database import SessionLocal
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-design job {job_id}")
model_manager = await ModelManager.get_instance()
await model_manager.load_model("voice-design")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load voice-design model")
result = tts.generate_voice_design(
text=request_data['text'],
language=request_data['language'],
instruct=request_data['instruct'],
max_new_tokens=request_data['max_new_tokens'],
temperature=request_data['temperature'],
top_k=request_data['top_k'],
top_p=request_data['top_p'],
repetition_penalty=request_data['repetition_penalty']
)
import numpy as np
if isinstance(result, tuple):
audio_data = result[0]
elif isinstance(result, list):
audio_data = np.array(result)
else:
audio_data = result
from pathlib import Path
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
save_audio_file(audio_data, 24000, output_path)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_clone_job(
job_id: int,
user_id: int,
request_data: dict,
ref_audio_path: str,
db_url: str
):
from core.database import SessionLocal
import numpy as np
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-clone job {job_id}")
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
x_vector = None
cache_id = None
if request_data.get('use_cache', True):
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
if cached:
x_vector = cached['data']
cache_id = cached['cache_id']
cache_metrics.record_hit(user_id)
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
if x_vector is None:
cache_metrics.record_miss(user_id)
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=request_data.get('ref_text', ''),
x_vector_only_mode=request_data.get('x_vector_only_mode', False)
)
if request_data.get('use_cache', True):
features = extract_audio_features(ref_audio_array, ref_sr)
metadata = {
'duration': features['duration'],
'sample_rate': features['sample_rate'],
'ref_text': request_data.get('ref_text', ''),
'x_vector_only_mode': request_data.get('x_vector_only_mode', False)
}
cache_id = await cache_manager.set_cache(
user_id, ref_audio_hash, x_vector, metadata, db
)
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
if request_data.get('x_vector_only_mode', False):
job.status = JobStatus.COMPLETED
job.output_path = f"x_vector_cached_{cache_id}"
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed (x_vector_only_mode)")
return
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
wavs, sample_rate = tts.generate_voice_clone(
text=request_data['text'],
language=request_data['language'],
voice_clone_prompt=x_vector,
max_new_tokens=request_data['max_new_tokens'],
temperature=request_data['temperature'],
top_k=request_data['top_k'],
top_p=request_data['top_p'],
repetition_penalty=request_data['repetition_penalty']
)
audio_data = wavs[0] if isinstance(wavs, list) else wavs
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
save_audio_file(audio_data, sample_rate, output_path)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
finally:
if Path(ref_audio_path).exists():
Path(ref_audio_path).unlink()
db.close()
@router.post("/custom-voice")
@limiter.limit("10/minute")
async def create_custom_voice_job(
request: Request,
req_data: CustomVoiceRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
speaker = validate_speaker(req_data.speaker)
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="custom-voice",
status=JobStatus.PENDING,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
background_tasks.add_task(
process_custom_voice_job,
job.id,
current_user.id,
request_data,
str(settings.DATABASE_URL)
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-design")
@limiter.limit("10/minute")
async def create_voice_design_job(
request: Request,
req_data: VoiceDesignRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
if not req_data.instruct or not req_data.instruct.strip():
raise ValueError("Instruct parameter is required for voice design")
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-design",
status=JobStatus.PENDING,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
background_tasks.add_task(
process_voice_design_job,
job.id,
current_user.id,
request_data,
str(settings.DATABASE_URL)
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-clone")
@limiter.limit("10/minute")
async def create_voice_clone_job(
request: Request,
text: str = Form(...),
language: str = Form(default="Auto"),
ref_audio: UploadFile = File(...),
ref_text: Optional[str] = Form(default=None),
use_cache: bool = Form(default=True),
x_vector_only_mode: bool = Form(default=False),
max_new_tokens: Optional[int] = Form(default=2048),
temperature: Optional[float] = Form(default=0.9),
top_k: Optional[int] = Form(default=50),
top_p: Optional[float] = Form(default=1.0),
repetition_penalty: Optional[float] = Form(default=1.05),
background_tasks: BackgroundTasks = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
try:
validate_text_length(text)
language = validate_language(language)
params = validate_generation_params({
'max_new_tokens': max_new_tokens,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': repetition_penalty
})
ref_audio_data = await ref_audio.read()
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-clone",
status=JobStatus.PENDING,
input_data="",
input_params={
"text": text,
"language": language,
"ref_text": ref_text or "",
"ref_audio_hash": ref_audio_hash,
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(ref_audio_data)
tmp_audio_path = tmp_file.name
request_data = {
"text": text,
"language": language,
"ref_text": ref_text or "",
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
**params
}
background_tasks.add_task(
process_voice_clone_job,
job.id,
current_user.id,
request_data,
tmp_audio_path,
str(settings.DATABASE_URL)
)
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
cache_info = {"cache_id": existing_cache['cache_id']} if existing_cache else None
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully",
"cache_info": cache_info
}
@router.get("/models")
@limiter.limit("30/minute")
async def list_models(request: Request):
model_manager = await ModelManager.get_instance()
return model_manager.get_model_info()
@router.get("/speakers")
@limiter.limit("30/minute")
async def list_speakers(request: Request):
return get_supported_speakers()
@router.get("/languages")
@limiter.limit("30/minute")
async def list_languages(request: Request):
return get_supported_languages()

View File

@@ -0,0 +1,169 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from api.auth import get_current_user
from config import settings
from core.security import get_password_hash
from db.database import get_db
from db.crud import (
get_user_by_id,
get_user_by_username,
get_user_by_email,
list_users,
create_user_by_admin,
update_user,
delete_user
)
from schemas.user import User, UserCreateByAdmin, UserUpdate, UserListResponse
router = APIRouter(prefix="/users", tags=["users"])
limiter = Limiter(key_func=get_remote_address)
async def require_superuser(
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Superuser access required"
)
return current_user
@router.get("", response_model=UserListResponse)
@limiter.limit("30/minute")
async def get_users(
request: Request,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
users, total = list_users(db, skip=skip, limit=limit)
return UserListResponse(users=users, total=total, skip=skip, limit=limit)
@router.post("", response_model=User, status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def create_user(
request: Request,
user_data: UserCreateByAdmin,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
existing_user = get_user_by_username(db, username=user_data.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
)
existing_email = get_user_by_email(db, email=user_data.email)
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
hashed_password = get_password_hash(user_data.password)
user = create_user_by_admin(
db,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
is_superuser=user_data.is_superuser
)
return user
@router.get("/{user_id}", response_model=User)
@limiter.limit("30/minute")
async def get_user(
request: Request,
user_id: int,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
user = get_user_by_id(db, user_id=user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.put("/{user_id}", response_model=User)
@limiter.limit("10/minute")
async def update_user_info(
request: Request,
user_id: int,
user_data: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_superuser)
):
existing_user = get_user_by_id(db, user_id=user_id)
if not existing_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if user_data.username is not None:
username_exists = get_user_by_username(db, username=user_data.username)
if username_exists and username_exists.id != user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already taken"
)
if user_data.email is not None:
email_exists = get_user_by_email(db, email=user_data.email)
if email_exists and email_exists.id != user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already taken"
)
hashed_password = None
if user_data.password is not None:
hashed_password = get_password_hash(user_data.password)
user = update_user(
db,
user_id=user_id,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
is_active=user_data.is_active,
is_superuser=user_data.is_superuser
)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@limiter.limit("10/minute")
async def delete_user_by_id(
request: Request,
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_superuser)
):
if user_id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete yourself"
)
success = delete_user(db, user_id=user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)