From a93754f449266cff383cbbaaff217461db516f63 Mon Sep 17 00:00:00 2001 From: bdim404 Date: Fri, 6 Mar 2026 12:03:41 +0800 Subject: [PATCH] feat: Enhance API interactions and improve job handling with new request validation and error management --- qwen3-tts-backend/api/jobs.py | 22 +++++----- qwen3-tts-backend/api/tts.py | 25 ++++++++--- qwen3-tts-backend/api/voice_designs.py | 30 +++++++++++-- qwen3-tts-backend/config.py | 3 +- qwen3-tts-backend/core/cache_manager.py | 25 ++++++++--- qwen3-tts-backend/core/cleanup.py | 23 ++++++---- qwen3-tts-backend/db/crud.py | 16 ++++++- qwen3-tts-backend/main.py | 12 +++++- qwen3-tts-frontend/.env.example | 2 +- .../src/components/tts/CustomVoiceForm.tsx | 14 +++---- .../src/components/tts/VoiceCloneForm.tsx | 12 +++--- .../src/components/tts/VoiceDesignForm.tsx | 14 +++---- .../src/contexts/JobContext.tsx | 42 ++++++++++++------- qwen3-tts-frontend/src/lib/api.ts | 29 ++++++++++++- qwen3-tts-frontend/vite.config.ts | 9 ++++ 15 files changed, 204 insertions(+), 74 deletions(-) diff --git a/qwen3-tts-backend/api/jobs.py b/qwen3-tts-backend/api/jobs.py index 4b61323..36cfb3d 100644 --- a/qwen3-tts-backend/api/jobs.py +++ b/qwen3-tts-backend/api/jobs.py @@ -20,9 +20,8 @@ router = APIRouter(prefix="/jobs", tags=["jobs"]) limiter = Limiter(key_func=get_remote_address) -async def get_user_from_token_or_query( +async def get_user_from_bearer_token( request: Request, - token: Optional[str] = Query(None), db: Session = Depends(get_db) ) -> User: auth_token = None @@ -30,8 +29,6 @@ async def get_user_from_token_or_query( auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): auth_token = auth_header.split(" ")[1] - elif token: - auth_token = token if not auth_token: raise HTTPException( @@ -76,14 +73,13 @@ async def get_job( if job.status == JobStatus.COMPLETED and job.output_path: output_file = Path(job.output_path) if output_file.exists(): - download_url = f"{settings.BASE_URL}/jobs/{job.id}/download" + download_url = f"/jobs/{job.id}/download" return { "id": job.id, "job_type": job.job_type, "status": job.status, "input_params": job.input_params, - "output_path": job.output_path, "download_url": download_url, "error_message": job.error_message, "created_at": job.created_at.isoformat() + 'Z' if job.created_at else None, @@ -120,14 +116,13 @@ async def list_jobs( if job.status == JobStatus.COMPLETED and job.output_path: output_file = Path(job.output_path) if output_file.exists(): - download_url = f"{settings.BASE_URL}/jobs/{job.id}/download" + download_url = f"/jobs/{job.id}/download" jobs_data.append({ "id": job.id, "job_type": job.job_type, "status": job.status, "input_params": job.input_params, - "output_path": job.output_path, "download_url": download_url, "error_message": job.error_message, "created_at": job.created_at.isoformat() + 'Z' if job.created_at else None, @@ -158,8 +153,15 @@ async def delete_job( if job.user_id != current_user.id: raise HTTPException(status_code=403, detail="Access denied") + output_file = None if job.output_path: - output_file = Path(job.output_path) + output_file = Path(job.output_path).resolve() + output_dir = Path(settings.OUTPUT_DIR).resolve() + if not output_file.is_relative_to(output_dir): + logger.warning(f"Skip deleting file outside output dir: {output_file}") + output_file = None + + if output_file: if output_file.exists(): try: output_file.unlink() @@ -178,7 +180,7 @@ async def delete_job( async def download_job_output( request: Request, job_id: int, - current_user: User = Depends(get_user_from_token_or_query), + current_user: User = Depends(get_user_from_bearer_token), db: Session = Depends(get_db) ): job = db.query(Job).filter(Job.id == job_id).first() diff --git a/qwen3-tts-backend/api/tts.py b/qwen3-tts-backend/api/tts.py index 2e5b601..322a1e4 100644 --- a/qwen3-tts-backend/api/tts.py +++ b/qwen3-tts-backend/api/tts.py @@ -2,7 +2,7 @@ import logging import tempfile from datetime import datetime from pathlib import Path -from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request, status from sqlalchemy.orm import Session from typing import Optional from slowapi import Limiter @@ -32,6 +32,20 @@ router = APIRouter(prefix="/tts", tags=["tts"]) limiter = Limiter(key_func=get_remote_address) +async def read_upload_with_size_limit(upload_file: UploadFile, max_size_bytes: int) -> bytes: + chunks = [] + total = 0 + while True: + chunk = await upload_file.read(1024 * 1024) + if not chunk: + break + total += len(chunk) + if total > max_size_bytes: + raise ValueError(f"Audio file exceeds {max_size_bytes // (1024 * 1024)}MB limit") + chunks.append(chunk) + return b"".join(chunks) + + async def process_custom_voice_job( job_id: int, user_id: int, @@ -85,7 +99,7 @@ async def process_custom_voice_job( job = db.query(Job).filter(Job.id == job_id).first() if job: job.status = JobStatus.FAILED - job.error_message = str(e) + job.error_message = "Job processing failed" job.completed_at = datetime.utcnow() db.commit() @@ -150,7 +164,7 @@ async def process_voice_design_job( job = db.query(Job).filter(Job.id == job_id).first() if job: job.status = JobStatus.FAILED - job.error_message = str(e) + job.error_message = "Job processing failed" job.completed_at = datetime.utcnow() db.commit() @@ -303,7 +317,7 @@ async def process_voice_clone_job( job = db.query(Job).filter(Job.id == job_id).first() if job: job.status = JobStatus.FAILED - job.error_message = str(e) + job.error_message = "Job processing failed" job.completed_at = datetime.utcnow() db.commit() @@ -594,7 +608,8 @@ async def create_voice_clone_job( if not ref_audio: raise ValueError("Either ref_audio or voice_design_id must be provided") - ref_audio_data = await ref_audio.read() + max_audio_size_bytes = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024 + ref_audio_data = await read_upload_with_size_limit(ref_audio, max_audio_size_bytes) if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB): raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB") diff --git a/qwen3-tts-backend/api/voice_designs.py b/qwen3-tts-backend/api/voice_designs.py index 167421d..a5e885a 100644 --- a/qwen3-tts-backend/api/voice_designs.py +++ b/qwen3-tts-backend/api/voice_designs.py @@ -1,4 +1,5 @@ import logging +import json from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks from sqlalchemy.orm import Session from typing import Optional @@ -20,6 +21,28 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/voice-designs", tags=["voice-designs"]) limiter = Limiter(key_func=get_remote_address) + +def to_voice_design_response(design) -> VoiceDesignResponse: + meta_data = design.meta_data + if isinstance(meta_data, str): + try: + meta_data = json.loads(meta_data) + except Exception: + meta_data = None + return VoiceDesignResponse( + id=design.id, + user_id=design.user_id, + name=design.name, + backend_type=design.backend_type, + instruct=design.instruct, + aliyun_voice_id=design.aliyun_voice_id, + meta_data=meta_data, + preview_text=design.preview_text, + created_at=design.created_at, + last_used=design.last_used, + use_count=design.use_count + ) + @router.post("", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED) @limiter.limit("30/minute") async def save_voice_design( @@ -39,7 +62,7 @@ async def save_voice_design( meta_data=data.meta_data, preview_text=data.preview_text ) - return VoiceDesignResponse.from_orm(design) + return to_voice_design_response(design) except Exception as e: logger.error(f"Failed to save voice design: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Failed to save voice design") @@ -55,7 +78,8 @@ async def list_voice_designs( db: Session = Depends(get_db) ): designs = crud.list_voice_designs(db, current_user.id, backend_type, skip, limit) - return VoiceDesignListResponse(designs=[VoiceDesignResponse.from_orm(d) for d in designs], total=len(designs)) + total = crud.count_voice_designs(db, current_user.id, backend_type) + return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total) @router.post("/{design_id}/prepare-clone") @limiter.limit("10/minute") @@ -168,4 +192,4 @@ async def prepare_voice_clone_prompt( except Exception as e: logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail="Failed to prepare voice clone prompt") diff --git a/qwen3-tts-backend/config.py b/qwen3-tts-backend/config.py index 8700e75..aeafe90 100644 --- a/qwen3-tts-backend/config.py +++ b/qwen3-tts-backend/config.py @@ -58,8 +58,7 @@ class Settings(BaseSettings): def validate(self): if self.SECRET_KEY == "your-secret-key-change-this-in-production": - import warnings - warnings.warn("Using default SECRET_KEY! Change this in production!") + raise ValueError("Insecure default SECRET_KEY is not allowed. Please set a strong SECRET_KEY in environment.") Path(self.CACHE_DIR).mkdir(parents=True, exist_ok=True) Path(self.OUTPUT_DIR).mkdir(parents=True, exist_ok=True) diff --git a/qwen3-tts-backend/core/cache_manager.py b/qwen3-tts-backend/core/cache_manager.py index ab48001..dd03471 100644 --- a/qwen3-tts-backend/core/cache_manager.py +++ b/qwen3-tts-backend/core/cache_manager.py @@ -1,10 +1,10 @@ import hashlib -import pickle import asyncio from pathlib import Path from typing import Optional, Dict, Any from datetime import datetime, timedelta import logging +import numpy as np from sqlalchemy.orm import Session from db.crud import ( @@ -58,8 +58,13 @@ class VoiceCacheManager: delete_cache_entry(db, cache_entry.id, user_id) return None + resolved_cache_file = cache_file.resolve() + if not resolved_cache_file.is_relative_to(self.cache_dir.resolve()): + logger.warning(f"Cache path out of cache dir: {resolved_cache_file}") + return None + with open(cache_file, 'rb') as f: - cache_data = pickle.load(f) + cache_data = np.load(f, allow_pickle=False) logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}") return { @@ -84,8 +89,13 @@ class VoiceCacheManager: logger.warning(f"Cache file missing: {cache_file}") return None + resolved_cache_file = cache_file.resolve() + if not resolved_cache_file.is_relative_to(self.cache_dir.resolve()): + logger.warning(f"Cache path out of cache dir: {resolved_cache_file}") + return None + with open(cache_file, 'rb') as f: - cache_data = pickle.load(f) + cache_data = np.load(f, allow_pickle=False) cache_entry.last_accessed = datetime.utcnow() cache_entry.access_count += 1 @@ -112,11 +122,16 @@ class VoiceCacheManager: ) -> str: async with self._lock: try: - cache_filename = f"{user_id}_{ref_audio_hash}.pkl" + cache_filename = f"{user_id}_{ref_audio_hash}.npy" cache_path = self.cache_dir / cache_filename + if hasattr(cache_data, "detach"): + cache_data = cache_data.detach().cpu().numpy() + elif not isinstance(cache_data, np.ndarray): + cache_data = np.asarray(cache_data, dtype=np.float32) + with open(cache_path, 'wb') as f: - pickle.dump(cache_data, f) + np.save(f, cache_data, allow_pickle=False) cache_entry = create_cache_entry( db=db, diff --git a/qwen3-tts-backend/core/cleanup.py b/qwen3-tts-backend/core/cleanup.py index 8c917fc..cf82a01 100644 --- a/qwen3-tts-backend/core/cleanup.py +++ b/qwen3-tts-backend/core/cleanup.py @@ -54,9 +54,17 @@ async def cleanup_old_jobs(db_url: str, days: int = 7) -> dict: ).all() deleted_files = 0 + output_dir = Path(settings.OUTPUT_DIR).resolve() for job in old_jobs: if job.output_path: - output_file = Path(job.output_path) + output_file = Path(job.output_path).resolve() + if not output_file.is_relative_to(output_dir): + logger.warning(f"Skip deleting file outside output dir during cleanup: {output_file}") + output_file = None + else: + output_file = None + + if output_file: if output_file.exists(): output_file.unlink() deleted_files += 1 @@ -110,12 +118,13 @@ async def cleanup_orphaned_files(db_url: str) -> dict: freed_space_bytes += size if cache_dir.exists(): - for cache_file in cache_dir.glob("*.pkl"): - if cache_file.name not in cache_files_in_db: - size = cache_file.stat().st_size - cache_file.unlink() - deleted_orphans += 1 - freed_space_bytes += size + for pattern in ("*.npy", "*.pkl"): + for cache_file in cache_dir.glob(pattern): + if cache_file.name not in cache_files_in_db: + size = cache_file.stat().st_size + cache_file.unlink() + deleted_orphans += 1 + freed_space_bytes += size freed_space_mb = freed_space_bytes / (1024 * 1024) diff --git a/qwen3-tts-backend/db/crud.py b/qwen3-tts-backend/db/crud.py index df615de..442aa2b 100644 --- a/qwen3-tts-backend/db/crud.py +++ b/qwen3-tts-backend/db/crud.py @@ -288,7 +288,7 @@ def create_voice_design( backend_type=backend_type, instruct=instruct, aliyun_voice_id=aliyun_voice_id, - meta_data=json.dumps(meta_data) if meta_data else None, + meta_data=meta_data, preview_text=preview_text, created_at=datetime.utcnow(), last_used=datetime.utcnow() @@ -320,6 +320,19 @@ def list_voice_designs( query = query.filter(VoiceDesign.backend_type == backend_type) return query.order_by(VoiceDesign.last_used.desc()).offset(skip).limit(limit).all() +def count_voice_designs( + db: Session, + user_id: int, + backend_type: Optional[str] = None +) -> int: + query = db.query(VoiceDesign).filter( + VoiceDesign.user_id == user_id, + VoiceDesign.is_active == True + ) + if backend_type: + query = query.filter(VoiceDesign.backend_type == backend_type) + return query.count() + def update_voice_design_usage(db: Session, design_id: int, user_id: int) -> Optional[VoiceDesign]: design = get_voice_design(db, design_id, user_id) if design: @@ -328,4 +341,3 @@ def update_voice_design_usage(db: Session, design_id: int, user_id: int) -> Opti db.commit() db.refresh(design) return design - diff --git a/qwen3-tts-backend/main.py b/qwen3-tts-backend/main.py index f1b230f..085dec7 100644 --- a/qwen3-tts-backend/main.py +++ b/qwen3-tts-backend/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from pathlib import Path import torch -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address @@ -16,6 +16,8 @@ from core.database import init_db from core.model_manager import ModelManager from core.cleanup import run_scheduled_cleanup from api import auth, jobs, tts, users, voice_designs +from api.auth import get_current_user +from schemas.user import User from apscheduler.schedulers.asyncio import AsyncIOScheduler logging.basicConfig( @@ -134,6 +136,14 @@ app.include_router(voice_designs.router) @app.get("/health") async def health_check(): + return {"status": "ok"} + + +@app.get("/health/details") +async def health_check_details(current_user: User = Depends(get_current_user)): + if not current_user.is_superuser: + raise HTTPException(status_code=403, detail="Superuser access required") + from core.batch_processor import BatchProcessor from core.database import SessionLocal diff --git a/qwen3-tts-frontend/.env.example b/qwen3-tts-frontend/.env.example index ec52b11..a17c527 100644 --- a/qwen3-tts-frontend/.env.example +++ b/qwen3-tts-frontend/.env.example @@ -1,2 +1,2 @@ -VITE_API_URL=http://localhost:8000 +VITE_API_URL=/api VITE_APP_NAME=Qwen3-TTS diff --git a/qwen3-tts-frontend/src/components/tts/CustomVoiceForm.tsx b/qwen3-tts-frontend/src/components/tts/CustomVoiceForm.tsx index 4370937..427ee8a 100644 --- a/qwen3-tts-frontend/src/components/tts/CustomVoiceForm.tsx +++ b/qwen3-tts-frontend/src/components/tts/CustomVoiceForm.tsx @@ -47,15 +47,15 @@ const CustomVoiceForm = forwardRef((_props, ref) => { const PRESET_INSTRUCTS = useMemo(() => tConstants('presetInstructs', { returnObjects: true }) as Array<{ label: string; instruct: string; text: string }>, [tConstants]) const formSchema = z.object({ - text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(5000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 5000 })), + text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(1000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 1000 })), language: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.language') })), speaker: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.speaker') })), instruct: z.string().optional(), - max_new_tokens: z.number().min(1).max(10000).optional(), - temperature: z.number().min(0).max(2).optional(), + max_new_tokens: z.number().min(128).max(4096).optional(), + temperature: z.number().min(0.1).max(2).optional(), top_k: z.number().min(1).max(100).optional(), top_p: z.number().min(0).max(1).optional(), - repetition_penalty: z.number().min(0).max(2).optional(), + repetition_penalty: z.number().min(1).max(2).optional(), }) const [languages, setLanguages] = useState([]) const [unifiedSpeakers, setUnifiedSpeakers] = useState([]) @@ -395,8 +395,8 @@ const CustomVoiceForm = forwardRef((_props, ref) => { setTempAdvancedParams({ ...tempAdvancedParams, @@ -414,7 +414,7 @@ const CustomVoiceForm = forwardRef((_props, ref) => { tConstants('presetRefTexts', { returnObjects: true }) as Array<{ label: string; text: string }>, [tConstants]) const formSchema = z.object({ - text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(5000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 5000 })), + text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(1000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 1000 })), language: z.string().optional(), ref_audio: z.instanceof(File, { message: tErrors('validation.required', { field: tErrors('fieldNames.reference_audio') }) }), ref_text: z.string().optional(), use_cache: z.boolean().optional(), x_vector_only_mode: z.boolean().optional(), - max_new_tokens: z.number().min(1).max(10000).optional(), - temperature: z.number().min(0).max(2).optional(), + max_new_tokens: z.number().min(128).max(4096).optional(), + temperature: z.number().min(0.1).max(2).optional(), top_k: z.number().min(1).max(100).optional(), top_p: z.number().min(0).max(1).optional(), - repetition_penalty: z.number().min(0).max(2).optional(), + repetition_penalty: z.number().min(1).max(2).optional(), }) const [languages, setLanguages] = useState([]) const [isLoading, setIsLoading] = useState(false) @@ -358,8 +358,8 @@ function VoiceCloneForm() { setTempAdvancedParams({ ...tempAdvancedParams, diff --git a/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx b/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx index ad306a3..251cd4e 100644 --- a/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx +++ b/qwen3-tts-frontend/src/components/tts/VoiceDesignForm.tsx @@ -46,14 +46,14 @@ const VoiceDesignForm = forwardRef((_props, ref) => { const PRESET_VOICE_DESIGNS = useMemo(() => tConstants('presetVoiceDesigns', { returnObjects: true }) as Array<{ label: string; instruct: string; text: string }>, [tConstants]) const formSchema = z.object({ - text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(5000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 5000 })), + text: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.text') })).max(1000, tErrors('validation.maxLength', { field: tErrors('fieldNames.text'), max: 1000 })), language: z.string().min(1, tErrors('validation.required', { field: tErrors('fieldNames.language') })), instruct: z.string().min(10, tErrors('validation.minLength', { field: tErrors('fieldNames.instruct'), min: 10 })).max(500, tErrors('validation.maxLength', { field: tErrors('fieldNames.instruct'), max: 500 })), - max_new_tokens: z.number().min(1).max(10000).optional(), - temperature: z.number().min(0).max(2).optional(), + max_new_tokens: z.number().min(128).max(4096).optional(), + temperature: z.number().min(0.1).max(2).optional(), top_k: z.number().min(1).max(100).optional(), top_p: z.number().min(0).max(1).optional(), - repetition_penalty: z.number().min(0).max(2).optional(), + repetition_penalty: z.number().min(1).max(2).optional(), }) const [languages, setLanguages] = useState([]) const [isLoading, setIsLoading] = useState(false) @@ -310,8 +310,8 @@ const VoiceDesignForm = forwardRef((_props, ref) => { setTempAdvancedParams({ ...tempAdvancedParams, @@ -329,7 +329,7 @@ const VoiceDesignForm = forwardRef((_props, ref) => { | null>(null) + const timeIntervalRef = useRef | null>(null) + + const clearIntervals = useCallback(() => { + if (pollIntervalRef.current) { + clearInterval(pollIntervalRef.current) + pollIntervalRef.current = null + } + if (timeIntervalRef.current) { + clearInterval(timeIntervalRef.current) + timeIntervalRef.current = null + } + }, []) const stopJob = useCallback(() => { + clearIntervals() setCurrentJob(null) setStatus(null) setError(null) setElapsedTime(0) - }, []) + }, [clearIntervals]) const resetJob = useCallback(() => { setError(null) @@ -45,15 +59,13 @@ export function JobProvider({ children }: { children: ReactNode }) { }, []) const startJob = useCallback((jobId: number) => { + clearIntervals() // Reset state for new job setCurrentJob(null) setStatus('pending') setError(null) setElapsedTime(0) - let pollInterval: ReturnType | null = null - let timeInterval: ReturnType | null = null - const poll = async () => { try { const job = await jobApi.getJob(jobId) @@ -61,15 +73,13 @@ export function JobProvider({ children }: { children: ReactNode }) { setStatus(job.status) if (job.status === 'completed') { - if (pollInterval) clearInterval(pollInterval) - if (timeInterval) clearInterval(timeInterval) + clearIntervals() toast.success('任务完成!') try { historyRefresh() } catch {} } else if (job.status === 'failed') { - if (pollInterval) clearInterval(pollInterval) - if (timeInterval) clearInterval(timeInterval) + clearIntervals() setError(job.error_message || '任务失败') toast.error(job.error_message || '任务失败') try { @@ -77,8 +87,7 @@ export function JobProvider({ children }: { children: ReactNode }) { } catch {} } } catch (error: any) { - if (pollInterval) clearInterval(pollInterval) - if (timeInterval) clearInterval(timeInterval) + clearIntervals() const message = error.response?.data?.detail || '获取任务状态失败' setError(message) toast.error(message) @@ -86,16 +95,17 @@ export function JobProvider({ children }: { children: ReactNode }) { } poll() - pollInterval = setInterval(poll, POLL_INTERVAL) - timeInterval = setInterval(() => { + pollIntervalRef.current = setInterval(poll, POLL_INTERVAL) + timeIntervalRef.current = setInterval(() => { setElapsedTime((prev) => prev + 1) }, 1000) + }, [historyRefresh, clearIntervals]) + useEffect(() => { return () => { - if (pollInterval) clearInterval(pollInterval) - if (timeInterval) clearInterval(timeInterval) + clearIntervals() } - }, [historyRefresh]) + }, [clearIntervals]) const value = useMemo( () => ({ diff --git a/qwen3-tts-frontend/src/lib/api.ts b/qwen3-tts-frontend/src/lib/api.ts index 17c8fee..aea6a1f 100644 --- a/qwen3-tts-frontend/src/lib/api.ts +++ b/qwen3-tts-frontend/src/lib/api.ts @@ -13,9 +13,22 @@ const apiClient = axios.create({ }, }) +const isTrustedApiRequest = (url?: string, baseURL?: string): boolean => { + if (!url) return false + if (url.startsWith('/')) return true + + try { + const resolvedUrl = new URL(url, baseURL || window.location.origin) + const apiOrigin = baseURL ? new URL(baseURL, window.location.origin).origin : window.location.origin + return resolvedUrl.origin === apiOrigin + } catch { + return false + } +} + apiClient.interceptors.request.use((config) => { const token = localStorage.getItem('token') - if (token) { + if (token && isTrustedApiRequest(config.url, config.baseURL || import.meta.env.VITE_API_URL)) { config.headers.Authorization = `Bearer ${token}` } return config @@ -346,11 +359,23 @@ export const jobApi = { getAudioUrl: (id: number, audioPath?: string): string => { if (audioPath) { if (audioPath.startsWith('http')) { + const apiBase = import.meta.env.VITE_API_URL + if (apiBase) { + try { + const audioOrigin = new URL(audioPath).origin + const apiOrigin = new URL(apiBase, window.location.origin).origin + if (audioOrigin !== apiOrigin) { + return API_ENDPOINTS.JOBS.AUDIO(id) + } + } catch { + return API_ENDPOINTS.JOBS.AUDIO(id) + } + } if (audioPath.includes('localhost') || audioPath.includes('127.0.0.1')) { const url = new URL(audioPath) return url.pathname } - return audioPath + return API_ENDPOINTS.JOBS.AUDIO(id) } else { return audioPath.startsWith('/') ? audioPath : `/${audioPath}` } diff --git a/qwen3-tts-frontend/vite.config.ts b/qwen3-tts-frontend/vite.config.ts index cbe423c..6ef3387 100644 --- a/qwen3-tts-frontend/vite.config.ts +++ b/qwen3-tts-frontend/vite.config.ts @@ -4,6 +4,15 @@ import { defineConfig } from "vite" export default defineConfig({ plugins: [react()], + server: { + proxy: { + '/api': { + target: 'http://127.0.0.1:8000', + changeOrigin: true, + rewrite: (path) => path.replace(/^\/api/, ''), + }, + }, + }, resolve: { alias: { "@": path.resolve(__dirname, "./src"),