import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass


@dataclass
class TokenMergingConfig:
    merge_strategy: str = 'bipartite'
    similarity_metric: str = 'cosine'
    preserve_spatial: bool = True
    max_merge_ratio: float = 0.5
    min_similarity: float = 0.5
    importance_weight: float = 0.3
    weighted_merge: bool = True
    spatial_window: int = 3


class BipartiteTokenMatcher:
    def __init__(self, config: TokenMergingConfig):
        self.config = config
        
    @torch.no_grad()
    def compute_similarity_matrix(
        self,
        source_tokens: torch.Tensor,
        target_tokens: torch.Tensor,
    ) -> torch.Tensor:
        if self.config.similarity_metric == 'cosine':
            source_norm = F.normalize(source_tokens, p=2, dim=-1)
            target_norm = F.normalize(target_tokens, p=2, dim=-1)
            similarity = torch.mm(source_norm, target_norm.t())
        elif self.config.similarity_metric == 'dot':
            similarity = torch.mm(source_tokens, target_tokens.t())
        elif self.config.similarity_metric == 'euclidean':
            dist = torch.cdist(source_tokens, target_tokens, p=2)
            similarity = 1.0 / (1.0 + dist)
        else:
            raise ValueError(f"Unknown similarity metric: {self.config.similarity_metric}")
            
        return similarity
    
    @torch.no_grad()
    def match(
        self,
        source_tokens: torch.Tensor,
        target_tokens: torch.Tensor,
        source_indices: torch.Tensor,
        target_indices: torch.Tensor,
        spatial_positions: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        similarity = self.compute_similarity_matrix(source_tokens, target_tokens)
        
        if self.config.preserve_spatial and spatial_positions is not None:
            source_pos = spatial_positions[source_indices]
            target_pos = spatial_positions[target_indices]
            
            spatial_dist = torch.cdist(
                source_pos.float().unsqueeze(0), 
                target_pos.float().unsqueeze(0), 
                p=2
            ).squeeze(0)
            
            window = self.config.spatial_window
            spatial_penalty = torch.where(
                spatial_dist > window,
                torch.zeros_like(similarity) - 10.0,
                torch.zeros_like(similarity)
            )
            similarity = similarity + spatial_penalty
        
        match_weights, matches = similarity.max(dim=-1)
        
        valid_mask = match_weights >= self.config.min_similarity
        matches[~valid_mask] = -1
        match_weights[~valid_mask] = 0.0
        
        return matches, match_weights


class SemanticTokenMerger(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: Optional[int] = None,
        config: Optional[TokenMergingConfig] = None,
    ):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim if head_dim else hidden_size // num_heads
        self.config = config if config else TokenMergingConfig()
        
        self.matcher = BipartiteTokenMatcher(self.config)
        
    @torch.no_grad()
    def split_by_importance(
        self,
        importance_scores: torch.Tensor,
        num_keep: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if importance_scores.dim() == 2:
            importance_scores = importance_scores.mean(dim=0)
            
        seq_len = importance_scores.shape[0]
        num_keep = min(num_keep, seq_len)
        
        _, sorted_indices = importance_scores.sort(descending=True)
        target_indices = sorted_indices[:num_keep]
        source_indices = sorted_indices[num_keep:]
        
        target_indices = target_indices.sort()[0]
        source_indices = source_indices.sort()[0]
        
        return target_indices, source_indices
    
    @torch.no_grad()
    def merge_tokens(
        self,
        tokens: torch.Tensor,
        source_indices: torch.Tensor,
        target_indices: torch.Tensor,
        matches: torch.Tensor,
        match_weights: torch.Tensor,
    ) -> torch.Tensor:
        has_head_dim = tokens.dim() == 3
        
        if has_head_dim:
            num_heads, seq_len, dim = tokens.shape
            merged_list = []
            for h in range(num_heads):
                merged_h = self._merge_single_head(
                    tokens[h], source_indices, target_indices, 
                    matches, match_weights
                )
                merged_list.append(merged_h)
            return torch.stack(merged_list, dim=0)
        else:
            return self._merge_single_head(
                tokens, source_indices, target_indices,
                matches, match_weights
            )
    
    def _merge_single_head(
        self,
        tokens: torch.Tensor,
        source_indices: torch.Tensor,
        target_indices: torch.Tensor,
        matches: torch.Tensor,
        match_weights: torch.Tensor,
    ) -> torch.Tensor:
        device = tokens.device
        dtype = tokens.dtype
        dim = tokens.shape[-1]
        num_targets = len(target_indices)
        
        merged = tokens[target_indices].clone()
        
        weight_sum = torch.ones(num_targets, device=device, dtype=dtype)
        
        for src_idx, (match_idx, weight) in enumerate(zip(matches, match_weights)):
            if match_idx >= 0 and weight > 0:
                src_token = tokens[source_indices[src_idx]]
                
                if self.config.weighted_merge:
                    merged[match_idx] = merged[match_idx] + weight * src_token
                    weight_sum[match_idx] = weight_sum[match_idx] + weight
                else:
                    merged[match_idx] = merged[match_idx] + src_token
                    weight_sum[match_idx] = weight_sum[match_idx] + 1.0
        
        merged = merged / weight_sum.unsqueeze(-1).clamp(min=1e-8)
        
        return merged
    
    def forward(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        importance_scores: torch.Tensor,
        num_keep: int,
        spatial_positions: Optional[torch.Tensor] = None,
        protected_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        device = key_states.device
        dtype = key_states.dtype
        
        merged_keys_list = []
        merged_values_list = []
        merge_mapping_list = []
        
        for b in range(batch_size):
            if importance_scores.dim() == 3:
                imp = importance_scores[b].mean(dim=0)
            else:
                imp = importance_scores[b]
            
            if protected_mask is not None:
                imp = imp.clone()
                imp[protected_mask[b]] = float('inf')
            
            target_idx, source_idx = self.split_by_importance(imp, num_keep)
            
            if len(source_idx) == 0:
                merged_keys_list.append(key_states[b])
                merged_values_list.append(value_states[b])
                merge_mapping_list.append(torch.arange(seq_len, device=device))
                continue
            
            k_avg = key_states[b].mean(dim=0)
            source_tokens = k_avg[source_idx]
            target_tokens = k_avg[target_idx]
            
            matches, match_weights = self.matcher.match(
                source_tokens, target_tokens,
                source_idx, target_idx,
                spatial_positions
            )
            
            merged_k = self.merge_tokens(
                key_states[b], source_idx, target_idx, matches, match_weights
            )
            merged_v = self.merge_tokens(
                value_states[b], source_idx, target_idx, matches, match_weights
            )
            
            merged_keys_list.append(merged_k)
            merged_values_list.append(merged_v)
            
            mapping = torch.zeros(seq_len, dtype=torch.long, device=device)
            for i, idx in enumerate(target_idx):
                mapping[idx] = i
            for src_i, match_i in enumerate(matches):
                if match_i >= 0:
                    mapping[source_idx[src_i]] = match_i
                else:
                    mapping[source_idx[src_i]] = -1
            merge_mapping_list.append(mapping)
        
        max_merged_len = max(mk.shape[1] for mk in merged_keys_list)
        
        merged_keys = torch.zeros(
            batch_size, num_heads, max_merged_len, head_dim,
            device=device, dtype=dtype
        )
        merged_values = torch.zeros_like(merged_keys)
        
        for b, (mk, mv) in enumerate(zip(merged_keys_list, merged_values_list)):
            merged_keys[b, :, :mk.shape[1], :] = mk
            merged_values[b, :, :mv.shape[1], :] = mv
        
        merge_mapping = torch.stack(merge_mapping_list, dim=0)
        
        return merged_keys, merged_values, merge_mapping


class VisualTokenMerger(SemanticTokenMerger):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: Optional[int] = None,
        config: Optional[TokenMergingConfig] = None,
        grid_size: Tuple[int, int] = (14, 14),
    ):
        if config is None:
            config = TokenMergingConfig(preserve_spatial=True)
        super().__init__(hidden_size, num_heads, head_dim, config)
        
        self.grid_size = grid_size
        
        self._spatial_positions = None
        
    def get_spatial_positions(self, seq_len: int, device: torch.device) -> torch.Tensor:
        h, w = self.grid_size
        
        if seq_len != h * w:
            import math
            side = int(math.sqrt(seq_len))
            h, w = side, seq_len // side
            
        positions = torch.zeros(seq_len, 2, device=device)
        for i in range(seq_len):
            positions[i, 0] = i // w 
            positions[i, 1] = i % w 
            
        return positions
    
    def forward(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        importance_scores: torch.Tensor,
        num_keep: int,
        visual_token_mask: Optional[torch.Tensor] = None,
        protected_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        seq_len = key_states.shape[2]
        device = key_states.device
        
        if visual_token_mask is not None:
            num_visual = visual_token_mask[0].sum().item()
            spatial_positions = self.get_spatial_positions(num_visual, device)
        else:
            spatial_positions = self.get_spatial_positions(seq_len, device)
        
        return super().forward(
            key_states=key_states,
            value_states=value_states,
            importance_scores=importance_scores,
            num_keep=num_keep,
            spatial_positions=spatial_positions,
            protected_mask=protected_mask,
        )
