import os
import argparse
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from safetensors.torch import load_file
from model.buddy_model import BuddyForCausalLM
from tqdm import tqdm
from functools import partial
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from utils.metric import get_test_loaders
from matplotlib.colors import LinearSegmentedColormap

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.manual_seed(20250925)


def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument("--num_remove_blocks", type=int, default=None, help="num_remove_blocks")
    parser.add_argument('--peft', type=str, default=None, help='peft path')
    parser.add_argument('--tasks', type=str, default="ptb,wikitext2", help='data name')
    parser.add_argument('--output_path', type=str, default="", help='output_path')
    parser.add_argument('--cutoff_len', type=int, default=1024, help='cutoff length')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size')

    # buddy
    parser.add_argument("--lambda_reg", type=float, default=0.1, help="lambda_reg`")
    parser.add_argument("--sensitivity_path", type=str, default="", help="sensitivity_path")
    parser.add_argument("--sensitivity_type", type=str, default="", help="sensitivity_type")

    args = parser.parse_args()
    return args


def decision_forward_hook(module, input, output, decisions):
    decisions.append(output)


def load_model(args):
    # load model
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    kwargs = {}

    model = BuddyForCausalLM.from_pretrained(args.base_model)
    if args.sensitivity_type != "None":
        from utils.sensitivity_utils import read_pre_sensitivity
        pre_sensitivity = read_pre_sensitivity(
            path=args.sensitivity_path,
            type=args.sensitivity_type
        )
        model.set_sensitivity(pre_sensitivity, args.lambda_reg)

    model = model.to(torch.bfloat16)
    model = PeftModel.from_pretrained(model, args.peft)

    budgets = 1.0 - (args.num_remove_blocks / model.model.config.num_hidden_layers)
    kwargs["budgets"] = budgets

    router_path = args.peft + "/router_weights.safetensors"
    router_weights = load_file(router_path)
    state_dict = model.state_dict()
    state_dict.update(router_weights)
    model.load_state_dict(state_dict)

    model.eval()
    model.to("cuda")

    model = model.to(device)
    model.eval()

    return model, tokenizer, kwargs


@torch.no_grad()
def main(args):
    model, tokenizer, kwargs = load_model(args)
    datasets = args.tasks.split(",")

    # infer once
    # input_text = [
    #     "Hello, this is a test sequence for FLOPS estimation.",
    #     "Hello, this is a test sequence for FLOPS estimation.",
    # ]
    # inputs = tokenizer(input_text, return_tensors="pt", padding="max_length", max_length=256).to(device)
    # output = model(inputs["input_ids"], budgets=kwargs["budgets"])

    decisions = []
    router = model.base_model.model.model.router
    router.register_forward_hook(partial(decision_forward_hook, decisions=decisions))

    results = {}

    for dataset in datasets:
        decisions.clear()
        test_loader = get_test_loaders(dataset, tokenizer, args.cutoff_len, args.batch_size)

        bar_format = "Calculating decision for " + dataset + ":" + "{l_bar}{bar}{r_bar}"
        for batch in tqdm(test_loader, bar_format=bar_format, ncols=100):
            batch = batch.to(device)
            # print(batch)
            output = model(batch, budgets=kwargs["budgets"])

        decisions_tensor = torch.cat(decisions, dim=0)
        # print(decisions_tensor.shape)
        results[dataset] = decisions_tensor.mean(dim=0).tolist()

    # save results
    file_name = f"decision_result_{args.num_remove_blocks}.json"

    with open(args.output_path + file_name, "w+") as file:
        result = json.dumps(results)
        file.write(result)


def plot_results(args):
    # save results
    file_name = f"decision_result_{args.num_remove_blocks}.json"
    with open(args.output_path + file_name, "r") as file:
        data = json.loads(file.read())

    # 转换为矩阵
    benchmarks = list(data.keys())
    layers = np.arange(1, 31)  # Layer 1~30
    matrix = np.array([data[bench] for bench in benchmarks])

    # 自定义颜色渐变
    colors = ["#0070C0", "#FFFF00", "#C91D32"]  # 蓝 -> 白 -> 红
    cmap = LinearSegmentedColormap.from_list("custom", colors)

    # 绘制热力图
    plt.figure(figsize=(12, 3))
    sns.heatmap(
        matrix,
        xticklabels=layers,
        yticklabels=benchmarks,
        # cmap=cmap,
        # cmap="RdYlBl_r",
        cmap="RdYlGn_r",
        annot=False,
        cbar_kws={
            'label': 'Value',
        },
        linewidths=0.1
    )
    plt.title(f"Layer Decisions on Benchmarks Heatmap of {args.num_remove_blocks} Layers Removal")
    plt.xlabel("Layer Index")
    plt.ylabel("Benchmarks")

    plt.tight_layout(pad=0.01)

    plt.savefig(f"{args.output_path}/decision_result_{args.num_remove_blocks}.png")
    plt.savefig(f"{args.output_path}/decision_result_{args.num_remove_blocks}.pdf")
    plt.show()


if __name__ == "__main__":
    args = parse_args()
    # main(args)
    plot_results(args)
