init commit

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-26 15:34:31 +08:00
commit 80513a3258
141 changed files with 24966 additions and 0 deletions

28
.gitignore vendored Normal file
View File

@@ -0,0 +1,28 @@
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
.idea/
.vscode/
venv/
env/
Qwen/
qwen3-tts-frontend/node_modules/
qwen3-tts-frontend/dist/
qwen3-tts-frontend/.env
qwen3-tts-frontend/.env.local

76
README.md Normal file
View File

@@ -0,0 +1,76 @@
# Qwen3-TTS WebUI
A text-to-speech web application based on Qwen3-TTS, supporting custom voice, voice design, and voice cloning.
[中文文档](./README.zh.md)
## Features
- Custom Voice: Predefined speaker voices
- Voice Design: Create voices from natural language descriptions
- Voice Cloning: Clone voices from uploaded audio
- JWT auth, async tasks, voice cache, dark mode
## Tech Stack
Backend: FastAPI + SQLAlchemy + PyTorch + JWT
Frontend: React 19 + TypeScript + Vite + Tailwind + Shadcn/ui
## Quick Start
### Backend
```bash
cd qwen3-tts-backend
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
cp .env.example .env
# Edit .env to configure MODEL_BASE_PATH etc.
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
```
### Frontend
```bash
cd qwen3-tts-frontend
npm install
cp .env.example .env
# Edit .env to configure VITE_API_URL
npm run dev
```
Visit `http://localhost:5173`
## Configuration
Backend `.env` key settings:
```env
SECRET_KEY=your-secret-key
MODEL_DEVICE=cuda:0
MODEL_BASE_PATH=../Qwen
DATABASE_URL=sqlite:///./qwen_tts.db
```
Frontend `.env`:
```env
VITE_API_URL=http://localhost:8000
```
## API
```
POST /auth/register - Register
POST /auth/token - Login
POST /tts/custom-voice - Custom voice
POST /tts/voice-design - Voice design
POST /tts/voice-clone - Voice cloning
GET /jobs - Job list
GET /jobs/{id}/download - Download result
```
## License
MIT

76
README.zh.md Normal file
View File

@@ -0,0 +1,76 @@
# Qwen3-TTS WebUI
基于 Qwen3-TTS 的文本转语音 Web 应用,支持自定义语音、语音设计和语音克隆。
[English Documentation](./README.md)
## 功能特性
- 自定义语音:预定义说话人语音
- 语音设计:自然语言描述创建语音
- 语音克隆:上传音频克隆语音
- JWT 认证、异步任务、语音缓存、暗黑模式
## 技术栈
后端FastAPI + SQLAlchemy + PyTorch + JWT
前端React 19 + TypeScript + Vite + Tailwind + Shadcn/ui
## 快速开始
### 后端
```bash
cd qwen3-tts-backend
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
cp .env.example .env
# 编辑 .env 配置 MODEL_BASE_PATH 等
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
```
### 前端
```bash
cd qwen3-tts-frontend
npm install
cp .env.example .env
# 编辑 .env 配置 VITE_API_URL
npm run dev
```
访问 `http://localhost:5173`
## 配置
后端 `.env` 关键配置:
```env
SECRET_KEY=your-secret-key
MODEL_DEVICE=cuda:0
MODEL_BASE_PATH=../Qwen
DATABASE_URL=sqlite:///./qwen_tts.db
```
前端 `.env`
```env
VITE_API_URL=http://localhost:8000
```
## API
```
POST /auth/register - 注册
POST /auth/token - 登录
POST /tts/custom-voice - 自定义语音
POST /tts/voice-design - 语音设计
POST /tts/voice-clone - 语音克隆
GET /jobs - 任务列表
GET /jobs/{id}/download - 下载结果
```
## 许可证
MIT

View File

@@ -0,0 +1,22 @@
SECRET_KEY=your-secret-key-change-this-in-production
ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=30
DATABASE_URL=sqlite:///./qwen_tts.db
CACHE_DIR=./voice_cache
OUTPUT_DIR=./outputs
MODEL_DEVICE=cuda:0
MODEL_BASE_PATH=../Qwen
MAX_CACHE_ENTRIES=100
CACHE_TTL_DAYS=7
HOST=0.0.0.0
PORT=8000
WORKERS=1
LOG_LEVEL=info
LOG_FILE=./app.log
RATE_LIMIT_PER_MINUTE=50
RATE_LIMIT_PER_HOUR=1000
MAX_QUEUE_SIZE=100
BATCH_SIZE=4
BATCH_WAIT_TIME=0.5
MAX_TEXT_LENGTH=1000
MAX_AUDIO_SIZE_MB=10

11
qwen3-tts-backend/.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
.env
*.pyc
__pycache__/
*.log
qwen_tts.db
voice_cache/
outputs/
venv/
.pytest_cache/
.coverage
htmlcov/

View File

View File

@@ -0,0 +1,107 @@
from datetime import timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from config import settings
from core.security import (
get_password_hash,
verify_password,
create_access_token,
decode_access_token
)
from db.database import get_db
from db.crud import get_user_by_username, get_user_by_email, create_user
from schemas.user import User, UserCreate, Token
router = APIRouter(prefix="/auth", tags=["authentication"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token")
limiter = Limiter(key_func=get_remote_address)
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)],
db: Session = Depends(get_db)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
username = decode_access_token(token)
if username is None:
raise credentials_exception
user = get_user_by_username(db, username=username)
if user is None:
raise credentials_exception
return user
@router.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
@limiter.limit("5/minute")
async def register(request: Request, user_data: UserCreate, db: Session = Depends(get_db)):
existing_user = get_user_by_username(db, username=user_data.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
)
existing_email = get_user_by_email(db, email=user_data.email)
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
hashed_password = get_password_hash(user_data.password)
user = create_user(
db,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password
)
return user
@router.post("/token", response_model=Token)
@limiter.limit("5/minute")
async def login(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_db)
):
user = get_user_by_username(db, username=form_data.username)
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/me", response_model=User)
@limiter.limit("30/minute")
async def get_current_user_info(
request: Request,
current_user: Annotated[User, Depends(get_current_user)]
):
return current_user

View File

