from transformers import HfArgumentParser ,AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model,LoraConfig
import torch
from lorra_finetune.src.args import ModelArguments,LoraArguments
from ext.train_normal import DollyDataProvider
from dataclasses import dataclass, field
import wandb

@dataclass
class OtherArgument:
    data_file_path: str = field(default="./data/hallu/safe_rlhf_safe1w.json", metadata={"help": "data_file to finetune LoRRA"})

    wandb_mode: str = field(
        default="disabled", metadata={"help": "Wandb Switch"}
    )
    # add Wandb parameters
    project_info: str = field(
        default="Rep-defalut", metadata={"help": "Wandb Project Info"}
    )

def main():

    device_map = "cuda:0"
    parser = HfArgumentParser((ModelArguments, TrainingArguments, LoraArguments, OtherArgument))
    model_args,training_args,lora_args,def_arg = parser.parse_args_into_dataclasses()

    wandb.login()
    wandb.init(
        project=f"{def_arg.project_info}",
        config={"model_args": model_args, "training_args": training_args, "lora_args": lora_args,
                "def_args": def_arg},
        name=training_args.run_name,
        mode=def_arg.wandb_mode
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        device_map=device_map,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )


    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        model_max_length=256,
        trust_remote_code=True
    )
    tokenizer.pad_token = tokenizer.unk_token

    lora_config = LoraConfig(
        r=lora_args.lora_r,
        lora_alpha=lora_args.lora_alpha,
        target_modules=lora_args.lora_target_modules,
        lora_dropout=lora_args.lora_dropout,
        bias=lora_args.lora_bias,
        task_type="CAUSAL_LM",
    )
    peft_model = get_peft_model(model,lora_config)
    peft_model.print_trainable_parameters()

    provider = DollyDataProvider("[INST]","[/INST]",tokenizer, data_file=def_arg.data_file_path)
    train_dataset,test_dataset = provider.genrate_dataloader(800,200)

    if training_args.gradient_checkpointing:
        peft_model.enable_input_require_grads()

    lang_trainer = Trainer(
        args=training_args,
        model=peft_model,
        train_dataset=train_dataset,
        eval_dataset=test_dataset
    )
    peft_model.config.use_cache = False

    lang_trainer.train()

    peft_model.save_pretrained(f"{training_args.output_dir}/{training_args.run_name}",safe_serialization=False)

    pass

if __name__=="__main__":
    main()

