import torch
import torch.nn as nn
from kvpress import BasePress, ScorerPress
from dataclasses import dataclass

def exact_leverage_score(K: torch.Tensor, V: torch.Tensor, k: int, method="value"):
    assert method in ["key", "value"]

    mask = torch.zeros(K.shape[0], K.shape[1]).to(K.device)
    
    if method == "key":
        X = K
    elif method == "value":
        X = V

    X = X.float()
    X_ = X - X.mean(dim=2, keepdim=True)

    U, S, Vh = torch.linalg.svd(X_, full_matrices=False)
    
    k = min(k, U.shape[-1]-2)
    U_k = U[:, :, :, :k]  # (T, k)
    
    leverage_scores = torch.sum(U_k ** 2, dim=-1) 

    return leverage_scores/leverage_scores.sum(dim=-1, keepdim=True)

@dataclass
class ExactCURPress(ScorerPress):
    """
    Base class for all KV cache compression methods.
    The `forward_hook` method is called after the forward pass of an attention layer to update the cache.
    """
    
    compression_ratio: float
    num_sinks: int = 4
    leverage_type: str = 'kv_product'

    def score(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs: dict,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        The core logic of the compression method.

        Parameters
        ----------
        module :
            Transformer layer, see `hook` method for more details
        hidden_states :
            Hidden states of the layer
        keys :
            Keys of the cache (unquantized)
        values :
            Values of the cache (unquantized)
        attentions :
            Attention weights of the layer
        kwargs :
            Keyword arguments, as given to the forward pass of the layer

        Returns
        -------
        tuple[torch.Tensor, torch.Tensor]
            Updated keys and values
        """
        num_selection = int((1-self.compression_ratio)*(keys.shape[2]))

        assert self.leverage_type in ["key", "value", "kv_product"]

        if self.leverage_type == 'kv_product':
            leverage_scores1 = exact_leverage_score(keys, values, k=num_selection, method='key')
            leverage_scores2 = exact_leverage_score(keys, values, k=num_selection, method='value')
            leverage_scores = leverage_scores1  * leverage_scores2
        else:
            leverage_scores = exact_leverage_score(keys, values, k=num_selection, method=self.leverage_type)

        return leverage_scores