import torch
import torch.nn.functional as F
from typing import Optional, Tuple, List


class ExplicitMemoryBank:
    def __init__(
        self, 
        hidden_size: int, 
        num_heads: int = 1,
        head_dim: Optional[int] = None,
        device: str = 'cpu',
        similarity_type: str = 'cosine'  # 'cosine' or 'dot'
    ):
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim if head_dim is not None else hidden_size // num_heads
        self.device = device
        self.similarity_type = similarity_type

        self.key_storage: List[torch.Tensor] = []
        self.value_storage: List[torch.Tensor] = []

        self._cached_keys: Optional[torch.Tensor] = None
        self._cached_values: Optional[torch.Tensor] = None
        self._cache_valid: bool = False
        
    @property
    def num_memories(self) -> int:
        return len(self.key_storage)
    
    @property
    def is_empty(self) -> bool:
        return len(self.key_storage) == 0
    
    def _invalidate_cache(self):
        self._cache_valid = False
        self._cached_keys = None
        self._cached_values = None
        
    def _build_cache(self):
        if self._cache_valid or self.is_empty:
            return

        self._cached_keys = torch.cat(self.key_storage, dim=1)
        self._cached_values = torch.cat(self.value_storage, dim=1)
        self._cache_valid = True
        
    def add_memory(
        self, 
        key_states: torch.Tensor, 
        value_states: torch.Tensor,
        pooling: str = 'mean'
    ):
        key_states = key_states.to(self.device)
        value_states = value_states.to(self.device)

        if key_states.dim() == 4:
            key_states = key_states[0]
            value_states = value_states[0]

        if pooling == 'mean':
            key_pooled = key_states.mean(dim=1, keepdim=True)
            value_pooled = value_states.mean(dim=1, keepdim=True)
        elif pooling == 'max':
            key_pooled = key_states.max(dim=1, keepdim=True)[0]
            value_pooled = value_states.max(dim=1, keepdim=True)[0]
        elif pooling == 'last':
            key_pooled = key_states[:, -1:, :]
            value_pooled = value_states[:, -1:, :]
        else:
            raise ValueError(f"Unknown pooling method: {pooling}")
        
        self.key_storage.append(key_pooled.detach())
        self.value_storage.append(value_pooled.detach())
        
        self._invalidate_cache()
        
    def add_memory_batch(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        chunk_size: int,
        pooling: str = 'mean'
    ):
        key_states = key_states.to(self.device)
        value_states = value_states.to(self.device)
        
        if key_states.dim() == 4:
            key_states = key_states[0]
            value_states = value_states[0]
            
        seq_len = key_states.shape[1]
        
        for start_idx in range(0, seq_len, chunk_size):
            end_idx = min(start_idx + chunk_size, seq_len)
            key_chunk = key_states[:, start_idx:end_idx, :]
            value_chunk = value_states[:, start_idx:end_idx, :]
            self.add_memory(key_chunk, value_chunk, pooling=pooling)
    
    def retrieve(
        self, 
        query_state: torch.Tensor, 
        top_k: int = 5,
        return_scores: bool = False
    ):
        if self.is_empty:
            empty_keys = torch.zeros(
                self.num_heads, 0, self.head_dim, 
                device=query_state.device, dtype=query_state.dtype
            )
            empty_values = torch.zeros_like(empty_keys)
            if return_scores:
                empty_scores = torch.zeros(self.num_heads, 0, device=query_state.device)
                return empty_keys, empty_values, empty_scores
            return empty_keys, empty_values

        self._build_cache()

        cached_keys = self._cached_keys.to(query_state.device, dtype=query_state.dtype)
        cached_values = self._cached_values.to(query_state.device, dtype=query_state.dtype)

        if query_state.dim() == 4:
            query_state = query_state[0]

        query_pooled = query_state.mean(dim=1, keepdim=True)

        if self.similarity_type == 'cosine':
            query_norm = F.normalize(query_pooled, p=2, dim=-1)
            keys_norm = F.normalize(cached_keys, p=2, dim=-1)
            similarity = torch.bmm(query_norm, keys_norm.transpose(1, 2))
        else:
            scale = self.head_dim ** -0.5
            similarity = torch.bmm(query_pooled, cached_keys.transpose(1, 2)) * scale

        similarity = similarity.squeeze(1)

        actual_k = min(top_k, self.num_memories)
        top_scores, top_indices = torch.topk(similarity, k=actual_k, dim=-1)

        num_heads = cached_keys.shape[0]
        head_dim = cached_keys.shape[2]

        indices_expanded = top_indices.unsqueeze(-1).expand(-1, -1, head_dim)

        retrieved_keys = torch.gather(cached_keys, dim=1, index=indices_expanded)
        retrieved_values = torch.gather(cached_values, dim=1, index=indices_expanded)
        
        if return_scores:
            return retrieved_keys, retrieved_values, top_scores
        return retrieved_keys, retrieved_values
    
    def retrieve_all(self):
        if self.is_empty:
            return (
                torch.zeros(self.num_heads, 0, self.head_dim, device=self.device),
                torch.zeros(self.num_heads, 0, self.head_dim, device=self.device)
            )
        
        self._build_cache()
        return self._cached_keys, self._cached_values
    
    def clear(self):
        self.key_storage.clear()
        self.value_storage.clear()
        self._invalidate_cache()
        
    def save(self, path: str):
        self._build_cache()
        state = {
            'hidden_size': self.hidden_size,
            'num_heads': self.num_heads,
            'head_dim': self.head_dim,
            'similarity_type': self.similarity_type,
            'keys': self._cached_keys,
            'values': self._cached_values,
        }
        torch.save(state, path)
        
    def load(self, path: str):
        state = torch.load(path, map_location=self.device)
        self.hidden_size = state['hidden_size']
        self.num_heads = state['num_heads']
        self.head_dim = state['head_dim']
        self.similarity_type = state['similarity_type']

        if state['keys'] is not None and state['keys'].numel() > 0:
            num_memories = state['keys'].shape[1]
            self.key_storage = [state['keys'][:, i:i+1, :] for i in range(num_memories)]
            self.value_storage = [state['values'][:, i:i+1, :] for i in range(num_memories)]
        else:
            self.key_storage = []
            self.value_storage = []
        self._invalidate_cache()
        

