import os

import pylab as p

os.environ["TOKENIZERS_PARALLELISM"] = "false"  # 或 "true"

# Load model directly
import argparse
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from transformers.trainer import Trainer
from model.buddy_model import BuddyForCausalLM

from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig

from safetensors.torch import save_file
from accelerate import Accelerator
from utils.dataset import load_instruction_dataset
import wandb

import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

accelerator = Accelerator()
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(20250926)

wandb.init("deputy", mode="disabled")

def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')

    # Model Type&Path
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument('--data_name', type=str, default="openbookqa", help='data name')
    parser.add_argument('--data_path', type=str, default="openbookqa", help='data path')
    parser.add_argument('--output_dir', type=str, default="./tune_log/openbookqa/deputy/", help='output directory')
    parser.add_argument('--lambda_reg', type=float, default=1.0, help='lambda_reg')
    parser.add_argument('--sensitivity_type', type=str, default="ppl", help='sensitivity_type')
    parser.add_argument('--sensitivity_path', type=str, default="utils/sensitivity/output/ppl/all_ppl_unsorted.csv",
                        help='sensitivity_path')

    # Training Hyperparameters
    parser.add_argument('--batch_size', type=int, default=4, help='batch size')
    parser.add_argument('--num_epochs', type=int, default=2, help='number of epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=16, help='gradient accumulation steps')
    parser.add_argument('--cutoff_len', type=int, default=1024, help='cutoff length')
    # Lora Configuration
    parser.add_argument('--lora_r', type=int, default=8, help='lora r')
    parser.add_argument('--lora_alpha', type=int, default=16, help='lora alpha')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='lora dropout')
    parser.add_argument('--lora_target_modules', type=str,
                        default="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj", help='lora target modules')

    # llm hyperparameters
    parser.add_argument('--group_by_length', default=False, action="store_true",
                        help="faster, but produces an odd training loss curve")

    # wandb params
    parser.add_argument('--resume_from_checkpoint', type=str, help="either training checkpoint or final adapter")

    args = parser.parse_args()
    torch_version = int(torch.__version__.split('.')[1])
    args.torch_version = torch_version

    return args

def load_model(args):
    # load pretrain model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = BuddyForCausalLM.from_pretrained(args.base_model)

    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    if args.sensitivity_type != "None":
        from utils.sensitivity_utils import read_pre_sensitivity
        pre_sensitivity = read_pre_sensitivity(
            path=args.sensitivity_path,
            type=args.sensitivity_type
        )
        model.set_sensitivity(pre_sensitivity, args.lambda_reg)

    # Prepare For LoRA
    model = prepare_model_for_kbit_training(model)

    if device == 'cuda':
        model.to(torch.bfloat16)

    config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_r * 2,
        target_modules=args.lora_target_modules.split(","),
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, config)

    for n, p in model.named_parameters():
        if ("lora_A" in n) or ("lora_B" in n) or ("router" in n):
            p.requires_grad = True
        else:
            p.requires_grad = False

    model = model.to(device)
    model.print_trainable_parameters()

    return model, tokenizer

def main(args):
    model, tokenizer = load_model(args)

    # num_hidden_layers = 28
    # budgets = [
    #     1.0 - 16 / num_hidden_layers,
    #     1.0 - 12 / num_hidden_layers,
    #     1.0 - 8 / num_hidden_layers,
    #     1.0 - 4 / num_hidden_layers,
    #     1.0
    # ]
    #
    # input_text = [
    #     "Hello, this is a test sequence for FLOPS estimation.",
    #     "Hello, this is a test sequence for FLOPS estimation.",
    #     "Hello, this is a test sequence for FLOPS estimation.",
    #     "Hello, this is a test sequence for FLOPS estimation.",
    # ]
    # inputs = tokenizer(input_text, return_tensors="pt", padding="max_length", max_length=256).to(device)
    # output = model(inputs["input_ids"], budgets=budgets)
    #
    # return

    train_data, val_data = load_instruction_dataset(
        name=args.data_name,
        path=args.data_path,
        tokenizer=tokenizer,
        max_length=args.cutoff_len
    )

    from transformers.debug_utils import DebugUnderflowOverflow
    debug_overflow = DebugUnderflowOverflow(model)

    trainer = Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=TrainingArguments(
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            # warmup_steps=100,
            warmup_ratio=0.1,
            num_train_epochs=args.num_epochs,
            learning_rate=args.learning_rate,
            bf16=True,
            logging_steps=1,
            logging_first_step=True,
            optim="adamw_torch",
            evaluation_strategy="steps",
            save_strategy="steps",
            eval_steps=200,
            save_steps=400,
            # max_steps=100,
            output_dir=args.output_dir,
            save_total_limit=1,
            load_best_model_at_end=True,
            ddp_find_unused_parameters=None,
            group_by_length=args.group_by_length,
            run_name=args.output_dir.split('/')[-1],
            metric_for_best_model="{}_loss".format(args.data_path),
        ),
        data_collator=DataCollatorForSeq2Seq(
            tokenizer,
            pad_to_multiple_of=8,
            return_tensors="pt",
            padding=True,
            max_length=args.cutoff_len
        )
    )

    model.config.use_cache = False
    trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)

    accelerator.wait_for_everyone()

    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)

    # save router weights
    router_weights = {}
    for name, param in model.named_parameters():
        if "router" in name:
            router_weights[name] = param

    os.makedirs(args.output_dir, exist_ok=True)
    save_file(router_weights, args.output_dir + "router_weights.safetensors")

if __name__ == "__main__":
    args = parse_args()
    main(args)
