import os
import argparse
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from safetensors.torch import load_file
from model.buddy_model import BuddyForCausalLM
from tqdm import tqdm
from functools import partial
from datasets import load_dataset
from utils.prompter import AlpacaPrompter, SamSumPrompter
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import evaluate

_rouge = evaluate.load("rouge")
# _bertscore = evaluate.load("bertscore")

device = "cuda:0" if torch.cuda.is_available() else "cpu"


# torch.manual_seed(20250925)


def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument("--num_remove_blocks", type=int, default=None, help="num_remove_blocks")
    parser.add_argument('--peft', type=str, default=None, help='peft path')
    parser.add_argument('--tasks', type=str, default="ptb,wikitext2", help='data name')
    parser.add_argument('--output_path', type=str, default="", help='output_path')
    parser.add_argument('--cutoff_len', type=int, default=1028, help='cutoff length')

    # buddy
    parser.add_argument("--lambda_reg", type=float, default=0.1, help="lambda_reg`")
    parser.add_argument("--sensitivity_path", type=str, default="", help="sensitivity_path")
    parser.add_argument("--sensitivity_type", type=str, default="", help="sensitivity_type")

    args = parser.parse_args()
    return args


def decision_forward_hook(module, input, output, decisions):
    decisions.append(output.tolist())


def load_model(args):
    # load model
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    kwargs = {}

    model = BuddyForCausalLM.from_pretrained(args.base_model)
    if args.sensitivity_type != "None":
        from utils.sensitivity_utils import read_pre_sensitivity
        pre_sensitivity = read_pre_sensitivity(
            path=args.sensitivity_path,
            type=args.sensitivity_type
        )
        model.set_sensitivity(pre_sensitivity, args.lambda_reg)

    model = model.to(torch.bfloat16)
    model = PeftModel.from_pretrained(model, args.peft)

    budgets = 1.0 - (args.num_remove_blocks / model.model.config.num_hidden_layers)
    kwargs["budgets"] = budgets

    router_path = args.peft + "/router_weights.safetensors"
    router_weights = load_file(router_path)
    state_dict = model.state_dict()
    state_dict.update(router_weights)
    model.load_state_dict(state_dict)

    model = model.to(device)
    model.eval()

    return model, tokenizer, kwargs


def load_eval_dataset(name):
    if name == "alpaca":
        data = load_dataset("yahma/alpaca-cleaned")
        train_val = data["train"].train_test_split(
            # test_size=2000, shuffle=False, seed=42
            test_size=100, shuffle=False, seed=42
        )
        data = train_val["test"]
        prompter = AlpacaPrompter()

    elif name == "samsum":
        data = load_dataset("knkarthick/samsum", split="validation[:100]")
        prompter = SamSumPrompter()
    else:
        raise Exception

    return data, prompter


######## 做推理的时候的路径变化测试 ########
def decode_in_dataset(dataset_name, model, tokenizer, decisions, **kwargs):
    data, prompter = load_eval_dataset(dataset_name)

    results = []
    bar_format = "Calculating decision for " + dataset_name + ":" + "{l_bar}{bar}{r_bar}"
    for i in tqdm(range(len(data)), bar_format=bar_format):
        decisions.clear()

        item = data[i]

        prompter_item = prompter.generate_prompt(item)
        full_prompt = prompter_item["prompt"]
        inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
        output_tokens = model.generate(
            **inputs,
            max_new_tokens=args.cutoff_len,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            eos_token_id=None,  # <-- 关键：不用 eos 作为停止标记
            **kwargs
        )
        results.append(decisions.copy())

    return results


def get_decisions(args):
    model, tokenizer, kwargs = load_model(args)
    dataset_names = ["alpaca", "samsum"]

    decisions = []
    router = model.base_model.model.model.router
    router.register_forward_hook(partial(decision_forward_hook, decisions=decisions))

    results = {}
    for dataset_name in dataset_names:
        result = decode_in_dataset(dataset_name, model, tokenizer, decisions, **kwargs)
        results[dataset_name] = result

    # save results
    file_name = f"decode_result_{args.num_remove_blocks}.json"

    with open(args.output_path + file_name, "w+") as file:
        result = json.dumps(results)
        file.write(result)


