"""Feature cache for efficient feature storage and retrieval."""

import torch
import numpy as np
from typing import Dict, List, Optional, Tuple
from pathlib import Path
import pickle
import os


class FeatureCache:
    """Cache for storing and retrieving model features and logits."""
    
    def __init__(
        self,
        cache_dir: str = "./cache",
        refresh_qt_only: bool = True
    ):
        """Initialize feature cache.
        
        Args:
            cache_dir: Directory to store cache files
            refresh_qt_only: Whether to only refresh Q_tilde features
        """
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.refresh_qt_only = refresh_qt_only
        
        # In-memory cache
        self.features_cache = {}
        self.logits_cache = {}
        
        # Cache metadata
        self.metadata = {}
    
    def get_cache_key(self, model_name: str, domain: str, split: str) -> str:
        """Generate cache key for model-domain-split combination.
        
        Args:
            model_name: Name of model (Q, Qt)
            domain: Domain name
            split: Split name
            
        Returns:
            Cache key
        """
        return f"{model_name}_{domain}_{split}"
    
    def get_cache_path(self, cache_key: str, data_type: str) -> Path:
        """Get cache file path.
        
        Args:
            cache_key: Cache key
            data_type: Type of data (features, logits)
            
        Returns:
            Cache file path
        """
        return self.cache_dir / f"{cache_key}_{data_type}.pkl"
    
    def store_features(
        self,
        model_name: str,
        domain: str,
        split: str,
        features: torch.Tensor,
        indices: Optional[List[int]] = None,
        force_overwrite: bool = False
    ) -> None:
        """Store features in cache.
        
        Args:
            model_name: Name of model
            domain: Domain name
            split: Split name
            features: Feature tensor
            indices: Sample indices
            force_overwrite: Whether to force overwrite existing cache
        """
        cache_key = self.get_cache_key(model_name, domain, split)
        cache_path = self.get_cache_path(cache_key, "features")
        
        # Check if cache exists and should be overwritten
        if cache_path.exists() and not force_overwrite:
            print(f"Cache exists for {cache_key}, skipping feature storage")
            return
        
        # Store in memory
        self.features_cache[cache_key] = {
            'features': features,
            'indices': indices,
            'shape': features.shape
        }
        
        # Store on disk
        cache_data = {
            'features': features.detach().cpu(),
            'indices': indices,
            'shape': features.shape,
            'model_name': model_name,
            'domain': domain,
            'split': split
        }
        
        with open(cache_path, 'wb') as f:
            pickle.dump(cache_data, f)
        
        print(f"Stored features for {cache_key}: {features.shape}")
    
    def store_logits(
        self,
        model_name: str,
        domain: str,
        split: str,
        logits: torch.Tensor,
        indices: Optional[List[int]] = None,
        force_overwrite: bool = False
    ) -> None:
        """Store logits in cache.
        
        Args:
            model_name: Name of model
            domain: Domain name
            split: Split name
            logits: Logits tensor
            indices: Sample indices
            force_overwrite: Whether to force overwrite existing cache
        """
        cache_key = self.get_cache_key(model_name, domain, split)
        cache_path = self.get_cache_path(cache_key, "logits")
        
        # Check if cache exists and should be overwritten
        if cache_path.exists() and not force_overwrite:
            print(f"Cache exists for {cache_key}, skipping logits storage")
            return
        
        # Store in memory
        self.logits_cache[cache_key] = {
            'logits': logits,
            'indices': indices,
            'shape': logits.shape
        }
        
        # Store on disk
        cache_data = {
            'logits': logits.detach().cpu(),
            'indices': indices,
            'shape': logits.shape,
            'model_name': model_name,
            'domain': domain,
            'split': split
        }
        
        with open(cache_path, 'wb') as f:
            pickle.dump(cache_data, f)
        
        print(f"Stored logits for {cache_key}: {logits.shape}")
    
    def load_features(
        self,
        model_name: str,
        domain: str,
        split: str,
        device: str = "cpu"
    ) -> Optional[Tuple[torch.Tensor, List[int]]]:
        """Load features from cache.
        
        Args:
            model_name: Name of model
            domain: Domain name
            split: Split name
            device: Device to load features on
            
        Returns:
            Tuple of (features, indices) or None if not found
        """
        cache_key = self.get_cache_key(model_name, domain, split)
        
        # Try memory cache first
        if cache_key in self.features_cache:
            cache_data = self.features_cache[cache_key]
            features = cache_data['features'].to(device)
            indices = cache_data['indices']
            print(f"Loaded features from memory cache: {cache_key}")
            return features, indices
        
        # Try disk cache
        cache_path = self.get_cache_path(cache_key, "features")
        if cache_path.exists():
            with open(cache_path, 'rb') as f:
                cache_data = pickle.load(f)
            
            features = cache_data['features'].to(device)
            indices = cache_data['indices']
            
            # Update memory cache
            self.features_cache[cache_key] = {
                'features': features,
                'indices': indices,
                'shape': features.shape
            }
            
            print(f"Loaded features from disk cache: {cache_key}")
            return features, indices
        
        print(f"Features not found in cache: {cache_key}")
        return None
    
    def load_logits(
        self,
        model_name: str,
        domain: str,
        split: str,
        device: str = "cpu"
    ) -> Optional[Tuple[torch.Tensor, List[int]]]:
        """Load logits from cache.
        
        Args:
            model_name: Name of model
            domain: Domain name
            split: Split name
            device: Device to load logits on
            
        Returns:
            Tuple of (logits, indices) or None if not found
        """
        cache_key = self.get_cache_key(model_name, domain, split)
        
        # Try memory cache first
        if cache_key in self.logits_cache:
            cache_data = self.logits_cache[cache_key]
            logits = cache_data['logits'].to(device)
            indices = cache_data['indices']
            print(f"Loaded logits from memory cache: {cache_key}")
            return logits, indices
        
        # Try disk cache
        cache_path = self.get_cache_path(cache_key, "logits")
        if cache_path.exists():
            with open(cache_path, 'rb') as f:
                cache_data = pickle.load(f)
            
            logits = cache_data['logits'].to(device)
            indices = cache_data['indices']
            
            # Update memory cache
            self.logits_cache[cache_key] = {
                'logits': logits,
                'indices': indices,
                'shape': logits.shape
            }
            
            print(f"Loaded logits from disk cache: {cache_key}")
            return logits, indices
        
        print(f"Logits not found in cache: {cache_key}")
        return None
    
    def refresh_qt_features(
        self,
        model_Qt: torch.nn.Module,
        data_loader: torch.utils.data.DataLoader,
        domain: str,
        split: str,
        device: str = "cuda",
        force_overwrite: bool = True
    ) -> None:
        """Refresh Q_tilde features for a specific domain/split.
        
        Args:
            model_Qt: Q_tilde model
            data_loader: Data loader
            domain: Domain name
            split: Split name
            device: Device to use
            force_overwrite: Whether to force overwrite existing cache
        """
        print(f"Refreshing Q_tilde features for {domain} {split}...")
        
        # Extract features
        features = self._extract_features(model_Qt, data_loader, device)
        
        # Store in cache
        self.store_features(
            "Qt", domain, split, features,
            force_overwrite=force_overwrite
        )
        
        print(f"Refreshed Q_tilde features for {domain} {split}")
    
    def refresh_qt_logits(
        self,
        model_Qt: torch.nn.Module,
        data_loader: torch.utils.data.DataLoader,
        domain: str,
        split: str,
        device: str = "cuda",
        force_overwrite: bool = True
    ) -> None:
        """Refresh Q_tilde logits for a specific domain/split.
        
        Args:
            model_Qt: Q_tilde model
            data_loader: Data loader
            domain: Domain name
            split: Split name
            device: Device to use
            force_overwrite: Whether to force overwrite existing cache
        """
        print(f"Refreshing Q_tilde logits for {domain} {split}...")
        
        # Extract logits
        logits = self._extract_logits(model_Qt, data_loader, device)
        
        # Store in cache
        self.store_logits(
            "Qt", domain, split, logits,
            force_overwrite=force_overwrite
        )
        
        print(f"Refreshed Q_tilde logits for {domain} {split}")
    
    def _extract_features(
        self,
        model: torch.nn.Module,
        data_loader: torch.utils.data.DataLoader,
        device: str
    ) -> torch.Tensor:
        """Extract features from model.
        
        Args:
            model: Model to extract features from
            data_loader: Data loader
            device: Device to use
            
        Returns:
            Feature tensor
        """
        model.eval()
        features_list = []
        
        with torch.no_grad():
            for data, _ in data_loader:
                data = data.to(device)
                
                if hasattr(model, 'get_features'):
                    features = model.get_features(data)
                else:
                    # Fallback: use penultimate layer
                    features = model.backbone(data)
                    features = features.view(features.size(0), -1)
                
                features_list.append(features.cpu())
        
        return torch.cat(features_list, dim=0)
    
    def _extract_logits(
        self,
        model: torch.nn.Module,
        data_loader: torch.utils.data.DataLoader,
        device: str
    ) -> torch.Tensor:
        """Extract logits from model.
        
        Args:
            model: Model to extract logits from
            data_loader: Data loader
            device: Device to use
            
        Returns:
            Logits tensor
        """
        model.eval()
        logits_list = []
        
        with torch.no_grad():
            for data, _ in data_loader:
                data = data.to(device)
                output = model(data)
                logits_list.append(output.cpu())
        
        return torch.cat(logits_list, dim=0)
    
    def clear_cache(self, model_name: Optional[str] = None) -> None:
        """Clear cache.
        
        Args:
            model_name: Specific model to clear, or None for all
        """
        if model_name is None:
            # Clear all caches
            self.features_cache.clear()
            self.logits_cache.clear()
            print("Cleared all caches")
        else:
            # Clear specific model cache
            keys_to_remove = [k for k in self.features_cache.keys() if k.startswith(model_name)]
            for key in keys_to_remove:
                del self.features_cache[key]
                del self.logits_cache[key]
            print(f"Cleared cache for model: {model_name}")
    
    def get_cache_info(self) -> Dict[str, Dict]:
        """Get information about cached data.
        
        Returns:
            Dictionary with cache information
        """
        info = {
            'features': {},
            'logits': {}
        }
        
        # Memory cache info
        for key, data in self.features_cache.items():
            info['features'][key] = {
                'shape': data['shape'],
                'device': str(data['features'].device)
            }
        
        for key, data in self.logits_cache.items():
            info['logits'][key] = {
                'shape': data['shape'],
                'device': str(data['logits'].device)
            }
        
        # Disk cache info
        for cache_file in self.cache_dir.glob("*.pkl"):
            try:
                with open(cache_file, 'rb') as f:
                    cache_data = pickle.load(f)
                
                cache_key = f"{cache_data['model_name']}_{cache_data['domain']}_{cache_data['split']}"
                data_type = "features" if "features" in cache_file.name else "logits"
                
                if cache_key not in info[data_type]:
                    info[data_type][cache_key] = {
                        'shape': cache_data['shape'],
                        'device': 'disk'
                    }
            except:
                continue
        
        return info
    
    def __str__(self) -> str:
        """String representation of cache."""
        info = self.get_cache_info()
        return f"FeatureCache(cache_dir={self.cache_dir}, features={len(info['features'])}, logits={len(info['logits'])})"

