# Load model directly
import argparse
import os


import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

from accelerate import Accelerator
from utils.dataset import load_instruction_dataset
import wandb

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(20250131)

wandb.init("deputy", mode="disabled")


def parse_args():
    parser = argparse.ArgumentParser(description='Tuning Pruned LLM')

    # Model Type&Path
    parser.add_argument('--base_model', type=str, default="baffo32/decapoda-research-llama-7B-hf",
                        help='base model name')
    parser.add_argument('--data_name', type=str, default="openbookqa", help='data name')
    parser.add_argument('--data_path', type=str, default="openbookqa", help='data path')
    # parser.add_argument('--data_path', type=str, default="yahma/alpaca-cleaned", help='data path')
    parser.add_argument('--output_dir', type=str, default="./tune_log/lora/", help='output directory')

    # Training Hyperparameters
    parser.add_argument('--batch_size', type=int, default=8, help='batch size')
    parser.add_argument('--num_epochs', type=int, default=2, help='number of epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--cutoff_len', type=int, default=256, help='cutoff length')
    # Lora Configuration
    parser.add_argument('--lora_r', type=int, default=32, help='lora r')
    parser.add_argument('--lora_alpha', type=int, default=32, help='lora alpha')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='lora dropout')
    parser.add_argument('--lora_target_modules', type=str,
                        default="q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj", help='lora target modules')

    # llm hyperparameters
    parser.add_argument('--group_by_length', default=False, action="store_true",
                        help="faster, but produces an odd training loss curve")

    # wandb params
    parser.add_argument('--resume_from_checkpoint', type=str, help="either training checkpoint or final adapter")

    args = parser.parse_args()
    torch_version = int(torch.__version__.split('.')[1])
    args.torch_version = torch_version

    return args


def load_model(args):
    # args.base_model = "/root/zhouchanghai/huggingface/hub/models--baffo32--decapoda-research-llama-7B-hf/snapshots/aa18b48a1330572a6dd5f5d5619ed19838ca285c"
    # load pretrain model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = AutoModelForCausalLM.from_pretrained(args.base_model)

    # from transformers import LlamaConfig, LlamaForCausalLM
    # model = LlamaForCausalLM(config=LlamaConfig(
    #     num_hidden_layers=1
    # ))

    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    # Prepare For LoRA
    model = prepare_model_for_kbit_training(model)

    if device == 'cuda':
        model.half()

    config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.lora_target_modules.split(","),
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, config)
    model = model.to(device)
    model.print_trainable_parameters()
    # print(model)
    return model, tokenizer


def main(args):
    # 初始化加速器
    # accelerator = Accelerator()

    # wandb.init(project="deputy", name="memory_lora")

    model, tokenizer = load_model(args)

    train_data, val_data = load_instruction_dataset(
        name=args.data_name,
        path=args.data_path,
        tokenizer=tokenizer,
        max_length=args.cutoff_len
    )

    trainer = Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=TrainingArguments(
            per_device_train_batch_size=args.batch_size,
            gradient_accumulation_steps=16,
            warmup_steps=100,
            num_train_epochs=args.num_epochs,
            learning_rate=args.learning_rate,
            fp16=True,
            logging_steps=10,
            logging_first_step=True,
            optim="adamw_torch",
            evaluation_strategy="steps",
            save_strategy="steps",
            # eval_steps=200,
            # save_steps=400,
            output_dir=args.output_dir,
            save_total_limit=1,
            # max_steps=100,
            load_best_model_at_end=True,
            ddp_find_unused_parameters=None,
            group_by_length=args.group_by_length,
            run_name=args.output_dir.split('/')[-1],
            metric_for_best_model="{}_loss".format(args.data_path),
        ),
        data_collator=DataCollatorForSeq2Seq(
            tokenizer,
            pad_to_multiple_of=8,
            return_tensors="pt",
            padding=True,
            max_length=args.cutoff_len
        )
    )

    model.config.use_cache = False
    trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)

    model.save_pretrained(args.output_dir)

if __name__ == "__main__":
    args = parse_args()
    main(args)
