import os
import pdb
import copy
import math
import numpy as np 
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import abc
import torch
from torch import nn
import torch.utils.checkpoint
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import LlamaPreTrainedModel
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from .base import MemoryPolicy, ParamMemoryPolicy
from utils import faster_attn_reversecumsum


def reconstruct_causal_mask(n_q, n_k, 
                            
                            attn_mask):
        device = attn_mask.device
        dtype = attn_mask.dtype
        causal_mask = torch.ones((n_q, n_q), dtype=dtype, device=device)
        causal_mask = torch.tril(causal_mask, diagonal=0)
        
        causal_mask = F.pad(causal_mask, (n_k-n_q, 0), value=1.0)

        
        mask = causal_mask*attn_mask.unsqueeze(-2)
        return mask


def num_attending_queries(n_q, n_k, attn_mask):
        '''determines number of attending queries for each key, applying
           causal ordering'''
        
        mask = reconstruct_causal_mask(n_q=n_q, n_k=n_k, attn_mask=attn_mask)
        
        return torch.sum(mask, dim=-2)


def reduce_new_weights(attn_weights, attn_mask, attn_ema, 
                       return_intermediate_mxs=False):
        '''Reduces new attn weights by performing a rolling EMA based on the
           recency of each query index in the attn matrix, implemented via
           matrix multiplication'''
        device = attn_weights.device
        *batch_ds, n_heads, num_new_tokens, num_all_tokens = attn_weights.shape
        n_q, n_k = num_new_tokens, num_all_tokens
        discount_vector_exponents = torch.arange(
            start=num_new_tokens-1, end=-1, step=-1,
            device=device,)
        
        discount_vector = torch.pow(attn_ema, discount_vector_exponents)

        
        discount_vector = discount_vector*(1-attn_ema)
        
        
        mask = reconstruct_causal_mask(n_q=n_q, n_k=n_k, attn_mask=attn_mask)

        masked_discount_mx = discount_vector.unsqueeze(-1)*mask
        
        
        reduced_attn_weights = (masked_discount_mx*attn_weights).sum(dim=-2)
        
        return reduced_attn_weights

def compute_backward_weights(attn_weights, faster_implementation=False):
        '''Reduces new attn weights by performing a rolling EMA based on the
           recency of each query index in the attn matrix, implemented via
           matrix multiplication'''
        device = attn_weights.device
        dtype = attn_weights.dtype

        min_value = torch.finfo(dtype).smallest_normal

        if faster_implementation:
            reverse_cumsum = faster_attn_reversecumsum(
                tensor=attn_weights, dim=-1)
        else:
            reverse_cumsum = torch.flip(attn_weights, dims=[-1]).cumsum_(
                dim=-1).flip(dims=[-1])
        
        renorm_attn_weights = attn_weights/torch.clamp_min_(
            reverse_cumsum, min=min_value)
        
        return renorm_attn_weights

