import os
import argparse
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from safetensors.torch import load_file
from model.buddy_model import BuddyForCausalLM
from tqdm import tqdm
from functools import partial
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from utils.metric import get_test_loaders
from torch.distributions import Categorical
import torch.nn.functional as F

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.manual_seed(20250925)


def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument('--peft', type=str, default=None, help='peft path')
    parser.add_argument('--tasks', type=str, default="ptb,wikitext2", help='data name')
    parser.add_argument('--output_path', type=str, default="", help='output_path')
    parser.add_argument('--cutoff_len', type=int, default=1024, help='cutoff length')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size')

    # buddy
    parser.add_argument("--lambda_reg", type=float, default=0.1, help="lambda_reg`")
    parser.add_argument("--sensitivity_path", type=str, default="", help="sensitivity_path")
    parser.add_argument("--sensitivity_type", type=str, default="", help="sensitivity_type")

    args = parser.parse_args()
    return args


def budget_predictor_forward_hook(module, input, output, decisions):
    # print(f"module={module}, input={input}, output={output}")

    # probs = F.softmax(output, dim=-1)
    # probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-8)
    # dist = Categorical(probs=probs)
    # k_idx = dist.sample()

    k_idx = output.argmax(dim=-1)  # [B], 0..K-1
    decisions.append(k_idx + 1)


def load_model(args):
    # load model
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    kwargs = {}

    model = BuddyForCausalLM.from_pretrained(args.base_model)
    model.model.train_mode = "predictor"
    if args.sensitivity_type != "None":
        from utils.sensitivity_utils import read_pre_sensitivity
        pre_sensitivity = read_pre_sensitivity(
            path=args.sensitivity_path,
            type=args.sensitivity_type
        )
        model.set_sensitivity(pre_sensitivity, args.lambda_reg)

    model = model.to(torch.bfloat16)
    model = PeftModel.from_pretrained(model, args.peft)
    kwargs["budgets"] = None

    router_path = args.peft + "/router_weights.safetensors"
    router_weights = load_file(router_path)
    state_dict = model.state_dict()
    state_dict.update(router_weights)
    model.load_state_dict(state_dict)

    model.eval()
    model.to("cuda")

    model = model.to(device)
    model.eval()

    return model, tokenizer, kwargs


@torch.no_grad()
def main(args):
    model, tokenizer, kwargs = load_model(args)
    datasets = args.tasks.split(",")

    budgets = []
    budget_predictor = model.base_model.model.model.budget_predictor
    budget_predictor.register_forward_hook(partial(budget_predictor_forward_hook, decisions=budgets))

    results = {}

    for dataset in datasets:
        budgets.clear()
        test_loader = get_test_loaders(dataset, tokenizer, args.cutoff_len, args.batch_size)

        bar_format = "Calculating decision for " + dataset + ":" + "{l_bar}{bar}{r_bar}"
        for batch in tqdm(test_loader, bar_format=bar_format):
            batch = batch.to(device)
            output = model(batch, budgets=kwargs["budgets"])

        budgets_tensor = torch.cat(budgets, dim=0).float()
        results[dataset] = budgets_tensor.mean(dim=0).tolist()

    # save results
    file_name = f"budget_result.json"

    with open(args.output_path + file_name, "w+") as file:
        result = json.dumps(results)
        file.write(result)


def plot_results(args):
    # save results
    file_name = f"decision_result.json"
    with open(args.output_path + file_name, "r") as file:
        data = json.loads(file.read())

    # 转换为矩阵
    benchmarks = list(data.keys())
    layers = np.arange(1, 31)  # Layer 1~30
    matrix = np.array([data[bench] for bench in benchmarks])

    # 绘制热力图
    plt.figure(figsize=(12, 3))
    sns.heatmap(
        matrix,
        xticklabels=layers,
        yticklabels=benchmarks,
        cmap="RdYlGn_r",
        annot=False,
        cbar_kws={
            'label': 'Value',
        },
        linewidths=0.1
    )
    plt.title(f"Layer Decisions on Benchmarks Heatmap of Layers Removal")
    plt.xlabel("Layer Inex")
    plt.ylabel("Benchmarks")

    plt.tight_layout(pad=0.01)

    plt.savefig(f"{args.output_path}/decision_result.pdf")
    plt.show()


if __name__ == "__main__":
    args = parse_args()
    main(args)
    # plot_results(args)
