import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict, Union
from dataclasses import dataclass, field
import math

# Import TASM components
from .task_vector import (
    TaskVectorConfig,
    TaskVectorExtractor,
    TaskVectorImportanceScorer,
    TaskVectorGuidedCompressor,
)
from .token_merging import (
    TokenMergingConfig,
    SemanticTokenMerger,
    VisualTokenMerger,
)
from .dynamic_memory import (
    DynamicMemoryConfig,
    DynamicMemoryManager,
    CoreMemory,
    LatentBank,
)


@dataclass
class TASMConfig:
    task_vector: TaskVectorConfig = field(default_factory=TaskVectorConfig)
    token_merging: TokenMergingConfig = field(default_factory=TokenMergingConfig)
    dynamic_memory: DynamicMemoryConfig = field(default_factory=DynamicMemoryConfig)
    enable_task_vector: bool = True
    enable_token_merging: bool = True
    enable_dynamic_retrieval: bool = True
    target_compression_ratio: float = 0.2
    js_threshold: float = 0.005
    layer_adaptive: bool = True
    sink_tokens: int = 4
    local_tokens: int = 64


class TASMCompressor:
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        num_layers: int,
        head_dim: Optional[int] = None,
        config: Optional[TASMConfig] = None,
    ):
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_layers = num_layers
        self.head_dim = head_dim if head_dim else hidden_size // num_heads
        self.config = config or TASMConfig()

        self.task_compressor = TaskVectorGuidedCompressor(
            hidden_size=hidden_size,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            num_layers=num_layers,
            head_dim=self.head_dim,
            config=self.config.task_vector,
        )

        self.token_merger = SemanticTokenMerger(
            hidden_size=hidden_size,
            num_heads=num_kv_heads,
            head_dim=self.head_dim,
            config=self.config.token_merging,
        )

        self.visual_merger = VisualTokenMerger(
            hidden_size=hidden_size,
            num_heads=num_kv_heads,
            head_dim=self.head_dim,
            config=self.config.token_merging,
        )

        self.memory_manager = DynamicMemoryManager(
            num_layers=num_layers,
            num_heads=num_kv_heads,
            head_dim=self.head_dim,
            hidden_size=hidden_size,
            config=self.config.dynamic_memory,
        )

        self._compression_stats = {
            'original_tokens': 0,
            'core_tokens': 0,
            'latent_tokens': 0,
            'merged_tokens': 0,
        }

        self._device = None
        
    def to(self, device, dtype=torch.bfloat16):
        """Move all components to the specified device and dtype."""
        self._device = device
        self._dtype = dtype
        self.task_compressor.to(device=device, dtype=dtype)
        self.token_merger.to(device=device, dtype=dtype)
        self.visual_merger.to(device=device, dtype=dtype)
        return self
        
    def clear(self):
        self.task_compressor.clear_cache()
        self.memory_manager.clear_all()
        self._compression_stats = {k: 0 for k in self._compression_stats}
        
    def extract_task_vector(
        self,
        question_embeddings: torch.Tensor,
        answer_embeddings: torch.Tensor,
        hidden_states: Optional[List[torch.Tensor]] = None,
        question_mask: Optional[torch.Tensor] = None,
        answer_mask: Optional[torch.Tensor] = None,
    ):
        if not self.config.enable_task_vector:
            return
            
        self.task_compressor.extract_and_cache_task_vector(
            question_embeddings=question_embeddings,
            answer_embeddings=answer_embeddings,
            hidden_states=hidden_states,
            question_mask=question_mask,
            answer_mask=answer_mask,
        )
        
    def compute_importance_scores(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        attention_scores: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.config.enable_task_vector:
            return self.task_compressor.compute_layer_importance(
                layer_idx=layer_idx,
                key_states=key_states,
                value_states=value_states,
                hidden_states=None,  
                attention_mask=None, 
            )
        elif attention_scores is not None:
            return attention_scores
        else:
            return torch.ones(
                key_states.shape[0], key_states.shape[1], key_states.shape[2],
                device=key_states.device, dtype=key_states.dtype
            )
    
    def compress_layer(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        importance_scores: torch.Tensor,
        visual_token_mask: Optional[torch.Tensor] = None,
        protected_mask: Optional[torch.Tensor] = None,
        num_keep: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        device = key_states.device
        if num_keep is None:
            num_keep = int(seq_len * self.config.target_compression_ratio)
        num_keep = max(num_keep, self.config.sink_tokens + self.config.local_tokens)
        
        self._compression_stats['original_tokens'] += seq_len
        
        if self.config.enable_token_merging:
            if visual_token_mask is not None and visual_token_mask.any():
                merged_keys, merged_values, merge_map = self.visual_merger(
                    key_states=key_states,
                    value_states=value_states,
                    importance_scores=importance_scores,
                    num_keep=num_keep,
                    visual_token_mask=visual_token_mask,
                    protected_mask=protected_mask,
                )
            else:
                merged_keys, merged_values, merge_map = self.token_merger(
                    key_states=key_states,
                    value_states=value_states,
                    importance_scores=importance_scores,
                    num_keep=num_keep,
                    protected_mask=protected_mask,
                )
            
            self._compression_stats['merged_tokens'] += merged_keys.shape[2]
            
            merged_seq_len = merged_keys.shape[2]
            merged_importance = torch.ones(
                batch_size, num_heads, merged_seq_len,
                device=device, dtype=importance_scores.dtype
            )
        else:
            indices = self.task_compressor.select_important_tokens(
                importance_scores, num_keep, protected_mask
            )
            indices_expanded = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
            merged_keys = torch.gather(key_states, dim=2, index=indices_expanded)
            merged_values = torch.gather(value_states, dim=2, index=indices_expanded)
            merged_importance = torch.gather(importance_scores, dim=2, index=indices[:, :, :merged_keys.shape[2]])
        
        if self.config.enable_dynamic_retrieval:
            core_keys, core_values = self.memory_manager.compress_and_store(
                layer_idx=layer_idx,
                key_states=merged_keys,
                value_states=merged_values,
                importance_scores=merged_importance,
                protected_mask=None, 
            )
            
            self._compression_stats['core_tokens'] += core_keys.shape[2]
            self._compression_stats['latent_tokens'] += (
                merged_keys.shape[2] - core_keys.shape[2]
            )
        else:
            core_keys = merged_keys
            core_values = merged_values
            self._compression_stats['core_tokens'] += core_keys.shape[2]
        
        return core_keys, core_values
    
    def get_augmented_kv(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
        current_keys: Optional[torch.Tensor] = None,
        current_values: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.config.enable_dynamic_retrieval:
            return self.memory_manager.get_augmented_kv(
                layer_idx=layer_idx,
                query_states=query_states,
                current_keys=current_keys,
                current_values=current_values,
            )
        else:
            core_k, core_v = self.memory_manager.core.get(layer_idx)
            if core_k is None:
                return current_keys, current_values
            
            if current_keys is not None:
                batch_size = query_states.shape[0]
                core_k = core_k.expand(batch_size, -1, -1, -1)
                core_v = core_v.expand(batch_size, -1, -1, -1)
                return (
                    torch.cat([core_k, current_keys], dim=2),
                    torch.cat([core_v, current_values], dim=2),
                )
            return core_k, core_v
    
    def clear_retrieval_cache(self):
        self.memory_manager.clear_retrieval_cache()
        
    def get_stats(self) -> Dict:
        stats = self._compression_stats.copy()
        stats.update(self.memory_manager.get_stats())
        
        if stats['original_tokens'] > 0:
            stats['compression_ratio'] = stats['core_tokens'] / stats['original_tokens']
            stats['total_retained_ratio'] = (
                (stats['core_tokens'] + stats['latent_tokens']) / stats['original_tokens']
            )
        
        return stats


class TASMCache:
    def __init__(
        self,
        compressor: TASMCompressor,
        max_cache_length: int = 2048,
    ):
        self.compressor = compressor
        self.max_cache_length = max_cache_length
        
        self.key_cache: List[Optional[torch.Tensor]] = []
        self.value_cache: List[Optional[torch.Tensor]] = []
        
        self.pre_lens: List[int] = []
        
        self._is_compressed: bool = False
        
    @property
    def num_layers(self) -> int:
        return len(self.key_cache)
    
    def get_past_seq_len(self) -> int:
        if len(self.key_cache) == 0 or self.key_cache[0] is None:
            return 0
        return self.key_cache[0].shape[2]
    
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        while len(self.key_cache) <= layer_idx:
            self.key_cache.append(None)
            self.value_cache.append(None)
            self.pre_lens.append(0)
        
        if self.key_cache[layer_idx] is not None:
            key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
            value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
        
        self.key_cache[layer_idx] = key_states
        self.value_cache[layer_idx] = value_states
        
        return key_states, value_states
    
    def compress(
        self,
        importance_scores: Dict[int, torch.Tensor],
        visual_token_mask: Optional[torch.Tensor] = None,
        protected_mask: Optional[torch.Tensor] = None,
    ):
        for layer_idx in range(len(self.key_cache)):
            if self.key_cache[layer_idx] is None:
                continue
                
            imp_scores = importance_scores.get(layer_idx)
            
            core_k, core_v = self.compressor.compress_layer(
                layer_idx=layer_idx,
                key_states=self.key_cache[layer_idx],
                value_states=self.value_cache[layer_idx],
                importance_scores=imp_scores,
                visual_token_mask=visual_token_mask,
                protected_mask=protected_mask,
            )
            
            self.key_cache[layer_idx] = core_k
            self.value_cache[layer_idx] = core_v
            self.pre_lens[layer_idx] = core_k.shape[2]
        
        self._is_compressed = True
    
    def merge_other_kv(self, other_cache):
        for layer_idx in range(len(other_cache.key_cache)):
            if other_cache.key_cache[layer_idx] is None:
                continue
                
            if self.key_cache[layer_idx] is not None:
                self.key_cache[layer_idx] = torch.cat([
                    self.key_cache[layer_idx],
                    other_cache.key_cache[layer_idx]
                ], dim=2)
                self.value_cache[layer_idx] = torch.cat([
                    self.value_cache[layer_idx],
                    other_cache.value_cache[layer_idx]
                ], dim=2)
            else:
                self.key_cache[layer_idx] = other_cache.key_cache[layer_idx]
                self.value_cache[layer_idx] = other_cache.value_cache[layer_idx]
    
    def get_augmented(
        self,
        layer_idx: int,
        query_states: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        current_k = self.key_cache[layer_idx] if layer_idx < len(self.key_cache) else None
        current_v = self.value_cache[layer_idx] if layer_idx < len(self.value_cache) else None
        
        return self.compressor.get_augmented_kv(
            layer_idx=layer_idx,
            query_states=query_states,
            current_keys=current_k,
            current_values=current_v,
        )


def create_tasm_compressor(
    model_config,
    tasm_config: Optional[TASMConfig] = None,
) -> TASMCompressor:
    return TASMCompressor(
        hidden_size=model_config.hidden_size,
        num_heads=model_config.num_attention_heads,
        num_kv_heads=getattr(model_config, 'num_key_value_heads', model_config.num_attention_heads),
        num_layers=model_config.num_hidden_layers,
        head_dim=model_config.hidden_size // model_config.num_attention_heads,
        config=tasm_config,
    )
