import random
import argparse

import numpy as np
import json
import os
import torch
from config import parse_args
from pytorch_lightning import seed_everything
import copy

from data_helper import SafetyDatasetDecoderOnly
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModel,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
)
from transformers import DataCollatorWithPadding, DefaultDataCollator

from trainers import GA_GD_Trainer, GA_GD_KL_Trainer, GA_GD_GD_Trainer, ORPOTrainer, DPOTrainer, GD_Trainer, AlignSFT_Trainer, NPOTrainer, DPOSFTTrainer

if __name__ == "__main__":

    args = parse_args()

    output_dir = args.savedmodel_path
    train_batch_size = args.batch_size
    gradient_accumulation_steps = args.gradient_accumulation
    learning_rate = args.learning_rate
    eval_batch_size = args.val_batch_size
    eval_steps = args.eval_step
    save_steps = args.save_step
    num_train_epochs = args.max_epochs
    warmup_steps = args.warmup_steps
    ds_config = args.ds_config
    seed_everything(args.seed)

    # setting loss type
    loss_type = args.loss_type

    print(f"using loss type: {args.loss_type}")

    # Set up the datasets
    with open(args.train_path, 'r', encoding='utf8') as f:
        train_data = json.load(f)

    # with open(args.valid_path, 'r', encoding='utf8') as f:
    #     valid_data = json.load(f)

    train_dataset = SafetyDatasetDecoderOnly(args, train_data, loss_type)
    tokenizer = train_dataset.tokenizer

    model = AutoModelForCausalLM.from_pretrained(args.model_dir, use_cache=True, trust_remote_code=True)

    pretrain_model = copy.deepcopy(model)

    # check lora
    if args.lora:
        from peft import LoraConfig, get_peft_model
        lora_config = LoraConfig(
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

    # Create a preprocessing function to extract out the proper logits from the model output
    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    # balanced_data_collator = BalancedDataCollator(tokenizer)

    # Prepare the trainer and start training
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=learning_rate,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        gradient_checkpointing=True,
        half_precision_backend='auto',
        # fp16=True,
        bf16=True,
        adam_beta1=0.9,
        adam_beta2=0.95,
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=num_train_epochs,
        warmup_steps=warmup_steps,
        evaluation_strategy="no",
        eval_accumulation_steps=1,
        # eval_steps=eval_steps,
        save_strategy='epoch',
        # save_strategy='no',
        save_only_model=True,
        # save_steps=save_steps,
        report_to='tensorboard',
        load_best_model_at_end=False,
        logging_steps=1,
        remove_unused_columns=False,
        deepspeed=ds_config,
    )

    if loss_type == 'DPO':
        trainer = DPOTrainer(
            model=model,
            reference_model=pretrain_model,
            args=training_args,
            train_dataset=train_dataset,
            # eval_dataset=dev_dataset,
            # compute_metrics=compute_metrics,
            # data_collator=sft_collator,
            data_collator=default_data_collator,
            # data_collator=balanced_data_collator,
            # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            alpha=args.dpo_alpha,
            beta=args.dpo_beta
        )
    elif loss_type == 'NPO':
        trainer = NPOTrainer(
            model=model,
            reference_model=pretrain_model,
            args=training_args,
            train_dataset=train_dataset,
            # eval_dataset=dev_dataset,
            # compute_metrics=compute_metrics,
            # data_collator=sft_collator,
            data_collator=default_data_collator,
            # data_collator=balanced_data_collator,
            # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            alpha=args.dpo_alpha,
            beta=args.dpo_beta,
            theta=args.theta_GD
        )
    else:
        raise ValueError(f"Invalid loss type: {loss_type}. Valid types are: ['DPO', 'NPO']")
        
    trainer.train(resume_from_checkpoint=False)
    # trainer.save_model(output_dir)
    # save tokenizer
    subdirs = os.listdir(output_dir)
    for subdir in subdirs:
        subpath = os.path.join(output_dir, subdir)
        if os.path.isdir(subpath) and subdir.startswith("checkpoint-"):
            print(f'saving tokenizer to {subpath}')
            tokenizer.save_pretrained(subpath)
