import logging
from dataclasses import dataclass

import torch
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.efficient_ada_scorer_press import EfficientAdaScorerPress
from kvpress.vw_norm import vw_l1norm

logger = logging.getLogger(__name__)


class CriticalKVPress(ScorerPress):
    """
    CriticalKV rescales the scores of a ScorerPress by
    the L1 norm of Wo @ values
    """

    def __init__(self, press: ScorerPress, epsilon: float = 1e-4, first_stage_ratio: float = 0.5):
        self.press = press
        self.epsilon = epsilon
        self.first_stage_ratio = first_stage_ratio

        assert isinstance(self.press, ScorerPress), "CriticalAdaKVPress requires a ScorerPress as input"

    @property
    def compression_ratio(self):
        return self.press.compression_ratio

    @compression_ratio.setter
    def compression_ratio(self, value):
        self.press.compression_ratio = value

    @staticmethod
    def vwl1norm(values, module):
        bsz, num_key_value_heads, q_len, _ = values.shape
        num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
        Wo = module.o_proj.weight.transpose(0, 1)
        Wo = Wo.view(module.config.num_attention_heads, module.config.head_dim, module.config.hidden_size)
        V = repeat_kv(values, num_key_value_groups)

        WoV_norm = vw_l1norm(V, Wo)
        WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len).mean(dim=2)


        return WoV_norm

    def score(self, module, hidden_states, keys, values, attentions, kwargs):
        # Stage 1
        scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
        q_len = keys.shape[2]
        selection_budget = int((1 - self.compression_ratio) * q_len * self.first_stage_ratio)
        top_k_index = torch.topk(scores, selection_budget, sorted=True, dim=-1).indices

        # Stage 2
        projected_norm = self.vwl1norm(values, module)
        scores = (scores + self.epsilon) * projected_norm

        # Merge the two stages
        scores.scatter_(-1, top_k_index, torch.finfo(scores.dtype).max)

        return scores


@dataclass
class EfficientAdaCriticalKVPress(EfficientAdaScorerPress):

    window_size: int = 32
    kernel_size: int = 5
    use_vnorm: bool = True
    alpha_safeguard: float = 0.20

    def __str__(self):
        return f"EfficientAdaCritical_com_ratio={self.compression_ratio}, wind_size={self.window_size}, kerl_size={self.kernel_size}, use_vnorm={self.use_vnorm}"
    def __init__(self, press: ScorerPress, epsilon: float = 1e-4, first_stage_ratio: float = 0.5):
        self.press = press
        self.epsilon = epsilon
        self.first_stage_ratio = first_stage_ratio
    @property
    def compression_ratio(self):
        return self.press.compression_ratio

    @compression_ratio.setter
    def compression_ratio(self, value):
        self.press.compression_ratio = value
        
    @staticmethod
    def vwl1norm(values, module):
        bsz, num_key_value_heads, q_len, _ = values.shape
        num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
        Wo = module.o_proj.weight.transpose(0, 1)
        Wo = Wo.view(module.config.num_attention_heads, module.config.head_dim, module.config.hidden_size)
        V = repeat_kv(values, num_key_value_groups)

        WoV_norm = vw_l1norm(V, Wo)
        WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len).mean(dim=2)


        return WoV_norm

    def score(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs,
    ) -> torch.Tensor:

        cache_metadata = kwargs.get("metadata", None)
        assert cache_metadata is not None, "cache_metadata is required for AdaSnapKVPress"

        # Current implementation only allows to compress once
        # check if first time compression
        head_lens = cache_metadata.head_lens
        assert all(
            x == head_lens[0] for x in head_lens
        ), "Not all elements in head_lens are the same, implying multiple compressions"

        # convert to (bsz, num_key_value_heads, q_len, head_dim) for easy score
        keys = keys.view(
            cache_metadata.bsz, cache_metadata.num_key_value_heads, cache_metadata.head_lens[0], keys.shape[-1]
        )
        values = values.view(
            cache_metadata.bsz, cache_metadata.num_key_value_heads, cache_metadata.head_lens[0], keys.shape[-1]
        )
        
        bsz, num_key_value_heads, q_len, head_dim = keys.shape
 
        scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
        n_kept = int(q_len * (1 - self.compression_ratio))
        n_safe = int(n_kept * self.alpha_safeguard)
        top_indices = torch.topk(scores, n_safe, dim=-1).indices
        scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max)

        q_len = keys.shape[2]
        selection_budget = int((1 - self.compression_ratio) * q_len * self.first_stage_ratio)
        top_k_index = torch.topk(scores, selection_budget, sorted=True, dim=-1).indices

        # Stage 2
        projected_norm = self.vwl1norm(values, module)
        scores = (scores + self.epsilon) * projected_norm

        # Merge the two stages
        scores.scatter_(-1, top_k_index, torch.finfo(scores.dtype).max)

        # Flatten scores
        flatten_scores = scores.view(bsz, num_key_value_heads * q_len)

        return flatten_scores