@@ -0,0 +1,156 @@
import logging
import json
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.config import settings
from core.database import get_db
from core.cache_manager import VoiceCacheManager
from api.auth import get_current_user
from db.crud import list_cache_entries, delete_cache_entry
from db.models import VoiceCache, User
from utils.metrics import cache_metrics
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/cache", tags=["cache"])
limiter = Limiter(key_func=get_remote_address)
@router.get("/voices")
@limiter.limit("30/minute")
async def list_user_caches(
request: Request,
skip: int = 0,
limit: int = 100,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
caches = list_cache_entries(db, current_user.id, skip=skip, limit=limit)
result = []
for cache in caches:
meta_data = json.loads(cache.meta_data) if cache.meta_data else {}
cache_file = Path(cache.cache_path)
file_size_mb = cache_file.stat().st_size / (1024 * 1024) if cache_file.exists() else 0
result.append({
'id': cache.id,
'ref_audio_hash': cache.ref_audio_hash,
'created_at': cache.created_at.isoformat(),
'last_accessed': cache.last_accessed.isoformat(),
'access_count': cache.access_count,
'metadata': meta_data,
'size_mb': round(file_size_mb, 2)
})
return {
'caches': result,
'total': len(result)
}
@router.delete("/voices/{cache_id}")
@limiter.limit("30/minute")
async def delete_user_cache(
request: Request,
cache_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
cache = db.query(VoiceCache).filter(
VoiceCache.id == cache_id,
VoiceCache.user_id == current_user.id
).first()
if not cache:
raise HTTPException(status_code=404, detail="Cache not found")
cache_file = Path(cache.cache_path)
if cache_file.exists():
cache_file.unlink()
success = delete_cache_entry(db, cache_id, current_user.id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete cache")
logger.info(f"Cache deleted: id={cache_id}, user={current_user.id}")
return {
'message': 'Cache deleted successfully',
'cache_id': cache_id
}
@router.delete("/voices")
@limiter.limit("10/minute")
async def cleanup_expired_caches(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
cache_manager = await VoiceCacheManager.get_instance()
deleted_count = await cache_manager.cleanup_expired(db)
logger.info(f"Expired cache cleanup: user={current_user.id}, deleted={deleted_count}")
return {
'message': 'Expired caches cleaned up',
'deleted_count': deleted_count
}
@router.post("/voices/prune")
@limiter.limit("10/minute")
async def prune_caches(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
cache_manager = await VoiceCacheManager.get_instance()
deleted_count = await cache_manager.enforce_max_entries(current_user.id, db)
logger.info(f"LRU prune: user={current_user.id}, deleted={deleted_count}")
return {
'message': 'LRU pruning completed',
'deleted_count': deleted_count
}
@router.get("/stats")
@limiter.limit("30/minute")
async def get_cache_stats(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
stats = cache_metrics.get_stats(db, settings.CACHE_DIR)
user_stats = None
for user_stat in stats['users']:
if user_stat['user_id'] == current_user.id:
user_stats = user_stat
break
if user_stats is None:
user_cache_count = db.query(VoiceCache).filter(
VoiceCache.user_id == current_user.id
).count()
user_stats = {
'user_id': current_user.id,
'hits': 0,
'misses': 0,
'hit_rate': 0.0,
'cache_entries': user_cache_count
}
return {
'global': stats['global'],
'user': user_stats
}

View File

@@ -0,0 +1,176 @@
import logging
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.database import get_db
from core.config import settings
from db.models import Job, JobStatus, User
from api.auth import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/jobs", tags=["jobs"])
limiter = Limiter(key_func=get_remote_address)
@router.get("/{job_id}")
@limiter.limit("30/minute")
async def get_job(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
download_url = None
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
return {
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"output_path": job.output_path,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
"started_at": job.started_at.isoformat() + 'Z' if job.started_at else None,
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
}
@router.get("")
@limiter.limit("30/minute")
async def list_jobs(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
status: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
query = db.query(Job).filter(Job.user_id == current_user.id)
if status:
try:
status_enum = JobStatus(status)
query = query.filter(Job.status == status_enum)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid status: {status}")
total = query.count()
jobs = query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
jobs_data = []
for job in jobs:
download_url = None
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"{settings.BASE_URL}/jobs/{job.id}/download"
jobs_data.append({
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"output_path": job.output_path,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
})
return {
"total": total,
"skip": skip,
"limit": limit,
"jobs": jobs_data
}
@router.delete("/{job_id}")
@limiter.limit("30/minute")
async def delete_job(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
if job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
try:
output_file.unlink()
logger.info(f"Deleted output file: {output_file}")
except Exception as e:
logger.error(f"Failed to delete output file {output_file}: {e}")
db.delete(job)
db.commit()
return {"message": "Job deleted successfully"}
@router.get("/{job_id}/download")
@limiter.limit("30/minute")
async def download_job_output(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
if job.status != JobStatus.COMPLETED:
raise HTTPException(status_code=400, detail="Job not completed yet")
if not job.output_path:
raise HTTPException(status_code=404, detail="Output file not found")
output_file = Path(job.output_path)
if not output_file.exists():
raise HTTPException(status_code=404, detail="Output file does not exist")
output_dir = Path(settings.OUTPUT_DIR).resolve()
if not output_file.resolve().is_relative_to(output_dir):
logger.warning(f"Path traversal attempt detected: {output_file}")
raise HTTPException(status_code=403, detail="Access denied")
return FileResponse(
path=str(output_file),
media_type="audio/wav",
filename=output_file.name,
headers={
"Content-Disposition": f'attachment; filename="{output_file.name}"'
}
)

View File

@@ -0,0 +1,21 @@
import logging
from fastapi import APIRouter
from core.metrics import MetricsCollector
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/metrics", tags=["metrics"])
@router.get("")
async def get_metrics():
metrics = await MetricsCollector.get_instance()
data = await metrics.get_metrics()
return data
@router.post("/reset")
async def reset_metrics():
metrics = await MetricsCollector.get_instance()
await metrics.reset()
return {"message": "Metrics reset successfully"}

View File

@@ -0,0 +1,553 @@
import logging
import tempfile
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request
from sqlalchemy.orm import Session
from typing import Optional
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.config import settings
from core.database import get_db
from core.model_manager import ModelManager
from core.cache_manager import VoiceCacheManager
from db.models import Job, JobStatus, User
from schemas.tts import CustomVoiceRequest, VoiceDesignRequest
from api.auth import get_current_user
from utils.validation import (
validate_language,
validate_speaker,
validate_text_length,
validate_generation_params,
get_supported_languages,
get_supported_speakers
)
from utils.audio import save_audio_file, validate_ref_audio, process_ref_audio, extract_audio_features
from utils.metrics import cache_metrics
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tts", tags=["tts"])
limiter = Limiter(key_func=get_remote_address)
async def process_custom_voice_job(
job_id: int,
user_id: int,
request_data: dict,
db_url: str
):
from core.database import SessionLocal
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing custom-voice job {job_id}")
model_manager = await ModelManager.get_instance()
await model_manager.load_model("custom-voice")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load custom-voice model")
result = tts.generate_custom_voice(
text=request_data['text'],
language=request_data['language'],
speaker=request_data['speaker'],
instruct=request_data.get('instruct', ''),
max_new_tokens=request_data['max_new_tokens'],
temperature=request_data['temperature'],
top_k=request_data['top_k'],
top_p=request_data['top_p'],
repetition_penalty=request_data['repetition_penalty']
)
import numpy as np
if isinstance(result, tuple):
audio_data = result[0]
elif isinstance(result, list):
audio_data = np.array(result)
else:
audio_data = result
from pathlib import Path
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
save_audio_file(audio_data, 24000, output_path)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_design_job(
job_id: int,
user_id: int,
request_data: dict,
db_url: str
):
from core.database import SessionLocal
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-design job {job_id}")
model_manager = await ModelManager.get_instance()
await model_manager.load_model("voice-design")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load voice-design model")
result = tts.generate_voice_design(
text=request_data['text'],
language=request_data['language'],
instruct=request_data['instruct'],
max_new_tokens=request_data['max_new_tokens'],
temperature=request_data['temperature'],
top_k=request_data['top_k'],
top_p=request_data['top_p'],
repetition_penalty=request_data['repetition_penalty']
)
import numpy as np
if isinstance(result, tuple):
audio_data = result[0]
elif isinstance(result, list):
audio_data = np.array(result)
else:
audio_data = result
from pathlib import Path
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
save_audio_file(audio_data, 24000, output_path)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_clone_job(
job_id: int,
user_id: int,
request_data: dict,
ref_audio_path: str,
db_url: str
):
from core.database import SessionLocal
import numpy as np
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-clone job {job_id}")
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
x_vector = None
cache_id = None
if request_data.get('use_cache', True):
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
if cached:
x_vector = cached['data']
cache_id = cached['cache_id']
cache_metrics.record_hit(user_id)
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
if x_vector is None:
cache_metrics.record_miss(user_id)
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=request_data.get('ref_text', ''),
x_vector_only_mode=request_data.get('x_vector_only_mode', False)
)
if request_data.get('use_cache', True):
features = extract_audio_features(ref_audio_array, ref_sr)
metadata = {
'duration': features['duration'],
'sample_rate': features['sample_rate'],
'ref_text': request_data.get('ref_text', ''),
'x_vector_only_mode': request_data.get('x_vector_only_mode', False)
}
cache_id = await cache_manager.set_cache(
user_id, ref_audio_hash, x_vector, metadata, db
)
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
if request_data.get('x_vector_only_mode', False):
job.status = JobStatus.COMPLETED
job.output_path = f"x_vector_cached_{cache_id}"
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed (x_vector_only_mode)")
return
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
wavs, sample_rate = tts.generate_voice_clone(
text=request_data['text'],
language=request_data['language'],
voice_clone_prompt=x_vector,
max_new_tokens=request_data['max_new_tokens'],
temperature=request_data['temperature'],
top_k=request_data['top_k'],
top_p=request_data['top_p'],
repetition_penalty=request_data['repetition_penalty']
)
audio_data = wavs[0] if isinstance(wavs, list) else wavs
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
save_audio_file(audio_data, sample_rate, output_path)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
finally:
if Path(ref_audio_path).exists():
Path(ref_audio_path).unlink()
db.close()
@router.post("/custom-voice")
@limiter.limit("10/minute")
async def create_custom_voice_job(
request: Request,
req_data: CustomVoiceRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
speaker = validate_speaker(req_data.speaker)
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="custom-voice",
status=JobStatus.PENDING,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
background_tasks.add_task(
process_custom_voice_job,
job.id,
current_user.id,
request_data,
str(settings.DATABASE_URL)
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-design")
@limiter.limit("10/minute")
async def create_voice_design_job(
request: Request,
req_data: VoiceDesignRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
if not req_data.instruct or not req_data.instruct.strip():
raise ValueError("Instruct parameter is required for voice design")
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-design",
status=JobStatus.PENDING,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
background_tasks.add_task(
process_voice_design_job,
job.id,
current_user.id,
request_data,
str(settings.DATABASE_URL)
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-clone")
@limiter.limit("10/minute")
async def create_voice_clone_job(
request: Request,
text: str = Form(...),
language: str = Form(default="Auto"),
ref_audio: UploadFile = File(...),
ref_text: Optional[str] = Form(default=None),
use_cache: bool = Form(default=True),
x_vector_only_mode: bool = Form(default=False),
max_new_tokens: Optional[int] = Form(default=2048),
temperature: Optional[float] = Form(default=0.9),
top_k: Optional[int] = Form(default=50),
top_p: Optional[float] = Form(default=1.0),
repetition_penalty: Optional[float] = Form(default=1.05),
background_tasks: BackgroundTasks = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
try:
validate_text_length(text)
language = validate_language(language)
params = validate_generation_params({
'max_new_tokens': max_new_tokens,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': repetition_penalty
})
ref_audio_data = await ref_audio.read()
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-clone",
status=JobStatus.PENDING,
input_data="",
input_params={
"text": text,
"language": language,
"ref_text": ref_text or "",
"ref_audio_hash": ref_audio_hash,
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(ref_audio_data)
tmp_audio_path = tmp_file.name
request_data = {
"text": text,
"language": language,
"ref_text": ref_text or "",
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
**params
}
background_tasks.add_task(
process_voice_clone_job,
job.id,
current_user.id,
request_data,
tmp_audio_path,
str(settings.DATABASE_URL)
)
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
cache_info = {"cache_id": existing_cache['cache_id']} if existing_cache else None
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully",
"cache_info": cache_info
}
@router.get("/models")
@limiter.limit("30/minute")
async def list_models(request: Request):
model_manager = await ModelManager.get_instance()
return model_manager.get_model_info()
@router.get("/speakers")
@limiter.limit("30/minute")
async def list_speakers(request: Request):
return get_supported_speakers()
@router.get("/languages")
@limiter.limit("30/minute")
async def list_languages(request: Request):
return get_supported_languages()

View File

@@ -0,0 +1,169 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from api.auth import get_current_user
from config import settings
from core.security import get_password_hash
from db.database import get_db
from db.crud import (
get_user_by_id,
get_user_by_username,
get_user_by_email,
list_users,
create_user_by_admin,
update_user,
delete_user
)
from schemas.user import User, UserCreateByAdmin, UserUpdate, UserListResponse
router = APIRouter(prefix="/users", tags=["users"])
limiter = Limiter(key_func=get_remote_address)
async def require_superuser(
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Superuser access required"
)
return current_user
@router.get("", response_model=UserListResponse)
@limiter.limit("30/minute")
async def get_users(
request: Request,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
users, total = list_users(db, skip=skip, limit=limit)
return UserListResponse(users=users, total=total, skip=skip, limit=limit)
@router.post("", response_model=User, status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def create_user(
request: Request,
user_data: UserCreateByAdmin,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
existing_user = get_user_by_username(db, username=user_data.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
)
existing_email = get_user_by_email(db, email=user_data.email)
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
hashed_password = get_password_hash(user_data.password)
user = create_user_by_admin(
db,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
is_superuser=user_data.is_superuser
)
return user
@router.get("/{user_id}", response_model=User)
@limiter.limit("30/minute")
async def get_user(
request: Request,
user_id: int,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
user = get_user_by_id(db, user_id=user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.put("/{user_id}", response_model=User)
@limiter.limit("10/minute")
async def update_user_info(
request: Request,
user_id: int,
user_data: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_superuser)
):
existing_user = get_user_by_id(db, user_id=user_id)
if not existing_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if user_data.username is not None:
username_exists = get_user_by_username(db, username=user_data.username)
if username_exists and username_exists.id != user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already taken"
)
if user_data.email is not None:
email_exists = get_user_by_email(db, email=user_data.email)
if email_exists and email_exists.id != user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already taken"
)
hashed_password = None
if user_data.password is not None:
hashed_password = get_password_hash(user_data.password)
user = update_user(
db,
user_id=user_id,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
is_active=user_data.is_active,
is_superuser=user_data.is_superuser
)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@limiter.limit("10/minute")
async def delete_user_by_id(
request: Request,
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_superuser)
):
if user_id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete yourself"
)
success = delete_user(db, user_id=user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)

View File

@@ -0,0 +1,66 @@
import os
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import Field, field_validator
class Settings(BaseSettings):
SECRET_KEY: str = Field(default="your-secret-key-change-this-in-production")
ALGORITHM: str = Field(default="HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30)
DATABASE_URL: str = Field(default="sqlite:///./qwen_tts.db")
CACHE_DIR: str = Field(default="./voice_cache")
OUTPUT_DIR: str = Field(default="./outputs")
BASE_URL: str = Field(default="http://localhost:8000")
MODEL_DEVICE: str = Field(default="cuda:0")
MODEL_BASE_PATH: str = Field(default="../Qwen")
MAX_CACHE_ENTRIES: int = Field(default=100)
CACHE_TTL_DAYS: int = Field(default=7)
HOST: str = Field(default="0.0.0.0")
PORT: int = Field(default=8000)
WORKERS: int = Field(default=1)
LOG_LEVEL: str = Field(default="info")
LOG_FILE: str = Field(default="./app.log")
RATE_LIMIT_PER_MINUTE: int = Field(default=50)
RATE_LIMIT_PER_HOUR: int = Field(default=1000)
MAX_QUEUE_SIZE: int = Field(default=100)
BATCH_SIZE: int = Field(default=4)
BATCH_WAIT_TIME: float = Field(default=0.5)
MAX_TEXT_LENGTH: int = Field(default=1000)
MAX_AUDIO_SIZE_MB: int = Field(default=10)
class Config:
env_file = ".env"
case_sensitive = True
@field_validator('MODEL_BASE_PATH')
@classmethod
def validate_model_path(cls, v: str) -> str:
path = Path(v)
if not path.exists():
raise ValueError(f"Model base path does not exist: {v}")
return v
def validate(self):
if self.SECRET_KEY == "your-secret-key-change-this-in-production":
import warnings
warnings.warn("Using default SECRET_KEY! Change this in production!")
Path(self.CACHE_DIR).mkdir(parents=True, exist_ok=True)
Path(self.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
if self.WORKERS > 1:
import warnings
warnings.warn("WORKERS > 1 not recommended for GPU models. Setting to 1.")
self.WORKERS = 1
return True
settings = Settings()

View File

View File

@@ -0,0 +1,141 @@
import asyncio
import logging
import time
from typing import Any, Callable, Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import deque
from core.config import settings
logger = logging.getLogger(__name__)
@dataclass
class BatchRequest:
request_id: str
data: Dict[str, Any]
future: asyncio.Future
timestamp: float
class BatchProcessor:
_instance: Optional['BatchProcessor'] = None
_lock = asyncio.Lock()
def __init__(self, batch_size: int = None, batch_wait_time: float = None):
self.batch_size = batch_size or settings.BATCH_SIZE
self.batch_wait_time = batch_wait_time or settings.BATCH_WAIT_TIME
self.queue: deque = deque()
self.queue_lock = asyncio.Lock()
self.processing = False
self._processor_task: Optional[asyncio.Task] = None
logger.info(f"BatchProcessor initialized with batch_size={self.batch_size}, wait_time={self.batch_wait_time}s")
@classmethod
async def get_instance(cls) -> 'BatchProcessor':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
cls._instance._start_processor()
return cls._instance
def _start_processor(self):
if not self._processor_task or self._processor_task.done():
self._processor_task = asyncio.create_task(self._process_batches())
logger.info("Batch processor task started")
async def _process_batches(self):
logger.info("Batch processing loop started")
while True:
try:
await asyncio.sleep(0.1)
async with self.queue_lock:
if not self.queue:
continue
current_time = time.time()
oldest_request = self.queue[0]
wait_duration = current_time - oldest_request.timestamp
should_process = (
len(self.queue) >= self.batch_size or
wait_duration >= self.batch_wait_time
)
if should_process:
batch = []
for _ in range(min(self.batch_size, len(self.queue))):
if self.queue:
batch.append(self.queue.popleft())
if batch:
logger.info(f"Processing batch of {len(batch)} requests (queue_wait={wait_duration:.3f}s)")
asyncio.create_task(self._process_batch(batch))
except Exception as e:
logger.error(f"Error in batch processor loop: {e}", exc_info=True)
await asyncio.sleep(1)
async def _process_batch(self, batch: List[BatchRequest]):
for request in batch:
try:
if not request.future.done():
result = await self._execute_single_request(request.data)
request.future.set_result(result)
except Exception as e:
logger.error(f"Error processing request {request.request_id}: {e}", exc_info=True)
if not request.future.done():
request.future.set_exception(e)
async def _execute_single_request(self, data: Dict[str, Any]) -> Any:
raise NotImplementedError("Subclass must implement _execute_single_request")
async def submit(self, request_id: str, data: Dict[str, Any], timeout: float = 300) -> Any:
future = asyncio.Future()
request = BatchRequest(
request_id=request_id,
data=data,
future=future,
timestamp=time.time()
)
async with self.queue_lock:
self.queue.append(request)
queue_size = len(self.queue)
logger.debug(f"Request {request_id} queued (queue_size={queue_size})")
try:
result = await asyncio.wait_for(future, timeout=timeout)
return result
except asyncio.TimeoutError:
logger.error(f"Request {request_id} timed out after {timeout}s")
async with self.queue_lock:
if request in self.queue:
self.queue.remove(request)
raise TimeoutError(f"Request timed out after {timeout}s")
async def get_queue_length(self) -> int:
async with self.queue_lock:
return len(self.queue)
async def get_stats(self) -> Dict[str, Any]:
queue_length = await self.get_queue_length()
return {
"queue_length": queue_length,
"batch_size": self.batch_size,
"batch_wait_time": self.batch_wait_time,
"processor_running": self._processor_task is not None and not self._processor_task.done()
}
class TTSBatchProcessor(BatchProcessor):
def __init__(self, process_func: Callable, batch_size: int = None, batch_wait_time: float = None):
super().__init__(batch_size, batch_wait_time)
self.process_func = process_func
async def _execute_single_request(self, data: Dict[str, Any]) -> Any:
return await self.process_func(**data)

View File

@@ -0,0 +1,161 @@
import hashlib
import pickle
import asyncio
from pathlib import Path
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
import logging
from sqlalchemy.orm import Session
from db.crud import (
create_cache_entry,
get_cache_entry,
list_cache_entries,
delete_cache_entry
)
from db.models import VoiceCache
from core.config import settings
logger = logging.getLogger(__name__)
class VoiceCacheManager:
_instance = None
_lock = asyncio.Lock()
def __init__(self, cache_dir: str, max_entries: int, ttl_days: int):
self.cache_dir = Path(cache_dir)
self.max_entries = max_entries
self.ttl_days = ttl_days
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"VoiceCacheManager initialized: dir={cache_dir}, max={max_entries}, ttl={ttl_days}d")
@classmethod
async def get_instance(cls) -> 'VoiceCacheManager':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = VoiceCacheManager(
cache_dir=settings.CACHE_DIR,
max_entries=settings.MAX_CACHE_ENTRIES,
ttl_days=settings.CACHE_TTL_DAYS
)
return cls._instance
def get_audio_hash(self, audio_data: bytes) -> str:
return hashlib.sha256(audio_data).hexdigest()
async def get_cache(self, user_id: int, ref_audio_hash: str, db: Session) -> Optional[Dict[str, Any]]:
try:
cache_entry = get_cache_entry(db, user_id, ref_audio_hash)
if not cache_entry:
logger.debug(f"Cache miss: user={user_id}, hash={ref_audio_hash[:8]}...")
return None
cache_file = Path(cache_entry.cache_path)
if not cache_file.exists():
logger.warning(f"Cache file missing: {cache_file}")
delete_cache_entry(db, cache_entry.id, user_id)
return None
with open(cache_file, 'rb') as f:
cache_data = pickle.load(f)
logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}")
return {
'cache_id': cache_entry.id,
'data': cache_data,
'metadata': cache_entry.meta_data
}
except Exception as e:
logger.error(f"Cache retrieval error: {e}", exc_info=True)
return None
async def set_cache(
self,
user_id: int,
ref_audio_hash: str,
cache_data: Any,
metadata: Dict[str, Any],
db: Session
) -> str:
async with self._lock:
try:
cache_filename = f"{user_id}_{ref_audio_hash}.pkl"
cache_path = self.cache_dir / cache_filename
with open(cache_path, 'wb') as f:
pickle.dump(cache_data, f)
cache_entry = create_cache_entry(
db=db,
user_id=user_id,
ref_audio_hash=ref_audio_hash,
cache_path=str(cache_path),
meta_data=metadata
)
await self.enforce_max_entries(user_id, db)
logger.info(f"Cache created: user={user_id}, hash={ref_audio_hash[:8]}..., id={cache_entry.id}")
return cache_entry.id
except Exception as e:
logger.error(f"Cache creation error: {e}", exc_info=True)
if cache_path.exists():
cache_path.unlink()
raise
async def enforce_max_entries(self, user_id: int, db: Session) -> int:
try:
all_caches = list_cache_entries(db, user_id, skip=0, limit=9999)
if len(all_caches) <= self.max_entries:
return 0
caches_to_delete = all_caches[self.max_entries:]
deleted_count = 0
for cache in caches_to_delete:
cache_file = Path(cache.cache_path)
if cache_file.exists():
cache_file.unlink()
delete_cache_entry(db, cache.id, user_id)
deleted_count += 1
if deleted_count > 0:
logger.info(f"LRU eviction: user={user_id}, deleted={deleted_count} entries")
return deleted_count
except Exception as e:
logger.error(f"LRU enforcement error: {e}", exc_info=True)
return 0
async def cleanup_expired(self, db: Session) -> int:
try:
cutoff_date = datetime.utcnow() - timedelta(days=self.ttl_days)
expired_caches = db.query(VoiceCache).filter(
VoiceCache.last_accessed < cutoff_date
).all()
deleted_count = 0
for cache in expired_caches:
cache_file = Path(cache.cache_path)
if cache_file.exists():
cache_file.unlink()
db.delete(cache)
deleted_count += 1
if deleted_count > 0:
db.commit()
logger.info(f"Expired cache cleanup: deleted={deleted_count} entries")
return deleted_count
except Exception as e:
logger.error(f"Expired cache cleanup error: {e}", exc_info=True)
db.rollback()
return 0

View File

@@ -0,0 +1,166 @@
import logging
from datetime import datetime, timedelta
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from core.config import settings
from core.cache_manager import VoiceCacheManager
from db.models import Job
logger = logging.getLogger(__name__)
async def cleanup_expired_caches(db_url: str) -> dict:
try:
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
cache_manager = await VoiceCacheManager.get_instance()
deleted_count = await cache_manager.cleanup_expired(db)
freed_space_mb = 0
db.close()
logger.info(f"Cleanup: deleted {deleted_count} expired caches")
return {
'deleted_count': deleted_count,
'freed_space_mb': freed_space_mb
}
except Exception as e:
logger.error(f"Expired cache cleanup failed: {e}", exc_info=True)
return {
'deleted_count': 0,
'freed_space_mb': 0,
'error': str(e)
}
async def cleanup_old_jobs(db_url: str, days: int = 7) -> dict:
try:
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
cutoff_date = datetime.utcnow() - timedelta(days=days)
old_jobs = db.query(Job).filter(
Job.completed_at < cutoff_date,
Job.status.in_(['completed', 'failed'])
).all()
deleted_files = 0
for job in old_jobs:
if job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
output_file.unlink()
deleted_files += 1
db.delete(job)
db.commit()
deleted_jobs = len(old_jobs)
db.close()
logger.info(f"Cleanup: deleted {deleted_jobs} old jobs, {deleted_files} files")
return {
'deleted_jobs': deleted_jobs,
'deleted_files': deleted_files
}
except Exception as e:
logger.error(f"Old job cleanup failed: {e}", exc_info=True)
return {
'deleted_jobs': 0,
'deleted_files': 0,
'error': str(e)
}
async def cleanup_orphaned_files(db_url: str) -> dict:
try:
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
output_dir = Path(settings.OUTPUT_DIR)
cache_dir = Path(settings.CACHE_DIR)
output_files_in_db = {Path(job.output_path).name for job in db.query(Job.output_path).filter(Job.output_path.isnot(None)).all()}
from db.models import VoiceCache
cache_files_in_db = {Path(cache.cache_path).name for cache in db.query(VoiceCache.cache_path).all()}
deleted_orphans = 0
freed_space_bytes = 0
if output_dir.exists():
for output_file in output_dir.glob("*.wav"):
if output_file.name not in output_files_in_db:
size = output_file.stat().st_size
output_file.unlink()
deleted_orphans += 1
freed_space_bytes += size
if cache_dir.exists():
for cache_file in cache_dir.glob("*.pkl"):
if cache_file.name not in cache_files_in_db:
size = cache_file.stat().st_size
cache_file.unlink()
deleted_orphans += 1
freed_space_bytes += size
freed_space_mb = freed_space_bytes / (1024 * 1024)
db.close()
logger.info(f"Cleanup: deleted {deleted_orphans} orphaned files, freed {freed_space_mb:.2f} MB")
return {
'deleted_orphans': deleted_orphans,
'freed_space_mb': freed_space_mb
}
except Exception as e:
logger.error(f"Orphaned file cleanup failed: {e}", exc_info=True)
return {
'deleted_orphans': 0,
'freed_space_mb': 0,
'error': str(e)
}
async def run_scheduled_cleanup(db_url: str) -> dict:
logger.info("Starting scheduled cleanup task...")
try:
cache_result = await cleanup_expired_caches(db_url)
job_result = await cleanup_old_jobs(db_url)
orphan_result = await cleanup_orphaned_files(db_url)
result = {
'timestamp': datetime.utcnow().isoformat(),
'expired_caches': cache_result,
'old_jobs': job_result,
'orphaned_files': orphan_result,
'status': 'completed'
}
logger.info(f"Scheduled cleanup completed: {result}")
return result
except Exception as e:
logger.error(f"Scheduled cleanup failed: {e}", exc_info=True)
return {
'timestamp': datetime.utcnow().isoformat(),
'status': 'failed',
'error': str(e)
}

View File

@@ -0,0 +1,3 @@
from config import settings, Settings
__all__ = ['settings', 'Settings']

View File

@@ -0,0 +1,3 @@
from db.database import Base, engine, SessionLocal, get_db, init_db
__all__ = ['Base', 'engine', 'SessionLocal', 'get_db', 'init_db']

View File

@@ -0,0 +1,156 @@
import time
import logging
from collections import deque, defaultdict
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
import asyncio
import statistics
logger = logging.getLogger(__name__)
@dataclass
class RequestMetric:
timestamp: float
endpoint: str
duration: float
status_code: int
queue_time: float = 0.0
class MetricsCollector:
_instance: Optional['MetricsCollector'] = None
_lock = asyncio.Lock()
def __init__(self, window_size: int = 1000):
self.window_size = window_size
self.requests: deque = deque(maxlen=window_size)
self.request_counts: Dict[str, int] = defaultdict(int)
self.error_counts: Dict[str, int] = defaultdict(int)
self.total_requests = 0
self.start_time = time.time()
self.batch_stats = {
'total_batches': 0,
'total_requests_batched': 0,
'avg_batch_size': 0.0
}
self._lock_local = asyncio.Lock()
logger.info("MetricsCollector initialized")
@classmethod
async def get_instance(cls) -> 'MetricsCollector':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def record_request(
self,
endpoint: str,
duration: float,
status_code: int,
queue_time: float = 0.0
):
async with self._lock_local:
metric = RequestMetric(
timestamp=time.time(),
endpoint=endpoint,
duration=duration,
status_code=status_code,
queue_time=queue_time
)
self.requests.append(metric)
self.request_counts[endpoint] += 1
self.total_requests += 1
if status_code >= 400:
self.error_counts[endpoint] += 1
async def record_batch(self, batch_size: int):
async with self._lock_local:
self.batch_stats['total_batches'] += 1
self.batch_stats['total_requests_batched'] += batch_size
total_batches = self.batch_stats['total_batches']
total_requests = self.batch_stats['total_requests_batched']
self.batch_stats['avg_batch_size'] = total_requests / total_batches if total_batches > 0 else 0.0
async def get_metrics(self) -> Dict[str, Any]:
async with self._lock_local:
current_time = time.time()
uptime = current_time - self.start_time
recent_requests = [r for r in self.requests if current_time - r.timestamp < 60]
durations = [r.duration for r in self.requests if r.duration > 0]
queue_times = [r.queue_time for r in self.requests if r.queue_time > 0]
percentiles = {}
if durations:
sorted_durations = sorted(durations)
percentiles = {
'p50': statistics.median(sorted_durations),
'p95': sorted_durations[int(len(sorted_durations) * 0.95)] if len(sorted_durations) > 0 else 0,
'p99': sorted_durations[int(len(sorted_durations) * 0.99)] if len(sorted_durations) > 0 else 0,
'avg': statistics.mean(sorted_durations),
'min': min(sorted_durations),
'max': max(sorted_durations)
}
queue_percentiles = {}
if queue_times:
sorted_queue_times = sorted(queue_times)
queue_percentiles = {
'p50': statistics.median(sorted_queue_times),
'p95': sorted_queue_times[int(len(sorted_queue_times) * 0.95)] if len(sorted_queue_times) > 0 else 0,
'p99': sorted_queue_times[int(len(sorted_queue_times) * 0.99)] if len(sorted_queue_times) > 0 else 0,
'avg': statistics.mean(sorted_queue_times)
}
requests_per_second = len(recent_requests) / 60.0 if recent_requests else 0.0
import torch
gpu_stats = {}
if torch.cuda.is_available():
gpu_stats = {
'gpu_available': True,
'gpu_memory_allocated_mb': torch.cuda.memory_allocated(0) / 1024**2,
'gpu_memory_reserved_mb': torch.cuda.memory_reserved(0) / 1024**2,
'gpu_memory_total_mb': torch.cuda.get_device_properties(0).total_memory / 1024**2
}
else:
gpu_stats = {'gpu_available': False}
from core.batch_processor import BatchProcessor
batch_processor = await BatchProcessor.get_instance()
batch_stats_current = await batch_processor.get_stats()
return {
'uptime_seconds': uptime,
'total_requests': self.total_requests,
'requests_per_second': requests_per_second,
'request_counts_by_endpoint': dict(self.request_counts),
'error_counts_by_endpoint': dict(self.error_counts),
'latency': percentiles,
'queue_time': queue_percentiles,
'batch_processing': {
**self.batch_stats,
**batch_stats_current
},
'gpu': gpu_stats
}
async def reset(self):
async with self._lock_local:
self.requests.clear()
self.request_counts.clear()
self.error_counts.clear()
self.total_requests = 0
self.start_time = time.time()
self.batch_stats = {
'total_batches': 0,
'total_requests_batched': 0,
'avg_batch_size': 0.0
}
logger.info("Metrics reset")

View File

@@ -0,0 +1,123 @@
import asyncio
import logging
from typing import Optional
import torch
from qwen_tts import Qwen3TTSModel
from core.config import settings
logger = logging.getLogger(__name__)
class ModelManager:
_instance: Optional['ModelManager'] = None
_lock = asyncio.Lock()
MODEL_PATHS = {
"custom-voice": "Qwen3-TTS-12Hz-1.7B-CustomVoice",
"voice-design": "Qwen3-TTS-12Hz-1.7B-VoiceDesign",
"base": "Qwen3-TTS-12Hz-1.7B-Base"
}
def __init__(self):
if ModelManager._instance is not None:
raise RuntimeError("Use get_instance() to get ModelManager")
self.current_model_name: Optional[str] = None
self.tts: Optional[Qwen3TTSModel] = None
@classmethod
async def get_instance(cls) -> 'ModelManager':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def load_model(self, model_name: str) -> None:
if model_name not in self.MODEL_PATHS:
raise ValueError(
f"Unknown model: {model_name}. "
f"Available models: {list(self.MODEL_PATHS.keys())}"
)
if self.current_model_name == model_name and self.tts is not None:
logger.info(f"Model {model_name} already loaded")
return
async with self._lock:
logger.info(f"Loading model: {model_name}")
if self.tts is not None:
logger.info(f"Unloading current model: {self.current_model_name}")
await self._unload_model_internal()
from pathlib import Path
model_base_path = Path(settings.MODEL_BASE_PATH)
local_model_path = model_base_path / self.MODEL_PATHS[model_name]
if local_model_path.exists():
model_path = str(local_model_path)
logger.info(f"Using local model: {model_path}")
else:
model_path = f"Qwen/{self.MODEL_PATHS[model_name]}"
logger.info(f"Local path not found, using HuggingFace: {model_path}")
try:
self.tts = Qwen3TTSModel.from_pretrained(
str(model_path),
device_map=settings.MODEL_DEVICE,
torch_dtype=torch.bfloat16
)
self.current_model_name = model_name
logger.info(f"Successfully loaded model: {model_name}")
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / 1024**3
logger.info(f"GPU memory allocated: {allocated:.2f} GB")
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
self.tts = None
self.current_model_name = None
raise
async def get_current_model(self) -> tuple[Optional[str], Optional[Qwen3TTSModel]]:
return self.current_model_name, self.tts
async def unload_model(self) -> None:
async with self._lock:
await self._unload_model_internal()
async def _unload_model_internal(self) -> None:
if self.tts is not None:
logger.info(f"Unloading model: {self.current_model_name}")
del self.tts
self.tts = None
self.current_model_name = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Cleared CUDA cache")
async def get_memory_usage(self) -> dict:
memory_info = {
"gpu_available": torch.cuda.is_available(),
"current_model": self.current_model_name
}
if torch.cuda.is_available():
memory_info.update({
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
})
return memory_info
def get_model_info(self) -> dict:
return {
name: {
"path": path,
"loaded": name == self.current_model_name
}
for name, path in self.MODEL_PATHS.items()
}

View File

@@ -0,0 +1,35 @@
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> Optional[str]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
return username
except JWTError:
return None

View File

View File

@@ -0,0 +1,198 @@
import json
from typing import Optional, List, Dict, Any
from datetime import datetime
from sqlalchemy.orm import Session
from db.models import User, Job, VoiceCache
def get_user_by_username(db: Session, username: str) -> Optional[User]:
return db.query(User).filter(User.username == username).first()
def get_user_by_email(db: Session, email: str) -> Optional[User]:
return db.query(User).filter(User.email == email).first()
def create_user(db: Session, username: str, email: str, hashed_password: str) -> User:
user = User(
username=username,
email=email,
hashed_password=hashed_password
)
db.add(user)
db.commit()
db.refresh(user)
return user
def create_user_by_admin(
db: Session,
username: str,
email: str,
hashed_password: str,
is_superuser: bool = False
) -> User:
user = User(
username=username,
email=email,
hashed_password=hashed_password,
is_superuser=is_superuser
)
db.add(user)
db.commit()
db.refresh(user)
return user
def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
return db.query(User).filter(User.id == user_id).first()
def list_users(db: Session, skip: int = 0, limit: int = 100) -> tuple[List[User], int]:
total = db.query(User).count()
users = db.query(User).order_by(User.created_at.desc()).offset(skip).limit(limit).all()
return users, total
def update_user(
db: Session,
user_id: int,
username: Optional[str] = None,
email: Optional[str] = None,
hashed_password: Optional[str] = None,
is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None
) -> Optional[User]:
user = get_user_by_id(db, user_id)
if not user:
return None
if username is not None:
user.username = username
if email is not None:
user.email = email
if hashed_password is not None:
user.hashed_password = hashed_password
if is_active is not None:
user.is_active = is_active
if is_superuser is not None:
user.is_superuser = is_superuser
user.updated_at = datetime.utcnow()
db.commit()
db.refresh(user)
return user
def delete_user(db: Session, user_id: int) -> bool:
user = get_user_by_id(db, user_id)
if not user:
return False
db.delete(user)
db.commit()
return True
def create_job(db: Session, user_id: int, job_type: str, input_data: Dict[str, Any]) -> Job:
job = Job(
user_id=user_id,
job_type=job_type,
input_data=json.dumps(input_data),
status="pending"
)
db.add(job)
db.commit()
db.refresh(job)
return job
def get_job(db: Session, job_id: int, user_id: int) -> Optional[Job]:
return db.query(Job).filter(Job.id == job_id, Job.user_id == user_id).first()
def list_jobs(
db: Session,
user_id: int,
skip: int = 0,
limit: int = 100,
status: Optional[str] = None
) -> List[Job]:
query = db.query(Job).filter(Job.user_id == user_id)
if status:
query = query.filter(Job.status == status)
return query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
def update_job_status(
db: Session,
job_id: int,
user_id: int,
status: str,
output_path: Optional[str] = None,
error_message: Optional[str] = None
) -> Optional[Job]:
job = get_job(db, job_id, user_id)
if not job:
return None
job.status = status
if output_path:
job.output_path = output_path
if error_message:
job.error_message = error_message
if status in ["completed", "failed"]:
job.completed_at = datetime.utcnow()
db.commit()
db.refresh(job)
return job
def delete_job(db: Session, job_id: int, user_id: int) -> bool:
job = get_job(db, job_id, user_id)
if not job:
return False
db.delete(job)
db.commit()
return True
def create_cache_entry(
db: Session,
user_id: int,
ref_audio_hash: str,
cache_path: str,
meta_data: Optional[Dict[str, Any]] = None
) -> VoiceCache:
cache = VoiceCache(
user_id=user_id,
ref_audio_hash=ref_audio_hash,
cache_path=cache_path,
meta_data=json.dumps(meta_data) if meta_data else None
)
db.add(cache)
db.commit()
db.refresh(cache)
return cache
def get_cache_entry(db: Session, user_id: int, ref_audio_hash: str) -> Optional[VoiceCache]:
cache = db.query(VoiceCache).filter(
VoiceCache.user_id == user_id,
VoiceCache.ref_audio_hash == ref_audio_hash
).first()
if cache:
cache.last_accessed = datetime.utcnow()
cache.access_count += 1
db.commit()
db.refresh(cache)
return cache
def list_cache_entries(
db: Session,
user_id: int,
skip: int = 0,
limit: int = 100
) -> List[VoiceCache]:
return db.query(VoiceCache).filter(
VoiceCache.user_id == user_id
).order_by(VoiceCache.last_accessed.desc()).offset(skip).limit(limit).all()
def delete_cache_entry(db: Session, cache_id: int, user_id: int) -> bool:
cache = db.query(VoiceCache).filter(
VoiceCache.id == cache_id,
VoiceCache.user_id == user_id
).first()
if not cache:
return False
db.delete(cache)
db.commit()
return True

View File

@@ -0,0 +1,23 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from config import settings
engine = create_engine(
settings.DATABASE_URL,
connect_args={"check_same_thread": False} if "sqlite" in settings.DATABASE_URL else {}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db():
Base.metadata.create_all(bind=engine)

View File

@@ -0,0 +1,67 @@
from datetime import datetime
from enum import Enum
from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey, Text, Index, JSON
from sqlalchemy.orm import relationship
from db.database import Base
class JobStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(255), unique=True, nullable=False, index=True)
hashed_password = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True, nullable=False)
is_superuser = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
jobs = relationship("Job", back_populates="user", cascade="all, delete-orphan")
voice_caches = relationship("VoiceCache", back_populates="user", cascade="all, delete-orphan")
class Job(Base):
__tablename__ = "jobs"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
job_type = Column(String(50), nullable=False)
status = Column(String(50), default="pending", nullable=False, index=True)
input_data = Column(Text, nullable=True)
input_params = Column(JSON, nullable=True)
output_path = Column(String(500), nullable=True)
error_message = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
user = relationship("User", back_populates="jobs")
__table_args__ = (
Index('idx_user_status', 'user_id', 'status'),
Index('idx_user_created', 'user_id', 'created_at'),
)
class VoiceCache(Base):
__tablename__ = "voice_caches"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
ref_audio_hash = Column(String(64), nullable=False, index=True)
cache_path = Column(String(500), nullable=False)
meta_data = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
last_accessed = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
access_count = Column(Integer, default=0, nullable=False)
user = relationship("User", back_populates="voice_caches")
__table_args__ = (
Index('idx_user_hash', 'user_id', 'ref_audio_hash'),
)

View File

@@ -0,0 +1,55 @@
upstream qwen_tts_backend {
server 127.0.0.1:8000;
}
server {
listen 80;
server_name your-domain.com;
client_max_body_size 100M;
client_body_timeout 300s;
proxy_read_timeout 300s;
proxy_connect_timeout 300s;
proxy_send_timeout 300s;
location / {
proxy_pass http://qwen_tts_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'Authorization, Content-Type' always;
if ($request_method = 'OPTIONS') {
add_header 'Access-Control-Allow-Origin' '*';
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
add_header 'Access-Control-Allow-Headers' 'Authorization, Content-Type';
add_header 'Content-Length' '0';
add_header 'Content-Type' 'text/plain';
return 204;
}
}
location /outputs/ {
alias /opt/qwen3-tts-backend/outputs/;
autoindex off;
add_header Cache-Control "public, max-age=3600";
add_header Content-Disposition "attachment";
}
location /health {
proxy_pass http://qwen_tts_backend/health;
proxy_set_header Host $host;
access_log off;
}
location /metrics {
proxy_pass http://qwen_tts_backend/metrics;
proxy_set_header Host $host;
allow 127.0.0.1;
deny all;
}
}

View File

@@ -0,0 +1,21 @@
[Unit]
Description=Qwen3-TTS Backend API Service
After=network.target
[Service]
Type=simple
User=qwen-tts
Group=qwen-tts
WorkingDirectory=/opt/qwen3-tts-backend
Environment="PATH=/opt/conda/envs/qwen3-tts/bin:/usr/local/bin:/usr/bin:/bin"
EnvironmentFile=/opt/qwen3-tts-backend/.env
ExecStart=/opt/conda/envs/qwen3-tts/bin/python main.py
Restart=on-failure
RestartSec=10s
StandardOutput=append:/var/log/qwen-tts/app.log
StandardError=append:/var/log/qwen-tts/error.log
TimeoutStopSec=30s
KillMode=mixed
[Install]
WantedBy=multi-user.target

221
qwen3-tts-backend/main.py Normal file
View File

@@ -0,0 +1,221 @@
import logging
import sys
from contextlib import asynccontextmanager
from pathlib import Path
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from sqlalchemy import text
from core.config import settings
from core.database import init_db
from core.model_manager import ModelManager
from core.cleanup import run_scheduled_cleanup
from api import auth, jobs, tts, cache, metrics, users
from apscheduler.schedulers.asyncio import AsyncIOScheduler
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL.upper()),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(settings.LOG_FILE)
]
)
logger = logging.getLogger(__name__)
def get_user_identifier(request: Request) -> str:
from jose import jwt
from core.config import settings
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
user_id = payload.get("sub")
if user_id:
return f"user:{user_id}"
except Exception:
pass
return get_remote_address(request)
limiter = Limiter(key_func=get_user_identifier)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting Qwen3-TTS Backend Service...")
logger.info(f"Model base path: {settings.MODEL_BASE_PATH}")
logger.info(f"Cache directory: {settings.CACHE_DIR}")
logger.info(f"Output directory: {settings.OUTPUT_DIR}")
logger.info(f"Device: {settings.MODEL_DEVICE}")
try:
settings.validate()
logger.info("Configuration validated successfully")
except Exception as e:
logger.error(f"Configuration validation failed: {e}")
raise
try:
init_db()
logger.info("Database initialized successfully")
except Exception as e:
logger.error(f"Database initialization failed: {e}")
raise
try:
model_manager = await ModelManager.get_instance()
await model_manager.load_model("custom-voice")
logger.info("Preloaded custom-voice model")
except Exception as e:
logger.warning(f"Model preload failed: {e}")
scheduler = AsyncIOScheduler()
scheduler.add_job(
run_scheduled_cleanup,
'interval',
hours=6,
args=[str(settings.DATABASE_URL)],
id='cleanup_task'
)
scheduler.start()
logger.info("Background cleanup scheduler started (runs every 6 hours)")
yield
logger.info("Shutting down Qwen3-TTS Backend Service...")
scheduler.shutdown()
logger.info("Scheduler shutdown completed")
try:
model_manager = await ModelManager.get_instance()
await model_manager.unload_model()
logger.info("Model cleanup completed")
except Exception as e:
logger.error(f"Model cleanup failed: {e}")
app = FastAPI(
title="Qwen3-TTS-WebUI Backend API",
description="Backend service for Qwen3-TTS-WebUI text-to-speech system",
version="0.1.0",
lifespan=lifespan
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth.router)
app.include_router(jobs.router)
app.include_router(tts.router)
app.include_router(cache.router)
app.include_router(metrics.router)
app.include_router(users.router)
@app.get("/health")
async def health_check():
from core.batch_processor import BatchProcessor
from core.database import SessionLocal
gpu_available = torch.cuda.is_available()
gpu_memory_used_mb = 0
gpu_memory_total_mb = 0
if gpu_available:
gpu_memory_used_mb = torch.cuda.memory_allocated(0) / 1024**2
gpu_memory_total_mb = torch.cuda.get_device_properties(0).total_memory / 1024**2
model_manager = await ModelManager.get_instance()
current_model, _ = await model_manager.get_current_model()
batch_processor = await BatchProcessor.get_instance()
queue_length = await batch_processor.get_queue_length()
database_connected = True
try:
db = SessionLocal()
db.execute(text("SELECT 1"))
db.close()
except Exception as e:
logger.error(f"Database health check failed: {e}")
database_connected = False
cache_dir_writable = True
try:
test_file = Path(settings.CACHE_DIR) / ".health_check"
test_file.write_text("test")
test_file.unlink()
except Exception as e:
logger.error(f"Cache directory health check failed: {e}")
cache_dir_writable = False
output_dir_writable = True
try:
test_file = Path(settings.OUTPUT_DIR) / ".health_check"
test_file.write_text("test")
test_file.unlink()
except Exception as e:
logger.error(f"Output directory health check failed: {e}")
output_dir_writable = False
critical_issues = []
if not database_connected:
critical_issues.append("database_disconnected")
if not cache_dir_writable:
critical_issues.append("cache_dir_not_writable")
if not output_dir_writable:
critical_issues.append("output_dir_not_writable")
minor_issues = []
if not gpu_available:
minor_issues.append("gpu_not_available")
if queue_length > 50:
minor_issues.append("queue_congested")
if critical_issues:
status = "unhealthy"
elif minor_issues:
status = "degraded"
else:
status = "healthy"
return {
"status": status,
"gpu_available": gpu_available,
"gpu_memory_used_mb": gpu_memory_used_mb,
"gpu_memory_total_mb": gpu_memory_total_mb,
"queue_length": queue_length,
"active_model": current_model,
"database_connected": database_connected,
"cache_dir_writable": cache_dir_writable,
"output_dir_writable": output_dir_writable,
"issues": {
"critical": critical_issues,
"minor": minor_issues
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=settings.HOST,
port=settings.PORT,
workers=settings.WORKERS,
log_level=settings.LOG_LEVEL.lower()
)

View File

@@ -0,0 +1,6 @@
[pytest]
asyncio_mode = auto
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*

View File

@@ -0,0 +1,18 @@
fastapi==0.115.0
uvicorn[standard]==0.32.0
pydantic==2.9.0
python-multipart==0.0.12
python-jose[cryptography]==3.3.0
passlib==1.7.4
bcrypt==3.2.2
sqlalchemy==2.0.35
aiosqlite==0.20.0
soundfile==0.12.1
scipy>=1.11.0
apscheduler>=3.10.0
slowapi==0.1.9
locust==2.20.0
pytest==8.3.0
pytest-cov==4.1.0
pytest-asyncio==0.23.0
httpx==0.27.0

View File

View File

@@ -0,0 +1,15 @@
from datetime import datetime
from typing import Optional, Dict, Any
from pydantic import BaseModel, ConfigDict
class CacheEntry(BaseModel):
id: int
user_id: int
ref_audio_hash: str
cache_path: str
meta_data: Optional[Dict[str, Any]] = None
created_at: datetime
last_accessed: datetime
access_count: int
model_config = ConfigDict(from_attributes=True)

View File

@@ -0,0 +1,25 @@
from datetime import datetime
from typing import Optional, Dict, Any, List
from pydantic import BaseModel, ConfigDict
class JobBase(BaseModel):
job_type: str
class JobCreate(JobBase):
input_data: Dict[str, Any]
class Job(JobBase):
id: int
user_id: int
status: str
output_path: Optional[str] = None
download_url: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime
completed_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class JobList(BaseModel):
total: int
jobs: List[Job]

View File

@@ -0,0 +1,50 @@
from typing import Optional, List
from pydantic import BaseModel, Field
class TTSRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=1000)
ref_audio: Optional[str] = None
ref_text: Optional[str] = None
language: str = Field(default="en")
speed: float = Field(default=1.0, ge=0.5, le=2.0)
class TTSResponse(BaseModel):
job_id: int
status: str
audio_url: Optional[str] = None
class CustomVoiceRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=1000)
language: str = Field(default="Auto")
speaker: str
instruct: Optional[str] = Field(default="")
max_new_tokens: Optional[int] = Field(default=2048, ge=128, le=4096)
temperature: Optional[float] = Field(default=0.9, ge=0.1, le=2.0)
top_k: Optional[int] = Field(default=50, ge=1, le=100)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
repetition_penalty: Optional[float] = Field(default=1.05, ge=1.0, le=2.0)
class VoiceDesignRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=1000)
language: str = Field(default="Auto")
instruct: str = Field(..., min_length=1)
max_new_tokens: Optional[int] = Field(default=2048, ge=128, le=4096)
temperature: Optional[float] = Field(default=0.9, ge=0.1, le=2.0)
top_k: Optional[int] = Field(default=50, ge=1, le=100)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
repetition_penalty: Optional[float] = Field(default=1.05, ge=1.0, le=2.0)
class VoiceCloneRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=1000)
language: str = Field(default="Auto")
ref_text: Optional[str] = Field(default=None, max_length=500)
use_cache: bool = Field(default=True)
x_vector_only_mode: bool = Field(default=False)
max_new_tokens: Optional[int] = Field(default=2048, ge=128, le=4096)
temperature: Optional[float] = Field(default=0.9, ge=0.1, le=2.0)
top_k: Optional[int] = Field(default=50, ge=1, le=100)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
repetition_penalty: Optional[float] = Field(default=1.05, ge=1.0, le=2.0)

View File

@@ -0,0 +1,91 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, EmailStr, Field, field_validator, ConfigDict
import re
class UserBase(BaseModel):
username: str = Field(..., min_length=3, max_length=50)
email: EmailStr
@field_validator('username')
@classmethod
def validate_username(cls, v: str) -> str:
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
raise ValueError('Username must contain only alphanumeric characters, underscores, and dashes')
return v
class UserCreate(UserBase):
password: str = Field(..., min_length=8, max_length=128)
@field_validator('password')
@classmethod
def validate_password_strength(cls, v: str) -> str:
if not re.search(r'[A-Z]', v):
raise ValueError('Password must contain at least one uppercase letter')
if not re.search(r'[a-z]', v):
raise ValueError('Password must contain at least one lowercase letter')
if not re.search(r'\d', v):
raise ValueError('Password must contain at least one digit')
return v
class User(UserBase):
id: int
is_active: bool
is_superuser: bool
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class UserCreateByAdmin(UserBase):
password: str = Field(..., min_length=8, max_length=128)
is_superuser: bool = False
@field_validator('password')
@classmethod
def validate_password_strength(cls, v: str) -> str:
if not re.search(r'[A-Z]', v):
raise ValueError('Password must contain at least one uppercase letter')
if not re.search(r'[a-z]', v):
raise ValueError('Password must contain at least one lowercase letter')
if not re.search(r'\d', v):
raise ValueError('Password must contain at least one digit')
return v
class UserUpdate(BaseModel):
username: Optional[str] = Field(None, min_length=3, max_length=50)
email: Optional[EmailStr] = None
password: Optional[str] = Field(None, min_length=8, max_length=128)
is_active: Optional[bool] = None
is_superuser: Optional[bool] = None
@field_validator('username')
@classmethod
def validate_username(cls, v: Optional[str]) -> Optional[str]:
if v is not None and not re.match(r'^[a-zA-Z0-9_-]+$', v):
raise ValueError('Username must contain only alphanumeric characters, underscores, and dashes')
return v
@field_validator('password')
@classmethod
def validate_password_strength(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
if not re.search(r'[A-Z]', v):
raise ValueError('Password must contain at least one uppercase letter')
if not re.search(r'[a-z]', v):
raise ValueError('Password must contain at least one lowercase letter')
if not re.search(r'\d', v):
raise ValueError('Password must contain at least one digit')
return v
class UserListResponse(BaseModel):
users: list[User]
total: int
skip: int
limit: int
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None

View File

@@ -0,0 +1,23 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import text
from core.database import engine
def add_superuser_field():
try:
with engine.connect() as conn:
conn.execute(text("ALTER TABLE users ADD COLUMN is_superuser BOOLEAN NOT NULL DEFAULT 0"))
conn.commit()
print("Successfully added is_superuser field to users table")
except Exception as e:
if "duplicate column name" in str(e).lower() or "already exists" in str(e).lower():
print("is_superuser field already exists, skipping")
else:
print(f"Error adding is_superuser field: {e}")
raise
if __name__ == "__main__":
add_superuser_field()

View File

@@ -0,0 +1,40 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from core.database import SessionLocal
from core.security import get_password_hash
from db.crud import get_user_by_username, create_user_by_admin
def create_admin():
db = SessionLocal()
try:
existing_admin = get_user_by_username(db, username="admin")
if existing_admin:
print("Admin user already exists")
if not existing_admin.is_superuser:
existing_admin.is_superuser = True
db.commit()
print("Updated existing admin user to superuser")
return
hashed_password = get_password_hash("admin123456")
admin_user = create_user_by_admin(
db,
username="admin",
email="admin@example.com",
hashed_password=hashed_password,
is_superuser=True
)
print(f"Created admin user successfully: {admin_user.username}")
print("Username: admin")
print("Password: admin123456")
except Exception as e:
print(f"Error creating admin user: {e}")
raise
finally:
db.close()
if __name__ == "__main__":
create_admin()

View File

View File

@@ -0,0 +1,113 @@
import base64
import io
from pathlib import Path
import numpy as np
import soundfile as sf
from scipy import signal
def validate_ref_audio(audio_data: bytes, max_size_mb: int = 10) -> bool:
try:
size_mb = len(audio_data) / (1024 * 1024)
if size_mb > max_size_mb:
return False
buffer = io.BytesIO(audio_data)
audio_array, sample_rate = sf.read(buffer)
duration = len(audio_array) / sample_rate
if duration < 1.0 or duration > 30.0:
return False
return True
except Exception:
return False
def process_ref_audio(audio_data: bytes) -> tuple[np.ndarray, int]:
buffer = io.BytesIO(audio_data)
audio_array, orig_sr = sf.read(buffer)
if audio_array.ndim > 1:
audio_array = np.mean(audio_array, axis=1)
target_sr = 24000
if orig_sr != target_sr:
audio_array = resample_audio(audio_array, orig_sr, target_sr)
audio_array = audio_array.astype(np.float32)
return audio_array, target_sr
def resample_audio(audio_array: np.ndarray, orig_sr: int, target_sr: int = 24000) -> np.ndarray:
if orig_sr == target_sr:
return audio_array
num_samples = int(len(audio_array) * target_sr / orig_sr)
resampled = signal.resample(audio_array, num_samples)
return resampled.astype(np.float32)
def extract_audio_features(audio_array: np.ndarray, sample_rate: int) -> dict:
duration = len(audio_array) / sample_rate
rms_energy = np.sqrt(np.mean(audio_array ** 2))
return {
'duration': float(duration),
'sample_rate': int(sample_rate),
'num_samples': int(len(audio_array)),
'rms_energy': float(rms_energy)
}
def encode_audio_to_base64(audio_array: np.ndarray, sample_rate: int) -> str:
buffer = io.BytesIO()
sf.write(buffer, audio_array, sample_rate, format='WAV')
buffer.seek(0)
audio_bytes = buffer.read()
return base64.b64encode(audio_bytes).decode('utf-8')
def decode_base64_to_audio(base64_string: str) -> tuple[np.ndarray, int]:
audio_bytes = base64.b64decode(base64_string)
buffer = io.BytesIO(audio_bytes)
audio_array, sample_rate = sf.read(buffer)
return audio_array, sample_rate
def validate_audio_format(audio_data: bytes) -> bool:
try:
buffer = io.BytesIO(audio_data)
sf.read(buffer)
return True
except Exception:
return False
def get_audio_duration(audio_array: np.ndarray, sample_rate: int) -> float:
return len(audio_array) / sample_rate
def save_audio_file(
audio_array: np.ndarray,
sample_rate: int,
output_path: str | Path
) -> str:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
if not isinstance(audio_array, np.ndarray):
audio_array = np.array(audio_array, dtype=np.float32)
if audio_array.ndim == 1:
pass
elif audio_array.ndim == 2:
if audio_array.shape[0] < audio_array.shape[1]:
audio_array = audio_array.T
else:
raise ValueError(f"Unexpected audio array shape: {audio_array.shape}")
audio_array = audio_array.astype(np.float32)
sf.write(str(output_path), audio_array, sample_rate, format='WAV', subtype='PCM_16')
return str(output_path)

View File

@@ -0,0 +1,80 @@
import threading
from typing import Dict
from pathlib import Path
from sqlalchemy.orm import Session
from db.models import VoiceCache
class CacheMetrics:
def __init__(self):
self._lock = threading.Lock()
self.cache_hits = 0
self.cache_misses = 0
self._user_hits: Dict[int, int] = {}
self._user_misses: Dict[int, int] = {}
def record_hit(self, user_id: int):
with self._lock:
self.cache_hits += 1
self._user_hits[user_id] = self._user_hits.get(user_id, 0) + 1
def record_miss(self, user_id: int):
with self._lock:
self.cache_misses += 1
self._user_misses[user_id] = self._user_misses.get(user_id, 0) + 1
def get_stats(self, db: Session, cache_dir: str) -> dict:
with self._lock:
total_requests = self.cache_hits + self.cache_misses
hit_rate = self.cache_hits / total_requests if total_requests > 0 else 0.0
total_entries = db.query(VoiceCache).count()
total_size_bytes = 0
cache_path = Path(cache_dir)
if cache_path.exists():
for cache_file in cache_path.glob("*.pkl"):
total_size_bytes += cache_file.stat().st_size
total_size_mb = total_size_bytes / (1024 * 1024)
user_stats = []
for user_id in set(list(self._user_hits.keys()) + list(self._user_misses.keys())):
hits = self._user_hits.get(user_id, 0)
misses = self._user_misses.get(user_id, 0)
total = hits + misses
user_hit_rate = hits / total if total > 0 else 0.0
user_cache_count = db.query(VoiceCache).filter(
VoiceCache.user_id == user_id
).count()
user_stats.append({
'user_id': user_id,
'hits': hits,
'misses': misses,
'hit_rate': user_hit_rate,
'cache_entries': user_cache_count
})
return {
'global': {
'total_requests': total_requests,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'hit_rate': hit_rate,
'total_entries': total_entries,
'total_size_mb': total_size_mb
},
'users': user_stats
}
def reset(self):
with self._lock:
self.cache_hits = 0
self.cache_misses = 0
self._user_hits.clear()
self._user_misses.clear()
cache_metrics = CacheMetrics()

View File

@@ -0,0 +1,102 @@
from typing import List, Dict
SUPPORTED_LANGUAGES = [
"Chinese", "English", "Japanese", "Korean", "German",
"French", "Russian", "Portuguese", "Spanish", "Italian",
"Auto", "Cantonese"
]
SUPPORTED_SPEAKERS = [
"Vivian", "Serena", "Uncle_Fu", "Dylan", "Eric",
"Ryan", "Aiden", "Ono_Anna", "Sohee"
]
SPEAKER_DESCRIPTIONS = {
"Vivian": "Female, professional and clear",
"Serena": "Female, gentle and warm",
"Uncle_Fu": "Male, mature and authoritative",
"Dylan": "Male, young and energetic",
"Eric": "Male, calm and steady",
"Ryan": "Male, friendly and casual",
"Aiden": "Male, deep and resonant",
"Ono_Anna": "Female, cute and lively",
"Sohee": "Female, soft and melodious"
}
def validate_language(language: str) -> str:
normalized = language.strip()
for supported in SUPPORTED_LANGUAGES:
if normalized.lower() == supported.lower():
return supported
raise ValueError(
f"Unsupported language: {language}. "
f"Supported languages: {', '.join(SUPPORTED_LANGUAGES)}"
)
def validate_speaker(speaker: str) -> str:
normalized = speaker.strip()
for supported in SUPPORTED_SPEAKERS:
if normalized.lower() == supported.lower():
return supported
raise ValueError(
f"Unsupported speaker: {speaker}. "
f"Supported speakers: {', '.join(SUPPORTED_SPEAKERS)}"
)
def validate_text_length(text: str, max_length: int = 1000) -> str:
if not text or not text.strip():
raise ValueError("Text cannot be empty")
if len(text) > max_length:
raise ValueError(
f"Text length ({len(text)}) exceeds maximum ({max_length})"
)
return text.strip()
def validate_generation_params(params: dict) -> dict:
validated = {}
validated['max_new_tokens'] = params.get('max_new_tokens', 2048)
if not 128 <= validated['max_new_tokens'] <= 4096:
raise ValueError("max_new_tokens must be between 128 and 4096")
validated['temperature'] = params.get('temperature', 0.9)
if not 0.1 <= validated['temperature'] <= 2.0:
raise ValueError("temperature must be between 0.1 and 2.0")
validated['top_k'] = params.get('top_k', 50)
if not 1 <= validated['top_k'] <= 100:
raise ValueError("top_k must be between 1 and 100")
validated['top_p'] = params.get('top_p', 1.0)
if not 0.0 <= validated['top_p'] <= 1.0:
raise ValueError("top_p must be between 0.0 and 1.0")
validated['repetition_penalty'] = params.get('repetition_penalty', 1.05)
if not 1.0 <= validated['repetition_penalty'] <= 2.0:
raise ValueError("repetition_penalty must be between 1.0 and 2.0")
return validated
def get_supported_languages() -> List[str]:
return SUPPORTED_LANGUAGES.copy()
def get_supported_speakers() -> List[dict]:
return [
{
"name": speaker,
"description": SPEAKER_DESCRIPTIONS.get(speaker, "")
}
for speaker in SUPPORTED_SPEAKERS
]

View File

@@ -0,0 +1,2 @@
VITE_API_URL=http://localhost:8000
VITE_APP_NAME=Qwen3-TTS

View File

@@ -0,0 +1,2 @@
VITE_API_URL=https://api.example.com
VITE_APP_NAME=Qwen3-TTS

27
qwen3-tts-frontend/.gitignore vendored Normal file
View File

@@ -0,0 +1,27 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
.env
.env.local
.env.*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

View File

@@ -0,0 +1,20 @@
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "default",
"rsc": false,
"tsx": true,
"tailwind": {
"config": "tailwind.config.js",
"css": "src/index.css",
"baseColor": "slate",
"cssVariables": true,
"prefix": ""
},
"aliases": {
"components": "@/components",
"utils": "@/lib/utils",
"ui": "@/components/ui",
"lib": "@/lib",
"hooks": "@/hooks"
}
}

View File

@@ -0,0 +1,23 @@
import js from '@eslint/js'
import globals from 'globals'
import reactHooks from 'eslint-plugin-react-hooks'
import reactRefresh from 'eslint-plugin-react-refresh'
import tseslint from 'typescript-eslint'
import { defineConfig, globalIgnores } from 'eslint/config'
export default defineConfig([
globalIgnores(['dist']),
{
files: ['**/*.{ts,tsx}'],
extends: [
js.configs.recommended,
tseslint.configs.recommended,
reactHooks.configs.flat.recommended,
reactRefresh.configs.vite,
],
languageOptions: {
ecmaVersion: 2020,
globals: globals.browser,
},
},
])

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,25 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link rel="preload" href="/fonts/noto-serif-regular.woff2" as="font" type="font/woff2" crossorigin>
<title>Qwen3-TTS-WebUI</title>
<script>
(function() {
try {
const theme = localStorage.getItem('theme');
const systemDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
if (theme === 'dark' || (!theme && systemDark)) {
document.documentElement.classList.add('dark');
}
} catch (e) {}
})();
</script>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

6157
qwen3-tts-frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,58 @@
{
"name": "qwen3-tts-frontend",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build",
"lint": "eslint .",
"preview": "vite preview"
},
"dependencies": {
"@hookform/resolvers": "^5.2.2",
"@radix-ui/react-alert-dialog": "^1.1.15",
"@radix-ui/react-checkbox": "^1.3.3",
"@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-label": "^2.1.8",
"@radix-ui/react-progress": "^1.1.8",
"@radix-ui/react-scroll-area": "^1.2.10",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.8",
"@radix-ui/react-slider": "^1.3.6",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-tooltip": "^1.2.8",
"axios": "^1.13.3",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"lucide-react": "^0.563.0",
"next-themes": "^0.4.6",
"react": "^19.2.0",
"react-dom": "^19.2.0",
"react-h5-audio-player": "^3.10.1",
"react-hook-form": "^7.71.1",
"react-router-dom": "^7.13.0",
"sonner": "^2.0.7",
"tailwind-merge": "^3.4.0",
"zod": "^4.3.6"
},
"devDependencies": {
"@eslint/js": "^9.39.1",
"@types/node": "^24.10.9",
"@types/react": "^19.2.5",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.1",
"autoprefixer": "^10.4.23",
"eslint": "^9.39.1",
"eslint-plugin-react-hooks": "^7.0.1",
"eslint-plugin-react-refresh": "^0.4.24",
"globals": "^16.5.0",
"postcss": "^8.5.6",
"tailwindcss": "^3.4.19",
"typescript": "~5.9.3",
"typescript-eslint": "^8.46.4",
"vite": "^7.2.4"
}
}

View File

@@ -0,0 +1,6 @@
export default {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="31.88" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 257"><defs><linearGradient id="IconifyId1813088fe1fbc01fb466" x1="-.828%" x2="57.636%" y1="7.652%" y2="78.411%"><stop offset="0%" stop-color="#41D1FF"></stop><stop offset="100%" stop-color="#BD34FE"></stop></linearGradient><linearGradient id="IconifyId1813088fe1fbc01fb467" x1="43.376%" x2="50.316%" y1="2.242%" y2="89.03%"><stop offset="0%" stop-color="#FFEA83"></stop><stop offset="8.333%" stop-color="#FFDD35"></stop><stop offset="100%" stop-color="#FFA800"></stop></linearGradient></defs><path fill="url(#IconifyId1813088fe1fbc01fb466)" d="M255.153 37.938L134.897 252.976c-2.483 4.44-8.862 4.466-11.382.048L.875 37.958c-2.746-4.814 1.371-10.646 6.827-9.67l120.385 21.517a6.537 6.537 0 0 0 2.322-.004l117.867-21.483c5.438-.991 9.574 4.796 6.877 9.62Z"></path><path fill="url(#IconifyId1813088fe1fbc01fb467)" d="M185.432.063L96.44 17.501a3.268 3.268 0 0 0-2.634 3.014l-5.474 92.456a3.268 3.268 0 0 0 3.997 3.378l24.777-5.718c2.318-.535 4.413 1.507 3.936 3.838l-7.361 36.047c-.495 2.426 1.782 4.5 4.151 3.78l15.304-4.649c2.372-.72 4.652 1.36 4.15 3.788l-11.698 56.621c-.732 3.542 3.979 5.473 5.943 2.437l1.313-2.028l72.516-144.72c1.215-2.423-.88-5.186-3.54-4.672l-25.505 4.922c-2.396.462-4.435-1.77-3.759-4.114l16.646-57.705c.677-2.35-1.37-4.583-3.769-4.113Z"></path></svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -0,0 +1,98 @@
import { lazy, Suspense } from 'react'
import { BrowserRouter, Routes, Route, Navigate } from 'react-router-dom'
import { Toaster } from 'sonner'
import { ThemeProvider } from '@/contexts/ThemeContext'
import { AuthProvider, useAuth } from '@/contexts/AuthContext'
import { AppProvider } from '@/contexts/AppContext'
import { JobProvider } from '@/contexts/JobContext'
import ErrorBoundary from '@/components/ErrorBoundary'
import LoadingScreen from '@/components/LoadingScreen'
import { SuperAdminRoute } from '@/components/SuperAdminRoute'
const Login = lazy(() => import('@/pages/Login'))
const Home = lazy(() => import('@/pages/Home'))
const UserManagement = lazy(() => import('@/pages/UserManagement'))
function ProtectedRoute({ children }: { children: React.ReactNode }) {
const { isAuthenticated, isLoading } = useAuth()
if (isLoading) {
return (
<div className="min-h-screen flex items-center justify-center">
<div className="text-lg">...</div>
</div>
)
}
if (!isAuthenticated) {
return <Navigate to="/login" replace />
}
return <>{children}</>
}
function PublicRoute({ children }: { children: React.ReactNode }) {
const { isAuthenticated, isLoading } = useAuth()
if (isLoading) {
return (
<div className="min-h-screen flex items-center justify-center">
<div className="text-lg">...</div>
</div>
)
}
if (isAuthenticated) {
return <Navigate to="/" replace />
}
return <>{children}</>
}
function App() {
return (
<ThemeProvider>
<ErrorBoundary>
<BrowserRouter>
<AuthProvider>
<Toaster position="top-right" />
<Suspense fallback={<LoadingScreen />}>
<Routes>
<Route
path="/login"
element={
<PublicRoute>
<Login />
</PublicRoute>
}
/>
<Route
path="/"
element={
<ProtectedRoute>
<AppProvider>
<JobProvider>
<Home />
</JobProvider>
</AppProvider>
</ProtectedRoute>
}
/>
<Route
path="/users"
element={
<SuperAdminRoute>
<UserManagement />
</SuperAdminRoute>
}
/>
</Routes>
</Suspense>
</AuthProvider>
</BrowserRouter>
</ErrorBoundary>
</ThemeProvider>
)
}
export default App

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>

After

Width:  |  Height:  |  Size: 4.0 KiB

View File

@@ -0,0 +1,44 @@
import { useState } from 'react'
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { Upload, Mic } from 'lucide-react'
import { FileUploader } from '@/components/FileUploader'
import { AudioRecorder } from '@/components/AudioRecorder'
interface AudioInputSelectorProps {
value: File | null
onChange: (file: File | null) => void
error?: string
}
export function AudioInputSelector({ value, onChange, error }: AudioInputSelectorProps) {
const [activeTab, setActiveTab] = useState<string>('upload')
const handleTabChange = (newTab: string) => {
onChange(null)
setActiveTab(newTab)
}
return (
<Tabs value={activeTab} onValueChange={handleTabChange} className="w-full">
<TabsList className="grid w-full grid-cols-2">
<TabsTrigger value="upload" className="flex items-center gap-2">
<Upload className="h-4 w-4" />
</TabsTrigger>
<TabsTrigger value="record" className="flex items-center gap-2">
<Mic className="h-4 w-4" />
</TabsTrigger>
</TabsList>
<TabsContent value="upload" className="mt-4">
<FileUploader value={value} onChange={onChange} error={error} />
</TabsContent>
<TabsContent value="record" className="mt-4">
<AudioRecorder onChange={onChange} />
{error && <p className="text-sm text-destructive mt-2">{error}</p>}
</TabsContent>
</Tabs>
)
}

View File

@@ -0,0 +1,102 @@
.audioPlayerWrapper {
display: flex;
align-items: center;
gap: 0.5rem;
border: 1px solid hsl(var(--border));
border-radius: var(--radius);
padding: 0.75rem;
background: transparent;
}
.audioPlayerWrapper :global(.rhap_container) {
flex: 1;
background-color: transparent;
box-shadow: none;
padding: 0;
}
.audioPlayerWrapper :global(.rhap_main) {
--rhap_theme-color: hsl(var(--primary));
--rhap_background-color: transparent;
--rhap_bar-color: hsl(var(--secondary));
--rhap_time-color: hsl(var(--muted-foreground));
}
.audioPlayerWrapper :global(.rhap_progress-indicator),
.audioPlayerWrapper :global(.rhap_volume-indicator) {
background: hsl(var(--primary));
}
.audioPlayerWrapper :global(.rhap_progress-filled),
.audioPlayerWrapper :global(.rhap_volume-bar) {
background-color: hsl(var(--primary));
}
.audioPlayerWrapper :global(.rhap_progress-bar),
.audioPlayerWrapper :global(.rhap_volume-container) {
background-color: hsl(var(--secondary));
}
.audioPlayerWrapper :global(.rhap_progress-bar) {
height: 6px;
border-radius: 3px;
transition: height 0.15s ease;
}
.audioPlayerWrapper :global(.rhap_progress-bar):hover {
height: 7px;
}
.audioPlayerWrapper :global(.rhap_progress-filled) {
border-radius: 3px;
}
.audioPlayerWrapper :global(.rhap_progress-indicator) {
width: 14px;
height: 14px;
top: -4px;
margin-left: -7px;
transition: transform 0.15s ease, box-shadow 0.15s ease;
}
.audioPlayerWrapper :global(.rhap_progress-indicator):hover {
transform: scale(1.1);
}
.audioPlayerWrapper :global(.rhap_progress-container) {
margin: 0 0.5rem;
}
.audioPlayerWrapper :global(.rhap_horizontal .rhap_controls-section) {
margin-left: 0;
}
.audioPlayerWrapper :global(.rhap_time) {
color: hsl(var(--muted-foreground));
font-size: 0.875rem;
font-weight: 500;
}
.audioPlayerWrapper :global(.rhap_button-clear) {
color: hsl(var(--foreground));
font-size: 1.25rem;
}
.audioPlayerWrapper :global(.rhap_button-clear):hover {
color: hsl(var(--primary));
}
.audioPlayerWrapper :global(.rhap_main-controls-button) {
width: 40px;
height: 40px;
}
.audioPlayerWrapper :global(.rhap_main-controls-button svg) {
width: 22px;
height: 22px;
}
.downloadButton {
min-height: 40px;
min-width: 40px;
}

View File

@@ -0,0 +1,121 @@
import { useRef, useState, useEffect, useCallback, memo } from 'react'
import AudioPlayerLib from 'react-h5-audio-player'
import 'react-h5-audio-player/lib/styles.css'
import { Button } from '@/components/ui/button'
import { Download } from 'lucide-react'
import apiClient from '@/lib/api'
import styles from './AudioPlayer.module.css'
interface AudioPlayerProps {
audioUrl: string
jobId: number
}
const AudioPlayer = memo(({ audioUrl, jobId }: AudioPlayerProps) => {
const [blobUrl, setBlobUrl] = useState<string>('')
const [isLoading, setIsLoading] = useState(false)
const [loadError, setLoadError] = useState<string | null>(null)
const previousAudioUrlRef = useRef<string>('')
const playerRef = useRef<any>(null)
useEffect(() => {
if (!audioUrl || audioUrl === previousAudioUrlRef.current) return
let active = true
const prevBlobUrl = blobUrl
const fetchAudio = async () => {
setIsLoading(true)
setLoadError(null)
if (prevBlobUrl) {
URL.revokeObjectURL(prevBlobUrl)
}
try {
const response = await apiClient.get(audioUrl, { responseType: 'blob' })
if (active) {
const url = URL.createObjectURL(response.data)
setBlobUrl(url)
previousAudioUrlRef.current = audioUrl
}
} catch (error) {
console.error("Failed to load audio:", error)
if (active) {
setLoadError('Failed to load audio')
}
} finally {
if (active) {
setIsLoading(false)
}
}
}
fetchAudio()
return () => {
active = false
}
}, [audioUrl])
useEffect(() => {
return () => {
if (blobUrl) URL.revokeObjectURL(blobUrl)
}
}, [])
const handleDownload = useCallback(() => {
const link = document.createElement('a')
link.href = blobUrl || audioUrl
link.download = `tts-${jobId}-${Date.now()}.wav`
link.click()
}, [blobUrl, audioUrl, jobId])
if (isLoading) {
return (
<div className="flex items-center justify-center p-4 border rounded-lg">
<span className="text-sm text-muted-foreground">Loading...</span>
</div>
)
}
if (loadError) {
return (
<div className="flex items-center justify-center p-4 border rounded-lg">
<span className="text-sm text-destructive">{loadError}</span>
</div>
)
}
if (!blobUrl) {
return null
}
return (
<div className={styles.audioPlayerWrapper}>
<AudioPlayerLib
src={blobUrl}
layout="horizontal"
customAdditionalControls={[
<Button
key="download"
type="button"
variant="ghost"
size="icon"
onClick={handleDownload}
className={styles.downloadButton}
>
<Download className="h-4 w-4" />
</Button>
]}
customVolumeControls={[]}
showJumpControls={false}
volume={1}
/>
</div>
)
})
AudioPlayer.displayName = 'AudioPlayer'
export { AudioPlayer }

View File

@@ -0,0 +1,153 @@
import { useEffect, useState } from 'react'
import { Button } from '@/components/ui/button'
import { Mic, Trash2, RotateCcw, FileAudio } from 'lucide-react'
import { toast } from 'sonner'
import { useAudioRecorder } from '@/hooks/useAudioRecorder'
import { useAudioValidation } from '@/hooks/useAudioValidation'
interface AudioRecorderProps {
onChange: (file: File | null) => void
}
export function AudioRecorder({ onChange }: AudioRecorderProps) {
const {
isRecording,
recordingDuration,
audioBlob,
error: recorderError,
isSupported,
startRecording,
stopRecording,
clearRecording,
} = useAudioRecorder()
const { validateAudioFile } = useAudioValidation()
const [audioInfo, setAudioInfo] = useState<{ duration: number; size: number } | null>(null)
const [validationError, setValidationError] = useState<string | null>(null)
useEffect(() => {
if (recorderError) {
toast.error(recorderError)
}
}, [recorderError])
useEffect(() => {
if (audioBlob) {
handleValidateRecording(audioBlob)
}
}, [audioBlob])
const handleValidateRecording = async (blob: Blob) => {
const file = new File([blob], 'recording.wav', { type: 'audio/wav' })
const result = await validateAudioFile(file)
if (result.valid && result.duration) {
onChange(file)
setAudioInfo({ duration: result.duration, size: file.size })
setValidationError(null)
} else {
setValidationError(result.error || '录音验证失败')
clearRecording()
onChange(null)
}
}
const handleMouseDown = () => {
if (!isRecording && !audioBlob) {
startRecording()
}
}
const handleMouseUp = () => {
if (isRecording) {
stopRecording()
}
}
const handleReset = () => {
clearRecording()
setAudioInfo(null)
setValidationError(null)
onChange(null)
}
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === ' ' && !isRecording && !audioBlob) {
e.preventDefault()
startRecording()
}
}
const handleKeyUp = (e: React.KeyboardEvent) => {
if (e.key === ' ' && isRecording) {
e.preventDefault()
stopRecording()
}
}
if (!isSupported) {
return (
<div className="p-4 border rounded bg-muted text-muted-foreground text-sm">
</div>
)
}
if (audioBlob && audioInfo) {
return (
<div className="space-y-2">
<div className="flex items-center gap-2 p-3 border rounded">
<FileAudio className="h-5 w-5 text-muted-foreground" />
<div className="flex-1 min-w-0">
<p className="text-sm font-medium"></p>
<p className="text-xs text-muted-foreground">
{(audioInfo.size / 1024 / 1024).toFixed(2)} MB · {audioInfo.duration.toFixed(1)}
</p>
</div>
<Button type="button" variant="ghost" size="icon" onClick={handleReset}>
<Trash2 className="h-4 w-4" />
</Button>
</div>
</div>
)
}
return (
<div className="space-y-2">
<Button
type="button"
variant={isRecording ? 'default' : 'outline'}
className={`w-full h-24 ${isRecording ? 'animate-pulse' : ''}`}
onMouseDown={handleMouseDown}
onMouseUp={handleMouseUp}
onMouseLeave={handleMouseUp}
onTouchStart={handleMouseDown}
onTouchEnd={handleMouseUp}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
>
<div className="flex flex-col items-center gap-2">
<Mic className="h-8 w-8" />
{isRecording ? (
<>
<span className="text-lg font-semibold">{recordingDuration.toFixed(1)}s</span>
<span className="text-xs"></span>
</>
) : (
<span></span>
)}
</div>
</Button>
{validationError && (
<div className="flex items-center justify-between p-2 border border-destructive rounded bg-destructive/10">
<p className="text-sm text-destructive">{validationError}</p>
<Button type="button" variant="ghost" size="sm" onClick={handleReset}>
<RotateCcw className="h-4 w-4" />
</Button>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,73 @@
import { Component, type ReactNode } from 'react';
interface Props {
children: ReactNode;
}
interface State {
hasError: boolean;
error: Error | null;
}
class ErrorBoundary extends Component<Props, State> {
constructor(props: Props) {
super(props);
this.state = { hasError: false, error: null };
}
static getDerivedStateFromError(error: Error): State {
return { hasError: true, error };
}
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
console.error('ErrorBoundary caught error:', error, errorInfo);
}
handleReset = () => {
this.setState({ hasError: false, error: null });
};
render() {
if (this.state.hasError) {
return (
<div className="flex items-center justify-center min-h-screen bg-background p-4">
<div className="max-w-md w-full space-y-4 text-center">
<div className="space-y-2">
<h1 className="text-2xl font-bold text-destructive">Something went wrong</h1>
<p className="text-muted-foreground">
An unexpected error occurred. Please try refreshing the page.
</p>
</div>
{this.state.error && (
<div className="p-4 bg-muted rounded-lg text-left">
<p className="text-sm font-mono text-destructive break-all">
{this.state.error.message}
</p>
</div>
)}
<div className="flex gap-2 justify-center">
<button
onClick={this.handleReset}
className="px-4 py-2 bg-primary text-primary-foreground rounded-md hover:bg-primary/90 transition-colors"
>
Try Again
</button>
<button
onClick={() => window.location.reload()}
className="px-4 py-2 bg-secondary text-secondary-foreground rounded-md hover:bg-secondary/90 transition-colors"
>
Reload Page
</button>
</div>
</div>
</div>
);
}
return this.props.children;
}
}
export default ErrorBoundary;

View File

@@ -0,0 +1,89 @@
import { useRef, useState, type ChangeEvent } from 'react'
import { Button } from '@/components/ui/button'
import { Upload, X, FileAudio } from 'lucide-react'
import { toast } from 'sonner'
import { useAudioValidation } from '@/hooks/useAudioValidation'
interface AudioInfo {
duration: number
size: number
}
interface FileUploaderProps {
value: File | null
onChange: (file: File | null) => void
error?: string
}
export function FileUploader({ value, onChange, error }: FileUploaderProps) {
const inputRef = useRef<HTMLInputElement>(null)
const { validateAudioFile } = useAudioValidation()
const [isValidating, setIsValidating] = useState(false)
const [audioInfo, setAudioInfo] = useState<AudioInfo | null>(null)
const handleFileSelect = async (e: ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0]
if (!file) return
setIsValidating(true)
const result = await validateAudioFile(file)
setIsValidating(false)
if (result.valid && result.duration) {
onChange(file)
setAudioInfo({ duration: result.duration, size: file.size })
} else {
toast.error(result.error || '文件验证失败')
e.target.value = ''
}
}
const handleRemove = () => {
onChange(null)
setAudioInfo(null)
if (inputRef.current) {
inputRef.current.value = ''
}
}
return (
<div className="space-y-2">
{!value ? (
<Button
type="button"
variant="outline"
onClick={() => inputRef.current?.click()}
disabled={isValidating}
>
<Upload className="mr-2 h-4 w-4" />
{isValidating ? '验证中...' : '选择音频文件'}
</Button>
) : (
<div className="flex items-center gap-2 p-3 border rounded">
<FileAudio className="h-5 w-5 text-muted-foreground" />
<div className="flex-1 min-w-0">
<p className="text-sm font-medium truncate">{value.name}</p>
{audioInfo && (
<p className="text-xs text-muted-foreground">
{(audioInfo.size / 1024 / 1024).toFixed(2)} MB · {audioInfo.duration.toFixed(1)}
</p>
)}
</div>
<Button type="button" variant="ghost" size="icon" onClick={handleRemove}>
<X className="h-4 w-4" />
</Button>
</div>
)}
<input
ref={inputRef}
type="file"
accept="audio/wav,audio/mp3,audio/mpeg"
className="hidden"
onChange={handleFileSelect}
/>
{error && <p className="text-sm text-destructive">{error}</p>}
</div>
)
}

View File

@@ -0,0 +1,29 @@
const FormSkeleton = () => {
return (
<div className="space-y-6 animate-pulse">
<div className="space-y-2">
<div className="h-4 bg-muted rounded w-24" />
<div className="h-10 bg-muted rounded" />
</div>
<div className="space-y-2">
<div className="h-4 bg-muted rounded w-32" />
<div className="h-10 bg-muted rounded" />
</div>
<div className="space-y-2">
<div className="h-4 bg-muted rounded w-28" />
<div className="h-32 bg-muted rounded" />
</div>
<div className="space-y-2">
<div className="h-4 bg-muted rounded w-36" />
<div className="h-10 bg-muted rounded" />
</div>
<div className="h-10 bg-muted rounded w-full" />
</div>
);
};
export default FormSkeleton;

View File

@@ -0,0 +1,172 @@
import { memo, useState } from 'react'
import type { Job } from '@/types/job'
import { Badge } from '@/components/ui/badge'
import { Button } from '@/components/ui/button'
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
AlertDialogTrigger,
} from '@/components/ui/alert-dialog'
import { Trash2, AlertCircle, Loader2, Clock, Eye } from 'lucide-react'
import { getRelativeTime, cn } from '@/lib/utils'
import { JobDetailDialog } from '@/components/JobDetailDialog'
interface HistoryItemProps {
job: Job
onDelete: (id: number) => void
onLoadParams: (job: Job) => void
}
const jobTypeBadgeVariant = {
custom_voice: 'default' as const,
voice_design: 'secondary' as const,
voice_clone: 'outline' as const,
}
const jobTypeLabel = {
custom_voice: '自定义音色',
voice_design: '音色设计',
voice_clone: '声音克隆',
}
const HistoryItem = memo(({ job, onDelete, onLoadParams }: HistoryItemProps) => {
const [detailDialogOpen, setDetailDialogOpen] = useState(false)
const getLanguageDisplay = (lang: string | undefined) => {
if (!lang || lang === 'Auto') return '自动检测'
return lang
}
const handleCardClick = (e: React.MouseEvent) => {
if ((e.target as HTMLElement).closest('button')) return
setDetailDialogOpen(true)
}
return (
<div
className={cn(
"relative border rounded-lg p-4 pb-14 space-y-3 hover:bg-accent/50 transition-colors cursor-pointer",
job.status === 'failed' && "border-destructive/50"
)}
onClick={handleCardClick}
>
<div className="flex items-start justify-between gap-2">
<Badge variant={jobTypeBadgeVariant[job.type]}>
{jobTypeLabel[job.type]}
</Badge>
<div className="flex items-center gap-1.5 text-xs text-muted-foreground whitespace-nowrap">
<span>{getRelativeTime(job.created_at)}</span>
<Eye className="w-3.5 h-3.5" />
</div>
</div>
<div className="space-y-2 text-sm">
{job.parameters?.text && (
<div>
<span className="text-muted-foreground">: </span>
<span className="line-clamp-2">{job.parameters.text}</span>
</div>
)}
<div className="text-muted-foreground">
: {getLanguageDisplay(job.parameters?.language)}
</div>
{job.type === 'custom_voice' && job.parameters?.speaker && (
<div className="text-muted-foreground">
: {job.parameters.speaker}
</div>
)}
{job.type === 'voice_design' && job.parameters?.instruct && (
<div>
<span className="text-muted-foreground">: </span>
<span className="text-xs line-clamp-2">{job.parameters.instruct}</span>
</div>
)}
{job.type === 'voice_clone' && job.parameters?.ref_text && (
<div>
<span className="text-muted-foreground">: </span>
<span className="text-xs line-clamp-1">{job.parameters.ref_text}</span>
</div>
)}
</div>
{job.status === 'processing' && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<Loader2 className="w-4 h-4 animate-spin" />
<span>...</span>
</div>
)}
{job.status === 'pending' && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<Clock className="w-4 h-4" />
<span>...</span>
</div>
)}
{job.status === 'failed' && job.error_message && (
<div className="flex items-start gap-2 p-2 bg-destructive/10 rounded-md">
<AlertCircle className="w-4 h-4 text-destructive mt-0.5 shrink-0" />
<span className="text-sm text-destructive">{job.error_message}</span>
</div>
)}
<div className="absolute bottom-3 right-3">
<AlertDialog>
<AlertDialogTrigger asChild>
<Button
variant="destructive"
size="sm"
className="min-h-[44px] md:min-h-[36px]"
>
<Trash2 className="w-4 h-4" />
</Button>
</AlertDialogTrigger>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle></AlertDialogTitle>
<AlertDialogDescription>
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel></AlertDialogCancel>
<AlertDialogAction
onClick={() => onDelete(job.id)}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</div>
<JobDetailDialog
job={job}
open={detailDialogOpen}
onOpenChange={setDetailDialogOpen}
/>
</div>
)
}, (prevProps, nextProps) => {
return (
prevProps.job.id === nextProps.job.id &&
prevProps.job.status === nextProps.job.status &&
prevProps.job.updated_at === nextProps.job.updated_at &&
prevProps.job.error_message === nextProps.job.error_message
)
})
HistoryItem.displayName = 'HistoryItem'
export { HistoryItem }

View File

@@ -0,0 +1,113 @@
import { useRef, useEffect } from 'react'
import { useHistory } from '@/hooks/useHistory'
import { HistoryItem } from '@/components/HistoryItem'
import { ScrollArea } from '@/components/ui/scroll-area'
import { Sheet, SheetContent } from '@/components/ui/sheet'
import { Button } from '@/components/ui/button'
import { Loader2, FileAudio, RefreshCw } from 'lucide-react'
import type { JobType } from '@/types/job'
import { toast } from 'sonner'
interface HistorySidebarProps {
open: boolean
onOpenChange: (open: boolean) => void
onLoadParams: (jobId: number, jobType: JobType) => Promise<void>
}
function HistorySidebarContent({ onLoadParams }: Pick<HistorySidebarProps, 'onLoadParams'>) {
const { jobs, loading, loadingMore, hasMore, loadMore, deleteJob, error, retry } = useHistory()
const observerTarget = useRef<HTMLDivElement>(null)
useEffect(() => {
const observer = new IntersectionObserver(
(entries) => {
if (entries[0].isIntersecting && hasMore && !loadingMore) {
loadMore()
}
},
{ threshold: 0.5 }
)
if (observerTarget.current) {
observer.observe(observerTarget.current)
}
return () => observer.disconnect()
}, [hasMore, loadingMore, loadMore])
const handleLoadParams = async (jobId: number, jobType: JobType) => {
try {
await onLoadParams(jobId, jobType)
} catch (error) {
toast.error('加载参数失败')
}
}
return (
<div className="flex flex-col h-full">
<div className="p-4 border-b">
<h2 className="text-lg font-semibold"></h2>
<p className="text-sm text-muted-foreground"> {jobs.length} </p>
</div>
<ScrollArea className="flex-1">
<div className="p-4 space-y-4">
{loading ? (
<div className="flex items-center justify-center py-8">
<Loader2 className="w-6 h-6 animate-spin text-muted-foreground" />
</div>
) : error ? (
<div className="flex flex-col items-center justify-center py-8 space-y-4">
<p className="text-sm text-destructive text-center">{error}</p>
<Button onClick={retry} variant="outline" size="sm">
<RefreshCw className="w-4 h-4 mr-2" />
</Button>
</div>
) : jobs.length === 0 ? (
<div className="flex flex-col items-center justify-center py-12 space-y-3">
<FileAudio className="w-12 h-12 text-muted-foreground/50" />
<p className="text-sm font-medium text-muted-foreground"></p>
<p className="text-xs text-muted-foreground text-center">
</p>
</div>
) : (
<>
{jobs.map((job) => (
<HistoryItem
key={job.id}
job={job}
onDelete={deleteJob}
onLoadParams={(job) => handleLoadParams(job.id, job.type)}
/>
))}
{hasMore && (
<div ref={observerTarget} className="py-4 flex justify-center">
<Loader2 className="w-5 h-5 animate-spin text-muted-foreground" />
</div>
)}
</>
)}
</div>
</ScrollArea>
</div>
)
}
export function HistorySidebar({ open, onOpenChange, onLoadParams }: HistorySidebarProps) {
return (
<>
<aside className="hidden lg:block w-[320px] border-r h-[calc(100vh-64px)]">
<HistorySidebarContent onLoadParams={onLoadParams} />
</aside>
<Sheet open={open} onOpenChange={onOpenChange}>
<SheetContent side="left" className="w-full sm:max-w-md p-0">
<HistorySidebarContent onLoadParams={onLoadParams} />
</SheetContent>
</Sheet>
</>
)
}

View File

@@ -0,0 +1,230 @@
import { memo } from 'react'
import type { Job } from '@/types/job'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog'
import { Badge } from '@/components/ui/badge'
import { Separator } from '@/components/ui/separator'
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from '@/components/ui/collapsible'
import { ScrollArea } from '@/components/ui/scroll-area'
import { AudioPlayer } from '@/components/AudioPlayer'
import { ChevronDown, AlertCircle } from 'lucide-react'
import { jobApi } from '@/lib/api'
interface JobDetailDialogProps {
job: Job | null
open: boolean
onOpenChange: (open: boolean) => void
}
const jobTypeBadgeVariant = {
custom_voice: 'default' as const,
voice_design: 'secondary' as const,
voice_clone: 'outline' as const,
}
const jobTypeLabel = {
custom_voice: '自定义音色',
voice_design: '音色设计',
voice_clone: '声音克隆',
}
const formatTimestamp = (timestamp: string) => {
return new Date(timestamp).toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
})
}
const getLanguageDisplay = (lang: string | undefined) => {
if (!lang || lang === 'Auto') return '自动检测'
return lang
}
const formatBooleanDisplay = (value: boolean | undefined) => {
return value ? '是' : '否'
}
const JobDetailDialog = memo(({ job, open, onOpenChange }: JobDetailDialogProps) => {
if (!job) return null
const canPlay = job.status === 'completed'
const audioUrl = canPlay ? jobApi.getAudioUrl(job.id, job.audio_url) : ''
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="sm:max-w-2xl max-h-[90vh] bg-background">
<DialogHeader>
<div className="flex items-center justify-between gap-3">
<DialogTitle className="flex items-center gap-2">
<Badge variant={jobTypeBadgeVariant[job.type]}>
{jobTypeLabel[job.type]}
</Badge>
<span className="text-sm text-muted-foreground">#{job.id}</span>
</DialogTitle>
<span className="text-sm text-muted-foreground">
{formatTimestamp(job.created_at)}
</span>
</div>
</DialogHeader>
<ScrollArea className="max-h-[calc(90vh-120px)] pr-4">
<div className="space-y-4">
<div className="space-y-2">
<h3 className="font-semibold text-sm"></h3>
<div className="space-y-1.5 text-sm bg-muted/30 p-3 rounded-lg">
{job.type === 'custom_voice' && job.parameters?.speaker && (
<div>
<span className="text-muted-foreground">: </span>
<span>{job.parameters.speaker}</span>
</div>
)}
<div>
<span className="text-muted-foreground">: </span>
<span>{getLanguageDisplay(job.parameters?.language)}</span>
</div>
{job.type === 'voice_clone' && (
<>
<div>
<span className="text-muted-foreground">: </span>
<span>{formatBooleanDisplay(job.parameters?.x_vector_only_mode)}</span>
</div>
<div>
<span className="text-muted-foreground">使: </span>
<span>{formatBooleanDisplay(job.parameters?.use_cache)}</span>
</div>
</>
)}
</div>
</div>
<Separator />
<div className="space-y-2">
<h3 className="font-semibold text-sm"></h3>
<div className="text-sm bg-muted/30 p-3 rounded-lg border">
{job.parameters?.text || <span className="text-muted-foreground"></span>}
</div>
</div>
{job.type === 'voice_design' && job.parameters?.instruct && (
<>
<Separator />
<div className="space-y-2">
<h3 className="font-semibold text-sm"></h3>
<div className="text-sm bg-blue-50 dark:bg-blue-950/30 p-3 rounded-lg border border-blue-200 dark:border-blue-800">
{job.parameters.instruct}
</div>
</div>
</>
)}
{job.type === 'custom_voice' && job.parameters?.instruct && (
<>
<Separator />
<div className="space-y-2">
<h3 className="font-semibold text-sm"></h3>
<div className="text-sm bg-muted/30 p-3 rounded-lg border">
{job.parameters.instruct}
</div>
</div>
</>
)}
{job.type === 'voice_clone' && (
<>
<Separator />
<div className="space-y-2">
<h3 className="font-semibold text-sm"></h3>
<div className="text-sm bg-muted/30 p-3 rounded-lg border">
{job.parameters?.ref_text || <span className="text-muted-foreground"></span>}
</div>
</div>
</>
)}
<Separator />
<Collapsible>
<CollapsibleTrigger className="flex items-center gap-2 text-sm font-semibold hover:text-foreground transition-colors w-full">
<ChevronDown className="w-4 h-4 transition-transform ui-expanded:rotate-180" />
</CollapsibleTrigger>
<CollapsibleContent className="pt-3">
<div className="space-y-1.5 text-sm bg-muted/30 p-3 rounded-lg border">
{job.parameters?.max_new_tokens !== undefined && (
<div>
<span className="text-muted-foreground">: </span>
<span>{job.parameters.max_new_tokens}</span>
</div>
)}
{job.parameters?.temperature !== undefined && (
<div>
<span className="text-muted-foreground">: </span>
<span>{job.parameters.temperature}</span>
</div>
)}
{job.parameters?.top_k !== undefined && (
<div>
<span className="text-muted-foreground">Top K: </span>
<span>{job.parameters.top_k}</span>
</div>
)}
{job.parameters?.top_p !== undefined && (
<div>
<span className="text-muted-foreground">Top P: </span>
<span>{job.parameters.top_p}</span>
</div>
)}
{job.parameters?.repetition_penalty !== undefined && (
<div>
<span className="text-muted-foreground">: </span>
<span>{job.parameters.repetition_penalty}</span>
</div>
)}
</div>
</CollapsibleContent>
</Collapsible>
{job.status === 'failed' && job.error_message && (
<>
<Separator />
<div className="flex items-start gap-2 p-3 bg-red-50 dark:bg-red-950/30 rounded-lg border border-red-200 dark:border-red-800">
<AlertCircle className="w-4 h-4 text-destructive mt-0.5 shrink-0" />
<div>
<h3 className="font-semibold text-sm text-destructive mb-1"></h3>
<p className="text-sm text-destructive">{job.error_message}</p>
</div>
</div>
</>
)}
{canPlay && (
<>
<Separator />
<div className="space-y-2">
<h3 className="font-semibold text-sm"></h3>
<AudioPlayer audioUrl={audioUrl} jobId={job.id} />
</div>
</>
)}
</div>
</ScrollArea>
</DialogContent>
</Dialog>
)
})
JobDetailDialog.displayName = 'JobDetailDialog'
export { JobDetailDialog }

View File

@@ -0,0 +1,12 @@
const LoadingScreen = () => {
return (
<div className="flex items-center justify-center min-h-screen bg-background">
<div className="flex flex-col items-center gap-4">
<div className="w-12 h-12 border-4 border-primary border-t-transparent rounded-full animate-spin" />
<p className="text-sm text-muted-foreground">Loading...</p>
</div>
</div>
);
};
export default LoadingScreen;

View File

@@ -0,0 +1,24 @@
import { memo } from 'react'
interface LoadingStateProps {
elapsedTime: number
}
const LoadingState = memo(({ elapsedTime }: LoadingStateProps) => {
const displayText = elapsedTime > 60
? '生成用时较长,请耐心等待...'
: '正在生成音频,请稍候...'
return (
<div className="space-y-4 py-6">
<p className="text-center text-muted-foreground">{displayText}</p>
<p className="text-center text-sm text-muted-foreground">
{elapsedTime}
</p>
</div>
)
})
LoadingState.displayName = 'LoadingState'
export { LoadingState }

View File

@@ -0,0 +1,50 @@
import { Menu, LogOut, Users } from 'lucide-react'
import { Link } from 'react-router-dom'
import { Button } from '@/components/ui/button'
import { ThemeToggle } from '@/components/ThemeToggle'
import { useAuth } from '@/contexts/AuthContext'
interface NavbarProps {
onToggleSidebar?: () => void
}
export function Navbar({ onToggleSidebar }: NavbarProps) {
const { logout, user } = useAuth()
return (
<nav className="h-16 border-b bg-background flex items-center px-4 gap-4">
{onToggleSidebar && (
<Button
variant="ghost"
size="icon"
onClick={onToggleSidebar}
className="lg:hidden"
>
<Menu className="h-5 w-5" />
</Button>
)}
<div className="flex-1">
<Link to="/">
<h1 className="text-sm md:text-xl font-bold cursor-pointer hover:opacity-80 transition-opacity">
Qwen3-TTS-WebUI
</h1>
</Link>
</div>
<div className="flex items-center gap-2">
{user?.is_superuser && (
<Link to="/users">
<Button variant="ghost" size="icon">
<Users className="h-5 w-5" />
</Button>
</Link>
)}
<ThemeToggle />
<Button variant="ghost" size="icon" onClick={logout}>
<LogOut className="h-5 w-5" />
</Button>
</div>
</nav>
)
}

View File

@@ -0,0 +1,55 @@
import { Label } from '@/components/ui/label'
import { Input } from '@/components/ui/input'
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip'
import { HelpCircle } from 'lucide-react'
import type { UseFormRegister, FieldValues, Path } from 'react-hook-form'
interface ParamInputProps<T extends FieldValues> {
name: Path<T>
label: string
description: string
tooltip: string
register: UseFormRegister<T>
type?: 'number'
step?: number
min?: number
max?: number
}
export function ParamInput<T extends FieldValues>({
name,
label,
description,
tooltip,
register,
type = 'number',
step,
min,
max,
}: ParamInputProps<T>) {
return (
<div className="space-y-2">
<div className="flex items-center gap-2">
<Label htmlFor={name}>{label}</Label>
<TooltipProvider>
<Tooltip>
<TooltipTrigger type="button" asChild>
<HelpCircle className="h-4 w-4 text-muted-foreground cursor-help" />
</TooltipTrigger>
<TooltipContent side="right" className="max-w-xs">
<p className="text-sm">{tooltip}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<Input
{...register(name, { valueAsNumber: type === 'number' })}
type={type}
step={step}
min={min}
max={max}
/>
<p className="text-xs text-muted-foreground md:hidden">{description}</p>
</div>
)
}

View File

@@ -0,0 +1,37 @@
import { memo, useMemo } from 'react'
import { Button } from '@/components/ui/button'
interface Preset {
label: string
[key: string]: any
}
interface PresetSelectorProps<T extends Preset> {
presets: readonly T[]
onSelect: (preset: T) => void
}
const PresetSelectorInner = <T extends Preset>({ presets, onSelect }: PresetSelectorProps<T>) => {
const presetButtons = useMemo(() => {
return presets.map((preset, index) => (
<Button
key={`${preset.label}-${index}`}
type="button"
variant="outline"
size="sm"
onClick={() => onSelect(preset)}
className="text-xs md:text-sm px-2.5 md:px-3 h-7 md:h-8"
>
{preset.label}
</Button>
))
}, [presets, onSelect])
return (
<div className="flex flex-wrap gap-1.5 md:gap-2 mt-1.5 md:mt-2">
{presetButtons}
</div>
)
}
export const PresetSelector = memo(PresetSelectorInner) as typeof PresetSelectorInner

View File

@@ -0,0 +1,21 @@
import { Navigate } from 'react-router-dom'
import { useAuth } from '@/contexts/AuthContext'
import LoadingScreen from '@/components/LoadingScreen'
export function SuperAdminRoute({ children }: { children: React.ReactNode }) {
const { isAuthenticated, isLoading, user } = useAuth()
if (isLoading) {
return <LoadingScreen />
}
if (!isAuthenticated) {
return <Navigate to="/login" replace />
}
if (!user?.is_superuser) {
return <Navigate to="/" replace />
}
return <>{children}</>
}

View File

@@ -0,0 +1,17 @@
import { Sun, Moon } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { useTheme } from '@/contexts/ThemeContext'
export function ThemeToggle() {
const { theme, toggleTheme } = useTheme()
return (
<Button variant="ghost" size="icon" onClick={toggleTheme}>
{theme === 'light' ? (
<Sun className="h-5 w-5" />
) : (
<Moon className="h-5 w-5" />
)}
</Button>
)
}

View File

@@ -0,0 +1,276 @@
import { useForm } from 'react-hook-form'
import { zodResolver } from '@hookform/resolvers/zod'
import * as z from 'zod'
import { useEffect, useState, forwardRef, useImperativeHandle, useMemo } from 'react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Textarea } from '@/components/ui/textarea'
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '@/components/ui/collapsible'
import { Label } from '@/components/ui/label'
import { ChevronDown } from 'lucide-react'
import { toast } from 'sonner'
import { ttsApi, jobApi } from '@/lib/api'
import { useJobPolling } from '@/hooks/useJobPolling'
import { LoadingState } from '@/components/LoadingState'
import { AudioPlayer } from '@/components/AudioPlayer'
import { PresetSelector } from '@/components/PresetSelector'
import { ParamInput } from '@/components/ParamInput'
import { PRESET_INSTRUCTS, ADVANCED_PARAMS_INFO } from '@/lib/constants'
import type { Language, Speaker } from '@/types/tts'
const formSchema = z.object({
text: z.string().min(1, '请输入要合成的文本').max(5000, '文本长度不能超过 5000 字符'),
language: z.string().min(1, '请选择语言'),
speaker: z.string().min(1, '请选择发音人'),
instruct: z.string().optional(),
max_new_tokens: z.number().min(1).max(10000).optional(),
temperature: z.number().min(0).max(2).optional(),
top_k: z.number().min(1).max(100).optional(),
top_p: z.number().min(0).max(1).optional(),
repetition_penalty: z.number().min(0).max(2).optional(),
})
type FormData = z.infer<typeof formSchema>
export interface CustomVoiceFormHandle {
loadParams: (params: any) => void
}
const CustomVoiceForm = forwardRef<CustomVoiceFormHandle>((_props, ref) => {
const [languages, setLanguages] = useState<Language[]>([])
const [speakers, setSpeakers] = useState<Speaker[]>([])
const [isLoading, setIsLoading] = useState(false)
const [advancedOpen, setAdvancedOpen] = useState(false)
const { currentJob, isPolling, isCompleted, startPolling, elapsedTime } = useJobPolling()
const {
register,
handleSubmit,
setValue,
watch,
formState: { errors },
} = useForm<FormData>({
resolver: zodResolver(formSchema),
defaultValues: {
text: '',
language: 'Auto',
speaker: '',
instruct: '',
max_new_tokens: 2048,
temperature: 0.3,
top_k: 20,
top_p: 0.7,
repetition_penalty: 1.05,
},
})
useImperativeHandle(ref, () => ({
loadParams: (params: any) => {
setValue('text', params.text || '')
setValue('language', params.language || 'Auto')
setValue('speaker', params.speaker || '')
setValue('instruct', params.instruct || '')
setValue('max_new_tokens', params.max_new_tokens || 2048)
setValue('temperature', params.temperature || 0.3)
setValue('top_k', params.top_k || 20)
setValue('top_p', params.top_p || 0.7)
setValue('repetition_penalty', params.repetition_penalty || 1.05)
}
}))
useEffect(() => {
const fetchData = async () => {
try {
const [langs, spks] = await Promise.all([
ttsApi.getLanguages(),
ttsApi.getSpeakers(),
])
setLanguages(langs)
setSpeakers(spks)
} catch (error) {
toast.error('加载数据失败')
}
}
fetchData()
}, [])
const onSubmit = async (data: FormData) => {
setIsLoading(true)
try {
const result = await ttsApi.createCustomVoiceJob(data)
toast.success('任务已创建')
startPolling(result.job_id)
} catch (error) {
toast.error('创建任务失败')
} finally {
setIsLoading(false)
}
}
const memoizedAudioUrl = useMemo(() => {
if (!currentJob) return ''
return jobApi.getAudioUrl(currentJob.id, currentJob.audio_url)
}, [currentJob?.id, currentJob?.audio_url])
return (
<form onSubmit={handleSubmit(onSubmit)} className="space-y-4 md:space-y-6">
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="language"></Label>
<Select
value={watch('language')}
onValueChange={(value: string) => setValue('language', value)}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{languages.map((lang) => (
<SelectItem key={lang.code} value={lang.code}>
{lang.name}
</SelectItem>
))}
</SelectContent>
</Select>
{errors.language && (
<p className="text-sm text-destructive">{errors.language.message}</p>
)}
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="speaker"></Label>
<Select
value={watch('speaker')}
onValueChange={(value: string) => setValue('speaker', value)}
>
<SelectTrigger>
<SelectValue placeholder="选择发音人" />
</SelectTrigger>
<SelectContent>
{speakers.map((speaker) => (
<SelectItem key={speaker.name} value={speaker.name}>
{speaker.name} - {speaker.description}
</SelectItem>
))}
</SelectContent>
</Select>
{errors.speaker && (
<p className="text-sm text-destructive">{errors.speaker.message}</p>
)}
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="text"></Label>
<Textarea
{...register('text')}
placeholder="输入要合成的文本..."
rows={2}
className="min-h-[60px] md:min-h-[96px]"
/>
{errors.text && (
<p className="text-sm text-destructive">{errors.text.message}</p>
)}
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="instruct"></Label>
<Textarea
{...register('instruct')}
placeholder="例如:温柔体贴,语速平缓,充满关怀"
rows={2}
className="min-h-[60px] md:min-h-[80px]"
/>
<PresetSelector
presets={PRESET_INSTRUCTS}
onSelect={(preset) => {
setValue('instruct', preset.instruct)
if (preset.text) {
setValue('text', preset.text)
}
}}
/>
{errors.instruct && (
<p className="text-sm text-destructive">{errors.instruct.message}</p>
)}
</div>
<Collapsible open={advancedOpen} onOpenChange={setAdvancedOpen}>
<CollapsibleTrigger asChild>
<Button type="button" variant="ghost" className="w-full">
<ChevronDown className="ml-2 h-4 w-4" />
</Button>
</CollapsibleTrigger>
<CollapsibleContent className="space-y-3 md:space-y-4 pt-3 md:pt-4">
<ParamInput
name="max_new_tokens"
label={ADVANCED_PARAMS_INFO.max_new_tokens.label}
description={ADVANCED_PARAMS_INFO.max_new_tokens.description}
tooltip={ADVANCED_PARAMS_INFO.max_new_tokens.tooltip}
register={register}
min={1}
max={10000}
/>
<ParamInput
name="temperature"
label={ADVANCED_PARAMS_INFO.temperature.label}
description={ADVANCED_PARAMS_INFO.temperature.description}
tooltip={ADVANCED_PARAMS_INFO.temperature.tooltip}
register={register}
step={0.1}
min={0}
max={2}
/>
<ParamInput
name="top_k"
label={ADVANCED_PARAMS_INFO.top_k.label}
description={ADVANCED_PARAMS_INFO.top_k.description}
tooltip={ADVANCED_PARAMS_INFO.top_k.tooltip}
register={register}
min={1}
max={100}
/>
<ParamInput
name="top_p"
label={ADVANCED_PARAMS_INFO.top_p.label}
description={ADVANCED_PARAMS_INFO.top_p.description}
tooltip={ADVANCED_PARAMS_INFO.top_p.tooltip}
register={register}
step={0.1}
min={0}
max={1}
/>
<ParamInput
name="repetition_penalty"
label={ADVANCED_PARAMS_INFO.repetition_penalty.label}
description={ADVANCED_PARAMS_INFO.repetition_penalty.description}
tooltip={ADVANCED_PARAMS_INFO.repetition_penalty.tooltip}
register={register}
step={0.01}
min={0}
max={2}
/>
</CollapsibleContent>
</Collapsible>
<Button type="submit" className="w-full" disabled={isLoading || isPolling}>
{isLoading ? '创建中...' : '生成语音'}
</Button>
{isPolling && <LoadingState elapsedTime={elapsedTime} />}
{isCompleted && currentJob && (
<div className="space-y-4 pt-4 border-t">
<AudioPlayer
audioUrl={memoizedAudioUrl}
jobId={currentJob.id}
/>
</div>
)}
</form>
)
})
export default CustomVoiceForm

View File

@@ -0,0 +1,247 @@
import { useForm, Controller } from 'react-hook-form'
import { zodResolver } from '@hookform/resolvers/zod'
import * as z from 'zod'
import { useEffect, useState, useMemo } from 'react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Textarea } from '@/components/ui/textarea'
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '@/components/ui/collapsible'
import { Checkbox } from '@/components/ui/checkbox'
import { Label } from '@/components/ui/label'
import { ChevronDown } from 'lucide-react'
import { toast } from 'sonner'
import { ttsApi, jobApi } from '@/lib/api'
import { useJobPolling } from '@/hooks/useJobPolling'
import { LoadingState } from '@/components/LoadingState'
import { AudioPlayer } from '@/components/AudioPlayer'
import { AudioInputSelector } from '@/components/AudioInputSelector'
import { PresetSelector } from '@/components/PresetSelector'
import { ParamInput } from '@/components/ParamInput'
import { PRESET_REF_TEXTS, ADVANCED_PARAMS_INFO } from '@/lib/constants'
import type { Language } from '@/types/tts'
const formSchema = z.object({
text: z.string().min(1, '请输入要合成的文本').max(5000, '文本长度不能超过 5000 字符'),
language: z.string().optional(),
ref_audio: z.instanceof(File, { message: '请上传参考音频' }),
ref_text: z.string().optional(),
use_cache: z.boolean().optional(),
x_vector_only_mode: z.boolean().optional(),
max_new_tokens: z.number().min(1).max(10000).optional(),
temperature: z.number().min(0).max(2).optional(),
top_k: z.number().min(1).max(100).optional(),
top_p: z.number().min(0).max(1).optional(),
repetition_penalty: z.number().min(0).max(2).optional(),
})
type FormData = z.infer<typeof formSchema>
function VoiceCloneForm() {
const [languages, setLanguages] = useState<Language[]>([])
const [isLoading, setIsLoading] = useState(false)
const [advancedOpen, setAdvancedOpen] = useState(false)
const { currentJob, isPolling, isCompleted, startPolling, elapsedTime } = useJobPolling()
const {
register,
handleSubmit,
setValue,
watch,
control,
formState: { errors },
} = useForm<FormData>({
resolver: zodResolver(formSchema),
defaultValues: {
text: '',
language: 'Auto',
ref_text: '',
use_cache: true,
x_vector_only_mode: false,
max_new_tokens: 2048,
temperature: 0.3,
top_k: 20,
top_p: 0.7,
repetition_penalty: 1.05,
} as Partial<FormData>,
})
useEffect(() => {
const fetchData = async () => {
try {
const langs = await ttsApi.getLanguages()
setLanguages(langs)
} catch (error) {
toast.error('加载数据失败')
}
}
fetchData()
}, [])
const onSubmit = async (data: FormData) => {
setIsLoading(true)
try {
const result = await ttsApi.createVoiceCloneJob({
...data,
ref_audio: data.ref_audio,
})
toast.success('任务已创建')
startPolling(result.job_id)
} catch (error) {
toast.error('创建任务失败')
} finally {
setIsLoading(false)
}
}
const memoizedAudioUrl = useMemo(() => {
if (!currentJob) return ''
return jobApi.getAudioUrl(currentJob.id, currentJob.audio_url)
}, [currentJob?.id, currentJob?.audio_url])
return (
<form onSubmit={handleSubmit(onSubmit)} className="space-y-4 md:space-y-6">
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="ref_text">稿</Label>
<Textarea
{...register('ref_text')}
placeholder="参考音频对应的文本..."
rows={2}
className="min-h-[60px] md:min-h-[80px]"
/>
<PresetSelector
presets={PRESET_REF_TEXTS}
onSelect={(preset) => setValue('ref_text', preset.text)}
/>
{errors.ref_text && (
<p className="text-sm text-destructive">{errors.ref_text.message}</p>
)}
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="ref_audio"></Label>
<Controller
name="ref_audio"
control={control}
render={({ field }) => (
<AudioInputSelector
value={field.value}
onChange={field.onChange}
error={errors.ref_audio?.message}
/>
)}
/>
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="language"></Label>
<Select
value={watch('language')}
onValueChange={(value: string) => setValue('language', value)}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{languages.map((lang) => (
<SelectItem key={lang.code} value={lang.code}>
{lang.name}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="text"></Label>
<Textarea
{...register('text')}
placeholder="输入要合成的文本..."
rows={2}
className="min-h-[60px] md:min-h-[96px]"
/>
<PresetSelector
presets={PRESET_REF_TEXTS}
onSelect={(preset) => setValue('text', preset.text)}
/>
{errors.text && (
<p className="text-sm text-destructive">{errors.text.message}</p>
)}
</div>
<div className="flex items-center space-x-4">
<div className="flex items-center space-x-2">
<Controller
name="x_vector_only_mode"
control={control}
render={({ field }) => (
<Checkbox
id="x_vector_only_mode"
checked={field.value}
onCheckedChange={field.onChange}
/>
)}
/>
<Label htmlFor="x_vector_only_mode" className="text-sm font-normal">
</Label>
</div>
<div className="flex items-center space-x-2">
<Controller
name="use_cache"
control={control}
render={({ field }) => (
<Checkbox
id="use_cache"
checked={field.value}
onCheckedChange={field.onChange}
/>
)}
/>
<Label htmlFor="use_cache" className="text-sm font-normal">
使
</Label>
</div>
</div>
<Collapsible open={advancedOpen} onOpenChange={setAdvancedOpen}>
<CollapsibleTrigger asChild>
<Button type="button" variant="ghost" className="w-full">
<ChevronDown className="ml-2 h-4 w-4" />
</Button>
</CollapsibleTrigger>
<CollapsibleContent className="space-y-3 md:space-y-4 pt-3 md:pt-4">
<ParamInput
name="max_new_tokens"
label={ADVANCED_PARAMS_INFO.max_new_tokens.label}
description={ADVANCED_PARAMS_INFO.max_new_tokens.description}
tooltip={ADVANCED_PARAMS_INFO.max_new_tokens.tooltip}
register={register}
min={1}
max={10000}
/>
</CollapsibleContent>
</Collapsible>
<Button type="submit" className="w-full" disabled={isLoading || isPolling}>
{isLoading ? '创建中...' : '生成语音'}
</Button>
{isPolling && <LoadingState elapsedTime={elapsedTime} />}
{isCompleted && currentJob && (
<div className="space-y-4 pt-4 border-t">
<AudioPlayer
audioUrl={memoizedAudioUrl}
jobId={currentJob.id}
/>
</div>
)}
</form>
)
}
export default VoiceCloneForm

View File

@@ -0,0 +1,245 @@
import { useForm } from 'react-hook-form'
import { zodResolver } from '@hookform/resolvers/zod'
import * as z from 'zod'
import { useEffect, useState, forwardRef, useImperativeHandle, useMemo } from 'react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Textarea } from '@/components/ui/textarea'
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from '@/components/ui/collapsible'
import { Label } from '@/components/ui/label'
import { ChevronDown } from 'lucide-react'
import { toast } from 'sonner'
import { ttsApi, jobApi } from '@/lib/api'
import { useJobPolling } from '@/hooks/useJobPolling'
import { LoadingState } from '@/components/LoadingState'
import { AudioPlayer } from '@/components/AudioPlayer'
import { PresetSelector } from '@/components/PresetSelector'
import { ParamInput } from '@/components/ParamInput'
import { PRESET_VOICE_DESIGNS, ADVANCED_PARAMS_INFO } from '@/lib/constants'
import type { Language } from '@/types/tts'
const formSchema = z.object({
text: z.string().min(1, '请输入要合成的文本').max(5000, '文本长度不能超过 5000 字符'),
language: z.string().min(1, '请选择语言'),
instruct: z.string().min(10, '音色描述至少需要 10 个字符').max(500, '音色描述不能超过 500 字符'),
max_new_tokens: z.number().min(1).max(10000).optional(),
temperature: z.number().min(0).max(2).optional(),
top_k: z.number().min(1).max(100).optional(),
top_p: z.number().min(0).max(1).optional(),
repetition_penalty: z.number().min(0).max(2).optional(),
})
type FormData = z.infer<typeof formSchema>
export interface VoiceDesignFormHandle {
loadParams: (params: any) => void
}
const VoiceDesignForm = forwardRef<VoiceDesignFormHandle>((_props, ref) => {
const [languages, setLanguages] = useState<Language[]>([])
const [isLoading, setIsLoading] = useState(false)
const [advancedOpen, setAdvancedOpen] = useState(false)
const { currentJob, isPolling, isCompleted, startPolling, elapsedTime } = useJobPolling()
const {
register,
handleSubmit,
setValue,
watch,
formState: { errors },
} = useForm<FormData>({
resolver: zodResolver(formSchema),
defaultValues: {
text: '',
language: 'Auto',
instruct: '',
max_new_tokens: 2048,
temperature: 0.3,
top_k: 20,
top_p: 0.7,
repetition_penalty: 1.05,
},
})
useImperativeHandle(ref, () => ({
loadParams: (params: any) => {
setValue('text', params.text || '')
setValue('language', params.language || 'Auto')
setValue('instruct', params.instruct || '')
setValue('max_new_tokens', params.max_new_tokens || 2048)
setValue('temperature', params.temperature || 0.3)
setValue('top_k', params.top_k || 20)
setValue('top_p', params.top_p || 0.7)
setValue('repetition_penalty', params.repetition_penalty || 1.05)
}
}))
useEffect(() => {
const fetchData = async () => {
try {
const langs = await ttsApi.getLanguages()
setLanguages(langs)
} catch (error) {
toast.error('加载数据失败')
}
}
fetchData()
}, [])
const onSubmit = async (data: FormData) => {
setIsLoading(true)
try {
const result = await ttsApi.createVoiceDesignJob(data)
toast.success('任务已创建')
startPolling(result.job_id)
} catch (error) {
toast.error('创建任务失败')
} finally {
setIsLoading(false)
}
}
const memoizedAudioUrl = useMemo(() => {
if (!currentJob) return ''
return jobApi.getAudioUrl(currentJob.id, currentJob.audio_url)
}, [currentJob?.id, currentJob?.audio_url])
return (
<form onSubmit={handleSubmit(onSubmit)} className="space-y-4 md:space-y-6">
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="language"></Label>
<Select
value={watch('language')}
onValueChange={(value: string) => setValue('language', value)}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{languages.map((lang) => (
<SelectItem key={lang.code} value={lang.code}>
{lang.name}
</SelectItem>
))}
</SelectContent>
</Select>
{errors.language && (
<p className="text-sm text-destructive">{errors.language.message}</p>
)}
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="text"></Label>
<Textarea
{...register('text')}
placeholder="输入要合成的文本..."
rows={2}
className="min-h-[60px] md:min-h-[96px]"
/>
{errors.text && (
<p className="text-sm text-destructive">{errors.text.message}</p>
)}
</div>
<div className="space-y-1.5 md:space-y-2">
<Label htmlFor="instruct"></Label>
<Textarea
{...register('instruct')}
placeholder="例如:成熟男性,低沉磁性,充满权威感"
rows={2}
className="min-h-[60px] md:min-h-[80px]"
/>
<PresetSelector
presets={PRESET_VOICE_DESIGNS}
onSelect={(preset) => {
setValue('instruct', preset.instruct)
if (preset.text) {
setValue('text', preset.text)
}
}}
/>
{errors.instruct && (
<p className="text-sm text-destructive">{errors.instruct.message}</p>
)}
</div>
<Collapsible open={advancedOpen} onOpenChange={setAdvancedOpen}>
<CollapsibleTrigger asChild>
<Button type="button" variant="ghost" className="w-full">
<ChevronDown className="ml-2 h-4 w-4" />
</Button>
</CollapsibleTrigger>
<CollapsibleContent className="space-y-3 md:space-y-4 pt-3 md:pt-4">
<ParamInput
name="max_new_tokens"
label={ADVANCED_PARAMS_INFO.max_new_tokens.label}
description={ADVANCED_PARAMS_INFO.max_new_tokens.description}
tooltip={ADVANCED_PARAMS_INFO.max_new_tokens.tooltip}
register={register}
min={1}
max={10000}
/>
<ParamInput
name="temperature"
label={ADVANCED_PARAMS_INFO.temperature.label}
description={ADVANCED_PARAMS_INFO.temperature.description}
tooltip={ADVANCED_PARAMS_INFO.temperature.tooltip}
register={register}
step={0.1}
min={0}
max={2}
/>
<ParamInput
name="top_k"
label={ADVANCED_PARAMS_INFO.top_k.label}
description={ADVANCED_PARAMS_INFO.top_k.description}
tooltip={ADVANCED_PARAMS_INFO.top_k.tooltip}
register={register}
min={1}
max={100}
/>
<ParamInput
name="top_p"
label={ADVANCED_PARAMS_INFO.top_p.label}
description={ADVANCED_PARAMS_INFO.top_p.description}
tooltip={ADVANCED_PARAMS_INFO.top_p.tooltip}
register={register}
step={0.1}
min={0}
max={1}
/>
<ParamInput
name="repetition_penalty"
label={ADVANCED_PARAMS_INFO.repetition_penalty.label}
description={ADVANCED_PARAMS_INFO.repetition_penalty.description}
tooltip={ADVANCED_PARAMS_INFO.repetition_penalty.tooltip}
register={register}
step={0.01}
min={0}
max={2}
/>
</CollapsibleContent>
</Collapsible>
<Button type="submit" className="w-full" disabled={isLoading || isPolling}>
{isLoading ? '创建中...' : '生成语音'}
</Button>
{isPolling && <LoadingState elapsedTime={elapsedTime} />}
{isCompleted && currentJob && (
<div className="space-y-4 pt-4 border-t">
<AudioPlayer
audioUrl={memoizedAudioUrl}
jobId={currentJob.id}
/>
</div>
)}
</form>
)
})
export default VoiceDesignForm

View File

@@ -0,0 +1,139 @@
import * as React from "react"
import * as AlertDialogPrimitive from "@radix-ui/react-alert-dialog"
import { cn } from "@/lib/utils"
import { buttonVariants } from "@/components/ui/button"
const AlertDialog = AlertDialogPrimitive.Root
const AlertDialogTrigger = AlertDialogPrimitive.Trigger
const AlertDialogPortal = AlertDialogPrimitive.Portal
const AlertDialogOverlay = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Overlay
className={cn(
"fixed inset-0 z-50 bg-black/80 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
className
)}
{...props}
ref={ref}
/>
))
AlertDialogOverlay.displayName = AlertDialogPrimitive.Overlay.displayName
const AlertDialogContent = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Content>
>(({ className, ...props }, ref) => (
<AlertDialogPortal>
<AlertDialogOverlay />
<AlertDialogPrimitive.Content
ref={ref}
className={cn(
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 border bg-background p-6 shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%] sm:rounded-lg",
className
)}
{...props}
/>
</AlertDialogPortal>
))
AlertDialogContent.displayName = AlertDialogPrimitive.Content.displayName
const AlertDialogHeader = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col space-y-2 text-center sm:text-left",
className
)}
{...props}
/>
)
AlertDialogHeader.displayName = "AlertDialogHeader"
const AlertDialogFooter = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2",
className
)}
{...props}
/>
)
AlertDialogFooter.displayName = "AlertDialogFooter"
const AlertDialogTitle = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Title>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Title>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Title
ref={ref}
className={cn("text-lg font-semibold", className)}
{...props}
/>
))
AlertDialogTitle.displayName = AlertDialogPrimitive.Title.displayName
const AlertDialogDescription = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Description>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Description>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Description
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
))
AlertDialogDescription.displayName =
AlertDialogPrimitive.Description.displayName
const AlertDialogAction = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Action>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Action>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Action
ref={ref}
className={cn(buttonVariants(), className)}
{...props}
/>
))
AlertDialogAction.displayName = AlertDialogPrimitive.Action.displayName
const AlertDialogCancel = React.forwardRef<
React.ElementRef<typeof AlertDialogPrimitive.Cancel>,
React.ComponentPropsWithoutRef<typeof AlertDialogPrimitive.Cancel>
>(({ className, ...props }, ref) => (
<AlertDialogPrimitive.Cancel
ref={ref}
className={cn(
buttonVariants({ variant: "outline" }),
"mt-2 sm:mt-0",
className
)}
{...props}
/>
))
AlertDialogCancel.displayName = AlertDialogPrimitive.Cancel.displayName
export {
AlertDialog,
AlertDialogPortal,
AlertDialogOverlay,
AlertDialogTrigger,
AlertDialogContent,
AlertDialogHeader,
AlertDialogFooter,
AlertDialogTitle,
AlertDialogDescription,
AlertDialogAction,
AlertDialogCancel,
}