class LayerMemoryBank:
    def __init__(
        self,
        num_layers: int,
        hidden_size: int,
        num_heads: int,
        head_dim: Optional[int] = None,
        device: str = 'cpu',
        similarity_type: str = 'cosine'
    ):
        self.num_layers = num_layers
        self.banks: List[ExplicitMemoryBank] = [
            ExplicitMemoryBank(
                hidden_size=hidden_size,
                num_heads=num_heads,
                head_dim=head_dim,
                device=device,
                similarity_type=similarity_type
            )
            for _ in range(num_layers)
        ]
        
    def __getitem__(self, layer_idx: int) -> ExplicitMemoryBank:
        return self.banks[layer_idx]
    
    def add_memory(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        pooling: str = 'mean'
    ):
        self.banks[layer_idx].add_memory(key_states, value_states, pooling)
        
    def retrieve(
        self,
        layer_idx: int,
        query_state: torch.Tensor,
        top_k: int = 5
    ):
        return self.banks[layer_idx].retrieve(query_state, top_k)
    
    def clear(self):
        for bank in self.banks:
            bank.clear()
            
    def clear_layer(self, layer_idx: int):
        self.banks[layer_idx].clear()
        
    @property
    def total_memories(self) -> int:
        return sum(bank.num_memories for bank in self.banks)
    
    def save(self, path: str):
        states = []
        for bank in self.banks:
            bank._build_cache()
            states.append({
                'keys': bank._cached_keys,
                'values': bank._cached_values,
            })
        torch.save({
            'num_layers': self.num_layers,
            'config': {
                'hidden_size': self.banks[0].hidden_size,
                'num_heads': self.banks[0].num_heads,
                'head_dim': self.banks[0].head_dim,
                'similarity_type': self.banks[0].similarity_type,
            },
            'layers': states,
        }, path)
        
    def load(self, path: str):
        data = torch.load(path, map_location=self.banks[0].device)
        for i, layer_state in enumerate(data['layers']):
            if layer_state['keys'] is not None and layer_state['keys'].numel() > 0:
                num_memories = layer_state['keys'].shape[1]
                self.banks[i].key_storage = [
                    layer_state['keys'][:, j:j+1, :] for j in range(num_memories)
                ]
                self.banks[i].value_storage = [
                    layer_state['values'][:, j:j+1, :] for j in range(num_memories)
                ]
            self.banks[i]._invalidate_cache()
