from typing import Optional, Tuple
import torch
import types
from torch import nn
from transformers.models.qwen2.modeling_qwen2 import (Qwen2Attention, repeat_kv, apply_rotary_pos_emb)
from transformers.cache_utils import Cache
import math
import pandas as pd
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import DynamicCache
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
import time

big_model = None  
mapping_attn = None
mapping_dict = None 
past_key_values_big = None
heavy_budget_ratio, recent_budget_ratio, compensate_budget_ratio = 0, 0, 0

def qwen2_cc_attention_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()
    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

    if position_embeddings is None:
        cos, sin = self.rotary_emb(value_states, position_ids)
    else:
        cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    # repeat k/v heads if n_kv_heads < n_heads
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    # Modify for gsm8k test
    query_states = query_states.to(dtype=torch.float32)
    key_states = key_states.to(dtype=torch.float32)
    
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    if attention_mask is not None:  # no matter the length, we just slice it
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
    
    # Heavy + Recent + Maginal
    heavy_budget = int(heavy_budget_ratio * attn_weights.shape[-1])
    recent_budget = int(recent_budget_ratio * attn_weights.shape[-1])
    compensate_budget = int(compensate_budget_ratio * attn_weights.shape[-1])

    # saliency shift compensate
    global mapping_attn
    tmp_weights = mapping_attn[self.layer_idx].to(attn_weights.device)
    tmp_sum = torch.sum(tmp_weights, dim=-2).cuda()
    _, all_topk = tmp_sum.topk(k=heavy_budget+compensate_budget, dim=-1)
    heavy_topk = all_topk[:,:heavy_budget]
    '''
    # test simularity
    test_sum = torch.sum(attn_weights, dim=-2) 
    _, test_topk = test_sum.topk(k=heavy_budget, dim=-1)
    res_inter = 0
    for i in range(attn_weights.shape[1]):
        set1, set2 = set(tmp_topk[0,i].cpu().numpy()), set(test_topk[0,i].cpu().numpy())
        intersection_size = len(set1.intersection(set2))
        intersection_ratio = intersection_size/heavy_budget
        res_inter += intersection_ratio
    print(f"Layer {self.layer_idx} mean intersection ratio is {res_inter/attn_weights.shape[1]}")
    '''

    zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
    mask_bottom = zeros.scatter(-1, heavy_topk, True).unsqueeze(2)
    mask_bottom = mask_bottom.expand(mask_bottom.shape[0], mask_bottom.shape[1], attn_weights.shape[-2], mask_bottom.shape[-1]).to(attn_weights.device)

    ones = torch.ones_like(attn_weights, dtype=torch.bool).to(attn_weights.device)
    ones = torch.tril(ones, diagonal=recent_budget)
    ones = torch.triu(ones, diagonal=-recent_budget)

    mask_bottom = torch.logical_or(mask_bottom, ones)
    attn_weights[~mask_bottom] = 0

    # marginal compensate part
    if compensate_budget != 0:
        compensate_topk = all_topk[:,heavy_budget:]
        compensate_zeros = torch.zeros_like(tmp_sum, dtype=torch.bool)
        compensate_mask_bottom = compensate_zeros.scatter(-1, compensate_topk, True).unsqueeze(2)
        compensate_mask_bottom = compensate_mask_bottom.expand(compensate_mask_bottom.shape[0], compensate_mask_bottom.shape[1], attn_weights.shape[-2], compensate_mask_bottom.shape[-1]).to(attn_weights.device)

        compensate_ones = ~ones.to(attn_weights.device)
        compensate_mask_bottom = torch.logical_and(compensate_mask_bottom, compensate_ones)
        tmp_weights[~compensate_mask_bottom] = 0
        tmp_weights = tmp_weights.to(attn_weights.dtype)

        attn_weights = attn_weights + tmp_weights

    attn_output = torch.matmul(attn_weights, value_states)
    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None
        
    return attn_output, attn_weights, past_key_value

def enable_qwen2_normal(model):
    
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            enable_qwen2_normal(module)

        if isinstance(module, Qwen2Attention):
            model._modules[name].forward = types.MethodType(
                normal_forward, model._modules[name]
            )

def enable_qwen2_cc(model):

    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            enable_qwen2_cc(module)

        if isinstance(module, Qwen2Attention):
            model._modules[name].forward = types.MethodType(
                qwen2_cc_attention_forward, model._modules[name]
            )

