init commit

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-26 15:34:31 +08:00
commit 80513a3258
141 changed files with 24966 additions and 0 deletions

View File

View File

@@ -0,0 +1,141 @@
import asyncio
import logging
import time
from typing import Any, Callable, Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import deque
from core.config import settings
logger = logging.getLogger(__name__)
@dataclass
class BatchRequest:
request_id: str
data: Dict[str, Any]
future: asyncio.Future
timestamp: float
class BatchProcessor:
_instance: Optional['BatchProcessor'] = None
_lock = asyncio.Lock()
def __init__(self, batch_size: int = None, batch_wait_time: float = None):
self.batch_size = batch_size or settings.BATCH_SIZE
self.batch_wait_time = batch_wait_time or settings.BATCH_WAIT_TIME
self.queue: deque = deque()
self.queue_lock = asyncio.Lock()
self.processing = False
self._processor_task: Optional[asyncio.Task] = None
logger.info(f"BatchProcessor initialized with batch_size={self.batch_size}, wait_time={self.batch_wait_time}s")
@classmethod
async def get_instance(cls) -> 'BatchProcessor':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
cls._instance._start_processor()
return cls._instance
def _start_processor(self):
if not self._processor_task or self._processor_task.done():
self._processor_task = asyncio.create_task(self._process_batches())
logger.info("Batch processor task started")
async def _process_batches(self):
logger.info("Batch processing loop started")
while True:
try:
await asyncio.sleep(0.1)
async with self.queue_lock:
if not self.queue:
continue
current_time = time.time()
oldest_request = self.queue[0]
wait_duration = current_time - oldest_request.timestamp
should_process = (
len(self.queue) >= self.batch_size or
wait_duration >= self.batch_wait_time
)
if should_process:
batch = []
for _ in range(min(self.batch_size, len(self.queue))):
if self.queue:
batch.append(self.queue.popleft())
if batch:
logger.info(f"Processing batch of {len(batch)} requests (queue_wait={wait_duration:.3f}s)")
asyncio.create_task(self._process_batch(batch))
except Exception as e:
logger.error(f"Error in batch processor loop: {e}", exc_info=True)
await asyncio.sleep(1)
async def _process_batch(self, batch: List[BatchRequest]):
for request in batch:
try:
if not request.future.done():
result = await self._execute_single_request(request.data)
request.future.set_result(result)
except Exception as e:
logger.error(f"Error processing request {request.request_id}: {e}", exc_info=True)
if not request.future.done():
request.future.set_exception(e)
async def _execute_single_request(self, data: Dict[str, Any]) -> Any:
raise NotImplementedError("Subclass must implement _execute_single_request")
async def submit(self, request_id: str, data: Dict[str, Any], timeout: float = 300) -> Any:
future = asyncio.Future()
request = BatchRequest(
request_id=request_id,
data=data,
future=future,
timestamp=time.time()
)
async with self.queue_lock:
self.queue.append(request)
queue_size = len(self.queue)
logger.debug(f"Request {request_id} queued (queue_size={queue_size})")
try:
result = await asyncio.wait_for(future, timeout=timeout)
return result
except asyncio.TimeoutError:
logger.error(f"Request {request_id} timed out after {timeout}s")
async with self.queue_lock:
if request in self.queue:
self.queue.remove(request)
raise TimeoutError(f"Request timed out after {timeout}s")
async def get_queue_length(self) -> int:
async with self.queue_lock:
return len(self.queue)
async def get_stats(self) -> Dict[str, Any]:
queue_length = await self.get_queue_length()
return {
"queue_length": queue_length,
"batch_size": self.batch_size,
"batch_wait_time": self.batch_wait_time,
"processor_running": self._processor_task is not None and not self._processor_task.done()
}
class TTSBatchProcessor(BatchProcessor):
def __init__(self, process_func: Callable, batch_size: int = None, batch_wait_time: float = None):
super().__init__(batch_size, batch_wait_time)
self.process_func = process_func
async def _execute_single_request(self, data: Dict[str, Any]) -> Any:
return await self.process_func(**data)

View File