def plot_decisions(args):
    keys = [4, 8, 12, 16]
    alpaca = []
    samsum = []

    for rm_blocks in keys:
        # file_name = f"decode_result_{args.num_remove_blocks}.json"
        file_name = f"decode_result_{rm_blocks}.json"
        with open(args.output_path + file_name, "r") as file:
            data = json.load(file)
            alpaca_data = np.array(data["alpaca"]).squeeze(axis=2)
            samsum_data = np.array(data["samsum"]).squeeze(axis=2)

        result_alpaca = []
        for item in alpaca_data:
            _, unique_indices = np.unique(item, axis=0, return_index=True)
            result_alpaca.append(unique_indices.shape[0])

        result_samsum = []
        for item in samsum_data:
            _, unique_indices = np.unique(item, axis=0, return_index=True)
            result_samsum.append(unique_indices.shape[0])

        # print(f"{args.num_remove_blocks}: result_alpaca={result_alpaca}, mean={np.array(result_alpaca).mean()}")
        # print(f"{args.num_remove_blocks}: result_samsum={result_samsum}, mean={np.array(result_samsum).mean()}")

        alpaca.append(np.array(result_alpaca).mean())
        samsum.append(np.array(result_samsum).mean())

    # ===== 画图参数 =====
    x = range(len(keys))
    width = 0.45  # 每组柱宽

    plt.figure(figsize=(12, 8))

    # 两组柱子左右错位
    bars1 = plt.bar([i - width / 2 for i in x], alpaca, width, label="Alpaca", zorder=2, color="#A3D4D5")
    bars2 = plt.bar([i + width / 2 for i in x], samsum, width, label="Samsum", zorder=2, color="#F7CCAD")

    # 坐标与标题
    plt.xticks(list(x), keys)
    plt.xlabel("Remove Blocks", fontsize=28)
    plt.ylabel("Number Of Reasoning Paths", fontsize=28)
    plt.ylim(0, max(max(alpaca), max(samsum)) + 0.5)
    plt.tick_params(axis='both', labelsize=28)
    # plt.title("The number the inference path in the Alpaca and Samsum", fontsize=22)
    plt.legend(fontsize=28, loc="best")
    plt.grid(alpha=0.3, linestyle="--")

    # ===== 在柱顶添加数值标签 =====
    def add_labels(bars):
        for bar in bars:
            h = bar.get_height()
            plt.text(bar.get_x() + bar.get_width() / 2, h, f"{h:.2f}",
                     ha="center", va="bottom", fontsize=28)

    add_labels(bars1)
    add_labels(bars2)

    plt.tight_layout()
    plt.savefig(args.output_path + "/decode_decision_bar.png", dpi=300)
    plt.savefig(args.output_path + "/decode_decision_bar.pdf", dpi=300)
    plt.show()


######## 评估推理过程中是否启动推理路径变化对性能的影响 ########
def rouge_scores(ref: str, hyp: str):
    rouge = _rouge.compute(predictions=[hyp], references=[ref], use_stemmer=True)
    # bert = _bertscore.compute(predictions=[hyp], references=[ref], lang="en")
    out = {
        "rouge1": rouge["rouge1"],
        "rouge2": rouge["rouge2"],
        "rougeL": rouge["rougeL"],
        # "bertscore_f1": float(bert["f1"][0]),
    }
    return out


def get_acc(args):
    model, tokenizer, kwargs = load_model(args)
    dataset_name = "samsum"
    data, prompter = load_eval_dataset(dataset_name)

    def calculate_rouge_scores():
        scores = []
        bar_format = "Calculating Rouge scores for " + dataset_name + ":" + "{l_bar}{bar}{r_bar}"
        for i in tqdm(range(len(data)), bar_format=bar_format):
            item = data[i]
            prompter_item = prompter.generate_prompt(item)
            full_prompt = prompter_item["prompt"]
            inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
            output_tokens = model.generate(
                **inputs,
                max_new_tokens=args.cutoff_len,
                pad_token_id=tokenizer.eos_token_id,
                do_sample=True,
                temperature=2.0,
                top_k=5,
                **kwargs
            )

            output_tokens = output_tokens[0][inputs.input_ids.shape[1]:]
            output_text = tokenizer.decode(output_tokens, skip_special_tokens=True)

            score = rouge_scores(prompter_item["answer"], output_text)
            scores.append(score)
        return scores

    model.decision_reuse = True
    reuse_scores = calculate_rouge_scores()
    model.decision_reuse = False
    no_reuse_scores = calculate_rouge_scores()

    # save results
    file_path = f"{args.output_path}/decode_acc_{args.num_remove_blocks}.json"

    # if os.path.exists(file_path):
    #     with open(file_path, "r") as file:
    #         results = json.load(file)
    # else:
    #     results = {}

    with open(file_path, "w") as file:
        results = {
            "reuse_scores": reuse_scores,
            "no_reuse_scores": no_reuse_scores,

        }
        results = json.dumps(results)
        file.write(results)


