import shutil
from dataclasses import asdict
from typing import cast

import torch
from accelerate import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import (
    ModelConfig,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.trainer.dpo_trainer import DPOConfig, DPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

import wandb
from dr.dataset import get_uf_dpo_dataset
from dr.dpo_trainer import DRDPOTrainer
from dr.utils import DRConfig, get_rlhf_run_name

if __name__ == "__main__":
    parser = HfArgumentParser((DPOConfig, ModelConfig, DRConfig))  # type: ignore
    config, model_config, dr_config = parser.parse_args_into_dataclasses()

    config = cast(DPOConfig, config)
    model_config = cast(ModelConfig, model_config)
    dr_config = cast(DRConfig, dr_config)
    max_length = 1024

    if model_config.model_name_or_path is None:
        raise RuntimeError("No model specified")

    print(dr_config, "\n")

    run_name = get_rlhf_run_name("dpo", None, model_config.model_name_or_path, dr_config, config)
    output_dir = "models/" + run_name
    config.output_dir = output_dir

    print(f"Output directory: '{output_dir}'")
    print(f"wandb run name: '{run_name}'")

    if PartialState().is_main_process:
        wandb_run = wandb.init(project="dr-rlhf", name=run_name, config={**asdict(dr_config)}, allow_val_change=True)
        wandb_run.tags = wandb_run.tags + ("dpo", "dataset-uf")

    # remove output_dir if exists
    # TODO: do we really want to remove without confirmation?
    shutil.rmtree(config.output_dir, ignore_errors=True)

    ################
    # Model & Tokenizer
    ################

    torch_dtype = model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype)  # type: ignore
    quantization_config = get_quantization_config(model_config)
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
        padding_side="left",
        use_fast=True,
        clean_up_tokenization_spaces=True,
        trust_remote_code=model_config.trust_remote_code,
    )
    if tokenizer.pad_token is None:
        print("pad_token is None, replacing pad token with eos token")
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        print("chat_template is None, replacing chat template with SIMPLE_CHAT_TEMPLATE")
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

    model_kwargs = dict(
        revision=model_config.model_revision,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if config.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
        trust_remote_code=model_config.trust_remote_code,
    )
    model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
    peft_config = get_peft_config(model_config)
    if peft_config is None:
        ref_model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
    else:
        ref_model = None

    if model_config.model_name_or_path.startswith("google/gemma-2"):
        print("Attention implementation:", model.config._attn_implementation)

    # if script_args.ignore_bias_buffers:
    #     # torch distributed hack
    #     model._ddp_params_and_buffers_to_ignore = [
    #         name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
    #     ]

    ################
    # Dataset
    ################
    raw_datasets = get_uf_dpo_dataset(
        tokenizer, max_length, config.dataset_num_proc, version=dr_config.dataset_version, subset_to_remove=dr_config.subset_to_remove, remove_columns=True
    )

    train_dataset = raw_datasets["train"]
    eval_dataset = raw_datasets["val"]

    ################
    # Training
    ################
    if dr_config.eps != 0:
        DPOTrainerClass = DRDPOTrainer
        extra_kwargs = {"eps": dr_config.eps, "dist_fn": dr_config.dist_fn}
        print(f"Using DR RLHF, epsilon={dr_config.eps}")
        config.loss_type = "dr_sigmoid"  # type: ignore
    else:
        DPOTrainerClass = DPOTrainer
        extra_kwargs = {}

    trainer = DPOTrainerClass(
        model,
        ref_model,
        args=config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        peft_config=peft_config,  # type: ignore
        **extra_kwargs,
    )
    trainer.train()
    trainer.save_model(config.output_dir)
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    print(metrics)