@@ -0,0 +1,161 @@
import hashlib
import pickle
import asyncio
from pathlib import Path
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
import logging
from sqlalchemy.orm import Session
from db.crud import (
create_cache_entry,
get_cache_entry,
list_cache_entries,
delete_cache_entry
)
from db.models import VoiceCache
from core.config import settings
logger = logging.getLogger(__name__)
class VoiceCacheManager:
_instance = None
_lock = asyncio.Lock()
def __init__(self, cache_dir: str, max_entries: int, ttl_days: int):
self.cache_dir = Path(cache_dir)
self.max_entries = max_entries
self.ttl_days = ttl_days
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"VoiceCacheManager initialized: dir={cache_dir}, max={max_entries}, ttl={ttl_days}d")
@classmethod
async def get_instance(cls) -> 'VoiceCacheManager':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = VoiceCacheManager(
cache_dir=settings.CACHE_DIR,
max_entries=settings.MAX_CACHE_ENTRIES,
ttl_days=settings.CACHE_TTL_DAYS
)
return cls._instance
def get_audio_hash(self, audio_data: bytes) -> str:
return hashlib.sha256(audio_data).hexdigest()
async def get_cache(self, user_id: int, ref_audio_hash: str, db: Session) -> Optional[Dict[str, Any]]:
try:
cache_entry = get_cache_entry(db, user_id, ref_audio_hash)
if not cache_entry:
logger.debug(f"Cache miss: user={user_id}, hash={ref_audio_hash[:8]}...")
return None
cache_file = Path(cache_entry.cache_path)
if not cache_file.exists():
logger.warning(f"Cache file missing: {cache_file}")
delete_cache_entry(db, cache_entry.id, user_id)
return None
with open(cache_file, 'rb') as f:
cache_data = pickle.load(f)
logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}")
return {
'cache_id': cache_entry.id,
'data': cache_data,
'metadata': cache_entry.meta_data
}
except Exception as e:
logger.error(f"Cache retrieval error: {e}", exc_info=True)
return None
async def set_cache(
self,
user_id: int,
ref_audio_hash: str,
cache_data: Any,
metadata: Dict[str, Any],
db: Session
) -> str:
async with self._lock:
try:
cache_filename = f"{user_id}_{ref_audio_hash}.pkl"
cache_path = self.cache_dir / cache_filename
with open(cache_path, 'wb') as f:
pickle.dump(cache_data, f)
cache_entry = create_cache_entry(
db=db,
user_id=user_id,
ref_audio_hash=ref_audio_hash,
cache_path=str(cache_path),
meta_data=metadata
)
await self.enforce_max_entries(user_id, db)
logger.info(f"Cache created: user={user_id}, hash={ref_audio_hash[:8]}..., id={cache_entry.id}")
return cache_entry.id
except Exception as e:
logger.error(f"Cache creation error: {e}", exc_info=True)
if cache_path.exists():
cache_path.unlink()
raise
async def enforce_max_entries(self, user_id: int, db: Session) -> int:
try:
all_caches = list_cache_entries(db, user_id, skip=0, limit=9999)
if len(all_caches) <= self.max_entries:
return 0
caches_to_delete = all_caches[self.max_entries:]
deleted_count = 0
for cache in caches_to_delete:
cache_file = Path(cache.cache_path)
if cache_file.exists():
cache_file.unlink()
delete_cache_entry(db, cache.id, user_id)
deleted_count += 1
if deleted_count > 0:
logger.info(f"LRU eviction: user={user_id}, deleted={deleted_count} entries")
return deleted_count
except Exception as e:
logger.error(f"LRU enforcement error: {e}", exc_info=True)
return 0
async def cleanup_expired(self, db: Session) -> int:
try:
cutoff_date = datetime.utcnow() - timedelta(days=self.ttl_days)
expired_caches = db.query(VoiceCache).filter(
VoiceCache.last_accessed < cutoff_date
).all()
deleted_count = 0
for cache in expired_caches:
cache_file = Path(cache.cache_path)
if cache_file.exists():
cache_file.unlink()
db.delete(cache)
deleted_count += 1
if deleted_count > 0:
db.commit()
logger.info(f"Expired cache cleanup: deleted={deleted_count} entries")
return deleted_count
except Exception as e:
logger.error(f"Expired cache cleanup error: {e}", exc_info=True)
db.rollback()
return 0

