refactor: rename canto-backend → backend, canto-frontend → frontend
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
0
backend/api/__init__.py
Normal file
0
backend/api/__init__.py
Normal file
22
backend/api/admin.py
Normal file
22
backend/api/admin.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.users import require_superuser
|
||||
from db.database import get_db
|
||||
from db.crud import get_usage_stats
|
||||
from schemas.user import User
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/usage")
|
||||
async def get_usage_statistics(
|
||||
date_from: Optional[datetime] = Query(None),
|
||||
date_to: Optional[datetime] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser),
|
||||
):
|
||||
return get_usage_stats(db, date_from=date_from, date_to=date_to)
|
||||
1028
backend/api/audiobook.py
Normal file
1028
backend/api/audiobook.py
Normal file
File diff suppressed because it is too large
Load Diff
216
backend/api/auth.py
Normal file
216
backend/api/auth.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from datetime import timedelta
|
||||
from typing import Annotated, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from config import settings
|
||||
from core.security import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
decode_access_token
|
||||
)
|
||||
from db.database import get_db
|
||||
from db.crud import get_user_by_username, get_user_by_email, create_user, change_user_password, get_user_preferences, update_user_preferences, can_user_use_nsfw, get_system_setting
|
||||
from schemas.user import User, UserCreate, Token, PasswordChange, UserPreferences, UserPreferencesResponse
|
||||
from schemas.audiobook import LLMConfigResponse
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=not settings.DEV_MODE)
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
if settings.DEV_MODE and not token:
|
||||
user = get_user_by_username(db, username="admin")
|
||||
if user:
|
||||
return user
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if token is None:
|
||||
raise credentials_exception
|
||||
|
||||
username = decode_access_token(token)
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
|
||||
user = get_user_by_username(db, username=username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
@router.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("5/minute")
|
||||
async def register(request: Request, user_data: UserCreate, db: Session = Depends(get_db)):
|
||||
existing_user = get_user_by_username(db, username=user_data.username)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
existing_email = get_user_by_email(db, email=user_data.email)
|
||||
if existing_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
user = create_user(
|
||||
db,
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@router.post("/token", response_model=Token)
|
||||
@limiter.limit("5/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
user = get_user_by_username(db, username=form_data.username)
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
@router.get("/dev-token", response_model=Token)
|
||||
async def dev_token(db: Session = Depends(get_db)):
|
||||
if not settings.DEV_MODE:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not available outside DEV_MODE")
|
||||
user = get_user_by_username(db, username="admin")
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Admin user not found")
|
||||
access_token = create_access_token(data={"sub": user.username})
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_current_user_info(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
return current_user
|
||||
|
||||
@router.post("/change-password", response_model=User)
|
||||
@limiter.limit("5/minute")
|
||||
async def change_password(
|
||||
request: Request,
|
||||
password_data: PasswordChange,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
if not verify_password(password_data.current_password, current_user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
|
||||
new_hashed_password = get_password_hash(password_data.new_password)
|
||||
|
||||
user = change_user_password(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
new_hashed_password=new_hashed_password
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@router.get("/preferences", response_model=UserPreferencesResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_preferences(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
prefs = get_user_preferences(db, current_user.id)
|
||||
|
||||
return {
|
||||
"default_backend": "local",
|
||||
"onboarding_completed": prefs.get("onboarding_completed", False),
|
||||
"available_backends": ["local"]
|
||||
}
|
||||
|
||||
@router.put("/preferences")
|
||||
@limiter.limit("10/minute")
|
||||
async def update_preferences(
|
||||
request: Request,
|
||||
preferences: UserPreferences,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
updated_user = update_user_preferences(
|
||||
db,
|
||||
current_user.id,
|
||||
preferences.model_dump()
|
||||
)
|
||||
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {"message": "Preferences updated successfully"}
|
||||
|
||||
|
||||
@router.get("/llm-config", response_model=LLMConfigResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_llm_config(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
return LLMConfigResponse(
|
||||
base_url=get_system_setting(db, "llm_base_url"),
|
||||
model=get_system_setting(db, "llm_model"),
|
||||
has_key=bool(get_system_setting(db, "llm_api_key")),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nsfw-access")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_nsfw_access(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
):
|
||||
return {"has_access": can_user_use_nsfw(current_user)}
|
||||
216
backend/api/jobs.py
Normal file
216
backend/api/jobs.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from core.database import get_db
|
||||
from core.config import settings
|
||||
from core.security import decode_access_token
|
||||
from db.models import Job, JobStatus, User
|
||||
from db.crud import get_user_by_username
|
||||
from api.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
async def get_user_from_bearer_token(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
auth_token = None
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authentication token"
|
||||
)
|
||||
|
||||
username = decode_access_token(auth_token)
|
||||
if username is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid or expired token"
|
||||
)
|
||||
|
||||
user = get_user_by_username(db, username=username)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/{job_id}")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_job(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
download_url = None
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"/jobs/{job.id}/download"
|
||||
|
||||
return {
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
"started_at": job.started_at.isoformat() + 'Z' if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
|
||||
}
|
||||
|
||||
|
||||
@router.get("")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_jobs(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
status: Optional[str] = Query(None),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
query = db.query(Job).filter(Job.user_id == current_user.id)
|
||||
|
||||
if status:
|
||||
try:
|
||||
status_enum = JobStatus(status)
|
||||
query = query.filter(Job.status == status_enum)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid status: {status}")
|
||||
|
||||
total = query.count()
|
||||
jobs = query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
jobs_data = []
|
||||
for job in jobs:
|
||||
download_url = None
|
||||
if job.status == JobStatus.COMPLETED and job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
download_url = f"/jobs/{job.id}/download"
|
||||
|
||||
jobs_data.append({
|
||||
"id": job.id,
|
||||
"job_type": job.job_type,
|
||||
"status": job.status,
|
||||
"input_params": job.input_params,
|
||||
"download_url": download_url,
|
||||
"error_message": job.error_message,
|
||||
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"skip": skip,
|
||||
"limit": limit,
|
||||
"jobs": jobs_data
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
@limiter.limit("30/minute")
|
||||
async def delete_job(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
output_file = None
|
||||
if job.output_path:
|
||||
output_file = Path(job.output_path).resolve()
|
||||
output_dir = Path(settings.OUTPUT_DIR).resolve()
|
||||
if not output_file.is_relative_to(output_dir):
|
||||
logger.warning(f"Skip deleting file outside output dir: {output_file}")
|
||||
output_file = None
|
||||
|
||||
if output_file:
|
||||
if output_file.exists():
|
||||
try:
|
||||
output_file.unlink()
|
||||
logger.info(f"Deleted output file: {output_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete output file {output_file}: {e}")
|
||||
|
||||
db.delete(job)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Job deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/{job_id}/download")
|
||||
@limiter.limit("30/minute")
|
||||
async def download_job_output(
|
||||
request: Request,
|
||||
job_id: int,
|
||||
current_user: User = Depends(get_user_from_bearer_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if job.status != JobStatus.COMPLETED:
|
||||
raise HTTPException(status_code=400, detail="Job not completed yet")
|
||||
|
||||
if not job.output_path:
|
||||
raise HTTPException(status_code=404, detail="Output file not found")
|
||||
|
||||
output_file = Path(job.output_path)
|
||||
if not output_file.exists():
|
||||
raise HTTPException(status_code=404, detail="Output file does not exist")
|
||||
|
||||
output_dir = Path(settings.OUTPUT_DIR).resolve()
|
||||
if not output_file.resolve().is_relative_to(output_dir):
|
||||
logger.warning(f"Path traversal attempt detected: {output_file}")
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return FileResponse(
|
||||
path=str(output_file),
|
||||
media_type="audio/wav",
|
||||
filename=output_file.name,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{output_file.name}"'
|
||||
}
|
||||
)
|
||||
737
backend/api/tts.py
Normal file
737
backend/api/tts.py
Normal file
@@ -0,0 +1,737 @@
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from core.config import settings
|
||||
from core.database import get_db
|
||||
from core.model_manager import ModelManager
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from db.models import Job, JobStatus, User
|
||||
from schemas.tts import CustomVoiceRequest, VoiceDesignRequest, IndexTTS2FromDesignRequest
|
||||
from api.auth import get_current_user
|
||||
from utils.validation import (
|
||||
validate_language,
|
||||
validate_speaker,
|
||||
validate_text_length,
|
||||
validate_generation_params,
|
||||
get_supported_languages,
|
||||
get_supported_speakers
|
||||
)
|
||||
from utils.audio import save_audio_file, validate_ref_audio, process_ref_audio, extract_audio_features
|
||||
from utils.metrics import cache_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/tts", tags=["tts"])
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
async def read_upload_with_size_limit(upload_file: UploadFile, max_size_bytes: int) -> bytes:
|
||||
chunks = []
|
||||
total = 0
|
||||
while True:
|
||||
chunk = await upload_file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > max_size_bytes:
|
||||
raise ValueError(f"Audio file exceeds {max_size_bytes // (1024 * 1024)}MB limit")
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
async def process_custom_voice_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
backend_type: str,
|
||||
db_url: str
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
logger.error(f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing custom-voice job {job_id} with backend {backend_type}")
|
||||
|
||||
backend = await TTSServiceFactory.get_backend()
|
||||
|
||||
audio_bytes, sample_rate = await backend.generate_custom_voice(request_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def process_voice_design_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
backend_type: str,
|
||||
db_url: str,
|
||||
saved_voice_id: Optional[str] = None
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.security import decrypt_api_key
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
logger.error(f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing voice-design job {job_id} with backend {backend_type}")
|
||||
|
||||
backend = await TTSServiceFactory.get_backend()
|
||||
|
||||
audio_bytes, sample_rate = await backend.generate_voice_design(request_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def process_voice_clone_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
request_data: dict,
|
||||
ref_audio_path: str,
|
||||
backend_type: str,
|
||||
db_url: str,
|
||||
use_voice_design: bool = False
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import TTSServiceFactory
|
||||
import numpy as np
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
logger.error(f"Job {job_id} not found")
|
||||
return
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Processing voice-clone job {job_id} with backend {backend_type}")
|
||||
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
|
||||
voice_design_id = request_data.get('voice_design_id')
|
||||
if voice_design_id:
|
||||
from db.crud import get_voice_design
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
if not design or not design.voice_cache_id:
|
||||
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
|
||||
|
||||
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||
if not cached:
|
||||
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
|
||||
|
||||
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||
cache_metrics.record_hit(user_id)
|
||||
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
|
||||
|
||||
else:
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
if request_data.get('x_vector_only_mode', False):
|
||||
x_vector = None
|
||||
cache_id = None
|
||||
|
||||
if voice_design_id:
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
x_vector = cached['data']
|
||||
cache_id = design.voice_cache_id
|
||||
elif request_data.get('use_cache', True):
|
||||
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
|
||||
if cached:
|
||||
x_vector = cached['data']
|
||||
cache_id = cached['cache_id']
|
||||
cache_metrics.record_hit(user_id)
|
||||
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
|
||||
|
||||
if x_vector is None:
|
||||
cache_metrics.record_miss(user_id)
|
||||
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
|
||||
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=request_data.get('ref_text', ''),
|
||||
x_vector_only_mode=True
|
||||
)
|
||||
|
||||
if request_data.get('use_cache', True):
|
||||
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||
metadata = {
|
||||
'duration': features['duration'],
|
||||
'sample_rate': features['sample_rate'],
|
||||
'ref_text': request_data.get('ref_text', ''),
|
||||
'x_vector_only_mode': True
|
||||
}
|
||||
cache_id = await cache_manager.set_cache(
|
||||
user_id, ref_audio_hash, x_vector, metadata, db
|
||||
)
|
||||
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = f"x_vector_cached_{cache_id}"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
logger.info(f"Job {job_id} completed (x_vector_only_mode)")
|
||||
return
|
||||
|
||||
backend = await TTSServiceFactory.get_backend()
|
||||
|
||||
if voice_design_id:
|
||||
from db.crud import get_voice_design
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
|
||||
x_vector = cached['data']
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
|
||||
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
|
||||
else:
|
||||
with open(ref_audio_path, 'rb') as f:
|
||||
ref_audio_data = f.read()
|
||||
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = Path(settings.OUTPUT_DIR) / filename
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = str(output_path)
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
finally:
|
||||
if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
|
||||
Path(ref_audio_path).unlink()
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/custom-voice")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_custom_voice_job(
|
||||
request: Request,
|
||||
req_data: CustomVoiceRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from db.crud import can_user_use_local_model
|
||||
|
||||
if not can_user_use_local_model(current_user):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Local model is not available. Please contact administrator."
|
||||
)
|
||||
|
||||
backend_type = "local"
|
||||
|
||||
try:
|
||||
validate_text_length(req_data.text)
|
||||
language = validate_language(req_data.language)
|
||||
speaker = validate_speaker(req_data.speaker)
|
||||
|
||||
params = validate_generation_params({
|
||||
'max_new_tokens': req_data.max_new_tokens,
|
||||
'temperature': req_data.temperature,
|
||||
'top_k': req_data.top_k,
|
||||
'top_p': req_data.top_p,
|
||||
'repetition_penalty': req_data.repetition_penalty
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job = Job(
|
||||
user_id=current_user.id,
|
||||
job_type="custom-voice",
|
||||
status=JobStatus.PENDING,
|
||||
backend_type=backend_type,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"speaker": speaker,
|
||||
"instruct": req_data.instruct or "",
|
||||
**params
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
request_data = {
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"speaker": speaker,
|
||||
"instruct": req_data.instruct or "",
|
||||
**params
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
process_custom_voice_job,
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL)
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"message": "Job created successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/voice-design")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_voice_design_job(
|
||||
request: Request,
|
||||
req_data: VoiceDesignRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from db.crud import can_user_use_local_model, get_voice_design, update_voice_design_usage
|
||||
|
||||
if not can_user_use_local_model(current_user):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Local model is not available. Please contact administrator."
|
||||
)
|
||||
|
||||
backend_type = "local"
|
||||
|
||||
if req_data.saved_design_id:
|
||||
saved_design = get_voice_design(db, req_data.saved_design_id, current_user.id)
|
||||
if not saved_design:
|
||||
raise HTTPException(status_code=404, detail="Saved voice design not found")
|
||||
|
||||
req_data.instruct = saved_design.instruct
|
||||
update_voice_design_usage(db, req_data.saved_design_id, current_user.id)
|
||||
|
||||
try:
|
||||
validate_text_length(req_data.text)
|
||||
language = validate_language(req_data.language)
|
||||
|
||||
if not req_data.saved_design_id:
|
||||
if not req_data.instruct or not req_data.instruct.strip():
|
||||
raise ValueError("Instruct parameter is required when saved_design_id is not provided")
|
||||
|
||||
params = validate_generation_params({
|
||||
'max_new_tokens': req_data.max_new_tokens,
|
||||
'temperature': req_data.temperature,
|
||||
'top_k': req_data.top_k,
|
||||
'top_p': req_data.top_p,
|
||||
'repetition_penalty': req_data.repetition_penalty
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job = Job(
|
||||
user_id=current_user.id,
|
||||
job_type="voice-design",
|
||||
status=JobStatus.PENDING,
|
||||
backend_type=backend_type,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"instruct": req_data.instruct,
|
||||
**params
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
request_data = {
|
||||
"text": req_data.text,
|
||||
"language": language,
|
||||
"instruct": req_data.instruct,
|
||||
**params
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
process_voice_design_job,
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL),
|
||||
saved_voice_id
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"message": "Job created successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/voice-clone")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_voice_clone_job(
|
||||
request: Request,
|
||||
text: str = Form(...),
|
||||
language: str = Form(default="Auto"),
|
||||
ref_audio: Optional[UploadFile] = File(default=None),
|
||||
ref_text: Optional[str] = Form(default=None),
|
||||
use_cache: bool = Form(default=True),
|
||||
x_vector_only_mode: bool = Form(default=False),
|
||||
voice_design_id: Optional[int] = Form(default=None),
|
||||
max_new_tokens: Optional[int] = Form(default=2048),
|
||||
temperature: Optional[float] = Form(default=0.9),
|
||||
top_k: Optional[int] = Form(default=50),
|
||||
top_p: Optional[float] = Form(default=1.0),
|
||||
repetition_penalty: Optional[float] = Form(default=1.05),
|
||||
backend: Optional[str] = Form(default=None),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from db.crud import can_user_use_local_model, get_voice_design
|
||||
|
||||
if not can_user_use_local_model(current_user):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Local model is not available. Please contact administrator."
|
||||
)
|
||||
|
||||
backend_type = "local"
|
||||
|
||||
ref_audio_data = None
|
||||
ref_audio_hash = None
|
||||
use_voice_design = False
|
||||
|
||||
try:
|
||||
validate_text_length(text)
|
||||
language = validate_language(language)
|
||||
|
||||
params = validate_generation_params({
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'repetition_penalty': repetition_penalty
|
||||
})
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
|
||||
if voice_design_id:
|
||||
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||
if not design:
|
||||
raise ValueError("Voice design not found")
|
||||
|
||||
if not design.voice_cache_id:
|
||||
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
|
||||
|
||||
use_voice_design = True
|
||||
ref_audio_hash = f"voice_design_{voice_design_id}"
|
||||
if not ref_text:
|
||||
ref_text = design.ref_text
|
||||
|
||||
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
|
||||
|
||||
else:
|
||||
if not ref_audio:
|
||||
raise ValueError("Either ref_audio or voice_design_id must be provided")
|
||||
|
||||
max_audio_size_bytes = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024
|
||||
ref_audio_data = await read_upload_with_size_limit(ref_audio, max_audio_size_bytes)
|
||||
|
||||
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
|
||||
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
|
||||
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
job = Job(
|
||||
user_id=current_user.id,
|
||||
job_type="voice-clone",
|
||||
status=JobStatus.PENDING,
|
||||
backend_type=backend_type,
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": text,
|
||||
"language": language,
|
||||
"ref_text": ref_text or "",
|
||||
"ref_audio_hash": ref_audio_hash,
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
"voice_design_id": voice_design_id,
|
||||
**params
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
if use_voice_design:
|
||||
design = get_voice_design(db, voice_design_id, current_user.id)
|
||||
tmp_audio_path = design.ref_audio_path
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
tmp_file.write(ref_audio_data)
|
||||
tmp_audio_path = tmp_file.name
|
||||
|
||||
request_data = {
|
||||
"text": text,
|
||||
"language": language,
|
||||
"ref_text": ref_text or "",
|
||||
"use_cache": use_cache,
|
||||
"x_vector_only_mode": x_vector_only_mode,
|
||||
"voice_design_id": voice_design_id,
|
||||
**params
|
||||
}
|
||||
|
||||
background_tasks.add_task(
|
||||
process_voice_clone_job,
|
||||
job.id,
|
||||
current_user.id,
|
||||
request_data,
|
||||
tmp_audio_path,
|
||||
backend_type,
|
||||
str(settings.DATABASE_URL),
|
||||
use_voice_design
|
||||
)
|
||||
|
||||
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
|
||||
cache_info = {"cache_id": existing_cache['cache_id']} if existing_cache else None
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"message": "Job created successfully",
|
||||
"cache_info": cache_info
|
||||
}
|
||||
|
||||
|
||||
async def process_indextts2_job(
|
||||
job_id: int,
|
||||
user_id: int,
|
||||
voice_design_id: int,
|
||||
text: str,
|
||||
emo_text: Optional[str],
|
||||
emo_alpha: float,
|
||||
):
|
||||
from core.database import SessionLocal
|
||||
from core.tts_service import IndexTTS2Backend
|
||||
from db.crud import get_voice_design
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
design = get_voice_design(db, voice_design_id, user_id)
|
||||
if not design or not design.ref_audio_path:
|
||||
raise RuntimeError("Voice design has no ref_audio_path")
|
||||
|
||||
from pathlib import Path as _Path
|
||||
if not _Path(design.ref_audio_path).exists():
|
||||
raise RuntimeError(f"ref_audio_path does not exist: {design.ref_audio_path}")
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{user_id}_{job_id}_{timestamp}.wav"
|
||||
output_path = str(_Path(settings.OUTPUT_DIR) / filename)
|
||||
_Path(settings.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
backend = IndexTTS2Backend()
|
||||
audio_bytes = await backend.generate(
|
||||
text=text,
|
||||
spk_audio_prompt=design.ref_audio_path,
|
||||
output_path=output_path,
|
||||
emo_text=emo_text,
|
||||
emo_alpha=emo_alpha,
|
||||
)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.output_path = output_path
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"IndexTTS2 job {job_id} failed: {e}", exc_info=True)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error_message = "Job processing failed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/indextts2-from-design")
|
||||
@limiter.limit("10/minute")
|
||||
async def create_indextts2_from_design_job(
|
||||
request: Request,
|
||||
req_data: IndexTTS2FromDesignRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
from db.crud import get_voice_design
|
||||
|
||||
design = get_voice_design(db, req_data.voice_design_id, current_user.id)
|
||||
if not design:
|
||||
raise HTTPException(status_code=404, detail="Voice design not found")
|
||||
if not design.ref_audio_path:
|
||||
raise HTTPException(status_code=400, detail="Voice design has no ref_audio_path")
|
||||
|
||||
from pathlib import Path as _Path
|
||||
if not _Path(design.ref_audio_path).exists():
|
||||
raise HTTPException(status_code=400, detail="ref_audio_path file does not exist")
|
||||
|
||||
job = Job(
|
||||
user_id=current_user.id,
|
||||
job_type="indextts2",
|
||||
status=JobStatus.PENDING,
|
||||
backend_type="local",
|
||||
input_data="",
|
||||
input_params={
|
||||
"text": req_data.text,
|
||||
"voice_design_id": req_data.voice_design_id,
|
||||
"emo_text": req_data.emo_text,
|
||||
"emo_alpha": req_data.emo_alpha,
|
||||
}
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
background_tasks.add_task(
|
||||
process_indextts2_job,
|
||||
job.id,
|
||||
current_user.id,
|
||||
req_data.voice_design_id,
|
||||
req_data.text,
|
||||
req_data.emo_text,
|
||||
req_data.emo_alpha,
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"message": "Job created successfully"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/speakers")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_speakers(request: Request, backend: Optional[str] = "local"):
|
||||
return get_supported_speakers(backend)
|
||||
|
||||
|
||||
@router.get("/languages")
|
||||
@limiter.limit("30/minute")
|
||||
async def list_languages(request: Request):
|
||||
return get_supported_languages()
|
||||
290
backend/api/users.py
Normal file
290
backend/api/users.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from api.auth import get_current_user
|
||||
from core.security import get_password_hash
|
||||
from db.database import get_db
|
||||
from db.crud import (
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
get_user_by_email,
|
||||
list_users,
|
||||
create_user_by_admin,
|
||||
update_user,
|
||||
delete_user
|
||||
)
|
||||
from schemas.user import User, UserCreateByAdmin, UserUpdate, UserListResponse
|
||||
from schemas.audiobook import LLMConfigUpdate, LLMConfigResponse, NsfwSynopsisGenerationRequest, NsfwScriptGenerationRequest
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
async def require_superuser(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
) -> User:
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Superuser access required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
@router.get("", response_model=UserListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_users(
|
||||
request: Request,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
users, total = list_users(db, skip=skip, limit=limit)
|
||||
return UserListResponse(users=users, total=total, skip=skip, limit=limit)
|
||||
|
||||
@router.post("", response_model=User, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_user(
|
||||
request: Request,
|
||||
user_data: UserCreateByAdmin,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
existing_user = get_user_by_username(db, username=user_data.username)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
existing_email = get_user_by_email(db, email=user_data.email)
|
||||
if existing_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
user = create_user_by_admin(
|
||||
db,
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password,
|
||||
is_superuser=user_data.is_superuser,
|
||||
can_use_local_model=user_data.can_use_local_model
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@router.get("/me", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_current_user_info(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
return current_user
|
||||
|
||||
@router.get("/{user_id}", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_user(
|
||||
request: Request,
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
user = get_user_by_id(db, user_id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
return user
|
||||
|
||||
@router.put("/{user_id}", response_model=User)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_user_info(
|
||||
request: Request,
|
||||
user_id: int,
|
||||
user_data: UserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_superuser)
|
||||
):
|
||||
existing_user = get_user_by_id(db, user_id=user_id)
|
||||
if not existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
if user_data.username is not None:
|
||||
username_exists = get_user_by_username(db, username=user_data.username)
|
||||
if username_exists and username_exists.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already taken"
|
||||
)
|
||||
|
||||
if user_data.email is not None:
|
||||
email_exists = get_user_by_email(db, email=user_data.email)
|
||||
if email_exists and email_exists.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already taken"
|
||||
)
|
||||
|
||||
hashed_password = None
|
||||
if user_data.password is not None:
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
|
||||
user = update_user(
|
||||
db,
|
||||
user_id=user_id,
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
hashed_password=hashed_password,
|
||||
is_active=user_data.is_active,
|
||||
is_superuser=user_data.is_superuser,
|
||||
can_use_local_model=user_data.can_use_local_model,
|
||||
can_use_nsfw=user_data.can_use_nsfw,
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_user_by_id(
|
||||
request: Request,
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_superuser)
|
||||
):
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete yourself"
|
||||
)
|
||||
|
||||
success = delete_user(db, user_id=user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/system/llm-config")
|
||||
@limiter.limit("10/minute")
|
||||
async def set_system_llm_config(
|
||||
request: Request,
|
||||
config: LLMConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
from core.security import encrypt_api_key
|
||||
from core.llm_service import LLMService
|
||||
from db.crud import set_system_setting
|
||||
|
||||
api_key = config.api_key.strip()
|
||||
base_url = config.base_url.strip()
|
||||
model = config.model.strip()
|
||||
llm = LLMService(base_url=base_url, api_key=api_key, model=model)
|
||||
try:
|
||||
await llm.chat("You are a test assistant.", "Reply with 'ok'.")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"LLM API validation failed: {e}")
|
||||
set_system_setting(db, "llm_api_key", encrypt_api_key(api_key))
|
||||
set_system_setting(db, "llm_base_url", base_url)
|
||||
set_system_setting(db, "llm_model", model)
|
||||
return {"message": "LLM config updated"}
|
||||
|
||||
|
||||
@router.get("/system/llm-config", response_model=LLMConfigResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_system_llm_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
from db.crud import get_system_setting
|
||||
return LLMConfigResponse(
|
||||
base_url=get_system_setting(db, "llm_base_url"),
|
||||
model=get_system_setting(db, "llm_model"),
|
||||
has_key=bool(get_system_setting(db, "llm_api_key")),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/system/llm-config")
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_system_llm_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
from db.crud import delete_system_setting
|
||||
delete_system_setting(db, "llm_api_key")
|
||||
delete_system_setting(db, "llm_base_url")
|
||||
delete_system_setting(db, "llm_model")
|
||||
return {"message": "LLM config deleted"}
|
||||
|
||||
|
||||
@router.put("/system/grok-config")
|
||||
@limiter.limit("10/minute")
|
||||
async def set_system_grok_config(
|
||||
request: Request,
|
||||
config: LLMConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
from core.security import encrypt_api_key
|
||||
from core.llm_service import GrokLLMService
|
||||
from db.crud import set_system_setting
|
||||
|
||||
api_key = config.api_key.strip()
|
||||
base_url = config.base_url.strip()
|
||||
model = config.model.strip()
|
||||
grok = GrokLLMService(base_url=base_url, api_key=api_key, model=model)
|
||||
try:
|
||||
await grok.chat("You are a test assistant.", "Reply with 'ok'.")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Grok API validation failed: {e}")
|
||||
set_system_setting(db, "grok_api_key", encrypt_api_key(api_key))
|
||||
set_system_setting(db, "grok_base_url", base_url)
|
||||
set_system_setting(db, "grok_model", model)
|
||||
return {"message": "Grok config updated"}
|
||||
|
||||
|
||||
@router.get("/system/grok-config", response_model=LLMConfigResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_system_grok_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
from db.crud import get_system_setting
|
||||
return LLMConfigResponse(
|
||||
base_url=get_system_setting(db, "grok_base_url"),
|
||||
model=get_system_setting(db, "grok_model"),
|
||||
has_key=bool(get_system_setting(db, "grok_api_key")),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/system/grok-config")
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_system_grok_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(require_superuser)
|
||||
):
|
||||
from db.crud import delete_system_setting
|
||||
delete_system_setting(db, "grok_api_key")
|
||||
delete_system_setting(db, "grok_base_url")
|
||||
delete_system_setting(db, "grok_model")
|
||||
return {"message": "Grok config deleted"}
|
||||
281
backend/api/voice_designs.py
Normal file
281
backend/api/voice_designs.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import logging
|
||||
import json
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from pathlib import Path
|
||||
|
||||
from core.database import get_db
|
||||
from api.auth import get_current_user
|
||||
from db.models import User, Job, JobStatus
|
||||
from db import crud
|
||||
from schemas.voice_design import (
|
||||
VoiceDesignCreate,
|
||||
VoiceDesignResponse,
|
||||
VoiceDesignListResponse
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/voice-designs", tags=["voice-designs"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
def to_voice_design_response(design) -> VoiceDesignResponse:
|
||||
meta_data = design.meta_data
|
||||
if isinstance(meta_data, str):
|
||||
try:
|
||||
meta_data = json.loads(meta_data)
|
||||
except Exception:
|
||||
meta_data = None
|
||||
return VoiceDesignResponse(
|
||||
id=design.id,
|
||||
user_id=design.user_id,
|
||||
name=design.name,
|
||||
instruct=design.instruct,
|
||||
meta_data=meta_data,
|
||||
preview_text=design.preview_text,
|
||||
ref_audio_path=design.ref_audio_path,
|
||||
created_at=design.created_at,
|
||||
last_used=design.last_used,
|
||||
use_count=design.use_count
|
||||
)
|
||||
|
||||
@router.post("", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("30/minute")
|
||||
async def save_voice_design(
|
||||
request: Request,
|
||||
data: VoiceDesignCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
design = crud.create_voice_design(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
name=data.name,
|
||||
instruct=data.instruct,
|
||||
meta_data=data.meta_data,
|
||||
preview_text=data.preview_text
|
||||
)
|
||||
return to_voice_design_response(design)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save voice design: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to save voice design")
|
||||
|
||||
@router.get("", response_model=VoiceDesignListResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def list_voice_designs(
|
||||
request: Request,
|
||||
backend_type: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
designs = crud.list_voice_designs(db, current_user.id, backend_type, skip, limit)
|
||||
total = crud.count_voice_designs(db, current_user.id, backend_type)
|
||||
return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total)
|
||||
|
||||
@router.post("/prepare-and-create", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("10/minute")
|
||||
async def prepare_and_create_voice_design(
|
||||
request: Request,
|
||||
data: VoiceDesignCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from utils.audio import process_ref_audio, extract_audio_features
|
||||
from core.config import settings
|
||||
from db.crud import can_user_use_local_model
|
||||
from datetime import datetime
|
||||
|
||||
if not can_user_use_local_model(current_user):
|
||||
raise HTTPException(status_code=403, detail="Local model access required")
|
||||
|
||||
try:
|
||||
backend = await TTSServiceFactory.get_backend("local")
|
||||
ref_text = data.preview_text or data.instruct[:100]
|
||||
|
||||
ref_audio_bytes, _ = await backend.generate_voice_design({
|
||||
"text": ref_text,
|
||||
"language": "Auto",
|
||||
"instruct": data.instruct,
|
||||
"max_new_tokens": 2048,
|
||||
"temperature": 0.3,
|
||||
"top_k": 10,
|
||||
"top_p": 0.5,
|
||||
"repetition_penalty": 1.05
|
||||
})
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
ref_filename = f"voice_design_new_{timestamp}.wav"
|
||||
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
|
||||
with open(ref_audio_path, 'wb') as f:
|
||||
f.write(ref_audio_bytes)
|
||||
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=ref_text,
|
||||
)
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
|
||||
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||
metadata = {
|
||||
'duration': features['duration'],
|
||||
'sample_rate': features['sample_rate'],
|
||||
'ref_text': ref_text,
|
||||
'instruct': data.instruct
|
||||
}
|
||||
cache_id = await cache_manager.set_cache(
|
||||
current_user.id, ref_audio_hash, x_vector, metadata, db
|
||||
)
|
||||
|
||||
design = crud.create_voice_design(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
name=data.name,
|
||||
instruct=data.instruct,
|
||||
meta_data=data.meta_data,
|
||||
preview_text=data.preview_text,
|
||||
voice_cache_id=cache_id,
|
||||
ref_audio_path=str(ref_audio_path),
|
||||
ref_text=ref_text,
|
||||
)
|
||||
|
||||
logger.info(f"Voice design created with clone prompt: design_id={design.id}, cache_id={cache_id}")
|
||||
return to_voice_design_response(design)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to prepare and create voice design: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to prepare voice design")
|
||||
|
||||
|
||||
@router.delete("/{design_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_voice_design(
|
||||
design_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
deleted = crud.delete_voice_design(db, design_id, current_user.id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Voice design not found")
|
||||
|
||||
|
||||
@router.post("/{design_id}/prepare-clone")
|
||||
@limiter.limit("10/minute")
|
||||
async def prepare_voice_clone_prompt(
|
||||
request: Request,
|
||||
design_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
from core.tts_service import TTSServiceFactory
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from utils.audio import process_ref_audio, extract_audio_features
|
||||
from core.config import settings
|
||||
from db.crud import can_user_use_local_model
|
||||
from datetime import datetime
|
||||
|
||||
design = crud.get_voice_design(db, design_id, current_user.id)
|
||||
if not design:
|
||||
raise HTTPException(status_code=404, detail="Voice design not found")
|
||||
|
||||
if not can_user_use_local_model(current_user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Local model access required"
|
||||
)
|
||||
|
||||
if design.voice_cache_id:
|
||||
return {
|
||||
"message": "Voice clone prompt already exists",
|
||||
"cache_id": design.voice_cache_id
|
||||
}
|
||||
|
||||
try:
|
||||
backend = await TTSServiceFactory.get_backend("local")
|
||||
|
||||
ref_text = design.preview_text or design.instruct[:100]
|
||||
|
||||
logger.info(f"Generating reference audio for voice design {design_id}")
|
||||
ref_audio_bytes, sample_rate = await backend.generate_voice_design({
|
||||
"text": ref_text,
|
||||
"language": "Auto",
|
||||
"instruct": design.instruct,
|
||||
"max_new_tokens": 2048,
|
||||
"temperature": 0.3,
|
||||
"top_k": 10,
|
||||
"top_p": 0.5,
|
||||
"repetition_penalty": 1.05
|
||||
})
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
ref_filename = f"voice_design_{design_id}_{timestamp}.wav"
|
||||
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
|
||||
|
||||
with open(ref_audio_path, 'wb') as f:
|
||||
f.write(ref_audio_bytes)
|
||||
|
||||
logger.info(f"Extracting voice clone prompt from reference audio")
|
||||
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
model_manager = await ModelManager.get_instance()
|
||||
await model_manager.load_model("base")
|
||||
_, tts = await model_manager.get_current_model()
|
||||
|
||||
if tts is None:
|
||||
raise RuntimeError("Failed to load base model")
|
||||
|
||||
x_vector = tts.create_voice_clone_prompt(
|
||||
ref_audio=(ref_audio_array, ref_sr),
|
||||
ref_text=ref_text,
|
||||
)
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
|
||||
|
||||
features = extract_audio_features(ref_audio_array, ref_sr)
|
||||
metadata = {
|
||||
'duration': features['duration'],
|
||||
'sample_rate': features['sample_rate'],
|
||||
'ref_text': ref_text,
|
||||
'voice_design_id': design_id,
|
||||
'instruct': design.instruct
|
||||
}
|
||||
|
||||
cache_id = await cache_manager.set_cache(
|
||||
current_user.id, ref_audio_hash, x_vector, metadata, db
|
||||
)
|
||||
|
||||
design.voice_cache_id = cache_id
|
||||
design.ref_audio_path = str(ref_audio_path)
|
||||
design.ref_text = ref_text
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Voice clone prompt prepared for design {design_id}, cache_id={cache_id}")
|
||||
|
||||
return {
|
||||
"message": "Voice clone prompt prepared successfully",
|
||||
"cache_id": cache_id,
|
||||
"ref_text": ref_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to prepare voice clone prompt")
|
||||
Reference in New Issue
Block a user