import hashlib import asyncio from pathlib import Path from typing import Optional, Dict, Any from datetime import datetime, timedelta import logging import numpy as np from sqlalchemy.orm import Session from db.crud import ( create_cache_entry, get_cache_entry, list_cache_entries, delete_cache_entry ) from db.models import VoiceCache from core.config import settings logger = logging.getLogger(__name__) class VoiceCacheManager: _instance = None _lock = asyncio.Lock() def __init__(self, cache_dir: str, max_entries: int, ttl_days: int): self.cache_dir = Path(cache_dir) self.max_entries = max_entries self.ttl_days = ttl_days self.cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"VoiceCacheManager initialized: dir={cache_dir}, max={max_entries}, ttl={ttl_days}d") @classmethod async def get_instance(cls) -> 'VoiceCacheManager': if cls._instance is None: async with cls._lock: if cls._instance is None: cls._instance = VoiceCacheManager( cache_dir=settings.CACHE_DIR, max_entries=settings.MAX_CACHE_ENTRIES, ttl_days=settings.CACHE_TTL_DAYS ) return cls._instance def get_audio_hash(self, audio_data: bytes) -> str: return hashlib.sha256(audio_data).hexdigest() def _load_cache_file(self, cache_file: Path): if cache_file.suffix == '.pt': import torch return torch.load(cache_file, weights_only=False) with open(cache_file, 'rb') as f: return np.load(f, allow_pickle=False) async def get_cache(self, user_id: int, ref_audio_hash: str, db: Session) -> Optional[Dict[str, Any]]: try: cache_entry = get_cache_entry(db, user_id, ref_audio_hash) if not cache_entry: logger.debug(f"Cache miss: user={user_id}, hash={ref_audio_hash[:8]}...") return None cache_file = Path(cache_entry.cache_path) if not cache_file.exists(): logger.warning(f"Cache file missing: {cache_file}") delete_cache_entry(db, cache_entry.id, user_id) return None resolved_cache_file = cache_file.resolve() if not resolved_cache_file.is_relative_to(self.cache_dir.resolve()): logger.warning(f"Cache path out of cache dir: {resolved_cache_file}") return None cache_data = self._load_cache_file(cache_file) logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}") return { 'cache_id': cache_entry.id, 'data': cache_data, 'metadata': cache_entry.meta_data } except Exception as e: logger.error(f"Cache retrieval error: {e}", exc_info=True) return None async def get_cache_by_id(self, cache_id: int, db: Session) -> Optional[Dict[str, Any]]: try: cache_entry = db.query(VoiceCache).filter(VoiceCache.id == cache_id).first() if not cache_entry: logger.debug(f"Cache not found: id={cache_id}") return None cache_file = Path(cache_entry.cache_path) if not cache_file.exists(): logger.warning(f"Cache file missing: {cache_file}") return None resolved_cache_file = cache_file.resolve() if not resolved_cache_file.is_relative_to(self.cache_dir.resolve()): logger.warning(f"Cache path out of cache dir: {resolved_cache_file}") return None cache_data = self._load_cache_file(cache_file) cache_entry.last_accessed = datetime.utcnow() cache_entry.access_count += 1 db.commit() logger.info(f"Cache loaded by id: cache_id={cache_id}, access_count={cache_entry.access_count}") return { 'cache_id': cache_entry.id, 'data': cache_data, 'metadata': cache_entry.meta_data } except Exception as e: logger.error(f"Cache retrieval by id error: {e}", exc_info=True) return None async def set_cache( self, user_id: int, ref_audio_hash: str, cache_data: Any, metadata: Dict[str, Any], db: Session ) -> str: async with self._lock: try: if hasattr(cache_data, "detach"): cache_data = cache_data.detach().cpu().numpy() if isinstance(cache_data, np.ndarray): cache_filename = f"{user_id}_{ref_audio_hash}.npy" cache_path = self.cache_dir / cache_filename with open(cache_path, 'wb') as f: np.save(f, cache_data, allow_pickle=False) else: import torch cache_filename = f"{user_id}_{ref_audio_hash}.pt" cache_path = self.cache_dir / cache_filename torch.save(cache_data, cache_path) cache_entry = create_cache_entry( db=db, user_id=user_id, ref_audio_hash=ref_audio_hash, cache_path=str(cache_path), meta_data=metadata ) await self.enforce_max_entries(user_id, db) logger.info(f"Cache created: user={user_id}, hash={ref_audio_hash[:8]}..., id={cache_entry.id}") return cache_entry.id except Exception as e: logger.error(f"Cache creation error: {e}", exc_info=True) if cache_path.exists(): cache_path.unlink() raise async def enforce_max_entries(self, user_id: int, db: Session) -> int: try: all_caches = list_cache_entries(db, user_id, skip=0, limit=9999) if len(all_caches) <= self.max_entries: return 0 caches_to_delete = all_caches[self.max_entries:] deleted_count = 0 for cache in caches_to_delete: cache_file = Path(cache.cache_path) if cache_file.exists(): cache_file.unlink() delete_cache_entry(db, cache.id, user_id) deleted_count += 1 if deleted_count > 0: logger.info(f"LRU eviction: user={user_id}, deleted={deleted_count} entries") return deleted_count except Exception as e: logger.error(f"LRU enforcement error: {e}", exc_info=True) return 0 async def cleanup_expired(self, db: Session) -> int: try: cutoff_date = datetime.utcnow() - timedelta(days=self.ttl_days) expired_caches = db.query(VoiceCache).filter( VoiceCache.last_accessed < cutoff_date ).all() deleted_count = 0 for cache in expired_caches: cache_file = Path(cache.cache_path) if cache_file.exists(): cache_file.unlink() db.delete(cache) deleted_count += 1 if deleted_count > 0: db.commit() logger.info(f"Expired cache cleanup: deleted={deleted_count} entries") return deleted_count except Exception as e: logger.error(f"Expired cache cleanup error: {e}", exc_info=True) db.rollback() return 0