import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple, Any
from dataclasses import dataclass


@dataclass
class TaskVectorConfig:
    hidden_size: int = 3584
    num_heads: int = 28
    num_kv_heads: int = 4
    num_layers: int = 28
    head_dim: int = 128
    extraction_method: str = "combined"
    ema_alpha: float = 0.1
    temperature: float = 1.0
    min_importance: float = 0.01
    normalize_per_layer: bool = True

class TaskVectorExtractor(nn.Module):
    def __init__(self, config: TaskVectorConfig):
        super().__init__()
        self.config = config
        
        self.task_projector = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, config.hidden_size),
        )
        
        self.cached_task_vectors: Dict[int, torch.Tensor] = {}
        
        self.register_buffer('running_mean', torch.zeros(config.num_layers, config.hidden_size))
        self.register_buffer('num_updates', torch.zeros(config.num_layers))
    
    def extract_task_vector(
        self,
        question_embeddings: torch.Tensor,
        answer_embeddings: torch.Tensor,
        layer_idx: int,
        method: Optional[str] = None
    ) -> torch.Tensor:
        method = method or self.config.extraction_method
        
        if question_embeddings.dim() == 3:
            q_repr = question_embeddings.mean(dim=1)
        else:
            q_repr = question_embeddings
            
        if answer_embeddings.dim() == 3:
            a_repr = answer_embeddings.mean(dim=1)
        else:
            a_repr = answer_embeddings
        
        if method == "embedding_diff":
            task_vector = (a_repr - q_repr).mean(dim=0)
            
        elif method == "head_activation":
            q_proj = self.task_projector(q_repr)
            a_proj = self.task_projector(a_repr)
            task_vector = (a_proj - q_proj).mean(dim=0)
            
        elif method == "combined":
            diff = (a_repr - q_repr).mean(dim=0)
            q_proj = self.task_projector(q_repr)
            a_proj = self.task_projector(a_repr)
            proj_diff = (a_proj - q_proj).mean(dim=0)
            task_vector = 0.5 * diff + 0.5 * proj_diff
            
        else:
            raise ValueError(f"Unknown extraction method: {method}")
        
        if self.config.normalize_per_layer:
            task_vector = F.normalize(task_vector, dim=-1)
        
        if self.training or self.num_updates[layer_idx] == 0:
            self.running_mean[layer_idx] = (
                self.config.ema_alpha * task_vector +
                (1 - self.config.ema_alpha) * self.running_mean[layer_idx]
            )
            self.num_updates[layer_idx] += 1
        
        return task_vector
    
    def get_cached_task_vector(self, layer_idx: int) -> Optional[torch.Tensor]:
        if layer_idx in self.cached_task_vectors:
            return self.cached_task_vectors[layer_idx]
        elif self.num_updates[layer_idx] > 0:
            return self.running_mean[layer_idx]
        return None
    
    def cache_task_vector(self, layer_idx: int, task_vector: torch.Tensor):
        self.cached_task_vectors[layer_idx] = task_vector
    
    def clear_cache(self):
        self.cached_task_vectors.clear()