View File

@@ -0,0 +1,36 @@
import * as React from "react"
import { cva, type VariantProps } from "class-variance-authority"
import { cn } from "@/lib/utils"
const badgeVariants = cva(
"inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
{
variants: {
variant: {
default:
"border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
secondary:
"border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
destructive:
"border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
outline: "text-foreground",
},
},
defaultVariants: {
variant: "default",
},
}
)
export interface BadgeProps
extends React.HTMLAttributes<HTMLDivElement>,
VariantProps<typeof badgeVariants> {}
function Badge({ className, variant, ...props }: BadgeProps) {
return (
<div className={cn(badgeVariants({ variant }), className)} {...props} />
)
}
export { Badge, badgeVariants }

View File

@@ -0,0 +1,56 @@
import * as React from "react"
import { Slot } from "@radix-ui/react-slot"
import { cva, type VariantProps } from "class-variance-authority"
import { cn } from "@/lib/utils"
const buttonVariants = cva(
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0",
{
variants: {
variant: {
default: "bg-primary text-primary-foreground hover:bg-primary/90",
destructive:
"bg-destructive text-destructive-foreground hover:bg-destructive/90",
outline:
"border border-input bg-background hover:bg-accent hover:text-accent-foreground",
secondary:
"bg-secondary text-secondary-foreground hover:bg-secondary/80",
ghost: "hover:bg-accent hover:text-accent-foreground",
link: "text-primary underline-offset-4 hover:underline",
},
size: {
default: "h-10 px-4 py-2",
sm: "h-9 rounded-md px-3",
lg: "h-11 rounded-md px-8",
icon: "h-10 w-10",
},
},
defaultVariants: {
variant: "default",
size: "default",
},
}
)
export interface ButtonProps
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
VariantProps<typeof buttonVariants> {
asChild?: boolean
}
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
({ className, variant, size, asChild = false, ...props }, ref) => {
const Comp = asChild ? Slot : "button"
return (
<Comp
className={cn(buttonVariants({ variant, size, className }))}
ref={ref}
{...props}
/>
)
}
)
Button.displayName = "Button"
export { Button, buttonVariants }