View File

@@ -0,0 +1,166 @@
import logging
from datetime import datetime, timedelta
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from core.config import settings
from core.cache_manager import VoiceCacheManager
from db.models import Job
logger = logging.getLogger(__name__)
async def cleanup_expired_caches(db_url: str) -> dict:
try:
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
cache_manager = await VoiceCacheManager.get_instance()
deleted_count = await cache_manager.cleanup_expired(db)
freed_space_mb = 0
db.close()
logger.info(f"Cleanup: deleted {deleted_count} expired caches")
return {
'deleted_count': deleted_count,
'freed_space_mb': freed_space_mb
}
except Exception as e:
logger.error(f"Expired cache cleanup failed: {e}", exc_info=True)
return {
'deleted_count': 0,
'freed_space_mb': 0,
'error': str(e)
}
async def cleanup_old_jobs(db_url: str, days: int = 7) -> dict:
try:
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
cutoff_date = datetime.utcnow() - timedelta(days=days)
old_jobs = db.query(Job).filter(
Job.completed_at < cutoff_date,
Job.status.in_(['completed', 'failed'])
).all()
deleted_files = 0
for job in old_jobs:
if job.output_path:
output_file = Path(job.output_path)
if output_file.exists():
output_file.unlink()
deleted_files += 1
db.delete(job)
db.commit()
deleted_jobs = len(old_jobs)
db.close()
logger.info(f"Cleanup: deleted {deleted_jobs} old jobs, {deleted_files} files")
return {
'deleted_jobs': deleted_jobs,
'deleted_files': deleted_files
}
except Exception as e:
logger.error(f"Old job cleanup failed: {e}", exc_info=True)
return {
'deleted_jobs': 0,
'deleted_files': 0,
'error': str(e)
}
async def cleanup_orphaned_files(db_url: str) -> dict:
try:
engine = create_engine(db_url)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
output_dir = Path(settings.OUTPUT_DIR)
cache_dir = Path(settings.CACHE_DIR)
output_files_in_db = {Path(job.output_path).name for job in db.query(Job.output_path).filter(Job.output_path.isnot(None)).all()}
from db.models import VoiceCache
cache_files_in_db = {Path(cache.cache_path).name for cache in db.query(VoiceCache.cache_path).all()}
deleted_orphans = 0
freed_space_bytes = 0
if output_dir.exists():
for output_file in output_dir.glob("*.wav"):
if output_file.name not in output_files_in_db:
size = output_file.stat().st_size
output_file.unlink()
deleted_orphans += 1
freed_space_bytes += size
if cache_dir.exists():
for cache_file in cache_dir.glob("*.pkl"):
if cache_file.name not in cache_files_in_db:
size = cache_file.stat().st_size
cache_file.unlink()
deleted_orphans += 1
freed_space_bytes += size
freed_space_mb = freed_space_bytes / (1024 * 1024)
db.close()
logger.info(f"Cleanup: deleted {deleted_orphans} orphaned files, freed {freed_space_mb:.2f} MB")
return {
'deleted_orphans': deleted_orphans,
'freed_space_mb': freed_space_mb
}
except Exception as e:
logger.error(f"Orphaned file cleanup failed: {e}", exc_info=True)
return {
'deleted_orphans': 0,
'freed_space_mb': 0,
'error': str(e)
}
async def run_scheduled_cleanup(db_url: str) -> dict:
logger.info("Starting scheduled cleanup task...")
try:
cache_result = await cleanup_expired_caches(db_url)
job_result = await cleanup_old_jobs(db_url)
orphan_result = await cleanup_orphaned_files(db_url)
result = {
'timestamp': datetime.utcnow().isoformat(),
'expired_caches': cache_result,
'old_jobs': job_result,
'orphaned_files': orphan_result,
'status': 'completed'
}
logger.info(f"Scheduled cleanup completed: {result}")
return result
except Exception as e:
logger.error(f"Scheduled cleanup failed: {e}", exc_info=True)
return {
'timestamp': datetime.utcnow().isoformat(),
'status': 'failed',
'error': str(e)
}

