import math
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.generation.utils import GenerateDecoderOnlyOutput
from flash_attn import flash_attn_func
from typing import List, Optional, Tuple, Union, Any, Dict


def indexing(key, sort_idx, block_size, value=None):
    indices = sort_idx.unsqueeze(-1).expand(-1, -1, -1, key.shape[-1])
    new_n = math.ceil(sort_idx.shape[-1] / block_size) * block_size
    if new_n < sort_idx.shape[-1]:
        import pdb; pdb.set_trace();
    out_key = torch.nn.functional.pad(key.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
    out_value = None
    if value is not None:
        out_value = torch.nn.functional.pad(value.gather(2, indices), (0,0,0,new_n-sort_idx.shape[-1]), mode='constant', value=0.)
    return out_key, out_value

def balanced_walk(key, rng, gamma_, temp_, beta_, itrs, block_size, value=None, needle_mask = None, layer = None, sort_idx=None, query=None):
  b, h, n, d = key.shape
  if type(gamma_) != list:
      gamma_ = [gamma_] * itrs
  const_denom = 0.025 
  needle_mask_bw = None

  if type(block_size) != list:
      block_size = [block_size] * itrs
  weight_idx = None
  for t in range(itrs): 
    if needle_mask_bw is not None:
        needle_mask_bw = torch.nn.functional.pad(needle_mask_bw, (0, math.ceil(n / block_size[t]) * block_size[t] - needle_mask_bw.shape[-1])).view(b, h, -1, block_size[t])
    if sort_idx is not None:
      key_sorted, value_sorted = indexing(key, sort_idx, block_size[t], value)
      key_sorted = key_sorted.view(b, h, -1, block_size[t], d)
      if value is not None:
        weight_idx_padded = torch.nn.functional.pad(weight_idx, (0, math.ceil(n / block_size[t]) * block_size[t] - weight_idx.shape[-1]))
        value_sorted = value_sorted*weight_idx_padded.unsqueeze(-1)
        value_sorted = value_sorted.view(b, h, -1, block_size[t], d)
    else:
      new_n = math.ceil(n / block_size[t]) * block_size[t]
      key_sorted = torch.nn.functional.pad(key, (0,0,0,new_n-n), mode='constant', value=0.).view(b, h, -1, block_size[t], d)
      value_sorted = None
      if value is not None:
        value_sorted = torch.nn.functional.pad(value, (0,0,0,new_n-n), mode='constant', value=0.).view(b, h, -1, block_size[t], d)

    normal_keys = key_sorted - torch.mean(key_sorted, dim=-2, keepdim=True)

    if query is not None:
      query_key_correlation = torch.softmax(torch.einsum('b h n d,b h s m d->b h s n m',query[:,::4,:,:],normal_keys),dim=-1).mean(-2,keepdim=True)
      kernel_ = query_key_correlation*query_key_correlation.transpose(-1,-2)
    else:
      kernel_ = torch.exp(temp_ * torch.einsum('...nd,...sd->...ns', normal_keys, normal_keys)/math.sqrt(d) - beta_)
    if value is not None:

      kernel_ *= (1e-8 + torch.einsum('...nd,...sd->...ns', value_sorted, value_sorted)+const_denom)
    key_correlation = (1e-8 + torch.einsum('...nd,...sd->...ns', key_sorted, key_sorted))
    key_inner_prods = key_correlation.sum(dim = -1)

    if layer==1 and t==0:
        threshold = 0.0
        key_correlation = (1e-8 + torch.einsum('...nd,...sd->...ns', key_sorted, key_sorted))
        key_correlation_sum = key_correlation.sum(dim=-1)/key_correlation.shape[-1]
        needle_mask = key_correlation_sum > threshold
        needle_mask = needle_mask.view(b, h, -1)[:, :, :n]
        if n==0: 
            sort_idx = needle_mask[:, :, :0]
            weigth_idx = needle_mask[:, :, :0]
            needle_mask = needle_mask[:, :, :0]
            break
        needle_mask = needle_mask.to(torch.int32)
        needle_mask_padded = torch.nn.functional.pad(needle_mask, (25, 25), mode='constant', value=1)
        unfolded = needle_mask_padded.unfold(-1, 51, 1)
        result = unfolded.sum(dim=-1)
        new_needle_mask = result > 50
        needle_mask = new_needle_mask*needle_mask
        zero_counts = (needle_mask == 0).sum(dim=-1)  # Shape: (b, h, w)
        needle_mask = torch.nn.functional.pad(needle_mask, (0, math.ceil(n / block_size[t]) * block_size[t] - needle_mask.shape[-1])).view(b, h, -1, block_size[t])

  
    
    signs = torch.zeros(kernel_.shape[:4], dtype=torch.float32, device=kernel_.device)
    signs[:, :, :, 0] = 1
    rand_tensor = torch.rand(signs.shape, generator=rng, device=key.device)

    if needle_mask == None: 
      needle_mask = torch.ones_like(signs)
    if needle_mask_bw == None:
        needle_mask_bw = needle_mask
    
    for i in range(1, kernel_.shape[3]): 
      partial_inner_prod = (kernel_[:, :, :, i, :] * signs * needle_mask_bw).sum(dim=-1) 
      prev_sign = signs[:,:,:,i-1]
      samp_prb = 0.5 - gamma_[t] * partial_inner_prod 
      

      signs[:, :, :, i] = 2 * (rand_tensor[:, :, :, i] < samp_prb) - 1


    signs = signs.view(b, h, -1)[:, :, :n]
    needle_mask_bw = needle_mask_bw.view(b, h, -1)[:, :, :n]

    signs = signs*needle_mask_bw

    if signs.shape[-1]==0:
      sort_idx = signs[:, :, :0]
      weigth_idx = signs[:, :, :0]
      break
    if torch.all(needle_mask_bw):
        cumsum_neg = (signs == -1).cumsum(dim=-1)
        cumsum_pos = (signs == 1).cumsum(dim=-1)
        c_neg = torch.argmax((cumsum_neg == n//2).to(torch.int64), dim=-1) 
        c_pos = torch.argmax((cumsum_pos == n//2).to(torch.int64), dim=-1) 
        c = torch.maximum(c_neg, c_pos)
        c = c.to(signs.device)
        weight = signs
        indices = torch.arange(signs.shape[2], device=signs.device).view(1, 1, -1)
        mask_after_c = indices > c.unsqueeze(-1)
        weight[mask_after_c] = torch.abs(weight[mask_after_c])  
        mask_flip_needed = (signs.gather(2, c.unsqueeze(-1)) == 1).squeeze(-1)
        mask_before_c = indices <= c.unsqueeze(-1)
        weight[mask_before_c] *= 2
        flip_mask = mask_before_c & mask_flip_needed.unsqueeze(-1)
        weight[flip_mask] *= -1  
        weight_argsort = torch.argsort(-weight, dim=-1, stable=True)
    else:
        weight = signs
        weight_zeros = weight == 0
        weight[weight_zeros] += 2
        weight += 1
        weight_argsort = torch.argsort(-weight, dim=-1, stable=True)
        weight[weight_zeros] = 1


    n = n//2 
    if sort_idx is None:
      sort_idx = weight_argsort[:, :, :n]
      weight_idx = weight.gather(-1, weight_argsort[:, :, :n])
      needle_mask_bw = needle_mask_bw.gather(-1, weight_argsort[:, :, :n])
    else:
      sort_idx = sort_idx.gather(2, weight_argsort[:, :, :n])
      weigth_idx_1 = weight.gather(-1, weight_argsort[:, :, :n])
      weight_idx = weight_idx.gather(-1, weight_argsort[:, :, :n])
      weight_idx = weight_idx*weigth_idx_1
      needle_mask_bw = needle_mask_bw.gather(-1, weight_argsort[:, :, :n])


  return sort_idx, weight_idx, needle_mask


def manual_forward_llama(
    model,
    input_ids,
    kv_type,
    attention_mask=None,
    kv_cache=None, position_ids=None, cache_position=None, num_logits_to_keep=0,
):
    hh = model.model.embed_tokens(input_ids)
    if position_ids is None:
        position_ids = torch.arange(len(input_ids[0]), device=input_ids.device).unsqueeze(0)

    needle_mask = None
    past_kv_cache = []
    num_layers = len(model.model.layers)
    for i, decoder_layer in enumerate(model.model.layers):
        res = hh
        hh = decoder_layer.input_layernorm(hh)

        q_len = hh.shape[1]
        kv_len = q_len
        qq = decoder_layer.self_attn.q_proj(hh).reshape(1, q_len, -1, 128).transpose(1, 2)
        kk = decoder_layer.self_attn.k_proj(hh).reshape(1, kv_len, 8, 128).transpose(1, 2)
        vv = decoder_layer.self_attn.v_proj(hh).reshape(1, kv_len, 8, 128).transpose(1, 2)

        cos, sin = decoder_layer.self_attn.rotary_emb(vv, position_ids)
        qq, kk = apply_rotary_pos_emb(qq, kk, cos, sin)
        d = qq.shape[-1]

        if q_len > 1:
            attn_output = flash_attn_func(qq.transpose(1,2), kk.transpose(1,2), vv.transpose(1,2), causal=True)
        if kv_cache is None:
            if kv_type in ['exact']:
                key_quant = kk
                val_quant = vv
            elif kv_type in ['weightedbw', 'uniform', 'bw', 'balancedwalk', 'balancedwalk_rew']:
                rng = model.config.rng
                gamma = model.config.gamma
                temp = model.config.temp
                beta = model.config.beta
                itrs = model.config.itrs
                block_size = model.config.block_size
                window_size = model.config.window_size
                sink_size = model.config.sink_size

                k_compressed = kk[:, :, sink_size:-window_size]
                v_compressed = vv[:, :, sink_size:-window_size]

                if i==1: 
                    if kv_type in ['weightedbw', 'bw', 'balancedwalk']:
                        indices, weights, needle_mask = balanced_walk(k_compressed, rng, gamma, temp, beta, itrs, block_size, layer = i, value=v_compressed)
                    elif kv_type == 'uniform':
                        indices, weights, needle_mask = balanced_walk(k_compressed, rng, 0.0, temp, beta, itrs, block_size, layer = i, value=v_compressed)
                elif i > 1:
                    if kv_type in ['weightedbw', 'bw', 'balancedwalk']:
                        indices, weights, _ = balanced_walk(k_compressed, rng, gamma, temp, beta, itrs, block_size, layer = i, needle_mask = needle_mask, value=v_compressed)
                    elif kv_type == 'uniform':
                        indices, weights, _ = balanced_walk(k_compressed, rng, 0.0, temp, beta, itrs, block_size, layer = i, needle_mask = needle_mask, value=v_compressed)
                else:
                    if kv_type in ['weightedbw', 'bw', 'balancedwalk']:
                        indices, weights, _ = balanced_walk(k_compressed, rng, gamma, temp, beta, itrs, block_size, layer = i, value=v_compressed)
                    elif kv_type == 'uniform':
                        indices, weights, _ = balanced_walk(k_compressed, rng, 0.0, temp, beta, itrs, block_size, layer = i, value=v_compressed)

                k_bw = k_compressed.gather(dim=2, index=indices.unsqueeze(-1).expand(-1,-1,-1,kk.shape[-1]))
                v_bw = v_compressed.gather(dim=2, index=indices.unsqueeze(-1).expand(-1,-1,-1,vv.shape[-1]))

                if kv_type in ['weightedbw', 'bw', 'balancedwalk', 'uniform']:
                    if weights != None:
                      weights_zeros = weights > 0
                      weights_zeros = weights_zeros.unsqueeze(-1)
                      v_bw_num = v_bw*weights_zeros
                      v_bw_num = (v_bw_num).to(torch.bfloat16)
                    else:
                      v_bw_num = v_bw
                weights = weights.unsqueeze(-1)
                log_weights = torch.where(weights > 0, torch.log(weights), torch.full_like(weights, -1e9))
                    
                key_quant = torch.cat((kk[:,:,:sink_size], k_bw, kk[:, :, -window_size:]), dim=2)
                val_quant = torch.cat((vv[:,:,:sink_size], v_bw_num, vv[:, :, -window_size:]), dim=2)

                model.config.compress_size = k_bw.shape[2]
            else:
                import pdb; pdb.set_trace();
            past_kv_cache += [(key_quant, val_quant, log_weights)]
        else: 
            if kv_type in ['exact']:
                kk = torch.cat((kv_cache[i][0], kk), 2)
                vv = torch.cat((kv_cache[i][1], vv), 2)
                past_kv_cache += [(kk, vv)]

            elif kv_type in ['weightedbw', 'bw', 'balancedwalk', 'uniform']:
                itrs = model.config.itrs
                compress_size = model.config.compress_size
                sink_size = model.config.sink_size

                key_old = kv_cache[i][0]
                val_old = kv_cache[i][1]
                kk = torch.cat((key_old, kk), dim=2)
                vv = torch.cat((val_old, vv), dim=2)
                needle = kv_cache[i][2]
                needle_used = needle
                needle_used = needle_used.repeat_interleave(qq.shape[1]//kk.shape[1], dim=1).transpose(-1, -2)
                needle_used = needle_used.unsqueeze(0)

                qk = qq @ repeat_kv(kk, qq.shape[1]//kk.shape[1]).transpose(-1,-2) / d**0.5
                bias = torch.zeros_like(qk)
                bias[:, :, :, sink_size:sink_size+compress_size] = needle_used
                attn_output = ((qk + bias).softmax(dim=-1) @ repeat_kv(vv, qq.shape[1]//vv.shape[1])).transpose(1,2)
                past_kv_cache += [(kk, vv, needle)]


            if kv_type in ['exact']:
                attn_output = flash_attn_func(qq.transpose(1,2), kk.transpose(1,2), vv.transpose(1,2), causal=True)


        attn_output = attn_output.contiguous().view(qq.shape[0], qq.shape[2], -1)
        hh = decoder_layer.self_attn.o_proj(attn_output)

        hh = res + hh

        res = hh
        hh = decoder_layer.post_attention_layernorm(hh)
        hh = decoder_layer.mlp(hh)
        hh = res + hh

    hidden_states = model.model.norm(hh) 
    logits = model.lm_head(hidden_states[:, -num_logits_to_keep:, :])
    return logits, past_kv_cache, attention_mask


@torch.no_grad()
def greedy_generate(self, input_ids, max_new_tokens, kv_type, eos_token_id=[128009], return_dict_in_generate=False):
    position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device).unsqueeze(0)
    logits, cache, attention_mask = manual_forward_llama(self, input_ids, position_ids=position_ids, num_logits_to_keep=1, kv_type=kv_type)
    pred_token_idx = logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    generated_ids = [pred_token_idx.item()]
    for i in range(max_new_tokens - 1):
        position_ids = torch.tensor(input_ids.shape[-1] + i, device=input_ids.device).reshape(1,1)
        logits, cache, _ = manual_forward_llama(self, pred_token_idx, position_ids=position_ids, attention_mask=attention_mask, kv_cache=cache, num_logits_to_keep=1, kv_type=kv_type)
        pred_token_idx = logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        generated_ids.append(pred_token_idx.item())
        if pred_token_idx in eos_token_id:
            break
    sequences = torch.tensor(input_ids[0].tolist() + generated_ids, device=input_ids.device).unsqueeze(0)
    if not return_dict_in_generate:
        return sequences
    return GenerateDecoderOnlyOutput(sequences=sequences, past_key_values=cache)