View File

@@ -0,0 +1,79 @@
import * as React from "react"
import { cn } from "@/lib/utils"
const Card = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn(
"rounded-lg border bg-card text-card-foreground shadow-sm",
className
)}
{...props}
/>
))
Card.displayName = "Card"
const CardHeader = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex flex-col space-y-1.5 p-6", className)}
{...props}
/>
))
CardHeader.displayName = "CardHeader"
const CardTitle = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn(
"text-2xl font-semibold leading-none tracking-tight",
className
)}
{...props}
/>
))
CardTitle.displayName = "CardTitle"
const CardDescription = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
))
CardDescription.displayName = "CardDescription"
const CardContent = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
))
CardContent.displayName = "CardContent"
const CardFooter = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex items-center p-6 pt-0", className)}
{...props}
/>
))
CardFooter.displayName = "CardFooter"
export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent }

View File

@@ -0,0 +1,30 @@
"use client"
import * as React from "react"
import * as CheckboxPrimitive from "@radix-ui/react-checkbox"
import { Check } from "lucide-react"
import { cn } from "@/lib/utils"
const Checkbox = React.forwardRef<
React.ElementRef<typeof CheckboxPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof CheckboxPrimitive.Root>
>(({ className, ...props }, ref) => (
<CheckboxPrimitive.Root
ref={ref}
className={cn(
"grid place-content-center peer h-4 w-4 shrink-0 rounded-sm border border-primary ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=checked]:text-primary-foreground",
className
)}
{...props}
>
<CheckboxPrimitive.Indicator
className={cn("grid place-content-center text-current")}
>
<Check className="h-4 w-4" />
</CheckboxPrimitive.Indicator>
</CheckboxPrimitive.Root>
))
Checkbox.displayName = CheckboxPrimitive.Root.displayName
export { Checkbox }