class H2O(MemoryPolicy):
    def __init__(self, cache_size, h2_slots,):
        super().__init__(cache_size=cache_size)
        self.h2_slots = h2_slots
        self.recency_slots = cache_size - h2_slots
    
    
    def update_attn(self, attn_scores, attn_mx):
        
        raise NotImplementedError
    
    
    def update_layer_cache(self, layer_id, key_cache, value_cache, num_new_tokens, 
                     attn_weights, attn_mask=None, **kwargs,
                     ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        
        attn_weights = self.process_attn_weights(attn_weights=attn_weights)

        cumulated_attn = self.cumulated_attn_tensors[layer_id]
        device = attn_weights.device
        bs, n_heads, num_new_tokens, num_all_tokens = attn_weights.shape
        n_embd = key_cache.shape[-1]

        
        new_cumulative_weights = torch.sum(attn_weights, dim=-2)

        
        
        
        num_old_tokens = num_all_tokens - num_new_tokens
        
        
        new_cumulative_weights = torch.sum(attn_weights, dim=-2)

        
        if (cumulated_attn is not None) and (num_all_tokens > num_new_tokens):
            new_cumulative_weights[..., :num_old_tokens] += (
                cumulated_attn[..., :num_old_tokens])
            

        
        if num_all_tokens <= self.cache_size:
            self.cumulated_attn_tensors[layer_id] = new_cumulative_weights
            return key_cache, value_cache

        num_nonrecent = num_all_tokens - self.recency_slots
        recency_idxs = (torch.arange(
            self.recency_slots, device=device) + num_nonrecent).view(
                1, 1, self.recency_slots).expand(bs, n_heads, -1)
        
        if attn_mask is not None:
            
            
            
            
            
            
            
            
            
            
            attn_mask = attn_mask[..., -num_all_tokens:]
            masked_attn_scores = torch.where(
                attn_mask.bool().unsqueeze(-2), new_cumulative_weights, -1)
        else:
            masked_attn_scores = new_cumulative_weights
            
        _, h2_idxs = torch.topk(masked_attn_scores[..., :num_nonrecent],
                                k=self.h2_slots, dim=-1, sorted=False)
        
        
        
        h2_idxs, _ = torch.sort(h2_idxs, dim=-1)

        
        update_idxs = torch.concat([h2_idxs, recency_idxs], dim=-1)
        self.cumulated_attn_tensors[layer_id]  = torch.gather(
            new_cumulative_weights, dim=-1, index=update_idxs)
        
        exp_update_idxs = update_idxs.unsqueeze(-1).expand(-1, -1, -1, n_embd)
        key_cache = torch.gather(key_cache, dim=-2, index=exp_update_idxs)
        value_cache = torch.gather(value_cache, dim=-2, index=exp_update_idxs)

        return key_cache, value_cache

    @property
    def requires_attn_scores(self,):
        
        return True
    
    def finalize_registration(self,):
        super().finalize_registration()


class ParamH2O(ParamMemoryPolicy):
    
    
    def __init__(self, cache_size, pop_size, per_head, per_layer, 
                 h2_logits_init=0.0):
        super().__init__(cache_size=cache_size, base_param_size=1, 
                         pop_size=pop_size, per_head=per_head, 
                         per_layer=per_layer)
        self.cumulated_attn = None


    def get_h2_slots_mask(self, params,):
        
        rescaled_params = F.sigmoid(params)*self.cache_size - 0.5
        cache_range = torch.arange(self.cache_size, device=params.device)
        h2_slots_mask = rescaled_params > cache_range
        return h2_slots_mask


    def update_layer_cache(self, layer_id, key_cache, value_cache, num_new_tokens, 
                     attn_weights, attn_mask=None, **kwargs,
                     ) -> Tuple[torch.Tensor, torch.Tensor]:
        

        
        attn_weights = self.process_attn_weights(attn_weights=attn_weights)
        
        cumulated_attn = self.cumulated_attn_tensors[layer_id]
        device = attn_weights.device
        
        
        bs, n_heads, num_new_tokens, num_all_tokens = attn_weights.shape
        n_embd = key_cache.shape[-1]

        
        new_cumulative_weights = torch.sum(attn_weights, dim=-2)

        
        
        
        num_old_tokens = num_all_tokens - num_new_tokens


        if (cumulated_attn is not None) and (num_all_tokens > num_new_tokens):
            new_cumulative_weights[..., :num_old_tokens] += (
                cumulated_attn[..., :num_old_tokens])
            
        
        
        if num_all_tokens <= self.cache_size:
            
            self.cumulated_attn_tensors[layer_id] = new_cumulative_weights
            return key_cache, value_cache

        params = self.get_layer_params(layer_id=layer_id)
        h2_slot_mask = self.get_h2_slots_mask(params=params)

        
        num_outside_cache = num_all_tokens - self.cache_size
        recency_idxs = (torch.arange(self.cache_size, device=device) 
                        + num_outside_cache)
        
        masked_attn_scores = torch.clone(new_cumulative_weights)

        
        
        
        masked_attn_scores[..., num_outside_cache:] = torch.where(
            h2_slot_mask, masked_attn_scores[..., num_outside_cache:],
            -2)
        
        
        if attn_mask is not None:
            
            
            
            
            

            
            attn_mask = attn_mask[..., -num_all_tokens:]
            masked_attn_scores = torch.where(
                attn_mask.bool().unsqueeze(-2), masked_attn_scores, -1)
        
        
        
        _, h2_idxs = torch.topk(masked_attn_scores,
                                k=self.cache_size, dim=-1, sorted=True)

        
        update_idxs = torch.where(h2_slot_mask, h2_idxs, recency_idxs)

        
        update_idxs, _ = torch.sort(update_idxs, dim=-1)

        self.cumulated_attn_tensors[layer_id] = torch.gather(
            new_cumulative_weights, dim=-1, index=update_idxs)
        
        exp_update_idxs = update_idxs.unsqueeze(-1).expand(-1, -1, -1, n_embd)
        key_cache = torch.gather(key_cache, dim=-2, index=exp_update_idxs)
        value_cache = torch.gather(value_cache, dim=-2, index=exp_update_idxs)
        return key_cache, value_cache



    @property
    def requires_attn_scores(self,):
        
        return True
    
    def finalize_registration(self,):
        
        
        super().finalize_registration()
    

    def get_param_stats(self,) -> dict:
        stats = dict()
        for i in range(self.param_layer_dim):
            if self.per_layer:
                stats_key_prefix = f'mem_stats/layer_id_{i}/'
            else:
                stats_key_prefix = f'mem_stats/shared/'
            layer_params = self.get_layer_params(layer_id=i)[0] 
            layer_mask = self.get_h2_slots_mask(layer_params)
            h2_slots = layer_mask.sum(-1).float()
            h2_perc = h2_slots / self.cache_size
            
            stats[stats_key_prefix + 'h2_logits_mean'] = (
                layer_params.mean().item())
            stats[stats_key_prefix + 'h2_slots_mean'] = h2_slots.mean().item()
            stats[stats_key_prefix + 'h2_perc_mean'] = h2_perc.mean().item()
            if self.per_head:
                stats[stats_key_prefix + 'h2_logits_head_std'] = (
                    layer_params.std().item())
                stats[stats_key_prefix + 'h2_slots_head_std'] = (
                    h2_slots.std().item())
                stats[stats_key_prefix + 'h2_perc_head_std'] = (
                    h2_perc.std().item())
        return stats


class GenH2O(MemoryPolicy):
    '''Generalized version of h2o, using an EMA of the attention scores and 
       allowing for backward renormalization'''
    def __init__(self, cache_size, h2_slots, attn_ema_coeff, backward_renorm):
        super().__init__(cache_size=cache_size)
        self.h2_slots = h2_slots
        self.recency_slots = cache_size - h2_slots
        self.attn_ema_coeff = attn_ema_coeff
        self.backward_renorm = backward_renorm
    
    
    def update_layer_cache(self, layer_id, key_cache, value_cache, num_new_tokens, 
                     attn_weights, attn_mask=None, **kwargs,
                     ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        attn_ema = self.attn_ema_coeff
        
        attn_weights = self.process_attn_weights(attn_weights=attn_weights)

        cumulated_attn = self.cumulated_attn_tensors[layer_id]
        device = attn_weights.device
        dtype = attn_weights.dtype
        
        bs, n_heads, num_new_tokens, num_all_tokens = attn_weights.shape
        num_old_tokens = num_all_tokens - num_new_tokens

        n_embd = key_cache.shape[-1]

        if attn_mask is not None:
            
            attn_mask = attn_mask[..., -num_all_tokens:].unsqueeze(-2) > 0
        else:
            attn_mask = torch.ones([bs, 1, num_all_tokens], 
                                   device=device, dtype=torch.bool)

        if self.backward_renorm:
            
            backward_attn_weights = compute_backward_weights(
                attn_weights=attn_weights, faster_implementation=False)
            
            
            
            new_cumulative_weights = reduce_new_weights(
                attn_weights=backward_attn_weights, 
                attn_mask=attn_mask, 
                attn_ema=attn_ema,
                )
        else:
            new_cumulative_weights = reduce_new_weights(
                attn_weights=attn_weights, 
                attn_mask=attn_mask, 
                attn_ema=attn_ema,
                )

        
        if (cumulated_attn is not None) and (num_all_tokens > num_new_tokens):
            new_cumulative_weights[..., :num_old_tokens] = (
                new_cumulative_weights[..., :num_old_tokens] + 
                attn_ema**num_new_tokens * cumulated_attn[..., :num_old_tokens])
            

        
        if num_all_tokens <= self.cache_size:
            self.cumulated_attn_tensors[layer_id] = new_cumulative_weights
            return key_cache, value_cache

        num_nonrecent = num_all_tokens - self.recency_slots
        recency_idxs = (torch.arange(
            self.recency_slots, device=device) + num_nonrecent).view(
                1, 1, self.recency_slots).expand(bs, n_heads, -1)
        
        if attn_mask is not None:
            masked_attn_scores = torch.where(
                attn_mask, new_cumulative_weights, -1)
        else:
            masked_attn_scores = new_cumulative_weights
            
        _, h2_idxs = torch.topk(masked_attn_scores[..., :num_nonrecent],
                                k=self.h2_slots, dim=-1, sorted=False)
        
        
        
        h2_idxs, _ = torch.sort(h2_idxs, dim=-1)

        
        update_idxs = torch.concat([h2_idxs, recency_idxs], dim=-1)
        self.cumulated_attn_tensors[layer_id]  = torch.gather(
            new_cumulative_weights, dim=-1, index=update_idxs)
        
        exp_update_idxs = update_idxs.unsqueeze(-1).expand(-1, -1, -1, n_embd)
        key_cache = torch.gather(key_cache, dim=-2, index=exp_update_idxs)
        value_cache = torch.gather(value_cache, dim=-2, index=exp_update_idxs)

        return key_cache, value_cache

    @property
    def requires_attn_scores(self,):
        
        return True
    
    def finalize_registration(self,):
        super().finalize_registration()