0
qwen3-tts-backend/core/__init__.py
Normal file
0
qwen3-tts-backend/core/__init__.py
Normal file
141
qwen3-tts-backend/core/batch_processor.py
Normal file
141
qwen3-tts-backend/core/batch_processor.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
|
||||
from core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchRequest:
|
||||
request_id: str
|
||||
data: Dict[str, Any]
|
||||
future: asyncio.Future
|
||||
timestamp: float
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
_instance: Optional['BatchProcessor'] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, batch_size: int = None, batch_wait_time: float = None):
|
||||
self.batch_size = batch_size or settings.BATCH_SIZE
|
||||
self.batch_wait_time = batch_wait_time or settings.BATCH_WAIT_TIME
|
||||
self.queue: deque = deque()
|
||||
self.queue_lock = asyncio.Lock()
|
||||
self.processing = False
|
||||
self._processor_task: Optional[asyncio.Task] = None
|
||||
logger.info(f"BatchProcessor initialized with batch_size={self.batch_size}, wait_time={self.batch_wait_time}s")
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> 'BatchProcessor':
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
cls._instance._start_processor()
|
||||
return cls._instance
|
||||
|
||||
def _start_processor(self):
|
||||
if not self._processor_task or self._processor_task.done():
|
||||
self._processor_task = asyncio.create_task(self._process_batches())
|
||||
logger.info("Batch processor task started")
|
||||
|
||||
async def _process_batches(self):
|
||||
logger.info("Batch processing loop started")
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async with self.queue_lock:
|
||||
if not self.queue:
|
||||
continue
|
||||
|
||||
current_time = time.time()
|
||||
oldest_request = self.queue[0]
|
||||
wait_duration = current_time - oldest_request.timestamp
|
||||
|
||||
should_process = (
|
||||
len(self.queue) >= self.batch_size or
|
||||
wait_duration >= self.batch_wait_time
|
||||
)
|
||||
|
||||
if should_process:
|
||||
batch = []
|
||||
for _ in range(min(self.batch_size, len(self.queue))):
|
||||
if self.queue:
|
||||
batch.append(self.queue.popleft())
|
||||
|
||||
if batch:
|
||||
logger.info(f"Processing batch of {len(batch)} requests (queue_wait={wait_duration:.3f}s)")
|
||||
asyncio.create_task(self._process_batch(batch))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch processor loop: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _process_batch(self, batch: List[BatchRequest]):
|
||||
for request in batch:
|
||||
try:
|
||||
if not request.future.done():
|
||||
result = await self._execute_single_request(request.data)
|
||||
request.future.set_result(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing request {request.request_id}: {e}", exc_info=True)
|
||||
if not request.future.done():
|
||||
request.future.set_exception(e)
|
||||
|
||||
async def _execute_single_request(self, data: Dict[str, Any]) -> Any:
|
||||
raise NotImplementedError("Subclass must implement _execute_single_request")
|
||||
|
||||
async def submit(self, request_id: str, data: Dict[str, Any], timeout: float = 300) -> Any:
|
||||
future = asyncio.Future()
|
||||
request = BatchRequest(
|
||||
request_id=request_id,
|
||||
data=data,
|
||||
future=future,
|
||||
timestamp=time.time()
|
||||
)
|
||||
|
||||
async with self.queue_lock:
|
||||
self.queue.append(request)
|
||||
queue_size = len(self.queue)
|
||||
|
||||
logger.debug(f"Request {request_id} queued (queue_size={queue_size})")
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(future, timeout=timeout)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Request {request_id} timed out after {timeout}s")
|
||||
async with self.queue_lock:
|
||||
if request in self.queue:
|
||||
self.queue.remove(request)
|
||||
raise TimeoutError(f"Request timed out after {timeout}s")
|
||||
|
||||
async def get_queue_length(self) -> int:
|
||||
async with self.queue_lock:
|
||||
return len(self.queue)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
queue_length = await self.get_queue_length()
|
||||
return {
|
||||
"queue_length": queue_length,
|
||||
"batch_size": self.batch_size,
|
||||
"batch_wait_time": self.batch_wait_time,
|
||||
"processor_running": self._processor_task is not None and not self._processor_task.done()
|
||||
}
|
||||
|
||||
|
||||
class TTSBatchProcessor(BatchProcessor):
|
||||
|
||||
def __init__(self, process_func: Callable, batch_size: int = None, batch_wait_time: float = None):
|
||||
super().__init__(batch_size, batch_wait_time)
|
||||
self.process_func = process_func
|
||||
|
||||
async def _execute_single_request(self, data: Dict[str, Any]) -> Any:
|
||||
return await self.process_func(**data)
|
||||
161
qwen3-tts-backend/core/cache_manager.py
Normal file
161
qwen3-tts-backend/core/cache_manager.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import hashlib
|
||||
import pickle
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from db.crud import (
|
||||
create_cache_entry,
|
||||
get_cache_entry,
|
||||
list_cache_entries,
|
||||
delete_cache_entry
|
||||
)
|
||||
from db.models import VoiceCache
|
||||
from core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VoiceCacheManager:
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, cache_dir: str, max_entries: int, ttl_days: int):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.max_entries = max_entries
|
||||
self.ttl_days = ttl_days
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"VoiceCacheManager initialized: dir={cache_dir}, max={max_entries}, ttl={ttl_days}d")
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> 'VoiceCacheManager':
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = VoiceCacheManager(
|
||||
cache_dir=settings.CACHE_DIR,
|
||||
max_entries=settings.MAX_CACHE_ENTRIES,
|
||||
ttl_days=settings.CACHE_TTL_DAYS
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
def get_audio_hash(self, audio_data: bytes) -> str:
|
||||
return hashlib.sha256(audio_data).hexdigest()
|
||||
|
||||
async def get_cache(self, user_id: int, ref_audio_hash: str, db: Session) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
cache_entry = get_cache_entry(db, user_id, ref_audio_hash)
|
||||
if not cache_entry:
|
||||
logger.debug(f"Cache miss: user={user_id}, hash={ref_audio_hash[:8]}...")
|
||||
return None
|
||||
|
||||
cache_file = Path(cache_entry.cache_path)
|
||||
if not cache_file.exists():
|
||||
logger.warning(f"Cache file missing: {cache_file}")
|
||||
delete_cache_entry(db, cache_entry.id, user_id)
|
||||
return None
|
||||
|
||||
with open(cache_file, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
|
||||
logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}")
|
||||
return {
|
||||
'cache_id': cache_entry.id,
|
||||
'data': cache_data,
|
||||
'metadata': cache_entry.meta_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache retrieval error: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def set_cache(
|
||||
self,
|
||||
user_id: int,
|
||||
ref_audio_hash: str,
|
||||
cache_data: Any,
|
||||
metadata: Dict[str, Any],
|
||||
db: Session
|
||||
) -> str:
|
||||
async with self._lock:
|
||||
try:
|
||||
cache_filename = f"{user_id}_{ref_audio_hash}.pkl"
|
||||
cache_path = self.cache_dir / cache_filename
|
||||
|
||||
with open(cache_path, 'wb') as f:
|
||||
pickle.dump(cache_data, f)
|
||||
|
||||
cache_entry = create_cache_entry(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
ref_audio_hash=ref_audio_hash,
|
||||
cache_path=str(cache_path),
|
||||
meta_data=metadata
|
||||
)
|
||||
|
||||
await self.enforce_max_entries(user_id, db)
|
||||
|
||||
logger.info(f"Cache created: user={user_id}, hash={ref_audio_hash[:8]}..., id={cache_entry.id}")
|
||||
return cache_entry.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache creation error: {e}", exc_info=True)
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
raise
|
||||
|
||||
async def enforce_max_entries(self, user_id: int, db: Session) -> int:
|
||||
try:
|
||||
all_caches = list_cache_entries(db, user_id, skip=0, limit=9999)
|
||||
if len(all_caches) <= self.max_entries:
|
||||
return 0
|
||||
|
||||
caches_to_delete = all_caches[self.max_entries:]
|
||||
deleted_count = 0
|
||||
|
||||
for cache in caches_to_delete:
|
||||
cache_file = Path(cache.cache_path)
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
|
||||
delete_cache_entry(db, cache.id, user_id)
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"LRU eviction: user={user_id}, deleted={deleted_count} entries")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LRU enforcement error: {e}", exc_info=True)
|
||||
return 0
|
||||
|
||||
async def cleanup_expired(self, db: Session) -> int:
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.ttl_days)
|
||||
expired_caches = db.query(VoiceCache).filter(
|
||||
VoiceCache.last_accessed < cutoff_date
|
||||
).all()
|
||||
|
||||
deleted_count = 0
|
||||
for cache in expired_caches:
|
||||
cache_file = Path(cache.cache_path)
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
|
||||
db.delete(cache)
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
db.commit()
|
||||
logger.info(f"Expired cache cleanup: deleted={deleted_count} entries")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Expired cache cleanup error: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
return 0
|
||||
166
qwen3-tts-backend/core/cleanup.py
Normal file
166
qwen3-tts-backend/core/cleanup.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.config import settings
|
||||
from core.cache_manager import VoiceCacheManager
|
||||
from db.models import Job
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def cleanup_expired_caches(db_url: str) -> dict:
|
||||
try:
|
||||
engine = create_engine(db_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
cache_manager = await VoiceCacheManager.get_instance()
|
||||
deleted_count = await cache_manager.cleanup_expired(db)
|
||||
|
||||
freed_space_mb = 0
|
||||
|
||||
db.close()
|
||||
|
||||
logger.info(f"Cleanup: deleted {deleted_count} expired caches")
|
||||
|
||||
return {
|
||||
'deleted_count': deleted_count,
|
||||
'freed_space_mb': freed_space_mb
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Expired cache cleanup failed: {e}", exc_info=True)
|
||||
return {
|
||||
'deleted_count': 0,
|
||||
'freed_space_mb': 0,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
|
||||
async def cleanup_old_jobs(db_url: str, days: int = 7) -> dict:
|
||||
try:
|
||||
engine = create_engine(db_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
old_jobs = db.query(Job).filter(
|
||||
Job.completed_at < cutoff_date,
|
||||
Job.status.in_(['completed', 'failed'])
|
||||
).all()
|
||||
|
||||
deleted_files = 0
|
||||
for job in old_jobs:
|
||||
if job.output_path:
|
||||
output_file = Path(job.output_path)
|
||||
if output_file.exists():
|
||||
output_file.unlink()
|
||||
deleted_files += 1
|
||||
|
||||
db.delete(job)
|
||||
|
||||
db.commit()
|
||||
deleted_jobs = len(old_jobs)
|
||||
|
||||
db.close()
|
||||
|
||||
logger.info(f"Cleanup: deleted {deleted_jobs} old jobs, {deleted_files} files")
|
||||
|
||||
return {
|
||||
'deleted_jobs': deleted_jobs,
|
||||
'deleted_files': deleted_files
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Old job cleanup failed: {e}", exc_info=True)
|
||||
return {
|
||||
'deleted_jobs': 0,
|
||||
'deleted_files': 0,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
|
||||
async def cleanup_orphaned_files(db_url: str) -> dict:
|
||||
try:
|
||||
engine = create_engine(db_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
|
||||
output_dir = Path(settings.OUTPUT_DIR)
|
||||
cache_dir = Path(settings.CACHE_DIR)
|
||||
|
||||
output_files_in_db = {Path(job.output_path).name for job in db.query(Job.output_path).filter(Job.output_path.isnot(None)).all()}
|
||||
|
||||
from db.models import VoiceCache
|
||||
cache_files_in_db = {Path(cache.cache_path).name for cache in db.query(VoiceCache.cache_path).all()}
|
||||
|
||||
deleted_orphans = 0
|
||||
freed_space_bytes = 0
|
||||
|
||||
if output_dir.exists():
|
||||
for output_file in output_dir.glob("*.wav"):
|
||||
if output_file.name not in output_files_in_db:
|
||||
size = output_file.stat().st_size
|
||||
output_file.unlink()
|
||||
deleted_orphans += 1
|
||||
freed_space_bytes += size
|
||||
|
||||
if cache_dir.exists():
|
||||
for cache_file in cache_dir.glob("*.pkl"):
|
||||
if cache_file.name not in cache_files_in_db:
|
||||
size = cache_file.stat().st_size
|
||||
cache_file.unlink()
|
||||
deleted_orphans += 1
|
||||
freed_space_bytes += size
|
||||
|
||||
freed_space_mb = freed_space_bytes / (1024 * 1024)
|
||||
|
||||
db.close()
|
||||
|
||||
logger.info(f"Cleanup: deleted {deleted_orphans} orphaned files, freed {freed_space_mb:.2f} MB")
|
||||
|
||||
return {
|
||||
'deleted_orphans': deleted_orphans,
|
||||
'freed_space_mb': freed_space_mb
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Orphaned file cleanup failed: {e}", exc_info=True)
|
||||
return {
|
||||
'deleted_orphans': 0,
|
||||
'freed_space_mb': 0,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
|
||||
async def run_scheduled_cleanup(db_url: str) -> dict:
|
||||
logger.info("Starting scheduled cleanup task...")
|
||||
|
||||
try:
|
||||
cache_result = await cleanup_expired_caches(db_url)
|
||||
job_result = await cleanup_old_jobs(db_url)
|
||||
orphan_result = await cleanup_orphaned_files(db_url)
|
||||
|
||||
result = {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'expired_caches': cache_result,
|
||||
'old_jobs': job_result,
|
||||
'orphaned_files': orphan_result,
|
||||
'status': 'completed'
|
||||
}
|
||||
|
||||
logger.info(f"Scheduled cleanup completed: {result}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduled cleanup failed: {e}", exc_info=True)
|
||||
return {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'status': 'failed',
|
||||
'error': str(e)
|
||||
}
|
||||
3
qwen3-tts-backend/core/config.py
Normal file
3
qwen3-tts-backend/core/config.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from config import settings, Settings
|
||||
|
||||
__all__ = ['settings', 'Settings']
|
||||
3
qwen3-tts-backend/core/database.py
Normal file
3
qwen3-tts-backend/core/database.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from db.database import Base, engine, SessionLocal, get_db, init_db
|
||||
|
||||
__all__ = ['Base', 'engine', 'SessionLocal', 'get_db', 'init_db']
|
||||
156
qwen3-tts-backend/core/metrics.py
Normal file
156
qwen3-tts-backend/core/metrics.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import time
|
||||
import logging
|
||||
from collections import deque, defaultdict
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
import asyncio
|
||||
import statistics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMetric:
|
||||
timestamp: float
|
||||
endpoint: str
|
||||
duration: float
|
||||
status_code: int
|
||||
queue_time: float = 0.0
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
_instance: Optional['MetricsCollector'] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, window_size: int = 1000):
|
||||
self.window_size = window_size
|
||||
self.requests: deque = deque(maxlen=window_size)
|
||||
self.request_counts: Dict[str, int] = defaultdict(int)
|
||||
self.error_counts: Dict[str, int] = defaultdict(int)
|
||||
self.total_requests = 0
|
||||
self.start_time = time.time()
|
||||
self.batch_stats = {
|
||||
'total_batches': 0,
|
||||
'total_requests_batched': 0,
|
||||
'avg_batch_size': 0.0
|
||||
}
|
||||
self._lock_local = asyncio.Lock()
|
||||
logger.info("MetricsCollector initialized")
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> 'MetricsCollector':
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
async def record_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
duration: float,
|
||||
status_code: int,
|
||||
queue_time: float = 0.0
|
||||
):
|
||||
async with self._lock_local:
|
||||
metric = RequestMetric(
|
||||
timestamp=time.time(),
|
||||
endpoint=endpoint,
|
||||
duration=duration,
|
||||
status_code=status_code,
|
||||
queue_time=queue_time
|
||||
)
|
||||
self.requests.append(metric)
|
||||
self.request_counts[endpoint] += 1
|
||||
self.total_requests += 1
|
||||
|
||||
if status_code >= 400:
|
||||
self.error_counts[endpoint] += 1
|
||||
|
||||
async def record_batch(self, batch_size: int):
|
||||
async with self._lock_local:
|
||||
self.batch_stats['total_batches'] += 1
|
||||
self.batch_stats['total_requests_batched'] += batch_size
|
||||
|
||||
total_batches = self.batch_stats['total_batches']
|
||||
total_requests = self.batch_stats['total_requests_batched']
|
||||
self.batch_stats['avg_batch_size'] = total_requests / total_batches if total_batches > 0 else 0.0
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
async with self._lock_local:
|
||||
current_time = time.time()
|
||||
uptime = current_time - self.start_time
|
||||
|
||||
recent_requests = [r for r in self.requests if current_time - r.timestamp < 60]
|
||||
|
||||
durations = [r.duration for r in self.requests if r.duration > 0]
|
||||
queue_times = [r.queue_time for r in self.requests if r.queue_time > 0]
|
||||
|
||||
percentiles = {}
|
||||
if durations:
|
||||
sorted_durations = sorted(durations)
|
||||
percentiles = {
|
||||
'p50': statistics.median(sorted_durations),
|
||||
'p95': sorted_durations[int(len(sorted_durations) * 0.95)] if len(sorted_durations) > 0 else 0,
|
||||
'p99': sorted_durations[int(len(sorted_durations) * 0.99)] if len(sorted_durations) > 0 else 0,
|
||||
'avg': statistics.mean(sorted_durations),
|
||||
'min': min(sorted_durations),
|
||||
'max': max(sorted_durations)
|
||||
}
|
||||
|
||||
queue_percentiles = {}
|
||||
if queue_times:
|
||||
sorted_queue_times = sorted(queue_times)
|
||||
queue_percentiles = {
|
||||
'p50': statistics.median(sorted_queue_times),
|
||||
'p95': sorted_queue_times[int(len(sorted_queue_times) * 0.95)] if len(sorted_queue_times) > 0 else 0,
|
||||
'p99': sorted_queue_times[int(len(sorted_queue_times) * 0.99)] if len(sorted_queue_times) > 0 else 0,
|
||||
'avg': statistics.mean(sorted_queue_times)
|
||||
}
|
||||
|
||||
requests_per_second = len(recent_requests) / 60.0 if recent_requests else 0.0
|
||||
|
||||
import torch
|
||||
gpu_stats = {}
|
||||
if torch.cuda.is_available():
|
||||
gpu_stats = {
|
||||
'gpu_available': True,
|
||||
'gpu_memory_allocated_mb': torch.cuda.memory_allocated(0) / 1024**2,
|
||||
'gpu_memory_reserved_mb': torch.cuda.memory_reserved(0) / 1024**2,
|
||||
'gpu_memory_total_mb': torch.cuda.get_device_properties(0).total_memory / 1024**2
|
||||
}
|
||||
else:
|
||||
gpu_stats = {'gpu_available': False}
|
||||
|
||||
from core.batch_processor import BatchProcessor
|
||||
batch_processor = await BatchProcessor.get_instance()
|
||||
batch_stats_current = await batch_processor.get_stats()
|
||||
|
||||
return {
|
||||
'uptime_seconds': uptime,
|
||||
'total_requests': self.total_requests,
|
||||
'requests_per_second': requests_per_second,
|
||||
'request_counts_by_endpoint': dict(self.request_counts),
|
||||
'error_counts_by_endpoint': dict(self.error_counts),
|
||||
'latency': percentiles,
|
||||
'queue_time': queue_percentiles,
|
||||
'batch_processing': {
|
||||
**self.batch_stats,
|
||||
**batch_stats_current
|
||||
},
|
||||
'gpu': gpu_stats
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
async with self._lock_local:
|
||||
self.requests.clear()
|
||||
self.request_counts.clear()
|
||||
self.error_counts.clear()
|
||||
self.total_requests = 0
|
||||
self.start_time = time.time()
|
||||
self.batch_stats = {
|
||||
'total_batches': 0,
|
||||
'total_requests_batched': 0,
|
||||
'avg_batch_size': 0.0
|
||||
}
|
||||
logger.info("Metrics reset")
|
||||
123
qwen3-tts-backend/core/model_manager.py
Normal file
123
qwen3-tts-backend/core/model_manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
import torch
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
from core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
_instance: Optional['ModelManager'] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
MODEL_PATHS = {
|
||||
"custom-voice": "Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
||||
"voice-design": "Qwen3-TTS-12Hz-1.7B-VoiceDesign",
|
||||
"base": "Qwen3-TTS-12Hz-1.7B-Base"
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
if ModelManager._instance is not None:
|
||||
raise RuntimeError("Use get_instance() to get ModelManager")
|
||||
self.current_model_name: Optional[str] = None
|
||||
self.tts: Optional[Qwen3TTSModel] = None
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> 'ModelManager':
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
async def load_model(self, model_name: str) -> None:
|
||||
if model_name not in self.MODEL_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_name}. "
|
||||
f"Available models: {list(self.MODEL_PATHS.keys())}"
|
||||
)
|
||||
|
||||
if self.current_model_name == model_name and self.tts is not None:
|
||||
logger.info(f"Model {model_name} already loaded")
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
logger.info(f"Loading model: {model_name}")
|
||||
|
||||
if self.tts is not None:
|
||||
logger.info(f"Unloading current model: {self.current_model_name}")
|
||||
await self._unload_model_internal()
|
||||
|
||||
from pathlib import Path
|
||||
model_base_path = Path(settings.MODEL_BASE_PATH)
|
||||
local_model_path = model_base_path / self.MODEL_PATHS[model_name]
|
||||
|
||||
if local_model_path.exists():
|
||||
model_path = str(local_model_path)
|
||||
logger.info(f"Using local model: {model_path}")
|
||||
else:
|
||||
model_path = f"Qwen/{self.MODEL_PATHS[model_name]}"
|
||||
logger.info(f"Local path not found, using HuggingFace: {model_path}")
|
||||
|
||||
try:
|
||||
self.tts = Qwen3TTSModel.from_pretrained(
|
||||
str(model_path),
|
||||
device_map=settings.MODEL_DEVICE,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.current_model_name = model_name
|
||||
logger.info(f"Successfully loaded model: {model_name}")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
logger.info(f"GPU memory allocated: {allocated:.2f} GB")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_name}: {e}")
|
||||
self.tts = None
|
||||
self.current_model_name = None
|
||||
raise
|
||||
|
||||
async def get_current_model(self) -> tuple[Optional[str], Optional[Qwen3TTSModel]]:
|
||||
return self.current_model_name, self.tts
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
async with self._lock:
|
||||
await self._unload_model_internal()
|
||||
|
||||
async def _unload_model_internal(self) -> None:
|
||||
if self.tts is not None:
|
||||
logger.info(f"Unloading model: {self.current_model_name}")
|
||||
del self.tts
|
||||
self.tts = None
|
||||
self.current_model_name = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cleared CUDA cache")
|
||||
|
||||
async def get_memory_usage(self) -> dict:
|
||||
memory_info = {
|
||||
"gpu_available": torch.cuda.is_available(),
|
||||
"current_model": self.current_model_name
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
memory_info.update({
|
||||
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
|
||||
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
|
||||
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
})
|
||||
|
||||
return memory_info
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
return {
|
||||
name: {
|
||||
"path": path,
|
||||
"loaded": name == self.current_model_name
|
||||
}
|
||||
for name, path in self.MODEL_PATHS.items()
|
||||
}
|
||||
35
qwen3-tts-backend/core/security.py
Normal file
35
qwen3-tts-backend/core/security.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def decode_access_token(token: str) -> Optional[str]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
return None
|
||||
return username
|
||||
except JWTError:
|
||||
return None
|
||||
Reference in New Issue
Block a user