View File

@@ -0,0 +1,9 @@
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible"
const Collapsible = CollapsiblePrimitive.Root
const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger
const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent
export { Collapsible, CollapsibleTrigger, CollapsibleContent }

View File

@@ -0,0 +1,120 @@
import * as React from "react"
import * as DialogPrimitive from "@radix-ui/react-dialog"
import { X } from "lucide-react"
import { cn } from "@/lib/utils"
const Dialog = DialogPrimitive.Root
const DialogTrigger = DialogPrimitive.Trigger
const DialogPortal = DialogPrimitive.Portal
const DialogClose = DialogPrimitive.Close
const DialogOverlay = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<DialogPrimitive.Overlay
ref={ref}
className={cn(
"fixed inset-0 z-50 bg-black/80 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
className
)}
{...props}
/>
))
DialogOverlay.displayName = DialogPrimitive.Overlay.displayName
const DialogContent = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Content>
>(({ className, children, ...props }, ref) => (
<DialogPortal>
<DialogOverlay />
<DialogPrimitive.Content
ref={ref}
className={cn(
"fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%] gap-4 border bg-background p-6 shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%] sm:rounded-lg",
className
)}
{...props}
>
{children}
<DialogPrimitive.Close className="absolute right-4 top-4 rounded-sm opacity-70 ring-offset-background transition-opacity hover:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:pointer-events-none data-[state=open]:bg-accent data-[state=open]:text-muted-foreground">
<X className="h-4 w-4" />
<span className="sr-only">Close</span>
</DialogPrimitive.Close>
</DialogPrimitive.Content>
</DialogPortal>
))
DialogContent.displayName = DialogPrimitive.Content.displayName
const DialogHeader = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col space-y-1.5 text-center sm:text-left",
className
)}
{...props}
/>
)
DialogHeader.displayName = "DialogHeader"
const DialogFooter = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2",
className
)}
{...props}
/>
)
DialogFooter.displayName = "DialogFooter"
const DialogTitle = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Title>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Title>
>(({ className, ...props }, ref) => (
<DialogPrimitive.Title
ref={ref}
className={cn(
"text-lg font-semibold leading-none tracking-tight",
className
)}
{...props}
/>
))
DialogTitle.displayName = DialogPrimitive.Title.displayName
const DialogDescription = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Description>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Description>
>(({ className, ...props }, ref) => (
<DialogPrimitive.Description
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
))
DialogDescription.displayName = DialogPrimitive.Description.displayName
export {
Dialog,
DialogPortal,
DialogOverlay,
DialogClose,
DialogTrigger,
DialogContent,
DialogHeader,
DialogFooter,
DialogTitle,
DialogDescription,
}

