import os
import argparse
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from baselines.FiRST.model import RouterSelectiveLlamaForCausalLM
from peft import PeftModel, PeftConfig
from safetensors.torch import load_file
from model.buddy_model import BuddyForCausalLM
from datetime import datetime
from utils.sensitivity.utils import get_block_pruned_network
from utils.ppl import PPLMetric
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(20250925)


def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument('--name', type=str, default="wo", help='model name')
    parser.add_argument("--num_remove_blocks", type=int, default=None, help="num_remove_blocks")
    parser.add_argument('--peft', type=str, default=None, help='peft path')
    parser.add_argument('--tasks', type=str, default="ptb,wikitext2", help='data name')
    parser.add_argument('--output_path', type=str, default="", help='output_path')
    parser.add_argument('--cutoff_len', type=int, default=1024, help='cutoff length')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size')

    # buddy
    parser.add_argument("--lambda_reg", type=float, default=0.1, help="lambda_reg`")
    parser.add_argument("--sensitivity_path", type=str, default="", help="sensitivity_path")
    parser.add_argument("--sensitivity_type", type=str, default="", help="sensitivity_type")

    # pudding
    parser.add_argument("--layerset_path", type=str, default="", help="layerset_path`")

    # shorten, sleb, short
    parser.add_argument("--block_order_path", type=str, default="", help="block_order_path")

    args = parser.parse_args()
    return args


def load_model(args):
    name = args.name

    # load model
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = AutoModelForCausalLM.from_pretrained(args.base_model)
    model.to(torch.bfloat16)
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    kwargs = {}

    if name in ["shortened", "shortgpt", "sleb"]:
        block_order_path = args.block_order_path
        num_remove_blocks = args.num_remove_blocks
        unimportance_order = pd.read_csv(block_order_path).columns.tolist()
        unimportance_order = [int(i) for i in unimportance_order]

        model = get_block_pruned_network(
            model_orig=model,
            unimportance_order=unimportance_order,
            num_pruned_blocks=num_remove_blocks
        )
        model = model.to(torch.bfloat16)
        model = PeftModel.from_pretrained(model, args.peft)
    elif name == "pudding":
        from baselines.PuDDing.zero_shot_eval import PUDDINGLM
        model = PUDDINGLM(
            pretrained=args.base_model,
            peft=args.peft,
            num_remove_blocks=args.num_remove_blocks,
            layerset_path=args.layerset_path
        )
    elif name == "first":
        num_remove_blocks = args.num_remove_blocks

        model = RouterSelectiveLlamaForCausalLM.from_pretrained(args.base_model)
        router_lambda = 1.0
        target_keep_ratio = 1.0 - num_remove_blocks / model.config.num_hidden_layers
        print(target_keep_ratio)
        model.set_router_objective(router_lambda=router_lambda, target_keep_ratio=target_keep_ratio)

        model = model.to(torch.bfloat16)
        model = PeftModel.from_pretrained(model, args.peft)

        router_path = args.peft + "/router_weights.safetensors"
        router_weights = load_file(router_path)
        state_dict = model.state_dict()
        state_dict.update(router_weights)
        model.load_state_dict(state_dict)
    elif name == "buddy":
        model = BuddyForCausalLM.from_pretrained(args.base_model)
        if args.sensitivity_type != "None":
            from utils.sensitivity_utils import read_pre_sensitivity
            pre_sensitivity = read_pre_sensitivity(
                path=args.sensitivity_path,
                type=args.sensitivity_type
            )
            model.set_sensitivity(pre_sensitivity, args.lambda_reg)

        model = model.to(torch.bfloat16)
        model = PeftModel.from_pretrained(model, args.peft)

        budgets = 1.0 - (args.num_remove_blocks / model.model.config.num_hidden_layers)
        kwargs["budgets"] = budgets

        router_path = args.peft + "/router_weights.safetensors"
        router_weights = load_file(router_path)
        state_dict = model.state_dict()
        state_dict.update(router_weights)
        model.load_state_dict(state_dict)

        model.eval()
        model.to("cuda")

    model = model.to(device)
    model.eval()

    return model, tokenizer, kwargs


def main(args):
    model, tokenizer, kwargs = load_model(args)

    # infer once
    # input_text = [
    #     "Hello, this is a test sequence for FLOPS estimation.",
    #     "Hello, this is a test sequence for FLOPS estimation.",
    # ]
    # inputs = tokenizer(input_text, return_tensors="pt", padding="max_length", max_length=256).to(device)
    # output = model(inputs["input_ids"], budgets=0.5)
    #
    # return

    results = {}

    model.disable_adapter_layers()
    metrics = PPLMetric(
        model=model,
        tokenizer=tokenizer,
        datasets=args.tasks.split(","),
        seq_len=args.cutoff_len,
        device=device,
        batch_size=args.batch_size,
        **kwargs
    )

    results["no_lora"] = metrics

    model.enable_adapter_layers()
    metrics = PPLMetric(
        model=model,
        tokenizer=tokenizer,
        datasets=args.tasks.split(","),
        seq_len=args.cutoff_len,
        device=device,
        batch_size=args.batch_size,
        **kwargs
    )
    results["lora"] = metrics

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    current_time = datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%dT%H-%M-%S.%f")
    file_name = f"metric_{formatted_time}.json"

    with open(args.output_path + file_name, "w+") as file:
        result = json.dumps(results)
        file.write(result)


if __name__ == "__main__":
    args = parse_args()
    main(args)
