feat: Enhance API interactions and improve job handling with new request validation and error management
This commit is contained in:
@@ -20,9 +20,8 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
async def get_user_from_token_or_query(
|
||||
async def get_user_from_bearer_token(
|
||||
request: Request,
|
||||
token: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
auth_token = None
|
||||
@@ -30,8 +29,6 @@ async def get_user_from_token_or_query(
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
elif token:
|
||||
auth_token = token
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
@@ -76,14 +73,13 @@ async def get_job(
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
|
||||
download_url = f"/jobs/{job.id}/download"
|
||||
|
||||
return {
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"output_path": job.output_path,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
@@ -120,14 +116,13 @@ async def list_jobs(
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
|
||||
download_url = f"/jobs/{job.id}/download"
|
||||
|
||||
jobs_data.append({
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"output_path": job.output_path,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
@@ -158,8 +153,15 @@ async def delete_job(
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
output_file = None
|
||||
if job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
output_file = Path(job.output_path).resolve()
|
||||
output_dir = Path(settings.OUTPUT_DIR).resolve()
|
||||
if not output_file.is_relative_to(output_dir):
|
||||
logger.warning(f"Skip deleting file outside output dir: {output_file}")
|
||||
output_file = None
|
||||
|
||||
if output_file:
|
||||
if output_file.exists():
|
||||
try:
|
||||
output_file.unlink()
|
||||
@@ -178,7 +180,7 @@ async def delete_job(
|
||||
async def download_job_output(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_user_from_token_or_query),
|
||||
current_user: User = Depends(get_user_from_bearer_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from slowapi import Limiter
|
||||
@@ -32,6 +32,20 @@ router = APIRouter(prefix="/tts", tags=["tts"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
async def read_upload_with_size_limit(upload_file: UploadFile, max_size_bytes: int) -> bytes:
|
||||
chunks = []
|
||||
total = 0
|
||||
while True:
|
||||
chunk = await upload_file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > max_size_bytes:
|
||||
raise ValueError(f"Audio file exceeds {max_size_bytes // (1024 * 1024)}MB limit")
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
async def process_custom_voice_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
@@ -85,7 +99,7 @@ async def process_custom_voice_job(
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
@@ -150,7 +164,7 @@ async def process_voice_design_job(
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
@@ -303,7 +317,7 @@ async def process_voice_clone_job(
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = str(e)
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
@@ -594,7 +608,8 @@ async def create_voice_clone_job(
|
||||
if not ref_audio:
|
||||
raise ValueError("Either ref_audio or voice_design_id must be provided")
|
||||
|
||||
ref_audio_data = await ref_audio.read()
|
||||
max_audio_size_bytes = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024
|
||||
ref_audio_data = await read_upload_with_size_limit(ref_audio, max_audio_size_bytes)
|
||||
|
||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
@@ -20,6 +21,28 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/voice-designs", tags=["voice-designs"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
def to_voice_design_response(design) -> VoiceDesignResponse:
|
||||
meta_data = design.meta_data
|
||||
if isinstance(meta_data, str):
|
||||
try:
|
||||
meta_data = json.loads(meta_data)
|
||||
except Exception:
|
||||
meta_data = None
|
||||
return VoiceDesignResponse(
|
||||
id=design.id,
|
||||
user_id=design.user_id,
|
||||
name=design.name,
|
||||
backend_type=design.backend_type,
|
||||
instruct=design.instruct,
|
||||
aliyun_voice_id=design.aliyun_voice_id,
|
||||
meta_data=meta_data,
|
||||
preview_text=design.preview_text,
|
||||
created_at=design.created_at,
|
||||
last_used=design.last_used,
|
||||
use_count=design.use_count
|
||||
)
|
||||
|
||||
@router.post("", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("30/minute")
|
||||
async def save_voice_design(
|
||||
@@ -39,7 +62,7 @@ async def save_voice_design(
|
||||
meta_data=data.meta_data,
|
||||
preview_text=data.preview_text
|
||||
)
|
||||
return VoiceDesignResponse.from_orm(design)
|
||||
return to_voice_design_response(design)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save voice design: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to save voice design")
|
||||
@@ -55,7 +78,8 @@ async def list_voice_designs(
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
designs = crud.list_voice_designs(db, current_user.id, backend_type, skip, limit)
|
||||
return VoiceDesignListResponse(designs=[VoiceDesignResponse.from_orm(d) for d in designs], total=len(designs))
|
||||
total = crud.count_voice_designs(db, current_user.id, backend_type)
|
||||
return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total)
|
||||
|
||||
@router.post("/{design_id}/prepare-clone")
|
||||
@limiter.limit("10/minute")
|
||||
@@ -168,4 +192,4 @@ async def prepare_voice_clone_prompt(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail="Failed to prepare voice clone prompt")
|
||||
|
||||
@@ -58,8 +58,7 @@ class Settings(BaseSettings):
|
||||
|
||||
def validate(self):
|
||||
if self.SECRET_KEY == "your-secret-key-change-this-in-production":
|
||||
import warnings
|
||||
warnings.warn("Using default SECRET_KEY! Change this in production!")
|
||||
raise ValueError("Insecure default SECRET_KEY is not allowed. Please set a strong SECRET_KEY in environment.")
|
||||
|
||||
Path(self.CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
||||
Path(self.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import hashlib
|
||||
import pickle
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from db.crud import (
|
||||
@@ -58,8 +58,13 @@ class VoiceCacheManager:
|
||||
delete_cache_entry(db, cache_entry.id, user_id)
|
||||
return None
|
||||
|
||||
resolved_cache_file = cache_file.resolve()
|
||||
if not resolved_cache_file.is_relative_to(self.cache_dir.resolve()):
|
||||
logger.warning(f"Cache path out of cache dir: {resolved_cache_file}")
|
||||
return None
|
||||
|
||||
with open(cache_file, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
cache_data = np.load(f, allow_pickle=False)
|
||||
|
||||
logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}")
|
||||
return {
|
||||
@@ -84,8 +89,13 @@ class VoiceCacheManager:
|
||||
logger.warning(f"Cache file missing: {cache_file}")
|
||||
return None
|
||||
|
||||
resolved_cache_file = cache_file.resolve()
|
||||
if not resolved_cache_file.is_relative_to(self.cache_dir.resolve()):
|
||||
logger.warning(f"Cache path out of cache dir: {resolved_cache_file}")
|
||||
return None
|
||||
|
||||
with open(cache_file, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
cache_data = np.load(f, allow_pickle=False)
|
||||
|
||||
cache_entry.last_accessed = datetime.utcnow()
|
||||
cache_entry.access_count += 1
|
||||
@@ -112,11 +122,16 @@ class VoiceCacheManager:
|
||||
) -> str:
|
||||
async with self._lock:
|
||||
try:
|
||||
cache_filename = f"{user_id}_{ref_audio_hash}.pkl"
|
||||
cache_filename = f"{user_id}_{ref_audio_hash}.npy"
|
||||
cache_path = self.cache_dir / cache_filename
|
||||
|
||||
if hasattr(cache_data, "detach"):
|
||||
cache_data = cache_data.detach().cpu().numpy()
|
||||
elif not isinstance(cache_data, np.ndarray):
|
||||
cache_data = np.asarray(cache_data, dtype=np.float32)
|
||||
|
||||
with open(cache_path, 'wb') as f:
|
||||
pickle.dump(cache_data, f)
|
||||
np.save(f, cache_data, allow_pickle=False)
|
||||
|
||||
cache_entry = create_cache_entry(
|
||||
db=db,
|
||||
|
||||
@@ -54,9 +54,17 @@ async def cleanup_old_jobs(db_url: str, days: int = 7) -> dict:
|
||||
).all()
|
||||
|
||||
deleted_files = 0
|
||||
output_dir = Path(settings.OUTPUT_DIR).resolve()
|
||||
for job in old_jobs:
|
||||
if job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
output_file = Path(job.output_path).resolve()
|
||||
if not output_file.is_relative_to(output_dir):
|
||||
logger.warning(f"Skip deleting file outside output dir during cleanup: {output_file}")
|
||||
output_file = None
|
||||
else:
|
||||
output_file = None
|
||||
|
||||
if output_file:
|
||||
if output_file.exists():
|
||||
output_file.unlink()
|
||||
deleted_files += 1
|
||||
@@ -110,7 +118,8 @@ async def cleanup_orphaned_files(db_url: str) -> dict:
|
||||
freed_space_bytes += size
|
||||
|
||||
if cache_dir.exists():
|
||||
for cache_file in cache_dir.glob("*.pkl"):
|
||||
for pattern in ("*.npy", "*.pkl"):
|
||||
for cache_file in cache_dir.glob(pattern):
|
||||
if cache_file.name not in cache_files_in_db:
|
||||
size = cache_file.stat().st_size
|
||||
cache_file.unlink()
|
||||
|
||||
@@ -288,7 +288,7 @@ def create_voice_design(
|
||||
backend_type=backend_type,
|
||||
instruct=instruct,
|
||||
aliyun_voice_id=aliyun_voice_id,
|
||||
meta_data=json.dumps(meta_data) if meta_data else None,
|
||||
meta_data=meta_data,
|
||||
preview_text=preview_text,
|
||||
created_at=datetime.utcnow(),
|
||||
last_used=datetime.utcnow()
|
||||
@@ -320,6 +320,19 @@ def list_voice_designs(
|
||||
query = query.filter(VoiceDesign.backend_type == backend_type)
|
||||
return query.order_by(VoiceDesign.last_used.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
def count_voice_designs(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
backend_type: Optional[str] = None
|
||||
) -> int:
|
||||
query = db.query(VoiceDesign).filter(
|
||||
VoiceDesign.user_id == user_id,
|
||||
VoiceDesign.is_active == True
|
||||
)
|
||||
if backend_type:
|
||||
query = query.filter(VoiceDesign.backend_type == backend_type)
|
||||
return query.count()
|
||||
|
||||
def update_voice_design_usage(db: Session, design_id: int, user_id: int) -> Optional[VoiceDesign]:
|
||||
design = get_voice_design(db, design_id, user_id)
|
||||
if design:
|
||||
@@ -328,4 +341,3 @@ def update_voice_design_usage(db: Session, design_id: int, user_id: int) -> Opti
|
||||
db.commit()
|
||||
db.refresh(design)
|
||||
return design
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, Depends, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
@@ -16,6 +16,8 @@ from core.database import init_db
|
||||
from core.model_manager import ModelManager
|
||||
from core.cleanup import run_scheduled_cleanup
|
||||
from api import auth, jobs, tts, users, voice_designs
|
||||
from api.auth import get_current_user
|
||||
from schemas.user import User
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -134,6 +136,14 @@ app.include_router(voice_designs.router)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/health/details")
|
||||
async def health_check_details(current_user: User = Depends(get_current_user)):
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(status_code=403, detail="Superuser access required")
|
||||
|
||||
from core.batch_processor import BatchProcessor
|
||||
from core.database import SessionLocal
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
VITE_API_URL=http://localhost:8000
|
||||
VITE_API_URL=/api
|
||||
VITE_APP_NAME=Qwen3-TTS
|
||||
|
||||
@@ -47,15 +47,15 @@ const CustomVoiceForm = forwardRef<CustomVoiceFormHandle>((_props, ref) => {
|
||||
const PRESET_INSTRUCTS = useMemo(() => tConstants('presetInstructs', { returnObjects: true }) as Array<{ label: string; instruct: string; text: string }>, [tConstants])
|
||||
|
||||
const formSchema = z.object({
|
||||
text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(5000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 5000 })),
|
||||
text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(1000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 1000 })),
|
||||
language: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.language') })),
|
||||
speaker: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.speaker') })),
|
||||
instruct: z.string().optional(),
|
||||
max_new_tokens: z.number().min(1).max(10000).optional(),
|
||||
temperature: z.number().min(0).max(2).optional(),
|
||||
max_new_tokens: z.number().min(128).max(4096).optional(),
|
||||
temperature: z.number().min(0.1).max(2).optional(),
|
||||
top_k: z.number().min(1).max(100).optional(),
|
||||
top_p: z.number().min(0).max(1).optional(),
|
||||
repetition_penalty: z.number().min(0).max(2).optional(),
|
||||
repetition_penalty: z.number().min(1).max(2).optional(),
|
||||
})
|
||||
const [languages, setLanguages] = useState<Language[]>([])
|
||||
const [unifiedSpeakers, setUnifiedSpeakers] = useState<UnifiedSpeakerItem[]>([])
|
||||
@@ -395,8 +395,8 @@ const CustomVoiceForm = forwardRef<CustomVoiceFormHandle>((_props, ref) => {
|
||||
<Input
|
||||
id="dialog-max_new_tokens"
|
||||
type="number"
|
||||
min={1}
|
||||
max={10000}
|
||||
min={128}
|
||||
max={4096}
|
||||
value={tempAdvancedParams.max_new_tokens}
|
||||
onChange={(e) => setTempAdvancedParams({
|
||||
...tempAdvancedParams,
|
||||
@@ -414,7 +414,7 @@ const CustomVoiceForm = forwardRef<CustomVoiceFormHandle>((_props, ref) => {
|
||||
<Input
|
||||
id="dialog-temperature"
|
||||
type="number"
|
||||
min={0}
|
||||
min={0.1}
|
||||
max={2}
|
||||
step={0.1}
|
||||
value={tempAdvancedParams.temperature}
|
||||
|
||||
@@ -49,17 +49,17 @@ function VoiceCloneForm() {
|
||||
const PRESET_REF_TEXTS = useMemo(() => tConstants('presetRefTexts', { returnObjects: true }) as Array<{ label: string; text: string }>, [tConstants])
|
||||
|
||||
const formSchema = z.object({
|
||||
text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(5000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 5000 })),
|
||||
text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(1000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 1000 })),
|
||||
language: z.string().optional(),
|
||||
ref_audio: z.instanceof(File, { message: tErrors('validation.required', { field: tErrors('fieldNames.reference_audio') }) }),
|
||||
ref_text: z.string().optional(),
|
||||
use_cache: z.boolean().optional(),
|
||||
x_vector_only_mode: z.boolean().optional(),
|
||||
max_new_tokens: z.number().min(1).max(10000).optional(),
|
||||
temperature: z.number().min(0).max(2).optional(),
|
||||
max_new_tokens: z.number().min(128).max(4096).optional(),
|
||||
temperature: z.number().min(0.1).max(2).optional(),
|
||||
top_k: z.number().min(1).max(100).optional(),
|
||||
top_p: z.number().min(0).max(1).optional(),
|
||||
repetition_penalty: z.number().min(0).max(2).optional(),
|
||||
repetition_penalty: z.number().min(1).max(2).optional(),
|
||||
})
|
||||
const [languages, setLanguages] = useState<Language[]>([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
@@ -358,8 +358,8 @@ function VoiceCloneForm() {
|
||||
<Input
|
||||
id="dialog-max_new_tokens"
|
||||
type="number"
|
||||
min={1}
|
||||
max={10000}
|
||||
min={128}
|
||||
max={4096}
|
||||
value={tempAdvancedParams.max_new_tokens}
|
||||
onChange={(e) => setTempAdvancedParams({
|
||||
...tempAdvancedParams,
|
||||
|
||||
@@ -46,14 +46,14 @@ const VoiceDesignForm = forwardRef<VoiceDesignFormHandle>((_props, ref) => {
|
||||
const PRESET_VOICE_DESIGNS = useMemo(() => tConstants('presetVoiceDesigns', { returnObjects: true }) as Array<{ label: string; instruct: string; text: string }>, [tConstants])
|
||||
|
||||
const formSchema = z.object({
|
||||
text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(5000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 5000 })),
|
||||
text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(1000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 1000 })),
|
||||
language: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.language') })),
|
||||
instruct: z.string().min(10, tErrors('validation.minLength', { field: tErrors('fieldNames.instruct'), min: 10 })).max(500, tErrors('validation.maxLength', { field: tErrors('fieldNames.instruct'), max: 500 })),
|
||||
max_new_tokens: z.number().min(1).max(10000).optional(),
|
||||
temperature: z.number().min(0).max(2).optional(),
|
||||
max_new_tokens: z.number().min(128).max(4096).optional(),
|
||||
temperature: z.number().min(0.1).max(2).optional(),
|
||||
top_k: z.number().min(1).max(100).optional(),
|
||||
top_p: z.number().min(0).max(1).optional(),
|
||||
repetition_penalty: z.number().min(0).max(2).optional(),
|
||||
repetition_penalty: z.number().min(1).max(2).optional(),
|
||||
})
|
||||
const [languages, setLanguages] = useState<Language[]>([])
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
@@ -310,8 +310,8 @@ const VoiceDesignForm = forwardRef<VoiceDesignFormHandle>((_props, ref) => {
|
||||
<Input
|
||||
id="dialog-max_new_tokens"
|
||||
type="number"
|
||||
min={1}
|
||||
max={10000}
|
||||
min={128}
|
||||
max={4096}
|
||||
value={tempAdvancedParams.max_new_tokens}
|
||||
onChange={(e) => setTempAdvancedParams({
|
||||
...tempAdvancedParams,
|
||||
@@ -329,7 +329,7 @@ const VoiceDesignForm = forwardRef<VoiceDesignFormHandle>((_props, ref) => {
|
||||
<Input
|
||||
id="dialog-temperature"
|
||||
type="number"
|
||||
min={0}
|
||||
min={0.1}
|
||||
max={2}
|
||||
step={0.1}
|
||||
value={tempAdvancedParams.temperature}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { createContext, useContext, useState, useCallback, useMemo, type ReactNode } from 'react'
|
||||
import { createContext, useContext, useState, useCallback, useMemo, useRef, useEffect, type ReactNode } from 'react'
|
||||
import { toast } from 'sonner'
|
||||
import { jobApi } from '@/lib/api'
|
||||
import type { Job, JobStatus } from '@/types/job'
|
||||
@@ -25,13 +25,27 @@ export function JobProvider({ children }: { children: ReactNode }) {
|
||||
const [elapsedTime, setElapsedTime] = useState(0)
|
||||
|
||||
const { refresh: historyRefresh } = useHistoryContext()
|
||||
const pollIntervalRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
||||
const timeIntervalRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
||||
|
||||
const clearIntervals = useCallback(() => {
|
||||
if (pollIntervalRef.current) {
|
||||
clearInterval(pollIntervalRef.current)
|
||||
pollIntervalRef.current = null
|
||||
}
|
||||
if (timeIntervalRef.current) {
|
||||
clearInterval(timeIntervalRef.current)
|
||||
timeIntervalRef.current = null
|
||||
}
|
||||
}, [])
|
||||
|
||||
const stopJob = useCallback(() => {
|
||||
clearIntervals()
|
||||
setCurrentJob(null)
|
||||
setStatus(null)
|
||||
setError(null)
|
||||
setElapsedTime(0)
|
||||
}, [])
|
||||
}, [clearIntervals])
|
||||
|
||||
const resetJob = useCallback(() => {
|
||||
setError(null)
|
||||
@@ -45,15 +59,13 @@ export function JobProvider({ children }: { children: ReactNode }) {
|
||||
}, [])
|
||||
|
||||
const startJob = useCallback((jobId: number) => {
|
||||
clearIntervals()
|
||||
// Reset state for new job
|
||||
setCurrentJob(null)
|
||||
setStatus('pending')
|
||||
setError(null)
|
||||
setElapsedTime(0)
|
||||
|
||||
let pollInterval: ReturnType<typeof setInterval> | null = null
|
||||
let timeInterval: ReturnType<typeof setInterval> | null = null
|
||||
|
||||
const poll = async () => {
|
||||
try {
|
||||
const job = await jobApi.getJob(jobId)
|
||||
@@ -61,15 +73,13 @@ export function JobProvider({ children }: { children: ReactNode }) {
|
||||
setStatus(job.status)
|
||||
|
||||
if (job.status === 'completed') {
|
||||
if (pollInterval) clearInterval(pollInterval)
|
||||
if (timeInterval) clearInterval(timeInterval)
|
||||
clearIntervals()
|
||||
toast.success('任务完成!')
|
||||
try {
|
||||
historyRefresh()
|
||||
} catch {}
|
||||
} else if (job.status === 'failed') {
|
||||
if (pollInterval) clearInterval(pollInterval)
|
||||
if (timeInterval) clearInterval(timeInterval)
|
||||
clearIntervals()
|
||||
setError(job.error_message || '任务失败')
|
||||
toast.error(job.error_message || '任务失败')
|
||||
try {
|
||||
@@ -77,8 +87,7 @@ export function JobProvider({ children }: { children: ReactNode }) {
|
||||
} catch {}
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (pollInterval) clearInterval(pollInterval)
|
||||
if (timeInterval) clearInterval(timeInterval)
|
||||
clearIntervals()
|
||||
const message = error.response?.data?.detail || '获取任务状态失败'
|
||||
setError(message)
|
||||
toast.error(message)
|
||||
@@ -86,16 +95,17 @@ export function JobProvider({ children }: { children: ReactNode }) {
|
||||
}
|
||||
|
||||
poll()
|
||||
pollInterval = setInterval(poll, POLL_INTERVAL)
|
||||
timeInterval = setInterval(() => {
|
||||
pollIntervalRef.current = setInterval(poll, POLL_INTERVAL)
|
||||
timeIntervalRef.current = setInterval(() => {
|
||||
setElapsedTime((prev) => prev + 1)
|
||||
}, 1000)
|
||||
}, [historyRefresh, clearIntervals])
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (pollInterval) clearInterval(pollInterval)
|
||||
if (timeInterval) clearInterval(timeInterval)
|
||||
clearIntervals()
|
||||
}
|
||||
}, [historyRefresh])
|
||||
}, [clearIntervals])
|
||||
|
||||
const value = useMemo(
|
||||
() => ({
|
||||
|
||||
@@ -13,9 +13,22 @@ const apiClient = axios.create({
|
||||
},
|
||||
})
|
||||
|
||||
const isTrustedApiRequest = (url?: string, baseURL?: string): boolean => {
|
||||
if (!url) return false
|
||||
if (url.startsWith('/')) return true
|
||||
|
||||
try {
|
||||
const resolvedUrl = new URL(url, baseURL || window.location.origin)
|
||||
const apiOrigin = baseURL ? new URL(baseURL, window.location.origin).origin : window.location.origin
|
||||
return resolvedUrl.origin === apiOrigin
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
apiClient.interceptors.request.use((config) => {
|
||||
const token = localStorage.getItem('token')
|
||||
if (token) {
|
||||
if (token && isTrustedApiRequest(config.url, config.baseURL || import.meta.env.VITE_API_URL)) {
|
||||
config.headers.Authorization = `Bearer ${token}`
|
||||
}
|
||||
return config
|
||||
@@ -346,11 +359,23 @@ export const jobApi = {
|
||||
getAudioUrl: (id: number, audioPath?: string): string => {
|
||||
if (audioPath) {
|
||||
if (audioPath.startsWith('http')) {
|
||||
const apiBase = import.meta.env.VITE_API_URL
|
||||
if (apiBase) {
|
||||
try {
|
||||
const audioOrigin = new URL(audioPath).origin
|
||||
const apiOrigin = new URL(apiBase, window.location.origin).origin
|
||||
if (audioOrigin !== apiOrigin) {
|
||||
return API_ENDPOINTS.JOBS.AUDIO(id)
|
||||
}
|
||||
} catch {
|
||||
return API_ENDPOINTS.JOBS.AUDIO(id)
|
||||
}
|
||||
}
|
||||
if (audioPath.includes('localhost') || audioPath.includes('127.0.0.1')) {
|
||||
const url = new URL(audioPath)
|
||||
return url.pathname
|
||||
}
|
||||
return audioPath
|
||||
return API_ENDPOINTS.JOBS.AUDIO(id)
|
||||
} else {
|
||||
return audioPath.startsWith('/') ? audioPath : `/${audioPath}`
|
||||
}
|
||||
|
||||
@@ -4,6 +4,15 @@ import { defineConfig } from "vite"
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://127.0.0.1:8000',
|
||||
changeOrigin: true,
|
||||
rewrite: (path) => path.replace(/^\/api/, ''),
|
||||
},
|
||||
},
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
"@": path.resolve(__dirname, "./src"),
|
||||
|
||||
Reference in New Issue
Block a user