View File

@@ -0,0 +1,178 @@
"use client"
import * as React from "react"
import * as LabelPrimitive from "@radix-ui/react-label"
import { Slot } from "@radix-ui/react-slot"
import {
Controller,
FormProvider,
useFormContext,
type ControllerProps,
type FieldPath,
type FieldValues,
} from "react-hook-form"
import { cn } from "@/lib/utils"
import { Label } from "@/components/ui/label"
const Form = FormProvider
type FormFieldContextValue<
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>
> = {
name: TName
}
const FormFieldContext = React.createContext<FormFieldContextValue | null>(null)
const FormField = <
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>
>({
...props
}: ControllerProps<TFieldValues, TName>) => {
return (
<FormFieldContext.Provider value={{ name: props.name }}>
<Controller {...props} />
</FormFieldContext.Provider>
)
}
const useFormField = () => {
const fieldContext = React.useContext(FormFieldContext)
const itemContext = React.useContext(FormItemContext)
const { getFieldState, formState } = useFormContext()
if (!fieldContext) {
throw new Error("useFormField should be used within <FormField>")
}
if (!itemContext) {
throw new Error("useFormField should be used within <FormItem>")
}
const fieldState = getFieldState(fieldContext.name, formState)
const { id } = itemContext
return {
id,
name: fieldContext.name,
formItemId: `${id}-form-item`,
formDescriptionId: `${id}-form-item-description`,
formMessageId: `${id}-form-item-message`,
...fieldState,
}
}
type FormItemContextValue = {
id: string
}
const FormItemContext = React.createContext<FormItemContextValue | null>(null)
const FormItem = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => {
const id = React.useId()
return (
<FormItemContext.Provider value={{ id }}>
<div ref={ref} className={cn("space-y-2", className)} {...props} />
</FormItemContext.Provider>
)
})
FormItem.displayName = "FormItem"
const FormLabel = React.forwardRef<
React.ElementRef<typeof LabelPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root>
>(({ className, ...props }, ref) => {
const { error, formItemId } = useFormField()
return (
<Label
ref={ref}
className={cn(error && "text-destructive", className)}
htmlFor={formItemId}
{...props}
/>
)
})
FormLabel.displayName = "FormLabel"
const FormControl = React.forwardRef<
React.ElementRef<typeof Slot>,
React.ComponentPropsWithoutRef<typeof Slot>
>(({ ...props }, ref) => {
const { error, formItemId, formDescriptionId, formMessageId } = useFormField()
return (
<Slot
ref={ref}
id={formItemId}
aria-describedby={
!error
? `${formDescriptionId}`
: `${formDescriptionId} ${formMessageId}`
}
aria-invalid={!!error}
{...props}
/>
)
})
FormControl.displayName = "FormControl"
const FormDescription = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, ...props }, ref) => {
const { formDescriptionId } = useFormField()
return (
<p
ref={ref}
id={formDescriptionId}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
)
})
FormDescription.displayName = "FormDescription"
const FormMessage = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, children, ...props }, ref) => {
const { error, formMessageId } = useFormField()
const body = error ? String(error?.message ?? "") : children
if (!body) {
return null
}
return (
<p
ref={ref}
id={formMessageId}
className={cn("text-sm font-medium text-destructive", className)}
{...props}
>
{body}
</p>
)
})
FormMessage.displayName = "FormMessage"
export {
useFormField,
Form,
FormItem,
FormLabel,
FormControl,
FormDescription,
FormMessage,
FormField,
}