View File

@@ -0,0 +1,3 @@
from config import settings, Settings
__all__ = ['settings', 'Settings']

View File

@@ -0,0 +1,3 @@
from db.database import Base, engine, SessionLocal, get_db, init_db
__all__ = ['Base', 'engine', 'SessionLocal', 'get_db', 'init_db']

View File

@@ -0,0 +1,156 @@
import time
import logging
from collections import deque, defaultdict
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
import asyncio
import statistics
logger = logging.getLogger(__name__)
@dataclass
class RequestMetric:
timestamp: float
endpoint: str
duration: float
status_code: int
queue_time: float = 0.0
class MetricsCollector:
_instance: Optional['MetricsCollector'] = None
_lock = asyncio.Lock()
def __init__(self, window_size: int = 1000):
self.window_size = window_size
self.requests: deque = deque(maxlen=window_size)
self.request_counts: Dict[str, int] = defaultdict(int)
self.error_counts: Dict[str, int] = defaultdict(int)
self.total_requests = 0
self.start_time = time.time()
self.batch_stats = {
'total_batches': 0,
'total_requests_batched': 0,
'avg_batch_size': 0.0
}
self._lock_local = asyncio.Lock()
logger.info("MetricsCollector initialized")
@classmethod
async def get_instance(cls) -> 'MetricsCollector':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def record_request(
self,
endpoint: str,
duration: float,
status_code: int,
queue_time: float = 0.0
):
async with self._lock_local:
metric = RequestMetric(
timestamp=time.time(),
endpoint=endpoint,
duration=duration,
status_code=status_code,
queue_time=queue_time
)
self.requests.append(metric)
self.request_counts[endpoint] += 1
self.total_requests += 1
if status_code >= 400:
self.error_counts[endpoint] += 1
async def record_batch(self, batch_size: int):
async with self._lock_local:
self.batch_stats['total_batches'] += 1
self.batch_stats['total_requests_batched'] += batch_size
total_batches = self.batch_stats['total_batches']
total_requests = self.batch_stats['total_requests_batched']
self.batch_stats['avg_batch_size'] = total_requests / total_batches if total_batches > 0 else 0.0
async def get_metrics(self) -> Dict[str, Any]:
async with self._lock_local:
current_time = time.time()
uptime = current_time - self.start_time
recent_requests = [r for r in self.requests if current_time - r.timestamp < 60]
durations = [r.duration for r in self.requests if r.duration > 0]
queue_times = [r.queue_time for r in self.requests if r.queue_time > 0]
percentiles = {}
if durations:
sorted_durations = sorted(durations)
percentiles = {
'p50': statistics.median(sorted_durations),
'p95': sorted_durations[int(len(sorted_durations) * 0.95)] if len(sorted_durations) > 0 else 0,
'p99': sorted_durations[int(len(sorted_durations) * 0.99)] if len(sorted_durations) > 0 else 0,
'avg': statistics.mean(sorted_durations),
'min': min(sorted_durations),
'max': max(sorted_durations)
}
queue_percentiles = {}
if queue_times:
sorted_queue_times = sorted(queue_times)
queue_percentiles = {
'p50': statistics.median(sorted_queue_times),
'p95': sorted_queue_times[int(len(sorted_queue_times) * 0.95)] if len(sorted_queue_times) > 0 else 0,
'p99': sorted_queue_times[int(len(sorted_queue_times) * 0.99)] if len(sorted_queue_times) > 0 else 0,
'avg': statistics.mean(sorted_queue_times)
}
requests_per_second = len(recent_requests) / 60.0 if recent_requests else 0.0
import torch
gpu_stats = {}
if torch.cuda.is_available():
gpu_stats = {
'gpu_available': True,
'gpu_memory_allocated_mb': torch.cuda.memory_allocated(0) / 1024**2,
'gpu_memory_reserved_mb': torch.cuda.memory_reserved(0) / 1024**2,
'gpu_memory_total_mb': torch.cuda.get_device_properties(0).total_memory / 1024**2
}
else:
gpu_stats = {'gpu_available': False}
from core.batch_processor import BatchProcessor
batch_processor = await BatchProcessor.get_instance()
batch_stats_current = await batch_processor.get_stats()
return {
'uptime_seconds': uptime,
'total_requests': self.total_requests,
'requests_per_second': requests_per_second,
'request_counts_by_endpoint': dict(self.request_counts),
'error_counts_by_endpoint': dict(self.error_counts),
'latency': percentiles,
'queue_time': queue_percentiles,
'batch_processing': {
**self.batch_stats,
**batch_stats_current
},
'gpu': gpu_stats
}
async def reset(self):
async with self._lock_local:
self.requests.clear()
self.request_counts.clear()
self.error_counts.clear()
self.total_requests = 0
self.start_time = time.time()
self.batch_stats = {
'total_batches': 0,
'total_requests_batched': 0,
'avg_batch_size': 0.0
}
logger.info("Metrics reset")

