import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy

def laco_merge_layers(model, merge_base_lay, merge_layer_num):

    merge_layer_num = min(merge_layer_num, len(model.model.layers) - merge_base_lay - 1)
    
    model_copy = deepcopy(model)
    for diff_lay in range(merge_base_lay+1, merge_base_lay+1+merge_layer_num):      
        # gate_proj
        model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data.add_(
            model.model.layers[diff_lay].mlp.gate_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.gate_proj.weight.data
        )
        # down_proj
        model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data.add_(
            model.model.layers[diff_lay].mlp.down_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.down_proj.weight.data
        )
        # up_proj
        model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data.add_(
            model.model.layers[diff_lay].mlp.up_proj.weight.data - model_copy.model.layers[merge_base_lay].mlp.up_proj.weight.data
        )
        # q_proj
        model_copy.model.layers[merge_base_lay].self_attn.q_proj.weight.data.add_(
            model.model.layers[diff_lay].self_attn.q_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.q_proj.weight.data
        )
        # k_proj
        model_copy.model.layers[merge_base_lay].self_attn.k_proj.weight.data.add_(
            model.model.layers[diff_lay].self_attn.k_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.k_proj.weight.data
        ) 
        # v_proj
        model_copy.model.layers[merge_base_lay].self_attn.v_proj.weight.data.add_(
            model.model.layers[diff_lay].self_attn.v_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.v_proj.weight.data
        )
        # o_proj
        model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data.add_(
            model.model.layers[diff_lay].self_attn.o_proj.weight.data - model_copy.model.layers[merge_base_lay].self_attn.o_proj.weight.data
        )        
                       
    for diff_lay in range(merge_base_lay+merge_layer_num, merge_base_lay, -1):
        del(model_copy.model.layers[diff_lay])

    return model_copy

def cal_last_hidden_sim(model1, model2, tokenizer, sents):

    sim_ls = []
    device = next(model1.parameters()).device

    for s in sents:
        encoded_inputs = tokenizer(s, return_tensors='pt').to(device)

        with torch.no_grad():
            outputs1 = model1(**encoded_inputs, output_hidden_states=True)
        hidden_states1 = outputs1.hidden_states[-1]

        with torch.no_grad():
            outputs2 = model2(**encoded_inputs, output_hidden_states=True)
        hidden_states2 = outputs2.hidden_states[-1]

        sim_ls.append(torch.cosine_similarity(
            hidden_states1.flatten(start_dim=1),
            hidden_states2.flatten(start_dim=1)
        ))

    sim_ls = [i.item() for i in sim_ls]
    print(sim_ls, np.mean(sim_ls))
    return np.mean(sim_ls)

def merge_laco(args, model, tokenizer, text, device):

    INTERVAL = args.laco_interval
    MERGE_LAYERS = args.laco_merge_layers
    HIGHEST_LAY = len(model.model.layers) - 1
    LOWEST_LAY = 0
    THRESHOLD = args.laco_threshold
    lay = HIGHEST_LAY - MERGE_LAYERS

    sents = []
    sents.extend(text)

    model_copy_to_compress = deepcopy(model)
    while lay >= LOWEST_LAY:
        print(lay)
        print('current model layer', len(model_copy_to_compress.model.layers))
        tmp_merged_model = laco_merge_layers(model_copy_to_compress, lay, MERGE_LAYERS-1)
        sim_value = cal_last_hidden_sim(model, tmp_merged_model, tokenizer, sents)
        if sim_value > THRESHOLD:
            print("Successfully merged layers from", lay, "to", lay + MERGE_LAYERS)
            model_copy_to_compress = tmp_merged_model
            lay -= INTERVAL
            if lay >= len(model_copy_to_compress.model.layers):
                lay = len(model_copy_to_compress.model.layers) - 1 - MERGE_LAYERS
        else:
            lay -= 1

    return model_copy_to_compress