View File

@@ -0,0 +1,22 @@
import * as React from "react"
import { cn } from "@/lib/utils"
const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<"input">>(
({ className, type, ...props }, ref) => {
return (
<input
type={type}
className={cn(
"flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-base ring-offset-background file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-foreground placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
className
)}
ref={ref}
{...props}
/>
)
}
)
Input.displayName = "Input"
export { Input }

View File

@@ -0,0 +1,24 @@
import * as React from "react"
import * as LabelPrimitive from "@radix-ui/react-label"
import { cva, type VariantProps } from "class-variance-authority"
import { cn } from "@/lib/utils"
const labelVariants = cva(
"text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
)
const Label = React.forwardRef<
React.ElementRef<typeof LabelPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root> &
VariantProps<typeof labelVariants>
>(({ className, ...props }, ref) => (
<LabelPrimitive.Root
ref={ref}
className={cn(labelVariants(), className)}
{...props}
/>
))
Label.displayName = LabelPrimitive.Root.displayName
export { Label }

View File

@@ -0,0 +1,26 @@
import * as React from "react"
import * as ProgressPrimitive from "@radix-ui/react-progress"
import { cn } from "@/lib/utils"
const Progress = React.forwardRef<
React.ElementRef<typeof ProgressPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof ProgressPrimitive.Root>
>(({ className, value, ...props }, ref) => (
<ProgressPrimitive.Root
ref={ref}
className={cn(
"relative h-4 w-full overflow-hidden rounded-full bg-secondary",
className
)}
{...props}
>
<ProgressPrimitive.Indicator
className="h-full w-full flex-1 bg-primary transition-all"
style={{ transform: `translateX(-${100 - (value || 0)}%)` }}
/>
</ProgressPrimitive.Root>
))
Progress.displayName = ProgressPrimitive.Root.displayName
export { Progress }

View File

@@ -0,0 +1,48 @@
"use client"
import * as React from "react"
import * as ScrollAreaPrimitive from "@radix-ui/react-scroll-area"
import { cn } from "@/lib/utils"
const ScrollArea = React.forwardRef<
React.ElementRef<typeof ScrollAreaPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.Root>
>(({ className, children, ...props }, ref) => (
<ScrollAreaPrimitive.Root
ref={ref}
className={cn("relative overflow-hidden", className)}
{...props}
>
<ScrollAreaPrimitive.Viewport className="h-full w-full rounded-[inherit]">
{children}
</ScrollAreaPrimitive.Viewport>
<ScrollBar />
<ScrollAreaPrimitive.Corner />
</ScrollAreaPrimitive.Root>
))
ScrollArea.displayName = ScrollAreaPrimitive.Root.displayName
const ScrollBar = React.forwardRef<
React.ElementRef<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>,
React.ComponentPropsWithoutRef<typeof ScrollAreaPrimitive.ScrollAreaScrollbar>
>(({ className, orientation = "vertical", ...props }, ref) => (
<ScrollAreaPrimitive.ScrollAreaScrollbar
ref={ref}
orientation={orientation}
className={cn(
"flex touch-none select-none transition-colors",
orientation === "vertical" &&
"h-full w-2.5 border-l border-l-transparent p-[1px]",
orientation === "horizontal" &&
"h-2.5 flex-col border-t border-t-transparent p-[1px]",
className
)}
{...props}
>
<ScrollAreaPrimitive.ScrollAreaThumb className="relative flex-1 rounded-full bg-border" />
</ScrollAreaPrimitive.ScrollAreaScrollbar>
))
ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName
export { ScrollArea, ScrollBar }

View File

@@ -0,0 +1,158 @@
import * as React from "react"
import * as SelectPrimitive from "@radix-ui/react-select"
import { Check, ChevronDown, ChevronUp } from "lucide-react"
import { cn } from "@/lib/utils"
const Select = SelectPrimitive.Root
const SelectGroup = SelectPrimitive.Group
const SelectValue = SelectPrimitive.Value
const SelectTrigger = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Trigger>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Trigger>
>(({ className, children, ...props }, ref) => (
<SelectPrimitive.Trigger
ref={ref}
className={cn(
"flex h-10 w-full items-center justify-between rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background data-[placeholder]:text-muted-foreground focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
className
)}
{...props}
>
{children}
<SelectPrimitive.Icon asChild>
<ChevronDown className="h-4 w-4 opacity-50" />
</SelectPrimitive.Icon>
</SelectPrimitive.Trigger>
))
SelectTrigger.displayName = SelectPrimitive.Trigger.displayName
const SelectScrollUpButton = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.ScrollUpButton>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollUpButton>
>(({ className, ...props }, ref) => (
<SelectPrimitive.ScrollUpButton
ref={ref}
className={cn(
"flex cursor-default items-center justify-center py-1",
className
)}
{...props}
>
<ChevronUp className="h-4 w-4" />
</SelectPrimitive.ScrollUpButton>
))
SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName
const SelectScrollDownButton = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.ScrollDownButton>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollDownButton>
>(({ className, ...props }, ref) => (
<SelectPrimitive.ScrollDownButton
ref={ref}
className={cn(
"flex cursor-default items-center justify-center py-1",
className
)}
{...props}
>
<ChevronDown className="h-4 w-4" />
</SelectPrimitive.ScrollDownButton>
))
SelectScrollDownButton.displayName =
SelectPrimitive.ScrollDownButton.displayName
const SelectContent = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Content>
>(({ className, children, position = "popper", ...props }, ref) => (
<SelectPrimitive.Portal>
<SelectPrimitive.Content
ref={ref}
className={cn(
"relative z-50 max-h-[--radix-select-content-available-height] min-w-[8rem] overflow-y-auto overflow-x-hidden rounded-md border bg-popover text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-select-content-transform-origin]",
position === "popper" &&
"data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1",
className
)}
position={position}
{...props}
>
<SelectScrollUpButton />
<SelectPrimitive.Viewport
className={cn(
"p-1",
position === "popper" &&
"h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]"
)}
>
{children}
</SelectPrimitive.Viewport>
<SelectScrollDownButton />
</SelectPrimitive.Content>
</SelectPrimitive.Portal>
))
SelectContent.displayName = SelectPrimitive.Content.displayName
const SelectLabel = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Label>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Label>
>(({ className, ...props }, ref) => (
<SelectPrimitive.Label
ref={ref}
className={cn("py-1.5 pl-8 pr-2 text-sm font-semibold", className)}
{...props}
/>
))
SelectLabel.displayName = SelectPrimitive.Label.displayName
const SelectItem = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Item>
>(({ className, children, ...props }, ref) => (
<SelectPrimitive.Item
ref={ref}
className={cn(
"relative flex w-full cursor-default select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50",
className
)}
{...props}
>
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
<SelectPrimitive.ItemIndicator>
<Check className="h-4 w-4" />
</SelectPrimitive.ItemIndicator>
</span>
<SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>
</SelectPrimitive.Item>
))
SelectItem.displayName = SelectPrimitive.Item.displayName
const SelectSeparator = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Separator>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Separator>
>(({ className, ...props }, ref) => (
<SelectPrimitive.Separator
ref={ref}
className={cn("-mx-1 my-1 h-px bg-muted", className)}
{...props}
/>
))
SelectSeparator.displayName = SelectPrimitive.Separator.displayName
export {
Select,
SelectGroup,
SelectValue,
SelectTrigger,
SelectContent,
SelectLabel,
SelectItem,
SelectSeparator,
SelectScrollUpButton,
SelectScrollDownButton,
}