class TaskVectorImportanceScorer(nn.Module):
    def __init__(self, config: TaskVectorConfig):
        super().__init__()
        self.config = config
        
        self.importance_transform = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 4),
            nn.GELU(),
            nn.Linear(config.hidden_size // 4, 1),
        )
    
    def compute_importance(
        self,
        hidden_states: torch.Tensor,
        task_vector: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_learned_transform: bool = False
    ) -> torch.Tensor:
        batch_size, seq_len, hidden_size = hidden_states.shape
        
        hidden_norm = F.normalize(hidden_states, dim=-1)
        
        if task_vector.dim() == 1:
            task_vector = task_vector.unsqueeze(0).expand(batch_size, -1)
        task_norm = F.normalize(task_vector, dim=-1)
        
        projection_scores = torch.bmm(
            hidden_norm,
            task_norm.unsqueeze(-1)
        ).squeeze(-1)
        
        if use_learned_transform:
            learned_scores = self.importance_transform(hidden_states).squeeze(-1)
            projection_scores = 0.7 * projection_scores + 0.3 * torch.sigmoid(learned_scores)
        
        importance = projection_scores / self.config.temperature
        
        importance = importance - importance.min(dim=-1, keepdim=True)[0]
        importance = importance + self.config.min_importance
        
        if attention_mask is not None:
            importance = importance * attention_mask
        
        importance = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8)
        
        return importance
    
    def compute_kv_importance(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        task_vector: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        task_vector_size = task_vector.shape[0]
        
        if task_vector_size == num_heads * head_dim:
            task_vector_heads = task_vector.view(num_heads, head_dim)
        else:
            full_num_heads = task_vector_size // head_dim
            task_vector_full = task_vector.view(full_num_heads, head_dim)
            group_size = full_num_heads // num_heads
            if group_size > 0:
                task_vector_heads = task_vector_full.view(num_heads, group_size, head_dim).mean(dim=1)
            else:
                task_vector_heads = task_vector_full[:num_heads]
        
        task_vector_heads = F.normalize(task_vector_heads, dim=-1)
        
        key_norm = F.normalize(key_states, dim=-1)
        
        projections = torch.einsum(
            'bhsd,hd->bhs',
            key_norm,
            task_vector_heads
        )
        
        importance = projections.mean(dim=1)
        
        value_magnitude = value_states.norm(dim=-1).mean(dim=1)
        value_magnitude = value_magnitude / (value_magnitude.max(dim=-1, keepdim=True)[0] + 1e-8)
        
        importance = 0.8 * importance + 0.2 * value_magnitude
        
        importance = importance - importance.min(dim=-1, keepdim=True)[0]
        importance = importance + self.config.min_importance
        
        if attention_mask is not None:
            importance = importance * attention_mask
            
        importance = importance / (importance.sum(dim=-1, keepdim=True) + 1e-8)
        
        return importance


class TaskVectorGuidedCompressor(nn.Module):
    def __init__(
        self,
        hidden_size: int = 3584,
        num_heads: int = 28,
        num_kv_heads: int = 4,
        num_layers: int = 28,
        head_dim: int = 128,
        extraction_method: str = "combined",
        config: Optional[TaskVectorConfig] = None,
        **kwargs
    ):
        super().__init__()
        
        if config is not None:
            self.config = config
            self.config.hidden_size = hidden_size
            self.config.num_heads = num_heads
            self.config.num_kv_heads = num_kv_heads
            self.config.num_layers = num_layers
            self.config.head_dim = head_dim
        else:
            self.config = TaskVectorConfig(
                hidden_size=hidden_size,
                num_heads=num_heads,
                num_kv_heads=num_kv_heads,
                num_layers=num_layers,
                head_dim=head_dim,
                extraction_method=extraction_method,
            )
        
        self.extractor = TaskVectorExtractor(self.config)
        self.scorer = TaskVectorImportanceScorer(self.config)
        
        self.cached_importance: Dict[int, torch.Tensor] = {}
    
    def extract_and_cache_task_vector(
        self,
        question_embeddings: torch.Tensor,
        answer_embeddings: torch.Tensor,
        layer_idx: Optional[int] = None,
        hidden_states: Optional[List[torch.Tensor]] = None,
        question_mask: Optional[torch.Tensor] = None,
        answer_mask: Optional[torch.Tensor] = None,
    ):
        if question_mask is not None:
            question_embeddings = question_embeddings * question_mask.unsqueeze(-1)
        if answer_mask is not None:
            answer_embeddings = answer_embeddings * answer_mask.unsqueeze(-1)
            
        if layer_idx is not None:
            q_emb = hidden_states[layer_idx] if hidden_states else question_embeddings
            a_emb = hidden_states[layer_idx] if hidden_states else answer_embeddings
            task_vector = self.extractor.extract_task_vector(
                q_emb, a_emb, layer_idx
            )
            self.extractor.cache_task_vector(layer_idx, task_vector)
        else:
            for idx in range(self.config.num_layers):
                if hidden_states and idx < len(hidden_states):
                    q_emb = hidden_states[idx]
                    a_emb = hidden_states[idx]
                else:
                    q_emb = question_embeddings
                    a_emb = answer_embeddings
                    
                task_vector = self.extractor.extract_task_vector(
                    q_emb, a_emb, idx
                )
                self.extractor.cache_task_vector(idx, task_vector)
    
    def compute_layer_importance(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        task_vector = self.extractor.get_cached_task_vector(layer_idx)
        
        if task_vector is None:
            batch_size = key_states.shape[0]
            seq_len = key_states.shape[2]
            return torch.ones(batch_size, seq_len, device=key_states.device) / seq_len
        
        task_vector = task_vector.to(key_states.device)
        
        if hidden_states is not None:
            importance = self.scorer.compute_importance(
                hidden_states, task_vector, attention_mask
            )
        else:
            importance = self.scorer.compute_kv_importance(
                key_states, value_states, task_vector, attention_mask
            )
        
        self.cached_importance[layer_idx] = importance
        
        return importance
    
    def compress_layer(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        target_length: int,
        hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        return_indices: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        batch_size, num_heads, seq_len, head_dim = key_states.shape
        
        if seq_len <= target_length:
            if return_indices:
                indices = torch.arange(seq_len, device=key_states.device)
                indices = indices.unsqueeze(0).expand(batch_size, -1)
                return key_states, value_states, indices
            return key_states, value_states, None
        
        importance = self.compute_layer_importance(
            layer_idx, key_states, value_states, hidden_states, attention_mask
        )
        
        _, indices = importance.topk(target_length, dim=-1, sorted=True)
        indices, _ = indices.sort(dim=-1)
        
        indices_expanded = indices.unsqueeze(1).unsqueeze(-1).expand(
            batch_size, num_heads, target_length, head_dim
        )
        
        compressed_keys = torch.gather(key_states, dim=2, index=indices_expanded)
        compressed_values = torch.gather(value_states, dim=2, index=indices_expanded)
        
        if return_indices:
            return compressed_keys, compressed_values, indices
        return compressed_keys, compressed_values, None
    
    def clear_cache(self):
        self.extractor.clear_cache()
        self.cached_importance.clear()


def create_task_vector_compressor(
    model_config: Any,
    extraction_method: str = "combined",
    **kwargs
) -> TaskVectorGuidedCompressor:
    hidden_size = getattr(model_config, 'hidden_size', 3584)
    num_heads = getattr(model_config, 'num_attention_heads', 28)
    num_kv_heads = getattr(model_config, 'num_key_value_heads', 4)
    num_layers = getattr(model_config, 'num_hidden_layers', 28)
    head_dim = hidden_size // num_heads
    
    return TaskVectorGuidedCompressor(
        hidden_size=hidden_size,
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        num_layers=num_layers,
        head_dim=head_dim,
        extraction_method=extraction_method,
        **kwargs
    )
