import os
import time

import argparse
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from baselines.FiRST.model import RouterSelectiveLlamaForCausalLM
from model.buddy_model import BuddyForCausalLM
from peft import PeftModel, PeftConfig
from datasets import load_dataset
from utils.prompter import AlpacaPrompter, SamSumPrompter
from datetime import datetime
from tqdm import tqdm
import pandas as pd
from utils.sensitivity.utils import get_block_pruned_network
from safetensors.torch import load_file
from peft import LoraConfig, get_peft_model
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(20250925)

def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument('--name', type=str, default="wo", help='model name')
    parser.add_argument("--num_remove_blocks", type=int, default=None, help="num_remove_blocks")
    parser.add_argument('--peft', type=str, default=None, help='peft path')
    parser.add_argument('--tasks', type=str, default="ptb,wikitext2", help='data name')
    parser.add_argument('--output_path', type=str, default="", help='output_path')
    parser.add_argument('--cutoff_len', type=int, default=1024, help='cutoff length')

    # buddy
    parser.add_argument("--lambda_reg", type=float, default=0.1, help="lambda_reg`")
    parser.add_argument("--sensitivity_path", type=str, default="", help="sensitivity_path")
    parser.add_argument("--sensitivity_type", type=str, default="", help="sensitivity_type")

    # pudding
    parser.add_argument("--layerset_path", type=str, default="", help="layerset_path`")

    # shorten, sleb, short
    parser.add_argument("--block_order_path", type=str, default="", help="block_order_path")

    args = parser.parse_args()
    return args

def load_eval_dataset(name):
    if name == "alpaca":
        data = load_dataset("yahma/alpaca-cleaned")
        train_val = data["train"].train_test_split(
            # test_size=2000, shuffle=False, seed=42
            test_size=100, shuffle=False, seed=42
        )
        data = train_val["test"]
        prompter = AlpacaPrompter()

    elif name == "samsum":
        data = load_dataset("knkarthick/samsum", split="validation[:100]")
        prompter = SamSumPrompter()
    else:
        raise Exception

    return data, prompter

def load_model(args):
    name = args.name

    # load model
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = AutoModelForCausalLM.from_pretrained(args.base_model)
    model.to(torch.bfloat16)
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    kwargs = {}

    if name == "unpruned":
        config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj".split(","),
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, config)

    elif name in ["shortened", "shortgpt", "sleb"]:
        block_order_path = args.block_order_path
        num_remove_blocks = args.num_remove_blocks
        unimportance_order = pd.read_csv(block_order_path).columns.tolist()
        unimportance_order = [int(i) for i in unimportance_order]

        model = get_block_pruned_network(
            model_orig=model,
            unimportance_order=unimportance_order,
            num_pruned_blocks=num_remove_blocks
        )
        model = model.to(torch.bfloat16)
        model = PeftModel.from_pretrained(model, args.peft)
    elif name == "pudding":
        from baselines.PuDDing.zero_shot_eval import PUDDINGLM
        model = PUDDINGLM(
            pretrained=args.base_model,
            peft=args.peft,
            num_remove_blocks=args.num_remove_blocks,
            layerset_path=args.layerset_path
        )
    elif name == "first":
        num_remove_blocks = args.num_remove_blocks

        model = RouterSelectiveLlamaForCausalLM.from_pretrained(args.base_model)
        router_lambda = 1.0
        target_keep_ratio = 1.0 - num_remove_blocks / model.config.num_hidden_layers
        print(target_keep_ratio)
        model.set_router_objective(router_lambda=router_lambda, target_keep_ratio=target_keep_ratio)

        model = model.to(torch.bfloat16)
        model = PeftModel.from_pretrained(model, args.peft)

        router_path = args.peft + "/router_weights.safetensors"
        router_weights = load_file(router_path)
        state_dict = model.state_dict()
        state_dict.update(router_weights)
        model.load_state_dict(state_dict)
    elif name == "buddy":
        model = BuddyForCausalLM.from_pretrained(args.base_model)
        if args.sensitivity_type != "None":
            from utils.sensitivity_utils import read_pre_sensitivity
            pre_sensitivity = read_pre_sensitivity(
                path=args.sensitivity_path,
                type=args.sensitivity_type
            )
            model.set_sensitivity(pre_sensitivity, args.lambda_reg)

        model = model.to(torch.bfloat16)
        model = PeftModel.from_pretrained(model, args.peft)

        budgets = 1.0 - (args.num_remove_blocks / model.model.config.num_hidden_layers)
        kwargs["budgets"] = budgets

        router_path = args.peft + "/router_weights.safetensors"
        router_weights = load_file(router_path)
        state_dict = model.state_dict()
        state_dict.update(router_weights)
        model.load_state_dict(state_dict)

        model.eval()
        model.to("cuda")

    model = model.to(device)
    model.eval()

    return model, tokenizer, kwargs

@torch.no_grad()
def warmup(model, tokenizer, warm_up_steps, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)

    prompt = "Can you explain Fermat’s Last Theorem?"
    bar_format = "Warming up " + "{l_bar}{bar}{r_bar}"
    for _ in tqdm(range(warm_up_steps), bar_format=bar_format):
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        model(**inputs, **kwargs)

def eval_prefill_speed(model, inputs, **kwargs):
    start_time = time.perf_counter()
    output = model(**inputs, **kwargs)

    time_cost = time.perf_counter() - start_time
    token_counts = inputs["input_ids"].numel()

    speed = token_counts / time_cost
    return speed

def eval_decode_speed(model, inputs, tokenizer, **kwargs):
    start_time = time.perf_counter()
    output_tokens = model.generate(
        **inputs,
        max_new_tokens=args.cutoff_len,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=False,
        eos_token_id=None,  # <-- 关键：不用 eos 作为停止标记
        num_beams=1,
        temperature=1.0,
        top_k=50,
        top_p=1.0,
        **kwargs
    )

    time_cost = time.perf_counter() - start_time
    token_counts = output_tokens.numel()

    speed = token_counts / time_cost
    return speed

def main(args):
    model, tokenizer, kwargs = load_model(args)
    # model.train()

    warmup(model, tokenizer, 100, **kwargs)

    metrics = {}
    for dataset_name in args.tasks.split(","):
        data, prompter = load_eval_dataset(dataset_name)

        speed_list_prefill = []
        speed_list_decode = []

        for i in tqdm(range(len(data))):
            item = data[i]
            prompter_item = prompter.generate_prompt(item)
            full_prompt = prompter_item["prompt"]
            inputs = tokenizer(full_prompt, return_tensors="pt").to(device)

            speed_prefill = eval_prefill_speed(model, inputs, **kwargs)
            speed_list_prefill.append(speed_prefill)
            speed_decode = eval_decode_speed(model, inputs, tokenizer, **kwargs)
            speed_list_decode.append(speed_decode)

        metrics[dataset_name] = {
            "speed_prefill": np.array(speed_list_prefill).mean(),
            "speed_decode": np.array(speed_list_decode).mean()
        }
    #     print(metrics)
    #
    # return

    os.makedirs(args.output_path, exist_ok=True)

    current_time = datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%dT%H-%M-%S.%f")
    file_name = f"speed_{formatted_time}.json"

    with open(args.output_path + file_name, "w+") as file:
        result = json.dumps(metrics)
        file.write(result)

if __name__ == "__main__":
    args = parse_args()
    main(args)