def plot_acc(args):
    # 数据
    keys = [4, 8, 12, 16]
    metric_list = ["rouge1", "rougeL"]

    all_data = {}
    for rm_blocks in keys:
        file_name = f"{args.output_path}/decode_acc_{rm_blocks}.json"
        with open(file_name, "r") as file:
            data = json.load(file)

        reuse_scores = data["reuse_scores"]
        no_reuse_scores = data["no_reuse_scores"]

        results = {
            "reuse_scores": {key: [] for key in metric_list},
            "no_reuse_scores": {key: [] for key in metric_list},
        }

        for item in reuse_scores:
            for metric in metric_list:
                results["reuse_scores"][metric].append(item[metric])

        for item in no_reuse_scores:
            for metric in metric_list:
                results["no_reuse_scores"][metric].append(item[metric])

        for metric in metric_list:
            results["reuse_scores"][metric] = np.array(results["reuse_scores"][metric]).mean()
            results["no_reuse_scores"][metric] = np.array(results["no_reuse_scores"][metric]).mean()

        all_data[rm_blocks] = results

    colors = ["#A3D4D5", "#F7CCAD", "#FFF2CC", "#C2DEE6"]

    series_data = [
        [all_data[k]["reuse_scores"]["rouge1"] for k in keys],
        [all_data[k]["no_reuse_scores"]["rouge1"] for k in keys],
        [all_data[k]["reuse_scores"]["rougeL"] for k in keys],
        [all_data[k]["no_reuse_scores"]["rougeL"] for k in keys],
    ]

    reuse_r1, noreuse_r1, reuse_rl, noreuse_rl = series_data

    x = np.arange(len(keys))
    w = 0.44

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 8), sharey=True)

    # ===== 辅助：在柱顶添加标签 =====
    def add_bar_labels(ax, bars, fmt="{:.3f}", dy=0.005, fontsize=28):
        """在每个 bar 顶部添加文本；dy 是向上偏移，防止贴边。"""
        max_h = 0.0
        for b in bars:
            h = b.get_height()
            max_h = max(max_h, h)
            ax.text(b.get_x() + b.get_width() / 2, h + dy, fmt.format(h),
                    ha="center", va="bottom", fontsize=fontsize)
        # 给 y 轴留一点顶部空间，防止遮挡
        ymin, ymax = ax.get_ylim()
        if ymax < max_h * 1.10:
            ax.set_ylim(ymin, max_h * 1.10)

    # (a) ROUGE-1
    bars_r1_reuse = ax1.bar(x - w / 2, reuse_r1, width=w, label="Reuse", color="#FFF2CC", zorder=2)
    bars_r1_noreuse = ax1.bar(x + w / 2, noreuse_r1, width=w, label="Recompute", color="#C2DEE6", zorder=2)
    ax1.set_xticks(x, keys)
    ax1.set_xlabel("Remove Blocks", fontsize=28)
    ax1.set_ylabel("Rough Score", fontsize=28)
    ax1.set_ylim(0, 0.35)
    ax1.tick_params(axis='both', labelsize=28)
    # ax1.set_title("ROUGE-1 between Reuse/Recompute inference path on Samsum", fontsize=22)
    ax1.legend(fontsize=28, loc="best")
    ax1.grid(alpha=0.3, linestyle="--")

    add_bar_labels(ax1, bars_r1_reuse)
    add_bar_labels(ax1, bars_r1_noreuse)

    # (b) ROUGE-L
    bars_rl_reuse = ax2.bar(x - w / 2, reuse_rl, width=w, label="Reuse", color="#FFF2CC", zorder=2)
    bars_rl_noreuse = ax2.bar(x + w / 2, noreuse_rl, width=w, label="Recompute", color="#C2DEE6", zorder=2)
    ax2.set_xticks(x, keys)
    ax2.set_xlabel("Remove Blocks", fontsize=28)
    ax2.set_ylim(0, 0.35)
    ax2.tick_params(axis='both', labelsize=28)
    # ax2.set_title("ROUGE-L between Reuse/Recompute inference path on Samsum", fontsize=22)
    ax2.legend(fontsize=28, loc="best")
    ax2.grid(alpha=0.3, linestyle="--")

    add_bar_labels(ax2, bars_rl_reuse)
    add_bar_labels(ax2, bars_rl_noreuse)

    fig.tight_layout()
    plt.savefig(args.output_path + "decode_acc_bar.png", dpi=300)
    plt.savefig(args.output_path + "decode_acc_bar.pdf", dpi=300)
    plt.show()

    return

    # keys = [4, 8, 12, 16]
    # metric_list = ["rouge1", "rougeL"]
    #
    # all_data = {}
    # for rm_blocks in keys:
    #     file_name = f"{args.output_path}/decode_acc_{rm_blocks}.json"
    #     with open(file_name, "r") as file:
    #         data = json.load(file)
    #
    #     reuse_scores = data["reuse_scores"]
    #     no_reuse_scores = data["no_reuse_scores"]
    #
    #     results = {
    #         "reuse_scores": {key: [] for key in metric_list},
    #         "no_reuse_scores": {key: [] for key in metric_list},
    #     }
    #
    #     for item in reuse_scores:
    #         for metric in metric_list:
    #             results["reuse_scores"][metric].append(item[metric])
    #
    #     for item in no_reuse_scores:
    #         for metric in metric_list:
    #             results["no_reuse_scores"][metric].append(item[metric])
    #
    #     for metric in metric_list:
    #         results["reuse_scores"][metric] = np.array(results["reuse_scores"][metric]).mean()
    #         results["no_reuse_scores"][metric] = np.array(results["no_reuse_scores"][metric]).mean()
    #
    #     all_data[rm_blocks] = results
    #
    # # 每个 key 下 4 根柱：reuse-R1, no_reuse-R1, reuse-RL, no_reuse-RL
    # series_labels = ["Reuse-ROUGE-1", "No_reuse-ROUGE-1", "Reuse-ROUGE-L", "No_reuse-ROUGE-L"]
    # colors = ["#A3D4D5", "#F7CCAD", "#FFF2CC", "#C2DEE6"]
    #
    # series_data = [
    #     [all_data[k]["reuse_scores"]["rouge1"] for k in keys],
    #     [all_data[k]["no_reuse_scores"]["rouge1"] for k in keys],
    #     [all_data[k]["reuse_scores"]["rougeL"] for k in keys],
    #     [all_data[k]["no_reuse_scores"]["rougeL"] for k in keys],
    # ]
    #
    # # ===== 画图 =====
    # x = list(range(len(keys)))
    # group_width = 0.8  # 每组内总宽度
    # n_bars = len(series_data)  # 4
    # bar_width = group_width / n_bars
    #
    # def offsets(i):
    #     # 让每组 4 根柱在组中心左右居中排布
    #     start = -group_width / 2 + bar_width / 2
    #     return [i + start + j * bar_width for j in range(n_bars)]
    #
    # plt.figure(figsize=(24, 5))
    #
    # bars_all = []
    # for j, yvals in enumerate(series_data):
    #     xs = [offsets(i)[j] for i in x]
    #     bars = plt.bar(xs, yvals, width=bar_width, label=series_labels[j], zorder=2, color=colors[j])
    #     bars_all.append(bars)
    #
    # # 坐标与标题
    # plt.xticks(x, keys)
    # plt.xlabel("Remove Blocks", fontsize=16)
    # plt.ylabel("Rough Score", fontsize=16)
    # plt.ylim(0, 0.35)
    # plt.tick_params(axis='both', labelsize=14)
    # plt.title("Performance between Reuse/No_reuse inference path on Samsum", fontsize=20)
    # plt.legend(fontsize=14, loc="best")
    # plt.grid(alpha=0.3, linestyle="--")
    #
    # # 柱顶数值标签
    # for group in bars_all:
    #     for bar in group:
    #         h = bar.get_height()
    #         plt.text(bar.get_x() + bar.get_width() / 2, h, f"{h:.3f}",
    #                  ha="center", va="bottom", fontsize=14)
    #
    # plt.tight_layout()
    # plt.savefig(args.output_path + "decode_acc_bar.png", dpi=300)
    # plt.show()


if __name__ == "__main__":
    args = parse_args()
    # get_decisions(args)
    # get_acc(args)

    plot_decisions(args)
    plot_acc(args)