def enable_qwen2_small_model(small_model, big_model_ar, config):
    global mapping_dict, big_model
    global heavy_budget_ratio, recent_budget_ratio, compensate_budget_ratio
    heavy_budget_ratio = config.heavy_budget_ratio
    recent_budget_ratio = config.recent_budget_ratio
    compensate_budget_ratio = config.compensate_budget_ratio

    big_model = big_model_ar
    mapping_dict = None
    small_model.forward = types.MethodType(sm_forward, small_model)
    return


def extract_lower_triangular(matrix):
    lower_tri = torch.tril(matrix, diagonal=0)
    mask = torch.tril(torch.ones_like(matrix, dtype=torch.bool), diagonal=0)
    return lower_tri[mask]

def calculate_jaccard_similarity(set1, set2):
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union != 0 else 0

def get_topk_column_indices(attentions, k=5):

    topk_indices_list = []
    for att in attentions:
        att_cpu = att.cpu()
        att_squeezed = att_cpu.squeeze(0)
        column_sums = att_squeezed.sum(dim=(-2))
        _, topk_indices = torch.topk(column_sums, k, dim=1)
        topk_indices_list.append(topk_indices)
    
    return topk_indices_list

def calculate_similarity_optimized(res_a, res_b, k, seq_len):
    global mapping_dict
    device = 'cuda:0'
    indices_a = torch.cat([layer.to(device) for layer in res_a], dim=0) # layers * head * topk * topk
    indices_b = torch.cat([layer.to(device) for layer in res_b], dim=0)
    
    total_a = indices_a.shape[0]  
    total_b = indices_b.shape[0]
    
    mask_a = torch.zeros(total_a, seq_len, device=device)
    mask_a.scatter_(1, indices_a.long(), 1.0)
    mask_b = torch.zeros(total_b, seq_len, device=device)
    mask_b.scatter_(1, indices_b.long(), 1.0)
    
    intersection = torch.matmul(mask_a, mask_b.T)         
    union = 2 * k - intersection                        
    similarity = intersection / (union + 1e-12)  

    max_sim, max_indices = torch.max(similarity, dim=1)
    heads_per_layer_a = res_a[0].shape[0] if len(res_a) > 0 else 0
    heads_per_layer_b = res_b[0].shape[0] if len(res_b) > 0 else 0
    
    mapping_dict = {}
    for a_idx in range(total_a):
        b_idx = max_indices[a_idx].item()
        
        layer_a = a_idx // heads_per_layer_a
        head_a = a_idx % heads_per_layer_a
        
        layer_b = b_idx // heads_per_layer_b
        head_b = b_idx % heads_per_layer_b
        
        mapping_dict[f"{layer_a}_{head_a}"] = f"{layer_b}_{head_b}"

def calculate_similarity(res_a, res_b, k=60):

    global mapping_dict
    # similarities = {}
    for layer_idx_a in range(len(res_a)):
        for head_idx_a in range(res_a[layer_idx_a].shape[0]):
            set_a = set(res_a[layer_idx_a][head_idx_a].cpu().numpy())
            max_sim = 0
            
            for layer_idx_b in range(len(res_b)):
                for head_idx_b in range(res_b[layer_idx_b].shape[0]):
                    set_b = set(res_b[layer_idx_b][head_idx_b].cpu().numpy())
                    
                    intersection = set_a.intersection(set_b)
                    union = set_a.union(set_b)
                    if len(union) > 0:
                        similarity = len(intersection) / len(union)
                    else:
                        similarity = 0.0  

                    if similarity > max_sim:
                        max_sim = similarity
                        mapping_dict[f"{layer_idx_a}_{head_idx_a}"] = f"{layer_idx_b}_{head_idx_b}"



