import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict, Union
from dataclasses import dataclass
import math


@dataclass
class DynamicMemoryConfig:
    core_compression_ratio: float = 0.1
    latent_compression_ratio: float = 0.5
    retrieval_top_k: int = 32
    similarity_metric: str = 'cosine'
    attention_weighted_retrieval: bool = True
    retrieval_temperature: float = 1.0
    cache_retrieval: bool = True
    latent_device: str = 'cpu'
    max_latent_size: int = 4096


class CoreMemory:
    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        device: str = 'cuda',
    ):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device
        
        self.key_cache: List[Optional[torch.Tensor]] = [None] * num_layers
        self.value_cache: List[Optional[torch.Tensor]] = [None] * num_layers
        
        self.seq_lengths: List[int] = [0] * num_layers
        
    def update(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
    ):
        self.key_cache[layer_idx] = key_states.to(self.device)
        self.value_cache[layer_idx] = value_states.to(self.device)
        self.seq_lengths[layer_idx] = key_states.shape[2]
        
    def get(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    
    def clear(self):
        self.key_cache = [None] * self.num_layers
        self.value_cache = [None] * self.num_layers
        self.seq_lengths = [0] * self.num_layers
        
    @property
    def total_tokens(self) -> int:
        return sum(self.seq_lengths)


class LatentBank:
    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        config: DynamicMemoryConfig,
    ):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.config = config
        self.device = config.latent_device
        
        self.key_bank: List[Optional[torch.Tensor]] = [None] * num_layers
        self.value_bank: List[Optional[torch.Tensor]] = [None] * num_layers
        
        self.importance_scores: List[Optional[torch.Tensor]] = [None] * num_layers
        
        self.bank_sizes: List[int] = [0] * num_layers
        
    def store(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        importance_scores: Optional[torch.Tensor] = None,
    ):
        keys = key_states.detach().to(self.device)
        values = value_states.detach().to(self.device)
        
        if self.key_bank[layer_idx] is not None:
            self.key_bank[layer_idx] = torch.cat([self.key_bank[layer_idx], keys], dim=2)
            self.value_bank[layer_idx] = torch.cat([self.value_bank[layer_idx], values], dim=2)
        else:
            self.key_bank[layer_idx] = keys
            self.value_bank[layer_idx] = values
            
        if importance_scores is not None:
            scores = importance_scores.detach().to(self.device)
            if self.importance_scores[layer_idx] is not None:
                self.importance_scores[layer_idx] = torch.cat(
                    [self.importance_scores[layer_idx], scores], dim=-1
                )
            else:
                self.importance_scores[layer_idx] = scores
        
        if self.key_bank[layer_idx].shape[2] > self.config.max_latent_size:
            self.key_bank[layer_idx] = self.key_bank[layer_idx][:, :, -self.config.max_latent_size:, :]
            self.value_bank[layer_idx] = self.value_bank[layer_idx][:, :, -self.config.max_latent_size:, :]
            if self.importance_scores[layer_idx] is not None:
                self.importance_scores[layer_idx] = self.importance_scores[layer_idx][:, :, -self.config.max_latent_size:]
        
        self.bank_sizes[layer_idx] = self.key_bank[layer_idx].shape[2]
        
    def retrieve(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
        top_k: Optional[int] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        if self.key_bank[layer_idx] is None or self.bank_sizes[layer_idx] == 0:
            return None, None, None
            
        top_k = top_k or self.config.retrieval_top_k
        top_k = min(top_k, self.bank_sizes[layer_idx])
        
        query = query_states.to(self.device)
        keys = self.key_bank[layer_idx]
        values = self.value_bank[layer_idx]
        
        query_for_retrieval = query[:, :, -1:, :]
        
        if self.config.similarity_metric == 'cosine':
            query_norm = F.normalize(query_for_retrieval, p=2, dim=-1)
            keys_norm = F.normalize(keys, p=2, dim=-1)
            
            scores = torch.matmul(query_norm, keys_norm.transpose(-2, -1))
        else:
            scale = 1.0 / math.sqrt(self.head_dim)
            scores = torch.matmul(query_for_retrieval, keys.transpose(-2, -1)) * scale
            
        scores = scores.squeeze(2)
        
        scores = scores / self.config.retrieval_temperature
        
        if self.config.attention_weighted_retrieval and self.importance_scores[layer_idx] is not None:
            imp_scores = self.importance_scores[layer_idx]
            imp_scores = (imp_scores - imp_scores.min()) / (imp_scores.max() - imp_scores.min() + 1e-8)
            scores = scores + 0.5 * imp_scores
        
        top_scores, top_indices = torch.topk(scores, top_k, dim=-1)
        
        batch_size, num_heads, _ = top_indices.shape

        indices_expanded = top_indices.unsqueeze(-1).expand(-1, -1, -1, self.head_dim)
        
        retrieved_keys = torch.gather(keys, dim=2, index=indices_expanded)
        retrieved_values = torch.gather(values, dim=2, index=indices_expanded)

        original_device = query_states.device
        retrieved_keys = retrieved_keys.to(original_device)
        retrieved_values = retrieved_values.to(original_device)
        top_scores = top_scores.to(original_device)
        
        return retrieved_keys, retrieved_values, top_scores
    
    def clear(self):
        """Clear all latent banks."""
        self.key_bank = [None] * self.num_layers
        self.value_bank = [None] * self.num_layers
        self.importance_scores = [None] * self.num_layers
        self.bank_sizes = [0] * self.num_layers
        
    @property
    def total_tokens(self) -> int:
        """Total tokens across all layers."""
        return sum(self.bank_sizes)


class DynamicMemoryManager:
    """
    Main class for managing the hierarchical dynamic memory system.
    
    Coordinates between Core Memory and Latent Bank to provide:
    1. Efficient storage at multiple resolutions
    2. Query-adaptive retrieval at inference time
    3. Dynamic memory allocation based on query complexity
    """
    
    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        hidden_size: int,
        config: Optional[DynamicMemoryConfig] = None,
    ):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.hidden_size = hidden_size
        self.config = config or DynamicMemoryConfig()
        
        self.core = CoreMemory(
            num_layers=num_layers,
            num_heads=num_heads,
            head_dim=head_dim,
            device='cuda',
        )

        self.latent = LatentBank(
            num_layers=num_layers,
            num_heads=num_heads,
            head_dim=head_dim,
            config=config,
        )

        self._retrieval_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
        
    def compress_and_store(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        importance_scores: torch.Tensor,
        protected_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        device = key_states.device
        
        num_core = int(seq_len * self.config.core_compression_ratio)
        num_latent = int(seq_len * self.config.latent_compression_ratio) - num_core
        
        num_core = max(num_core, 16)
        num_latent = max(num_latent, 0)
        
        if importance_scores.dim() == 3:
            imp_avg = importance_scores.mean(dim=1)
        else:
            imp_avg = importance_scores
            
        if protected_mask is not None:
            imp_avg = imp_avg.clone()
            imp_avg[protected_mask] = float('inf')

        sorted_scores, sorted_indices = imp_avg.sort(dim=-1, descending=True)

        core_indices = sorted_indices[:, :num_core]
        latent_indices = sorted_indices[:, num_core:num_core + num_latent]

        core_indices = core_indices.sort(dim=-1)[0]
        latent_indices = latent_indices.sort(dim=-1)[0]

        core_idx_expanded = core_indices.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, head_dim)
        latent_idx_expanded = latent_indices.unsqueeze(1).unsqueeze(-1).expand(-1, num_heads, -1, head_dim)
 
        core_keys = torch.gather(key_states, dim=2, index=core_idx_expanded)
        core_values = torch.gather(value_states, dim=2, index=core_idx_expanded)

        self.core.update(layer_idx, core_keys, core_values)

        if num_latent > 0:
            latent_keys = torch.gather(key_states, dim=2, index=latent_idx_expanded)
            latent_values = torch.gather(value_states, dim=2, index=latent_idx_expanded)

            latent_imp_idx = latent_indices.unsqueeze(1).expand(-1, num_heads, -1)
            latent_imp = torch.gather(importance_scores, dim=2, index=latent_imp_idx)
            
            self.latent.store(layer_idx, latent_keys, latent_values, latent_imp)
        
        return core_keys, core_values
    
    def retrieve_for_query(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
        use_cache: bool = True,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
        if use_cache and layer_idx in self._retrieval_cache:
            return self._retrieval_cache[layer_idx]

        retrieved_keys, retrieved_values, scores = self.latent.retrieve(
            layer_idx, query_states
        )

        if use_cache and retrieved_keys is not None:
            self._retrieval_cache[layer_idx] = (retrieved_keys, retrieved_values)
        
        return retrieved_keys, retrieved_values
    
    def get_augmented_kv(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
        current_keys: Optional[torch.Tensor] = None,
        current_values: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query_states.shape[0]
        device = query_states.device
        dtype = query_states.dtype
        
        kv_parts_keys = []
        kv_parts_values = []

        core_k, core_v = self.core.get(layer_idx)
        if core_k is not None:
            core_k = core_k.expand(batch_size, -1, -1, -1).to(device, dtype=dtype)
            core_v = core_v.expand(batch_size, -1, -1, -1).to(device, dtype=dtype)
            kv_parts_keys.append(core_k)
            kv_parts_values.append(core_v)

        ret_k, ret_v = self.retrieve_for_query(layer_idx, query_states)
        if ret_k is not None:
            kv_parts_keys.append(ret_k.to(device, dtype=dtype))
            kv_parts_values.append(ret_v.to(device, dtype=dtype))

        if current_keys is not None:
            kv_parts_keys.append(current_keys)
            kv_parts_values.append(current_values)

        if len(kv_parts_keys) == 0:
            return None, None
            
        augmented_keys = torch.cat(kv_parts_keys, dim=2)
        augmented_values = torch.cat(kv_parts_values, dim=2)
        
        return augmented_keys, augmented_values
    
    def clear_retrieval_cache(self):
        self._retrieval_cache.clear()
        
    def clear_all(self):
        self.core.clear()
        self.latent.clear()
        self._retrieval_cache.clear()
        
    def get_stats(self) -> Dict:
        return {
            'core_total_tokens': self.core.total_tokens,
            'latent_total_tokens': self.latent.total_tokens,
            'core_per_layer': self.core.seq_lengths,
            'latent_per_layer': self.latent.bank_sizes,
        }


class QueryAdaptiveAttention(nn.Module):
    def __init__(
        self,
        memory_manager: DynamicMemoryManager,
        layer_idx: int,
        num_heads: int,
        head_dim: int,
    ):
        super().__init__()
        
        self.memory = memory_manager
        self.layer_idx = layer_idx
        self.num_heads = num_heads
        self.head_dim = head_dim
        
    def forward(
        self,
        query_states: torch.Tensor,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        aug_keys, aug_values = self.memory.get_augmented_kv(
            self.layer_idx,
            query_states,
            key_states,
            value_states,
        )

        scale = 1.0 / math.sqrt(self.head_dim)
        attn_weights = torch.matmul(query_states, aug_keys.transpose(-2, -1)) * scale

        if attention_mask is not None:
            retrieved_len = aug_keys.shape[2] - key_states.shape[2]
            if retrieved_len > 0:
                batch_size, _, query_len, _ = attention_mask.shape
                retrieved_mask = torch.zeros(
                    batch_size, 1, query_len, retrieved_len,
                    device=attention_mask.device, dtype=attention_mask.dtype
                )
                attention_mask = torch.cat([retrieved_mask, attention_mask], dim=-1)
            
            attn_weights = attn_weights + attention_mask
        
        attn_probs = F.softmax(attn_weights, dim=-1)
        output = torch.matmul(attn_probs, aug_values)
        
        return output
