refactor: Consolidate cache file loading logic and enhance cache saving for different data types

This commit is contained in:
2026-03-06 14:07:22 +08:00
parent 0cbf629499
commit 5e1e3e0668

View File

@@ -45,6 +45,13 @@ class VoiceCacheManager:
def get_audio_hash(self, audio_data: bytes) -> str: def get_audio_hash(self, audio_data: bytes) -> str:
return hashlib.sha256(audio_data).hexdigest() return hashlib.sha256(audio_data).hexdigest()
def _load_cache_file(self, cache_file: Path):
if cache_file.suffix == '.pt':
import torch
return torch.load(cache_file, weights_only=False)
with open(cache_file, 'rb') as f:
return np.load(f, allow_pickle=False)
async def get_cache(self, user_id: int, ref_audio_hash: str, db: Session) -> Optional[Dict[str, Any]]: async def get_cache(self, user_id: int, ref_audio_hash: str, db: Session) -> Optional[Dict[str, Any]]:
try: try:
cache_entry = get_cache_entry(db, user_id, ref_audio_hash) cache_entry = get_cache_entry(db, user_id, ref_audio_hash)
@@ -63,8 +70,7 @@ class VoiceCacheManager:
logger.warning(f"Cache path out of cache dir: {resolved_cache_file}") logger.warning(f"Cache path out of cache dir: {resolved_cache_file}")
return None return None
with open(cache_file, 'rb') as f: cache_data = self._load_cache_file(cache_file)
cache_data = np.load(f, allow_pickle=False)
logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}") logger.info(f"Cache hit: user={user_id}, hash={ref_audio_hash[:8]}..., access_count={cache_entry.access_count}")
return { return {
@@ -94,8 +100,7 @@ class VoiceCacheManager:
logger.warning(f"Cache path out of cache dir: {resolved_cache_file}") logger.warning(f"Cache path out of cache dir: {resolved_cache_file}")
return None return None
with open(cache_file, 'rb') as f: cache_data = self._load_cache_file(cache_file)
cache_data = np.load(f, allow_pickle=False)
cache_entry.last_accessed = datetime.utcnow() cache_entry.last_accessed = datetime.utcnow()
cache_entry.access_count += 1 cache_entry.access_count += 1
@@ -122,16 +127,19 @@ class VoiceCacheManager:
) -> str: ) -> str:
async with self._lock: async with self._lock:
try: try:
cache_filename = f"{user_id}_{ref_audio_hash}.npy"
cache_path = self.cache_dir / cache_filename
if hasattr(cache_data, "detach"): if hasattr(cache_data, "detach"):
cache_data = cache_data.detach().cpu().numpy() cache_data = cache_data.detach().cpu().numpy()
elif not isinstance(cache_data, np.ndarray):
cache_data = np.asarray(cache_data, dtype=np.float32)
if isinstance(cache_data, np.ndarray):
cache_filename = f"{user_id}_{ref_audio_hash}.npy"
cache_path = self.cache_dir / cache_filename
with open(cache_path, 'wb') as f: with open(cache_path, 'wb') as f:
np.save(f, cache_data, allow_pickle=False) np.save(f, cache_data, allow_pickle=False)
else:
import torch
cache_filename = f"{user_id}_{ref_audio_hash}.pt"
cache_path = self.cache_dir / cache_filename
torch.save(cache_data, cache_path)
cache_entry = create_cache_entry( cache_entry = create_cache_entry(
db=db, db=db,