View File

@@ -0,0 +1,31 @@
"use client"
import * as React from "react"
import * as SeparatorPrimitive from "@radix-ui/react-separator"
import { cn } from "@/lib/utils"
const Separator = React.forwardRef<
React.ElementRef<typeof SeparatorPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SeparatorPrimitive.Root>
>(
(
{ className, orientation = "horizontal", decorative = true, ...props },
ref
) => (
<SeparatorPrimitive.Root
ref={ref}
decorative={decorative}
orientation={orientation}
className={cn(
"shrink-0 bg-border",
orientation === "horizontal" ? "h-[1px] w-full" : "h-full w-[1px]",
className
)}
{...props}
/>
)
)
Separator.displayName = SeparatorPrimitive.Root.displayName
export { Separator }

View File

@@ -0,0 +1,138 @@
import * as React from "react"
import * as SheetPrimitive from "@radix-ui/react-dialog"
import { cva, type VariantProps } from "class-variance-authority"
import { X } from "lucide-react"
import { cn } from "@/lib/utils"
const Sheet = SheetPrimitive.Root
const SheetTrigger = SheetPrimitive.Trigger
const SheetClose = SheetPrimitive.Close
const SheetPortal = SheetPrimitive.Portal
const SheetOverlay = React.forwardRef<
React.ElementRef<typeof SheetPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof SheetPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<SheetPrimitive.Overlay
className={cn(
"fixed inset-0 z-50 bg-black/80 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
className
)}
{...props}
ref={ref}
/>
))
SheetOverlay.displayName = SheetPrimitive.Overlay.displayName
const sheetVariants = cva(
"fixed z-50 gap-4 bg-background p-6 shadow-lg transition ease-in-out data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:duration-300 data-[state=open]:duration-500",
{
variants: {
side: {
top: "inset-x-0 top-0 border-b data-[state=closed]:slide-out-to-top data-[state=open]:slide-in-from-top",
bottom:
"inset-x-0 bottom-0 border-t data-[state=closed]:slide-out-to-bottom data-[state=open]:slide-in-from-bottom",
left: "inset-y-0 left-0 h-full w-3/4 border-r data-[state=closed]:slide-out-to-left data-[state=open]:slide-in-from-left sm:max-w-sm",
right:
"inset-y-0 right-0 h-full w-3/4 border-l data-[state=closed]:slide-out-to-right data-[state=open]:slide-in-from-right sm:max-w-sm",
},
},
defaultVariants: {
side: "right",
},
}
)
interface SheetContentProps
extends React.ComponentPropsWithoutRef<typeof SheetPrimitive.Content>,
VariantProps<typeof sheetVariants> {}
const SheetContent = React.forwardRef<
React.ElementRef<typeof SheetPrimitive.Content>,
SheetContentProps
>(({ side = "right", className, children, ...props }, ref) => (
<SheetPortal>
<SheetOverlay />
<SheetPrimitive.Content
ref={ref}
className={cn(sheetVariants({ side }), className)}
{...props}
>
{children}
<SheetPrimitive.Close className="absolute right-4 top-4 rounded-sm opacity-70 ring-offset-background transition-opacity hover:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:pointer-events-none data-[state=open]:bg-secondary">
<X className="h-4 w-4" />
<span className="sr-only">Close</span>
</SheetPrimitive.Close>
</SheetPrimitive.Content>
</SheetPortal>
))
SheetContent.displayName = SheetPrimitive.Content.displayName
const SheetHeader = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col space-y-2 text-center sm:text-left",
className
)}
{...props}
/>
)
SheetHeader.displayName = "SheetHeader"
const SheetFooter = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2",
className
)}
{...props}
/>
)
SheetFooter.displayName = "SheetFooter"
const SheetTitle = React.forwardRef<
React.ElementRef<typeof SheetPrimitive.Title>,
React.ComponentPropsWithoutRef<typeof SheetPrimitive.Title>
>(({ className, ...props }, ref) => (
<SheetPrimitive.Title
ref={ref}
className={cn("text-lg font-semibold text-foreground", className)}
{...props}
/>
))
SheetTitle.displayName = SheetPrimitive.Title.displayName
const SheetDescription = React.forwardRef<
React.ElementRef<typeof SheetPrimitive.Description>,
React.ComponentPropsWithoutRef<typeof SheetPrimitive.Description>
>(({ className, ...props }, ref) => (
<SheetPrimitive.Description
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
))
SheetDescription.displayName = SheetPrimitive.Description.displayName
export {
Sheet,
SheetPortal,
SheetOverlay,
SheetTrigger,
SheetClose,
SheetContent,
SheetHeader,
SheetFooter,
SheetTitle,
SheetDescription,
}

View File

@@ -0,0 +1,28 @@
"use client"
import * as React from "react"
import * as SliderPrimitive from "@radix-ui/react-slider"
import { cn } from "@/lib/utils"
const Slider = React.forwardRef<
React.ElementRef<typeof SliderPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root>
>(({ className, ...props }, ref) => (
<SliderPrimitive.Root
ref={ref}
className={cn(
"relative flex w-full touch-none select-none items-center",
className
)}
{...props}
>
<SliderPrimitive.Track className="relative h-2 w-full grow overflow-hidden rounded-full bg-secondary">
<SliderPrimitive.Range className="absolute h-full bg-primary" />
</SliderPrimitive.Track>
<SliderPrimitive.Thumb className="block h-5 w-5 rounded-full border-2 border-primary bg-background ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50" />
</SliderPrimitive.Root>
))
Slider.displayName = SliderPrimitive.Root.displayName
export { Slider }

View File

@@ -0,0 +1,53 @@
import * as React from "react"
import * as TabsPrimitive from "@radix-ui/react-tabs"
import { cn } from "@/lib/utils"
const Tabs = TabsPrimitive.Root
const TabsList = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.List>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>
>(({ className, ...props }, ref) => (
<TabsPrimitive.List
ref={ref}
className={cn(
"inline-flex h-10 items-center justify-center rounded-md bg-muted p-1 text-muted-foreground",
className
)}
{...props}
/>
))
TabsList.displayName = TabsPrimitive.List.displayName
const TabsTrigger = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.Trigger>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Trigger>
>(({ className, ...props }, ref) => (
<TabsPrimitive.Trigger
ref={ref}
className={cn(
"inline-flex items-center justify-center whitespace-nowrap rounded-sm px-3 py-1.5 text-sm font-medium ring-offset-background transition-all focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 data-[state=active]:bg-background data-[state=active]:text-foreground data-[state=active]:shadow-sm",
className
)}
{...props}
/>
))
TabsTrigger.displayName = TabsPrimitive.Trigger.displayName
const TabsContent = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Content>
>(({ className, ...props }, ref) => (
<TabsPrimitive.Content
ref={ref}
className={cn(
"mt-2 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
className
)}
{...props}
/>
))
TabsContent.displayName = TabsPrimitive.Content.displayName
export { Tabs, TabsList, TabsTrigger, TabsContent }

View File

@@ -0,0 +1,54 @@
import * as React from "react"
import { cn, debounce } from "@/lib/utils"
const Textarea = React.forwardRef<
HTMLTextAreaElement,
React.ComponentProps<"textarea">
>(({ className, ...props }, ref) => {
const internalRef = React.useRef<HTMLTextAreaElement>(null)
React.useImperativeHandle(ref, () => internalRef.current!)
const adjustHeight = React.useCallback((element: HTMLTextAreaElement) => {
element.style.height = 'auto'
element.style.height = `${element.scrollHeight}px`
}, [])
React.useLayoutEffect(() => {
const element = internalRef.current
if (element) {
adjustHeight(element)
}
}, [props.value, props.defaultValue, adjustHeight])
React.useEffect(() => {
const element = internalRef.current
if (!element) return
const handleInput = () => adjustHeight(element)
const handleResize = debounce(() => adjustHeight(element), 250)
element.addEventListener('input', handleInput)
window.addEventListener('resize', handleResize)
return () => {
element.removeEventListener('input', handleInput)
window.removeEventListener('resize', handleResize)
}
}, [adjustHeight])
return (
<textarea
className={cn(
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-base ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm max-h-[80vh] md:max-h-[400px] overflow-y-auto",
className
)}
ref={internalRef}
{...props}
/>
)
})
Textarea.displayName = "Textarea"
export { Textarea }

View File

@@ -0,0 +1,28 @@
import * as React from "react"
import * as TooltipPrimitive from "@radix-ui/react-tooltip"
import { cn } from "@/lib/utils"
const TooltipProvider = TooltipPrimitive.Provider
const Tooltip = TooltipPrimitive.Root
const TooltipTrigger = TooltipPrimitive.Trigger
const TooltipContent = React.forwardRef<
React.ElementRef<typeof TooltipPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof TooltipPrimitive.Content>
>(({ className, sideOffset = 4, ...props }, ref) => (
<TooltipPrimitive.Content
ref={ref}
sideOffset={sideOffset}
className={cn(
"z-50 overflow-hidden rounded-md border bg-popover px-3 py-1.5 text-sm text-popover-foreground shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 origin-[--radix-tooltip-content-transform-origin]",
className
)}
{...props}
/>
))
TooltipContent.displayName = TooltipPrimitive.Content.displayName
export { Tooltip, TooltipTrigger, TooltipContent, TooltipProvider }

View File

@@ -0,0 +1,48 @@
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from '@/components/ui/alert-dialog'
import type { User } from '@/types/auth'
interface DeleteUserDialogProps {
open: boolean
onOpenChange: (open: boolean) => void
user: User | null
onConfirm: () => Promise<void>
isLoading: boolean
}
export function DeleteUserDialog({
open,
onOpenChange,
user,
onConfirm,
isLoading,
}: DeleteUserDialogProps) {
return (
<AlertDialog open={open} onOpenChange={onOpenChange}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle></AlertDialogTitle>
<AlertDialogDescription>
<strong>{user?.username}</strong>
<br />
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isLoading}></AlertDialogCancel>
<AlertDialogAction onClick={onConfirm} disabled={isLoading}>
{isLoading ? '删除中...' : '确认删除'}
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
)
}

View File

@@ -0,0 +1,203 @@
import { useEffect } from 'react'
import { useForm } from 'react-hook-form'
import { zodResolver } from '@hookform/resolvers/zod'
import * as z from 'zod'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
} from '@/components/ui/dialog'
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from '@/components/ui/form'
import { Input } from '@/components/ui/input'
import { Button } from '@/components/ui/button'
import { Checkbox } from '@/components/ui/checkbox'
import type { User } from '@/types/auth'
const editUserFormSchema = z.object({
username: z.string().min(3, '用户名至少3个字符').max(20, '用户名最多20个字符'),
email: z.string().email('请输入有效的邮箱地址'),
password: z.string().optional(),
is_active: z.boolean().default(true),
is_superuser: z.boolean().default(false),
})
const createUserFormSchema = z.object({
username: z.string().min(3, '用户名至少3个字符').max(20, '用户名最多20个字符'),
email: z.string().email('请输入有效的邮箱地址'),
password: z.string().min(8, '密码至少8个字符'),
is_active: z.boolean().default(true),
is_superuser: z.boolean().default(false),
})
type UserFormValues = z.infer<typeof editUserFormSchema>
interface UserDialogProps {
open: boolean
onOpenChange: (open: boolean) => void
user?: User | null
onSubmit: (data: UserFormValues) => Promise<void>
isLoading: boolean
}
export function UserDialog({
open,
onOpenChange,
user,
onSubmit,
isLoading,
}: UserDialogProps) {
const isEditing = !!user
const form = useForm<UserFormValues>({
resolver: zodResolver(isEditing ? editUserFormSchema : createUserFormSchema),
defaultValues: {
username: '',
email: '',
password: '',
is_active: true,
is_superuser: false,
},
})
useEffect(() => {
if (user) {
form.reset({
username: user.username,
email: user.email,
password: '',
is_active: user.is_active,
is_superuser: user.is_superuser,
})
} else {
form.reset({
username: '',
email: '',
password: '',
is_active: true,
is_superuser: false,
})
}
}, [user, form])
const handleSubmit = async (data: UserFormValues) => {
await onSubmit(data)
form.reset()
}
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="sm:max-w-[500px]">
<DialogHeader>
<DialogTitle>{isEditing ? '编辑用户' : '创建用户'}</DialogTitle>
</DialogHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(handleSubmit)} className="space-y-4">
<FormField
control={form.control}
name="username"
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="email"
render={({ field }) => (
<FormItem>
<FormLabel></FormLabel>
<FormControl>
<Input type="email" {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="password"
render={({ field }) => (
<FormItem>
<FormLabel>
{isEditing && ' (留空则不修改)'}
</FormLabel>
<FormControl>
<Input type="password" {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="is_active"
render={({ field }) => (
<FormItem className="flex flex-row items-start space-x-3 space-y-0">
<FormControl>
<Checkbox
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
<div className="space-y-1 leading-none">
<FormLabel></FormLabel>
</div>
</FormItem>
)}
/>
<FormField
control={form.control}
name="is_superuser"
render={({ field }) => (
<FormItem className="flex flex-row items-start space-x-3 space-y-0">
<FormControl>
<Checkbox
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
<div className="space-y-1 leading-none">
<FormLabel></FormLabel>
</div>
</FormItem>
)}
/>
<DialogFooter>
<Button
type="button"
variant="outline"
onClick={() => onOpenChange(false)}
disabled={isLoading}
>
</Button>
<Button type="submit" disabled={isLoading}>
{isLoading ? '保存中...' : '保存'}
</Button>
</DialogFooter>
</form>
</Form>
</DialogContent>
</Dialog>
)
}

View File

@@ -0,0 +1,87 @@
import { Edit, Trash2 } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Badge } from '@/components/ui/badge'
import type { User } from '@/types/auth'
interface UserTableProps {
users: User[]
isLoading: boolean
onEdit: (user: User) => void
onDelete: (user: User) => void
}
export function UserTable({ users, isLoading, onEdit, onDelete }: UserTableProps) {
if (isLoading) {
return (
<div className="flex items-center justify-center py-8">
<div className="text-muted-foreground">...</div>
</div>
)
}
if (users.length === 0) {
return (
<div className="flex items-center justify-center py-8">
<div className="text-muted-foreground"></div>
</div>
)
}
return (
<div className="overflow-x-auto">
<table className="w-full">
<thead className="border-b">
<tr className="text-left">
<th className="px-4 py-3 font-medium">ID</th>
<th className="px-4 py-3 font-medium"></th>
<th className="px-4 py-3 font-medium"></th>
<th className="px-4 py-3 font-medium"></th>
<th className="px-4 py-3 font-medium"></th>
<th className="px-4 py-3 font-medium"></th>
<th className="px-4 py-3 font-medium text-right"></th>
</tr>
</thead>
<tbody>
{users.map((user) => (
<tr key={user.id} className="border-b hover:bg-muted/50">
<td className="px-4 py-3">{user.id}</td>
<td className="px-4 py-3">{user.username}</td>
<td className="px-4 py-3">{user.email}</td>
<td className="px-4 py-3">
<Badge variant={user.is_active ? 'default' : 'secondary'}>
{user.is_active ? '活跃' : '停用'}
</Badge>
</td>
<td className="px-4 py-3">
<Badge variant={user.is_superuser ? 'destructive' : 'outline'}>
{user.is_superuser ? '超级管理员' : '普通用户'}
</Badge>
</td>
<td className="px-4 py-3">
{new Date(user.created_at).toLocaleString('zh-CN')}
</td>
<td className="px-4 py-3">
<div className="flex justify-end gap-2">
<Button
variant="ghost"
size="icon"
onClick={() => onEdit(user)}
>
<Edit className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="icon"
onClick={() => onDelete(user)}
>
<Trash2 className="h-4 w-4" />
</Button>
</div>
</td>
</tr>
))}
</tbody>
</table>
</div>
)
}

View File

@@ -0,0 +1,101 @@
import { createContext, useContext, useState, useEffect, useMemo, useCallback, type ReactNode } from 'react'
import { ttsApi } from '@/lib/api'
import type { Language, Speaker } from '@/types/tts'
interface AppContextType {
currentTab: string
setCurrentTab: (tab: string) => void
languages: Language[]
speakers: Speaker[]
isLoadingConfig: boolean
}
interface CacheEntry<T> {
data: T
timestamp: number
}
const CACHE_DURATION = 5 * 60 * 1000
const cache: {
languages: CacheEntry<Language[]> | null
speakers: CacheEntry<Speaker[]> | null
} = {
languages: null,
speakers: null,
}
const isCacheValid = <T,>(entry: CacheEntry<T> | null): boolean => {
if (!entry) return false
return Date.now() - entry.timestamp < CACHE_DURATION
}
const AppContext = createContext<AppContextType | undefined>(undefined)
export function AppProvider({ children }: { children: ReactNode }) {
const [currentTab, setCurrentTabState] = useState('custom-voice')
const [languages, setLanguages] = useState<Language[]>([])
const [speakers, setSpeakers] = useState<Speaker[]>([])
const [isLoadingConfig, setIsLoadingConfig] = useState(true)
const setCurrentTab = useCallback((tab: string) => {
setCurrentTabState(tab)
}, [])
useEffect(() => {
const loadConfig = async () => {
try {
let languagesData: Language[]
let speakersData: Speaker[]
if (isCacheValid(cache.languages)) {
languagesData = cache.languages!.data
} else {
languagesData = await ttsApi.getLanguages()
cache.languages = { data: languagesData, timestamp: Date.now() }
}
if (isCacheValid(cache.speakers)) {
speakersData = cache.speakers!.data
} else {
speakersData = await ttsApi.getSpeakers()
cache.speakers = { data: speakersData, timestamp: Date.now() }
}
setLanguages(languagesData)
setSpeakers(speakersData)
} catch (error) {
console.error('Failed to load config:', error)
} finally {
setIsLoadingConfig(false)
}
}
loadConfig()
}, [])
const value = useMemo(
() => ({
currentTab,
setCurrentTab,
languages,
speakers,
isLoadingConfig,
}),
[currentTab, setCurrentTab, languages, speakers, isLoadingConfig]
)
return (
<AppContext.Provider value={value}>
{children}
</AppContext.Provider>
)
}
export function useApp() {
const context = useContext(AppContext)
if (!context) {
throw new Error('useApp must be used within AppProvider')
}
return context
}

View File

@@ -0,0 +1,91 @@
import { createContext, useContext, useState, useEffect, type ReactNode } from 'react'
import { useNavigate } from 'react-router-dom'
import { toast } from 'sonner'
import { authApi } from '@/lib/api'
import type { User, LoginRequest, AuthState } from '@/types/auth'
interface AuthContextType extends AuthState {
login: (credentials: LoginRequest) => Promise<void>
logout: () => void
}
const AuthContext = createContext<AuthContextType | undefined>(undefined)
export function AuthProvider({ children }: { children: ReactNode }) {
const [token, setToken] = useState<string | null>(null)
const [user, setUser] = useState<User | null>(null)
const [isLoading, setIsLoading] = useState(true)
const navigate = useNavigate()
useEffect(() => {
const initAuth = async () => {
try {
const storedToken = localStorage.getItem('token')
if (storedToken) {
setToken(storedToken)
const currentUser = await authApi.getCurrentUser()
setUser(currentUser)
}
} catch (error) {
localStorage.removeItem('token')
setToken(null)
setUser(null)
} finally {
setIsLoading(false)
}
}
initAuth()
}, [])
const login = async (credentials: LoginRequest) => {
try {
const response = await authApi.login(credentials)
const newToken = response.access_token
localStorage.setItem('token', newToken)
setToken(newToken)
const currentUser = await authApi.getCurrentUser()
setUser(currentUser)
toast.success('登录成功')
navigate('/')
} catch (error: any) {
const message = error.response?.data?.detail || '登录失败,请检查用户名和密码'
toast.error(message)
throw error
}
}
const logout = () => {
localStorage.removeItem('token')
setToken(null)
setUser(null)
toast.success('已退出登录')
navigate('/login')
}
return (
<AuthContext.Provider
value={{
token,
user,
isLoading,
isAuthenticated: !!token && !!user,
login,
logout,
}}
>
{children}
</AuthContext.Provider>
)
}
export function useAuth() {
const context = useContext(AuthContext)
if (!context) {
throw new Error('useAuth must be used within AuthProvider')
}
return context
}

Some files were not shown because too many files have changed in this diff Show More