import torch
from datetime import datetime
import os
import pickle
from typing import List, Dict, Any
import warnings

class CacheHelper:
    def __init__(self, cache_dir: str = "./data/hs_cache", max_cache_size: int = 100):
        """
        Initialize CacheHelper to store hidden states and logits.
        
        Args:
            cache_dir: Directory to save cached data
            max_cache_size: Maximum number of (hidden_states, logits) pairs to cache before saving to file
        """
        now = datetime.now()
        datetime_str = now.strftime("%Y-%m-%d-%H:%M:%S")
        self.cache_dir = cache_dir + '-' + datetime_str
        self.max_cache_size = max_cache_size
        self.hidden_states_cache: List = []
        
        # Create cache directory if it doesn't exist
        os.makedirs(self.cache_dir, exist_ok=True)
        
        # Counter for saved files
        self.file_counter = 0
        
        # Track if tensors are on GPU (we'll move to CPU before saving)
        self.device = None

    def cache_hidden_state_logits(self, hidden_states: torch.Tensor, mask: torch.Tensor, token_ids: torch.Tensor):
        """
        Cache the last layer hidden states and logits.
        
        Args:
            hidden_states: Hidden states from the last layer [batch_size, seq_len, hidden_size]
            logits: Output logits [batch_size, seq_len, vocab_size]
        """
        # Store device information
        if self.device is None:
            self.device = hidden_states.device
        
        # Move to CPU for saving (to avoid GPU memory issues)
        # if hidden_states.device != torch.device('cpu'):
        #     hidden_states = hidden_states.cpu()
            
        # Detach from computation graph and convert to numpy or keep as tensor
        # Using detach().clone() to avoid keeping reference to computation graph
        self.hidden_states_cache.append((hidden_states.cpu().detach(), mask.cpu().detach(), token_ids.cpu().detach()))
        
        # If cache is full, save to file
        if len(self.hidden_states_cache) >= self.max_cache_size:
            print("will save to file")
            self._save_to_file()

    def _save_to_file(self):
        """Save current cache to a file and clear the in-memory cache."""
        if len(self.hidden_states_cache) == 0:
            return
            
        filename = os.path.join(self.cache_dir, f"cache_{self.file_counter:06d}_{torch.cuda.current_device()}.pkl")
        
        # Prepare data for saving
        cache_data = {
            'hidden_states': self.hidden_states_cache,
            'device': self.device
        }
        
        try:
            with open(filename, 'wb') as f:
                pickle.dump(cache_data, f)
            print(f"Saved cache to {filename} ({len(self.hidden_states_cache)} items)")
        except Exception as e:
            warnings.warn(f"Failed to save cache: {e}")
        
        # Clear the cache
        self.hidden_states_cache.clear()
        self.file_counter += 1

    def save_remaining(self):
        """Save any remaining cached data to file."""
        self._save_to_file()

    def load_cache(self, file_index: int) -> Dict[str, Any]:
        """
        Load a specific cache file.
        
        Args:
            file_index: Index of the cache file to load
            
        Returns:
            Dictionary with 'hidden_states', 'logits', and 'device'
        """
        filename = os.path.join(self.cache_dir, f"cache_{file_index:06d}.pkl")
        if not os.path.exists(filename):
            raise FileNotFoundError(f"Cache file {filename} not found")
            
        with open(filename, 'rb') as f:
            cache_data = pickle.load(f)
        return cache_data

    def get_all_cache_files(self) -> List[str]:
        """Get list of all cache files."""
        files = [f for f in os.listdir(self.cache_dir) if f.startswith("cache_") and f.endswith(".pkl")]
        files.sort()
        return [os.path.join(self.cache_dir, f) for f in files]

    def __len__(self):
        """Return the current number of cached items."""
        return len(self.hidden_states_cache)

    def __del__(self):
        """Ensure remaining cache is saved when object is destroyed."""
        self.save_remaining()


# TOKEN_NORM='true' BASEMODEL_PATH='./.cache/modelscope/hub/models/Qwen/Qwen3-4B' accelerate launch --config_file deepspeed_config.yaml --num_processes=2 gate_train.py