View File

@@ -0,0 +1,123 @@
import asyncio
import logging
from typing import Optional
import torch
from qwen_tts import Qwen3TTSModel
from core.config import settings
logger = logging.getLogger(__name__)
class ModelManager:
_instance: Optional['ModelManager'] = None
_lock = asyncio.Lock()
MODEL_PATHS = {
"custom-voice": "Qwen3-TTS-12Hz-1.7B-CustomVoice",
"voice-design": "Qwen3-TTS-12Hz-1.7B-VoiceDesign",
"base": "Qwen3-TTS-12Hz-1.7B-Base"
}
def __init__(self):
if ModelManager._instance is not None:
raise RuntimeError("Use get_instance() to get ModelManager")
self.current_model_name: Optional[str] = None
self.tts: Optional[Qwen3TTSModel] = None
@classmethod
async def get_instance(cls) -> 'ModelManager':
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def load_model(self, model_name: str) -> None:
if model_name not in self.MODEL_PATHS:
raise ValueError(
f"Unknown model: {model_name}. "
f"Available models: {list(self.MODEL_PATHS.keys())}"
)
if self.current_model_name == model_name and self.tts is not None:
logger.info(f"Model {model_name} already loaded")
return
async with self._lock:
logger.info(f"Loading model: {model_name}")
if self.tts is not None:
logger.info(f"Unloading current model: {self.current_model_name}")
await self._unload_model_internal()
from pathlib import Path
model_base_path = Path(settings.MODEL_BASE_PATH)
local_model_path = model_base_path / self.MODEL_PATHS[model_name]
if local_model_path.exists():
model_path = str(local_model_path)
logger.info(f"Using local model: {model_path}")
else:
model_path = f"Qwen/{self.MODEL_PATHS[model_name]}"
logger.info(f"Local path not found, using HuggingFace: {model_path}")
try:
self.tts = Qwen3TTSModel.from_pretrained(
str(model_path),
device_map=settings.MODEL_DEVICE,
torch_dtype=torch.bfloat16
)
self.current_model_name = model_name
logger.info(f"Successfully loaded model: {model_name}")
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / 1024**3
logger.info(f"GPU memory allocated: {allocated:.2f} GB")
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
self.tts = None
self.current_model_name = None
raise
async def get_current_model(self) -> tuple[Optional[str], Optional[Qwen3TTSModel]]:
return self.current_model_name, self.tts
async def unload_model(self) -> None:
async with self._lock:
await self._unload_model_internal()
async def _unload_model_internal(self) -> None:
if self.tts is not None:
logger.info(f"Unloading model: {self.current_model_name}")
del self.tts
self.tts = None
self.current_model_name = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Cleared CUDA cache")
async def get_memory_usage(self) -> dict:
memory_info = {
"gpu_available": torch.cuda.is_available(),
"current_model": self.current_model_name
}
if torch.cuda.is_available():
memory_info.update({
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
})
return memory_info
def get_model_info(self) -> dict:
return {
name: {
"path": path,
"loaded": name == self.current_model_name
}
for name, path in self.MODEL_PATHS.items()
}

View File

@@ -0,0 +1,35 @@
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> Optional[str]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
return username
except JWTError:
return None