refactor: rename canto-backend → backend, canto-frontend → frontend

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-07 18:11:00 +08:00
parent 2fa9c1fcb6
commit 60489eab59
327 changed files with 0 additions and 0 deletions

0
backend/api/__init__.py Normal file
View File

22
backend/api/admin.py Normal file
View File

@@ -0,0 +1,22 @@
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from api.users import require_superuser
from db.database import get_db
from db.crud import get_usage_stats
from schemas.user import User
router = APIRouter(prefix="/admin", tags=["admin"])
@router.get("/usage")
async def get_usage_statistics(
date_from: Optional[datetime] = Query(None),
date_to: Optional[datetime] = Query(None),
db: Session = Depends(get_db),
_: User = Depends(require_superuser),
):
return get_usage_stats(db, date_from=date_from, date_to=date_to)

1028
backend/api/audiobook.py Normal file

File diff suppressed because it is too large Load Diff

216
backend/api/auth.py Normal file
View File

@@ -0,0 +1,216 @@
from datetime import timedelta
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from config import settings
from core.security import (
get_password_hash,
verify_password,
create_access_token,
decode_access_token
)
from db.database import get_db
from db.crud import get_user_by_username, get_user_by_email, create_user, change_user_password, get_user_preferences, update_user_preferences, can_user_use_nsfw, get_system_setting
from schemas.user import User, UserCreate, Token, PasswordChange, UserPreferences, UserPreferencesResponse
from schemas.audiobook import LLMConfigResponse
router = APIRouter(prefix="/auth", tags=["authentication"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token", auto_error=not settings.DEV_MODE)
limiter = Limiter(key_func=get_remote_address)
async def get_current_user(
token: Annotated[Optional[str], Depends(oauth2_scheme)],
db: Session = Depends(get_db)
) -> User:
if settings.DEV_MODE and not token:
user = get_user_by_username(db, username="admin")
if user:
return user
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
if token is None:
raise credentials_exception
username = decode_access_token(token)
if username is None:
raise credentials_exception
user = get_user_by_username(db, username=username)
if user is None:
raise credentials_exception
return user
@router.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
@limiter.limit("5/minute")
async def register(request: Request, user_data: UserCreate, db: Session = Depends(get_db)):
existing_user = get_user_by_username(db, username=user_data.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
)
existing_email = get_user_by_email(db, email=user_data.email)
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
hashed_password = get_password_hash(user_data.password)
user = create_user(
db,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password
)
return user
@router.post("/token", response_model=Token)
@limiter.limit("5/minute")
async def login(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_db)
):
user = get_user_by_username(db, username=form_data.username)
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/dev-token", response_model=Token)
async def dev_token(db: Session = Depends(get_db)):
if not settings.DEV_MODE:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not available outside DEV_MODE")
user = get_user_by_username(db, username="admin")
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Admin user not found")
access_token = create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/me", response_model=User)
@limiter.limit("30/minute")
async def get_current_user_info(
request: Request,
current_user: Annotated[User, Depends(get_current_user)]
):
return current_user
@router.post("/change-password", response_model=User)
@limiter.limit("5/minute")
async def change_password(
request: Request,
password_data: PasswordChange,
current_user: Annotated[User, Depends(get_current_user)],
db: Session = Depends(get_db)
):
if not verify_password(password_data.current_password, current_user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect"
)
new_hashed_password = get_password_hash(password_data.new_password)
user = change_user_password(
db,
user_id=current_user.id,
new_hashed_password=new_hashed_password
)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.get("/preferences", response_model=UserPreferencesResponse)
@limiter.limit("30/minute")
async def get_preferences(
request: Request,
current_user: Annotated[User, Depends(get_current_user)],
db: Session = Depends(get_db)
):
prefs = get_user_preferences(db, current_user.id)
return {
"default_backend": "local",
"onboarding_completed": prefs.get("onboarding_completed", False),
"available_backends": ["local"]
}
@router.put("/preferences")
@limiter.limit("10/minute")
async def update_preferences(
request: Request,
preferences: UserPreferences,
current_user: Annotated[User, Depends(get_current_user)],
db: Session = Depends(get_db)
):
updated_user = update_user_preferences(
db,
current_user.id,
preferences.model_dump()
)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return {"message": "Preferences updated successfully"}
@router.get("/llm-config", response_model=LLMConfigResponse)
@limiter.limit("30/minute")
async def get_llm_config(
request: Request,
current_user: Annotated[User, Depends(get_current_user)],
db: Session = Depends(get_db)
):
return LLMConfigResponse(
base_url=get_system_setting(db, "llm_base_url"),
model=get_system_setting(db, "llm_model"),
has_key=bool(get_system_setting(db, "llm_api_key")),
)
@router.get("/nsfw-access")
@limiter.limit("30/minute")
async def get_nsfw_access(
request: Request,
current_user: Annotated[User, Depends(get_current_user)],
):
return {"has_access": can_user_use_nsfw(current_user)}

216
backend/api/jobs.py Normal file
View File

@@ -0,0 +1,216 @@
import logging
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.database import get_db
from core.config import settings
from core.security import decode_access_token
from db.models import Job, JobStatus, User
from db.crud import get_user_by_username
from api.auth import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/jobs", tags=["jobs"])
limiter = Limiter(key_func=get_remote_address)
async def get_user_from_bearer_token(
request: Request,
db: Session = Depends(get_db)
) -> User:
auth_token = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if not auth_token:
raise HTTPException(
status_code=401,
detail="Missing authentication token"
)
username = decode_access_token(auth_token)
if username is None:
raise HTTPException(
status_code=401,
detail="Invalid or expired token"
)
user = get_user_by_username(db, username=username)
if user is None:
raise HTTPException(
status_code=401,
detail="User not found"
)
return user
@router.get("/{job_id}")
@limiter.limit("30/minute")
async def get_job(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
download_url = None
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"/jobs/{job.id}/download"
return {
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
"started_at": job.started_at.isoformat() + 'Z' if job.started_at else None,
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
}
@router.get("")
@limiter.limit("30/minute")
async def list_jobs(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
status: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
query = db.query(Job).filter(Job.user_id == current_user.id)
if status:
try:
status_enum = JobStatus(status)
query = query.filter(Job.status == status_enum)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid status: {status}")
total = query.count()
jobs = query.order_by(Job.created_at.desc()).offset(skip).limit(limit).all()
jobs_data = []
for job in jobs:
download_url = None
if job.status == JobStatus.COMPLETED and job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
download_url = f"/jobs/{job.id}/download"
jobs_data.append({
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"input_params": job.input_params,
"download_url": download_url,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() + 'Z' if job.created_at else None,
"completed_at": job.completed_at.isoformat() + 'Z' if job.completed_at else None
})
return {
"total": total,
"skip": skip,
"limit": limit,
"jobs": jobs_data
}
@router.delete("/{job_id}")
@limiter.limit("30/minute")
async def delete_job(
request: Request,
job_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
output_file = None
if job.output_path:
output_file = Path(job.output_path).resolve()
output_dir = Path(settings.OUTPUT_DIR).resolve()
if not output_file.is_relative_to(output_dir):
logger.warning(f"Skip deleting file outside output dir: {output_file}")
output_file = None
if output_file:
if output_file.exists():
try:
output_file.unlink()
logger.info(f"Deleted output file: {output_file}")
except Exception as e:
logger.error(f"Failed to delete output file {output_file}: {e}")
db.delete(job)
db.commit()
return {"message": "Job deleted successfully"}
@router.get("/{job_id}/download")
@limiter.limit("30/minute")
async def download_job_output(
request: Request,
job_id: int,
current_user: User = Depends(get_user_from_bearer_token),
db: Session = Depends(get_db)
):
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Access denied")
if job.status != JobStatus.COMPLETED:
raise HTTPException(status_code=400, detail="Job not completed yet")
if not job.output_path:
raise HTTPException(status_code=404, detail="Output file not found")
output_file = Path(job.output_path)
if not output_file.exists():
raise HTTPException(status_code=404, detail="Output file does not exist")
output_dir = Path(settings.OUTPUT_DIR).resolve()
if not output_file.resolve().is_relative_to(output_dir):
logger.warning(f"Path traversal attempt detected: {output_file}")
raise HTTPException(status_code=403, detail="Access denied")
return FileResponse(
path=str(output_file),
media_type="audio/wav",
filename=output_file.name,
headers={
"Content-Disposition": f'attachment; filename="{output_file.name}"'
}
)

737
backend/api/tts.py Normal file
View File

@@ -0,0 +1,737 @@
import logging
import tempfile
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File, Form, Request, status
from sqlalchemy.orm import Session
from typing import Optional
from slowapi import Limiter
from slowapi.util import get_remote_address
from core.config import settings
from core.database import get_db
from core.model_manager import ModelManager
from core.cache_manager import VoiceCacheManager
from db.models import Job, JobStatus, User
from schemas.tts import CustomVoiceRequest, VoiceDesignRequest, IndexTTS2FromDesignRequest
from api.auth import get_current_user
from utils.validation import (
validate_language,
validate_speaker,
validate_text_length,
validate_generation_params,
get_supported_languages,
get_supported_speakers
)
from utils.audio import save_audio_file, validate_ref_audio, process_ref_audio, extract_audio_features
from utils.metrics import cache_metrics
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tts", tags=["tts"])
limiter = Limiter(key_func=get_remote_address)
async def read_upload_with_size_limit(upload_file: UploadFile, max_size_bytes: int) -> bytes:
chunks = []
total = 0
while True:
chunk = await upload_file.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
if total > max_size_bytes:
raise ValueError(f"Audio file exceeds {max_size_bytes // (1024 * 1024)}MB limit")
chunks.append(chunk)
return b"".join(chunks)
async def process_custom_voice_job(
job_id: int,
user_id: int,
request_data: dict,
backend_type: str,
db_url: str
):
from core.database import SessionLocal
from core.tts_service import TTSServiceFactory
from core.security import decrypt_api_key
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing custom-voice job {job_id} with backend {backend_type}")
backend = await TTSServiceFactory.get_backend()
audio_bytes, sample_rate = await backend.generate_custom_voice(request_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
with open(output_path, 'wb') as f:
f.write(audio_bytes)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_design_job(
job_id: int,
user_id: int,
request_data: dict,
backend_type: str,
db_url: str,
saved_voice_id: Optional[str] = None
):
from core.database import SessionLocal
from core.tts_service import TTSServiceFactory
from core.security import decrypt_api_key
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-design job {job_id} with backend {backend_type}")
backend = await TTSServiceFactory.get_backend()
audio_bytes, sample_rate = await backend.generate_voice_design(request_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
with open(output_path, 'wb') as f:
f.write(audio_bytes)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
async def process_voice_clone_job(
job_id: int,
user_id: int,
request_data: dict,
ref_audio_path: str,
backend_type: str,
db_url: str,
use_voice_design: bool = False
):
from core.database import SessionLocal
from core.tts_service import TTSServiceFactory
import numpy as np
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
logger.error(f"Job {job_id} not found")
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
logger.info(f"Processing voice-clone job {job_id} with backend {backend_type}")
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
cache_manager = await VoiceCacheManager.get_instance()
voice_design_id = request_data.get('voice_design_id')
if voice_design_id:
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
if not design or not design.voice_cache_id:
raise RuntimeError(f"Voice design {voice_design_id} has no prepared clone prompt")
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
if not cached:
raise RuntimeError(f"Cache {design.voice_cache_id} not found")
ref_audio_hash = f"voice_design_{voice_design_id}"
cache_metrics.record_hit(user_id)
logger.info(f"Using voice design {voice_design_id}, cache_id={design.voice_cache_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
if request_data.get('x_vector_only_mode', False):
x_vector = None
cache_id = None
if voice_design_id:
design = get_voice_design(db, voice_design_id, user_id)
x_vector = cached['data']
cache_id = design.voice_cache_id
elif request_data.get('use_cache', True):
cached = await cache_manager.get_cache(user_id, ref_audio_hash, db)
if cached:
x_vector = cached['data']
cache_id = cached['cache_id']
cache_metrics.record_hit(user_id)
logger.info(f"Cache hit for job {job_id}, cache_id={cache_id}")
if x_vector is None:
cache_metrics.record_miss(user_id)
logger.info(f"Cache miss for job {job_id}, creating voice clone prompt")
ref_audio_array, ref_sr = process_ref_audio(ref_audio_data)
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=request_data.get('ref_text', ''),
x_vector_only_mode=True
)
if request_data.get('use_cache', True):
features = extract_audio_features(ref_audio_array, ref_sr)
metadata = {
'duration': features['duration'],
'sample_rate': features['sample_rate'],
'ref_text': request_data.get('ref_text', ''),
'x_vector_only_mode': True
}
cache_id = await cache_manager.set_cache(
user_id, ref_audio_hash, x_vector, metadata, db
)
logger.info(f"Created cache for job {job_id}, cache_id={cache_id}")
job.status = JobStatus.COMPLETED
job.output_path = f"x_vector_cached_{cache_id}"
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed (x_vector_only_mode)")
return
backend = await TTSServiceFactory.get_backend()
if voice_design_id:
from db.crud import get_voice_design
design = get_voice_design(db, voice_design_id, user_id)
cached = await cache_manager.get_cache_by_id(design.voice_cache_id, db)
x_vector = cached['data']
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, x_vector=x_vector)
logger.info(f"Generated audio using cached x_vector from voice design {voice_design_id}")
else:
with open(ref_audio_path, 'rb') as f:
ref_audio_data = f.read()
audio_bytes, sample_rate = await backend.generate_voice_clone(request_data, ref_audio_data)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = Path(settings.OUTPUT_DIR) / filename
with open(output_path, 'wb') as f:
f.write(audio_bytes)
job.status = JobStatus.COMPLETED
job.output_path = str(output_path)
job.completed_at = datetime.utcnow()
db.commit()
logger.info(f"Job {job_id} completed successfully")
except Exception as e:
logger.error(f"Job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
if not use_voice_design and ref_audio_path and Path(ref_audio_path).exists():
Path(ref_audio_path).unlink()
db.close()
@router.post("/custom-voice")
@limiter.limit("10/minute")
async def create_custom_voice_job(
request: Request,
req_data: CustomVoiceRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from db.crud import can_user_use_local_model
if not can_user_use_local_model(current_user):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Local model is not available. Please contact administrator."
)
backend_type = "local"
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
speaker = validate_speaker(req_data.speaker)
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="custom-voice",
status=JobStatus.PENDING,
backend_type=backend_type,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"speaker": speaker,
"instruct": req_data.instruct or "",
**params
}
background_tasks.add_task(
process_custom_voice_job,
job.id,
current_user.id,
request_data,
backend_type,
str(settings.DATABASE_URL)
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-design")
@limiter.limit("10/minute")
async def create_voice_design_job(
request: Request,
req_data: VoiceDesignRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from db.crud import can_user_use_local_model, get_voice_design, update_voice_design_usage
if not can_user_use_local_model(current_user):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Local model is not available. Please contact administrator."
)
backend_type = "local"
if req_data.saved_design_id:
saved_design = get_voice_design(db, req_data.saved_design_id, current_user.id)
if not saved_design:
raise HTTPException(status_code=404, detail="Saved voice design not found")
req_data.instruct = saved_design.instruct
update_voice_design_usage(db, req_data.saved_design_id, current_user.id)
try:
validate_text_length(req_data.text)
language = validate_language(req_data.language)
if not req_data.saved_design_id:
if not req_data.instruct or not req_data.instruct.strip():
raise ValueError("Instruct parameter is required when saved_design_id is not provided")
params = validate_generation_params({
'max_new_tokens': req_data.max_new_tokens,
'temperature': req_data.temperature,
'top_k': req_data.top_k,
'top_p': req_data.top_p,
'repetition_penalty': req_data.repetition_penalty
})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-design",
status=JobStatus.PENDING,
backend_type=backend_type,
input_data="",
input_params={
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
request_data = {
"text": req_data.text,
"language": language,
"instruct": req_data.instruct,
**params
}
background_tasks.add_task(
process_voice_design_job,
job.id,
current_user.id,
request_data,
backend_type,
str(settings.DATABASE_URL),
saved_voice_id
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.post("/voice-clone")
@limiter.limit("10/minute")
async def create_voice_clone_job(
request: Request,
text: str = Form(...),
language: str = Form(default="Auto"),
ref_audio: Optional[UploadFile] = File(default=None),
ref_text: Optional[str] = Form(default=None),
use_cache: bool = Form(default=True),
x_vector_only_mode: bool = Form(default=False),
voice_design_id: Optional[int] = Form(default=None),
max_new_tokens: Optional[int] = Form(default=2048),
temperature: Optional[float] = Form(default=0.9),
top_k: Optional[int] = Form(default=50),
top_p: Optional[float] = Form(default=1.0),
repetition_penalty: Optional[float] = Form(default=1.05),
backend: Optional[str] = Form(default=None),
background_tasks: BackgroundTasks = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from db.crud import can_user_use_local_model, get_voice_design
if not can_user_use_local_model(current_user):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Local model is not available. Please contact administrator."
)
backend_type = "local"
ref_audio_data = None
ref_audio_hash = None
use_voice_design = False
try:
validate_text_length(text)
language = validate_language(language)
params = validate_generation_params({
'max_new_tokens': max_new_tokens,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': repetition_penalty
})
cache_manager = await VoiceCacheManager.get_instance()
if voice_design_id:
design = get_voice_design(db, voice_design_id, current_user.id)
if not design:
raise ValueError("Voice design not found")
if not design.voice_cache_id:
raise ValueError("Voice design has no prepared clone prompt. Please call /voice-designs/{id}/prepare-clone first")
use_voice_design = True
ref_audio_hash = f"voice_design_{voice_design_id}"
if not ref_text:
ref_text = design.ref_text
logger.info(f"Using voice design {voice_design_id} with cache_id={design.voice_cache_id}")
else:
if not ref_audio:
raise ValueError("Either ref_audio or voice_design_id must be provided")
max_audio_size_bytes = settings.MAX_AUDIO_SIZE_MB * 1024 * 1024
ref_audio_data = await read_upload_with_size_limit(ref_audio, max_audio_size_bytes)
if not validate_ref_audio(ref_audio_data, max_size_mb=settings.MAX_AUDIO_SIZE_MB):
raise ValueError("Invalid reference audio: must be 1-30s duration and ≤10MB")
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
job = Job(
user_id=current_user.id,
job_type="voice-clone",
status=JobStatus.PENDING,
backend_type=backend_type,
input_data="",
input_params={
"text": text,
"language": language,
"ref_text": ref_text or "",
"ref_audio_hash": ref_audio_hash,
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params
}
)
db.add(job)
db.commit()
db.refresh(job)
if use_voice_design:
design = get_voice_design(db, voice_design_id, current_user.id)
tmp_audio_path = design.ref_audio_path
else:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(ref_audio_data)
tmp_audio_path = tmp_file.name
request_data = {
"text": text,
"language": language,
"ref_text": ref_text or "",
"use_cache": use_cache,
"x_vector_only_mode": x_vector_only_mode,
"voice_design_id": voice_design_id,
**params
}
background_tasks.add_task(
process_voice_clone_job,
job.id,
current_user.id,
request_data,
tmp_audio_path,
backend_type,
str(settings.DATABASE_URL),
use_voice_design
)
existing_cache = await cache_manager.get_cache(current_user.id, ref_audio_hash, db)
cache_info = {"cache_id": existing_cache['cache_id']} if existing_cache else None
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully",
"cache_info": cache_info
}
async def process_indextts2_job(
job_id: int,
user_id: int,
voice_design_id: int,
text: str,
emo_text: Optional[str],
emo_alpha: float,
):
from core.database import SessionLocal
from core.tts_service import IndexTTS2Backend
from db.crud import get_voice_design
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
return
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
db.commit()
design = get_voice_design(db, voice_design_id, user_id)
if not design or not design.ref_audio_path:
raise RuntimeError("Voice design has no ref_audio_path")
from pathlib import Path as _Path
if not _Path(design.ref_audio_path).exists():
raise RuntimeError(f"ref_audio_path does not exist: {design.ref_audio_path}")
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{user_id}_{job_id}_{timestamp}.wav"
output_path = str(_Path(settings.OUTPUT_DIR) / filename)
_Path(settings.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
backend = IndexTTS2Backend()
audio_bytes = await backend.generate(
text=text,
spk_audio_prompt=design.ref_audio_path,
output_path=output_path,
emo_text=emo_text,
emo_alpha=emo_alpha,
)
job.status = JobStatus.COMPLETED
job.output_path = output_path
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
logger.error(f"IndexTTS2 job {job_id} failed: {e}", exc_info=True)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
job.status = JobStatus.FAILED
job.error_message = "Job processing failed"
job.completed_at = datetime.utcnow()
db.commit()
finally:
db.close()
@router.post("/indextts2-from-design")
@limiter.limit("10/minute")
async def create_indextts2_from_design_job(
request: Request,
req_data: IndexTTS2FromDesignRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
from db.crud import get_voice_design
design = get_voice_design(db, req_data.voice_design_id, current_user.id)
if not design:
raise HTTPException(status_code=404, detail="Voice design not found")
if not design.ref_audio_path:
raise HTTPException(status_code=400, detail="Voice design has no ref_audio_path")
from pathlib import Path as _Path
if not _Path(design.ref_audio_path).exists():
raise HTTPException(status_code=400, detail="ref_audio_path file does not exist")
job = Job(
user_id=current_user.id,
job_type="indextts2",
status=JobStatus.PENDING,
backend_type="local",
input_data="",
input_params={
"text": req_data.text,
"voice_design_id": req_data.voice_design_id,
"emo_text": req_data.emo_text,
"emo_alpha": req_data.emo_alpha,
}
)
db.add(job)
db.commit()
db.refresh(job)
background_tasks.add_task(
process_indextts2_job,
job.id,
current_user.id,
req_data.voice_design_id,
req_data.text,
req_data.emo_text,
req_data.emo_alpha,
)
return {
"job_id": job.id,
"status": job.status,
"message": "Job created successfully"
}
@router.get("/speakers")
@limiter.limit("30/minute")
async def list_speakers(request: Request, backend: Optional[str] = "local"):
return get_supported_speakers(backend)
@router.get("/languages")
@limiter.limit("30/minute")
async def list_languages(request: Request):
return get_supported_languages()

290
backend/api/users.py Normal file
View File

@@ -0,0 +1,290 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from api.auth import get_current_user
from core.security import get_password_hash
from db.database import get_db
from db.crud import (
get_user_by_id,
get_user_by_username,
get_user_by_email,
list_users,
create_user_by_admin,
update_user,
delete_user
)
from schemas.user import User, UserCreateByAdmin, UserUpdate, UserListResponse
from schemas.audiobook import LLMConfigUpdate, LLMConfigResponse, NsfwSynopsisGenerationRequest, NsfwScriptGenerationRequest
router = APIRouter(prefix="/users", tags=["users"])
limiter = Limiter(key_func=get_remote_address)
async def require_superuser(
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Superuser access required"
)
return current_user
@router.get("", response_model=UserListResponse)
@limiter.limit("30/minute")
async def get_users(
request: Request,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
users, total = list_users(db, skip=skip, limit=limit)
return UserListResponse(users=users, total=total, skip=skip, limit=limit)
@router.post("", response_model=User, status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def create_user(
request: Request,
user_data: UserCreateByAdmin,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
existing_user = get_user_by_username(db, username=user_data.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
)
existing_email = get_user_by_email(db, email=user_data.email)
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
)
hashed_password = get_password_hash(user_data.password)
user = create_user_by_admin(
db,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
is_superuser=user_data.is_superuser,
can_use_local_model=user_data.can_use_local_model
)
return user
@router.get("/me", response_model=User)
@limiter.limit("30/minute")
async def get_current_user_info(
request: Request,
current_user: User = Depends(get_current_user)
):
return current_user
@router.get("/{user_id}", response_model=User)
@limiter.limit("30/minute")
async def get_user(
request: Request,
user_id: int,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
user = get_user_by_id(db, user_id=user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.put("/{user_id}", response_model=User)
@limiter.limit("10/minute")
async def update_user_info(
request: Request,
user_id: int,
user_data: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_superuser)
):
existing_user = get_user_by_id(db, user_id=user_id)
if not existing_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
if user_data.username is not None:
username_exists = get_user_by_username(db, username=user_data.username)
if username_exists and username_exists.id != user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already taken"
)
if user_data.email is not None:
email_exists = get_user_by_email(db, email=user_data.email)
if email_exists and email_exists.id != user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already taken"
)
hashed_password = None
if user_data.password is not None:
hashed_password = get_password_hash(user_data.password)
user = update_user(
db,
user_id=user_id,
username=user_data.username,
email=user_data.email,
hashed_password=hashed_password,
is_active=user_data.is_active,
is_superuser=user_data.is_superuser,
can_use_local_model=user_data.can_use_local_model,
can_use_nsfw=user_data.can_use_nsfw,
)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
return user
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@limiter.limit("10/minute")
async def delete_user_by_id(
request: Request,
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_superuser)
):
if user_id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot delete yourself"
)
success = delete_user(db, user_id=user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)
@router.put("/system/llm-config")
@limiter.limit("10/minute")
async def set_system_llm_config(
request: Request,
config: LLMConfigUpdate,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
from core.security import encrypt_api_key
from core.llm_service import LLMService
from db.crud import set_system_setting
api_key = config.api_key.strip()
base_url = config.base_url.strip()
model = config.model.strip()
llm = LLMService(base_url=base_url, api_key=api_key, model=model)
try:
await llm.chat("You are a test assistant.", "Reply with 'ok'.")
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"LLM API validation failed: {e}")
set_system_setting(db, "llm_api_key", encrypt_api_key(api_key))
set_system_setting(db, "llm_base_url", base_url)
set_system_setting(db, "llm_model", model)
return {"message": "LLM config updated"}
@router.get("/system/llm-config", response_model=LLMConfigResponse)
@limiter.limit("30/minute")
async def get_system_llm_config(
request: Request,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
from db.crud import get_system_setting
return LLMConfigResponse(
base_url=get_system_setting(db, "llm_base_url"),
model=get_system_setting(db, "llm_model"),
has_key=bool(get_system_setting(db, "llm_api_key")),
)
@router.delete("/system/llm-config")
@limiter.limit("10/minute")
async def delete_system_llm_config(
request: Request,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
from db.crud import delete_system_setting
delete_system_setting(db, "llm_api_key")
delete_system_setting(db, "llm_base_url")
delete_system_setting(db, "llm_model")
return {"message": "LLM config deleted"}
@router.put("/system/grok-config")
@limiter.limit("10/minute")
async def set_system_grok_config(
request: Request,
config: LLMConfigUpdate,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
from core.security import encrypt_api_key
from core.llm_service import GrokLLMService
from db.crud import set_system_setting
api_key = config.api_key.strip()
base_url = config.base_url.strip()
model = config.model.strip()
grok = GrokLLMService(base_url=base_url, api_key=api_key, model=model)
try:
await grok.chat("You are a test assistant.", "Reply with 'ok'.")
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Grok API validation failed: {e}")
set_system_setting(db, "grok_api_key", encrypt_api_key(api_key))
set_system_setting(db, "grok_base_url", base_url)
set_system_setting(db, "grok_model", model)
return {"message": "Grok config updated"}
@router.get("/system/grok-config", response_model=LLMConfigResponse)
@limiter.limit("30/minute")
async def get_system_grok_config(
request: Request,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
from db.crud import get_system_setting
return LLMConfigResponse(
base_url=get_system_setting(db, "grok_base_url"),
model=get_system_setting(db, "grok_model"),
has_key=bool(get_system_setting(db, "grok_api_key")),
)
@router.delete("/system/grok-config")
@limiter.limit("10/minute")
async def delete_system_grok_config(
request: Request,
db: Session = Depends(get_db),
_: User = Depends(require_superuser)
):
from db.crud import delete_system_setting
delete_system_setting(db, "grok_api_key")
delete_system_setting(db, "grok_base_url")
delete_system_setting(db, "grok_model")
return {"message": "Grok config deleted"}

View File

@@ -0,0 +1,281 @@
import logging
import json
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
from sqlalchemy.orm import Session
from typing import Optional
from slowapi import Limiter
from slowapi.util import get_remote_address
from pathlib import Path
from core.database import get_db
from api.auth import get_current_user
from db.models import User, Job, JobStatus
from db import crud
from schemas.voice_design import (
VoiceDesignCreate,
VoiceDesignResponse,
VoiceDesignListResponse
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/voice-designs", tags=["voice-designs"])
limiter = Limiter(key_func=get_remote_address)
def to_voice_design_response(design) -> VoiceDesignResponse:
meta_data = design.meta_data
if isinstance(meta_data, str):
try:
meta_data = json.loads(meta_data)
except Exception:
meta_data = None
return VoiceDesignResponse(
id=design.id,
user_id=design.user_id,
name=design.name,
instruct=design.instruct,
meta_data=meta_data,
preview_text=design.preview_text,
ref_audio_path=design.ref_audio_path,
created_at=design.created_at,
last_used=design.last_used,
use_count=design.use_count
)
@router.post("", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("30/minute")
async def save_voice_design(
request: Request,
data: VoiceDesignCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
try:
design = crud.create_voice_design(
db=db,
user_id=current_user.id,
name=data.name,
instruct=data.instruct,
meta_data=data.meta_data,
preview_text=data.preview_text
)
return to_voice_design_response(design)
except Exception as e:
logger.error(f"Failed to save voice design: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to save voice design")
@router.get("", response_model=VoiceDesignListResponse)
@limiter.limit("30/minute")
async def list_voice_designs(
request: Request,
backend_type: Optional[str] = None,
skip: int = 0,
limit: int = 100,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
designs = crud.list_voice_designs(db, current_user.id, backend_type, skip, limit)
total = crud.count_voice_designs(db, current_user.id, backend_type)
return VoiceDesignListResponse(designs=[to_voice_design_response(d) for d in designs], total=total)
@router.post("/prepare-and-create", response_model=VoiceDesignResponse, status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def prepare_and_create_voice_design(
request: Request,
data: VoiceDesignCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from core.tts_service import TTSServiceFactory
from core.cache_manager import VoiceCacheManager
from utils.audio import process_ref_audio, extract_audio_features
from core.config import settings
from db.crud import can_user_use_local_model
from datetime import datetime
if not can_user_use_local_model(current_user):
raise HTTPException(status_code=403, detail="Local model access required")
try:
backend = await TTSServiceFactory.get_backend("local")
ref_text = data.preview_text or data.instruct[:100]
ref_audio_bytes, _ = await backend.generate_voice_design({
"text": ref_text,
"language": "Auto",
"instruct": data.instruct,
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.5,
"repetition_penalty": 1.05
})
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
ref_filename = f"voice_design_new_{timestamp}.wav"
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
with open(ref_audio_path, 'wb') as f:
f.write(ref_audio_bytes)
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
from core.model_manager import ModelManager
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=ref_text,
)
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
features = extract_audio_features(ref_audio_array, ref_sr)
metadata = {
'duration': features['duration'],
'sample_rate': features['sample_rate'],
'ref_text': ref_text,
'instruct': data.instruct
}
cache_id = await cache_manager.set_cache(
current_user.id, ref_audio_hash, x_vector, metadata, db
)
design = crud.create_voice_design(
db=db,
user_id=current_user.id,
name=data.name,
instruct=data.instruct,
meta_data=data.meta_data,
preview_text=data.preview_text,
voice_cache_id=cache_id,
ref_audio_path=str(ref_audio_path),
ref_text=ref_text,
)
logger.info(f"Voice design created with clone prompt: design_id={design.id}, cache_id={cache_id}")
return to_voice_design_response(design)
except Exception as e:
logger.error(f"Failed to prepare and create voice design: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to prepare voice design")
@router.delete("/{design_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_voice_design(
design_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
deleted = crud.delete_voice_design(db, design_id, current_user.id)
if not deleted:
raise HTTPException(status_code=404, detail="Voice design not found")
@router.post("/{design_id}/prepare-clone")
@limiter.limit("10/minute")
async def prepare_voice_clone_prompt(
request: Request,
design_id: int,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
from core.tts_service import TTSServiceFactory
from core.cache_manager import VoiceCacheManager
from utils.audio import process_ref_audio, extract_audio_features
from core.config import settings
from db.crud import can_user_use_local_model
from datetime import datetime
design = crud.get_voice_design(db, design_id, current_user.id)
if not design:
raise HTTPException(status_code=404, detail="Voice design not found")
if not can_user_use_local_model(current_user):
raise HTTPException(
status_code=403,
detail="Local model access required"
)
if design.voice_cache_id:
return {
"message": "Voice clone prompt already exists",
"cache_id": design.voice_cache_id
}
try:
backend = await TTSServiceFactory.get_backend("local")
ref_text = design.preview_text or design.instruct[:100]
logger.info(f"Generating reference audio for voice design {design_id}")
ref_audio_bytes, sample_rate = await backend.generate_voice_design({
"text": ref_text,
"language": "Auto",
"instruct": design.instruct,
"max_new_tokens": 2048,
"temperature": 0.3,
"top_k": 10,
"top_p": 0.5,
"repetition_penalty": 1.05
})
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
ref_filename = f"voice_design_{design_id}_{timestamp}.wav"
ref_audio_path = Path(settings.OUTPUT_DIR) / ref_filename
with open(ref_audio_path, 'wb') as f:
f.write(ref_audio_bytes)
logger.info(f"Extracting voice clone prompt from reference audio")
ref_audio_array, ref_sr = process_ref_audio(ref_audio_bytes)
from core.model_manager import ModelManager
model_manager = await ModelManager.get_instance()
await model_manager.load_model("base")
_, tts = await model_manager.get_current_model()
if tts is None:
raise RuntimeError("Failed to load base model")
x_vector = tts.create_voice_clone_prompt(
ref_audio=(ref_audio_array, ref_sr),
ref_text=ref_text,
)
cache_manager = await VoiceCacheManager.get_instance()
ref_audio_hash = cache_manager.get_audio_hash(ref_audio_bytes)
features = extract_audio_features(ref_audio_array, ref_sr)
metadata = {
'duration': features['duration'],
'sample_rate': features['sample_rate'],
'ref_text': ref_text,
'voice_design_id': design_id,
'instruct': design.instruct
}
cache_id = await cache_manager.set_cache(
current_user.id, ref_audio_hash, x_vector, metadata, db
)
design.voice_cache_id = cache_id
design.ref_audio_path = str(ref_audio_path)
design.ref_text = ref_text
db.commit()
logger.info(f"Voice clone prompt prepared for design {design_id}, cache_id={cache_id}")
return {
"message": "Voice clone prompt prepared successfully",
"cache_id": cache_id,
"ref_text": ref_text
}
except Exception as e:
logger.error(f"Failed to prepare voice clone prompt: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to prepare voice clone prompt")