import os
import pdb
import copy
import math
import numpy as np 
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import abc
import torch
from torch import nn
import torch.utils.checkpoint
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import LlamaPreTrainedModel
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from .base import MemoryPolicy, ParamMemoryPolicy
from utils import faster_attn_reversecumsum


class L2(MemoryPolicy):
    ''' Implementation of the algorithm described in 
        https://arxiv.org/abs/2406.11430'''
    
    def __init__(self, cache_size):
        super().__init__(cache_size=cache_size)
    
    
    def update_layer_cache(
            self, layer_id, key_cache, value_cache, num_new_tokens, 
            attn_mask=None, **kwargs,) -> Tuple[torch.Tensor, torch.Tensor]:
        
        bs, n_heads, num_all_tokens, n_embd = key_cache.shape

        if num_all_tokens <= self.cache_size:
            return key_cache, value_cache
        
        l2_scores = torch.square(key_cache).sum(-1)
        
        if attn_mask is not None:
            dtype = key_cache.dtype
            max_dtype = torch.finfo(dtype).max
            attn_mask = attn_mask[..., -num_all_tokens:]
            l2_scores = torch.where(
                attn_mask.bool().unsqueeze(-2), l2_scores, max_dtype)
            
        _, top_idxs = torch.topk(l2_scores,
                                k=self.cache_size, 
                                dim=-1,
                                sorted=False, 
                                largest=False,
                                )
        
        
        
        top_idxs, _ = torch.sort(top_idxs, dim=-1)
        
        exp_update_idxs = top_idxs.unsqueeze(-1).expand(-1, -1, -1, n_embd)
        key_cache = torch.gather(key_cache, dim=-2, index=exp_update_idxs)
        value_cache = torch.gather(value_cache, dim=-2, index=exp_update_idxs)

        return key_cache, value_cache
    
    def finalize_registration(self,):
        super().finalize_registration()

