
import os

import fire
import torch
from datasets import load_from_disk
from peft import LoraConfig, TaskType
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)
from trl import DPOConfig, DPOTrainer


def main(
    preference_path: str = "./nicks_dpo/preference-FastDetectGPT",
    model_name: str = "mistralai/Mistral-7B-Instruct-v0.3",
):

    dataset = load_from_disk(preference_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side="left",
    )
    tokenizer.pad_token = tokenizer.eos_token
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=32,
        lora_alpha=64,
        lora_dropout=0.1,
        target_modules="all-linear",
    )
    model.add_adapter(lora_config, adapter_name="base")
    model.config.use_cache = False # weirdly will trigger tokenization padding side issue if not...

    # by default:
    # - 3 train epochs
    # - beta = 0.1
    suffix = "-".join(os.path.basename(preference_path).split("-")[1:])
    training_args = DPOConfig(
        output_dir="./nicks_dpo/outputs-{}".format(suffix),
        logging_steps=100,
        save_steps=100,
        per_device_train_batch_size=1,
        max_length=128*3,
        ddp_find_unused_parameters=False,
        save_total_limit=1,
    )
    trainer = DPOTrainer(
        model=model,
        processing_class=tokenizer,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()

    return 0

if __name__ == "__main__":
    set_seed(43)
    fire.Fire(main)
