feat: Enhance API interactions and improve job handling with new request validation and error management
This commit is contained in:
@@ -20,9 +20,8 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
async def get_user_from_token_or_query(
|
||||
async def get_user_from_bearer_token(
|
||||
request: Request,
|
||||
token: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
auth_token = None
|
||||
@@ -30,8 +29,6 @@ async def get_user_from_token_or_query(
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
elif token:
|
||||
auth_token = token
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
@@ -76,14 +73,13 @@ async def get_job(
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
|
||||
download_url = f"/jobs/{job.id}/download"
|
||||
|
||||
return {
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"output_path": job.output_path,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
@@ -120,14 +116,13 @@ async def list_jobs(
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
|
||||
download_url = f"/jobs/{job.id}/download"
|
||||
|
||||
jobs_data.append({
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"output_path": job.output_path,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
@@ -158,8 +153,15 @@ async def delete_job(
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
output_file = None
|
||||
if job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
output_file = Path(job.output_path).resolve()
|
||||
output_dir = Path(settings.OUTPUT_DIR).resolve()
|
||||
if not output_file.is_relative_to(output_dir):
|
||||
logger.warning(f"Skip deleting file outside output dir: {output_file}")
|
||||
output_file = None
|
||||
|
||||
if output_file:
|
||||
if output_file.exists():
|
||||
try:
|
||||
output_file.unlink()
|
||||
@@ -178,7 +180,7 @@ async def delete_job(
|
||||
async def download_job_output(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_user_from_token_or_query),
|
||||
current_user: User = Depends(get_user_from_bearer_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from slowapi import Limiter
|
||||
@@ -32,6 +32,20 @@ router = APIRouter(prefix="/tts", tags=["tts"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
async def read_upload_with_size_limit(upload_file: UploadFile, max_size_bytes: int) -> bytes:
|
||||
chunks = []
|
||||
total = 0
|
||||
while True:
|
||||
chunk = await upload_file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > max_size_bytes:
|
||||
raise ValueError(f"Audio file exceeds {max_size_bytes // (1024 * 1024)}MB limit")
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
async def process_custom_voice_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
@@ -85,7 +99,7 @@ async def process_custom_voice_job(
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
@@ -150,7 +164,7 @@ async def process_voice_design_job(
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
@@ -303,7 +317,7 @@ async def process_voice_clone_job(
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
@@ -594,7 +608,8 @@ async def create_voice_clone_job(
|
||||
if not ref_audio:
|
||||
raise ValueError("Either ref_audio or voice_design_id must be provided")
|
||||
|
||||
ref_audio_data = await ref_audio.read()
|
||||
max_audio_size_bytes = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024
|
||||
ref_audio_data = await read_upload_with_size_limit(ref_audio, max_audio_size_bytes)
|
||||
|
||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
@@ -20,6 +21,28 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/voice-designs", tags=["voice-designs"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
def to_voice_design_response(design) -> VoiceDesignResponse:
|
||||
meta_data = design.meta_data
|
||||
if isinstance(meta_data, str):
|
||||
try:
|
||||
meta_data = json.loads(meta_data)
|
||||
except Exception:
|
||||
meta_data = None
|
||||
return VoiceDesignResponse(
|
||||
id=design.id,
|
||||
user_id=design.user_id,
|
||||
name=design.name,
|
||||
backend_type=design.backend_type,
|
||||
instruct=design.instruct,
|
||||
aliyun_voice_id=design.aliyun_voice_id,
|
||||
meta_data=meta_data,
|
||||
preview_text=design.preview_text,
|
||||
created_at=design.created_at,
|
||||
last_used=design.last_used,
|
||||
use_count=design.use_count
|
||||
)
|
||||
|
||||
@router.post("", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("30/minute")
|
||||
async def save_voice_design(
|
||||
@@ -39,7 +62,7 @@ async def save_voice_design(
|
||||
meta_data=data.meta_data,
|
||||
preview_text=data.preview_text
|
||||
)
|
||||
return VoiceDesignResponse.from_orm(design)
|
||||
return to_voice_design_response(design)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save voice design: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to save voice design")
|
||||
@@ -55,7 +78,8 @@ async def list_voice_designs(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
designs = crud.list_voice_designs(db, current_user.id, backend_type, skip, limit)
|
||||
return VoiceDesignListResponse(designs=[VoiceDesignResponse.from_orm(d) for d in designs], total=len(designs))
|
||||
total = crud.count_voice_designs(db, current_user.id, backend_type)
|
||||
return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total)
|
||||
|
||||
@router.post("/{design_id}/prepare-clone")
|
||||
@limiter.limit("10/minute")
|
||||
@@ -168,4 +192,4 @@ async def prepare_voice_clone_prompt(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail="Failed to prepare voice clone prompt")
|
||||
|
||||
Reference in New Issue
Block a user