def establish_mapping(small_attn, big_attn):

    global mapping_dict
    mapping_dict = dict()
    small_layer, big_layer = len(small_attn), len(big_attn)
    small_head, big_head = small_attn[0].shape[1], big_attn[0].shape[1]
    #print(f"Establish mapping from {small_layer} * {small_head} to {big_layer} * {big_head}")

    '''
    # method1
    for big_layer_idx in range(big_layer):
        for big_head_idx in range(big_head):
            similarities = F.cosine_similarity(big_attn_tensors[big_layer_idx, big_head_idx].unsqueeze(0), 
                                               small_attn_tensors.view(-1, small_attn_tensors.size(-1)))
            max_sim, max_idx = torch.max(similarities, dim=0)
            small_layer_idx, small_head_idx = divmod(max_idx.item(), small_head)            
            mapping_dict[f"{big_layer_idx}_{big_head_idx}"] = f"{small_layer_idx}_{small_head_idx}"
    '''
    '''
    # method2
    def preprocess_attention(attn):
        attn_tensors = torch.cat([layer.unsqueeze(0) for layer in attn], dim=0).view(len(attn), attn[0].shape[1], -1, attn[0].shape[-1])
        flat_tensors = []
        for layer_idx in range(attn_tensors.shape[0]):
            for head_idx in range(attn_tensors.shape[1]):
                # all score
                #flat_tensors.append(extract_lower_triangular(attn_tensors[layer_idx, head_idx]))
                # accumalated acore
                flat_tensors.append(torch.sum(attn_tensors[layer_idx, head_idx], dim=-2))
        return torch.stack(flat_tensors)

    small_attn_flat = preprocess_attention(small_attn)
    big_attn_flat = preprocess_attention(big_attn)
    #torch.save(small_attn_flat, 'small_attn_flat.pt')
    #torch.save(big_attn_flat, 'big_attn_flat.pt')

    for big_layer_idx in range(big_layer):
        for big_head_idx in range(big_head):
            similarities = F.cosine_similarity(big_attn_flat[big_layer_idx * big_head + big_head_idx].unsqueeze(0), 
                                            small_attn_flat)
            max_sim, max_idx = torch.max(similarities, dim=0)
            #print(f"max sim: {max_sim}")
            small_layer_idx, small_head_idx = divmod(max_idx.item(), small_head)
            mapping_dict[f"{big_layer_idx}_{big_head_idx}"] = f"{small_layer_idx}_{small_head_idx}"
    '''

    
    # method3
    k = int(small_attn[0].shape[-1] * 0.2)
    res_s = get_topk_column_indices(small_attn, k)
    res_b = get_topk_column_indices(big_attn, k)
    calculate_similarity_optimized(res_b, res_s, k, small_attn[0].size(-1))
    #print(mapping_dict)


def extract_mapping_attn(small_attn):

    global mapping_dict
    global mapping_attn
    big_layer = len(set([key.split('_')[0] for key in mapping_dict.keys()]))
    big_head = len(set([key.split('_')[1] for key in mapping_dict.keys()]))


    mapping_attn = []
    for _ in range(big_layer):
        mapping_attn.append(torch.zeros((small_attn[0].shape[0], big_head, small_attn[0].shape[2], small_attn[0].shape[3])))

    for big_key, small_key in mapping_dict.items():
        big_layer_idx, big_head_idx = map(int, big_key.split('_'))
        small_layer_idx, small_head_idx = map(int, small_key.split('_'))

        small_matrix = small_attn[small_layer_idx][:, small_head_idx, :, :]
        mapping_attn[big_layer_idx][:, big_head_idx, :, :] = small_matrix

    mapping_attn = [layer.unsqueeze(0) for layer in mapping_attn]
    mapping_attn = torch.cat(mapping_attn, dim=0)


def sm_forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    num_logits_to_keep: int = 0,
    **loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:

    global mapping_dict, big_model, past_key_values_big, mapping_attn

    if past_key_values is None or past_key_values.get_seq_length() == 0:
        print("Not Establish mapping, will do now")
        mapping_dict = None
        mapping_attn = None

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=True,
        output_attentions=True,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        cache_position=cache_position,
    )
    
    hidden_states = outputs[0]
    # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
    logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

    loss = None
    #if labels is not None:
    #    loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output
    
    if mapping_dict is None:
        past_key_values_big = None
        print(f"Building mapping on {input_ids.size(1)} tokens in prefill stage")
        enable_qwen2_normal(big_model)
        outputs_big = big_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            labels=input_ids,
            output_attentions=True,
            use_cache=True,
            past_key_values=past_key_values_big,
        )
        establish_mapping(outputs.attentions, outputs_big.attentions)
        enable_qwen2_cc(big_model)
    else:
        extract_mapping_attn(outputs.attentions)
        outputs_big = big_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            labels=input_ids,
            output_attentions=True,
            use_cache=True,
            past_key_values=past_key_values_big,
        )
    past_key_values_big = outputs_big.past_key_values

    return CausalLMOutputWithPast(
        loss=outputs_big.loss,
        logits=outputs_big.logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs_big.hidden_states,
        attentions=outputs_big.attentions,
    )


def normal_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    if position_embeddings is None:
        cos, sin = self.rotary_emb(value_states, position_ids)
    else:
        cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    # repeat k/v heads if n_kv_heads < n_heads
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    if attention_mask is not None:  # no matter the length, we just slice it
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value
