feat: Enhance API interactions and improve job handling with new request validation and error management

This commit is contained in:
2026-03-06 12:03:41 +08:00
parent 3844e825cd
commit a93754f449
15 changed files with 204 additions and 74 deletions

View File

@@ -20,9 +20,8 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
limiter = Limiter(key_func=get_remote_address)
async def get_user_from_token_or_query(
async def get_user_from_bearer_token(
request: Request,
token: Optional[str] = Query(None),
db: Session = Depends(get_db)
) -> User:
auth_token = None
@@ -30,8 +29,6 @@ async def get_user_from_token_or_query(
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
elif token:
auth_token = token
if not auth_token:
raise HTTPException(
@@ -76,14 +73,13 @@ async def get_job(
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
download_url = f"/jobs/{job.id}/download"
return {
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"output_path": job.output_path,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
@@ -120,14 +116,13 @@ async def list_jobs(
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
download_url = f"/jobs/{job.id}/download"
jobs_data.append({
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"output_path": job.output_path,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
@@ -158,8 +153,15 @@ async def delete_job(
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
output_file = None
if job.output_path:
output_file = Path(job.output_path)
output_file = Path(job.output_path).resolve()
output_dir = Path(settings.OUTPUT_DIR).resolve()
if not output_file.is_relative_to(output_dir):
logger.warning(f"Skip deleting file outside output dir: {output_file}")
output_file = None
if output_file:
if output_file.exists():
try:
output_file.unlink()
@@ -178,7 +180,7 @@ async def delete_job(
async def download_job_output(
request: Request,
job_id: int,
current_user: User = Depends(get_user_from_token_or_query),
current_user: User = Depends(get_user_from_bearer_token),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()

View File

@@ -2,7 +2,7 @@ import logging
import tempfile
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request, status
from sqlalchemy.orm import Session
from typing import Optional
from slowapi import Limiter
@@ -32,6 +32,20 @@ router = APIRouter(prefix="/tts", tags=["tts"])
limiter = Limiter(key_func=get_remote_address)
async def read_upload_with_size_limit(upload_file: UploadFile, max_size_bytes: int) -> bytes:
chunks = []
total = 0
while True:
chunk = await upload_file.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
if total > max_size_bytes:
raise ValueError(f"Audio file exceeds {max_size_bytes // (1024 * 1024)}MB limit")
chunks.append(chunk)
return b"".join(chunks)
async def process_custom_voice_job(
job_id: int,
user_id: int,
@@ -85,7 +99,7 @@ async def process_custom_voice_job(
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
@@ -150,7 +164,7 @@ async def process_voice_design_job(
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
@@ -303,7 +317,7 @@ async def process_voice_clone_job(
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
@@ -594,7 +608,8 @@ async def create_voice_clone_job(
if not ref_audio:
raise ValueError("Either ref_audio or voice_design_id must be provided")
ref_audio_data = await ref_audio.read()
max_audio_size_bytes = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024
ref_audio_data = await read_upload_with_size_limit(ref_audio, max_audio_size_bytes)
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")

View File

@@ -1,4 +1,5 @@
import logging
import json
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
from sqlalchemy.orm import Session
from typing import Optional
@@ -20,6 +21,28 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/voice-designs", tags=["voice-designs"])
limiter = Limiter(key_func=get_remote_address)
def to_voice_design_response(design) -> VoiceDesignResponse:
meta_data = design.meta_data
if isinstance(meta_data, str):
try:
meta_data = json.loads(meta_data)
except Exception:
meta_data = None
return VoiceDesignResponse(
id=design.id,
user_id=design.user_id,
name=design.name,
backend_type=design.backend_type,
instruct=design.instruct,
aliyun_voice_id=design.aliyun_voice_id,
meta_data=meta_data,
preview_text=design.preview_text,
created_at=design.created_at,
last_used=design.last_used,
use_count=design.use_count
)
@router.post("", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("30/minute")
async def save_voice_design(
@@ -39,7 +62,7 @@ async def save_voice_design(
meta_data=data.meta_data,
preview_text=data.preview_text
)
return VoiceDesignResponse.from_orm(design)
return to_voice_design_response(design)
except Exception as e:
logger.error(f"Failed to save voice design: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to save voice design")
@@ -55,7 +78,8 @@ async def list_voice_designs(
db: Session = Depends(get_db)
):
designs = crud.list_voice_designs(db, current_user.id, backend_type, skip, limit)
return VoiceDesignListResponse(designs=[VoiceDesignResponse.from_orm(d) for d in designs], total=len(designs))
total = crud.count_voice_designs(db, current_user.id, backend_type)
return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total)
@router.post("/{design_id}/prepare-clone")
@limiter.limit("10/minute")
@@ -168,4 +192,4 @@ async def prepare_voice_clone_prompt(
except Exception as e:
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail="Failed to prepare voice clone prompt")