# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Copyright 2024 Paulius Sasnauskas. Changes:
# Implement distributionally robust RLHF trainer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import cast

import torch
from accelerate import PartialState
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
)
from trl import (
    ModelConfig,
    RewardConfig,
    RewardTrainer,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

import wandb
from dr.dataset import get_uf_rew_dataset
from dr.reward_trainer import DRRewardTrainer
from dr.utils import DRConfig, get_rew_run_name

tqdm.pandas()


if __name__ == "__main__":
    parser = HfArgumentParser((RewardConfig, ModelConfig, DRConfig))  # type: ignore
    config, model_config, dr_config = parser.parse_args_into_dataclasses()

    config = cast(RewardConfig, config)
    model_config = cast(ModelConfig, model_config)
    dr_config = cast(DRConfig, dr_config)

    if model_config.model_name_or_path is None:
        raise RuntimeError("No model specified")

    if config.max_length is None:
        raise RuntimeError("Max length is not specified")

    print(dr_config, "\n")

    run_name = get_rew_run_name("uf", model_config.model_name_or_path, dr_config, config)
    output_dir = "models/" + run_name
    config.output_dir = output_dir

    print(f"Output directory: '{output_dir}'")
    print(f"wandb run name: '{run_name}'")

    config.gradient_checkpointing_kwargs = dict(use_reentrant=False)

    if PartialState().is_main_process:
        wandb_run = wandb.init(project="dr-rlhf", name=run_name)
        wandb_run.tags = wandb_run.tags + ("rm", "dataset-uf")
        if model_config.model_name_or_path == "meta-llama/Llama-3.2-1B-Instruct":
            wandb_run.tags = wandb_run.tags + ("llama3.2-1b")

    ################
    # Model & Tokenizer
    ################
    torch_dtype = model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype)  # type: ignore
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
        use_cache=False,
        attn_implementation=model_config.attn_implementation,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True, clean_up_tokenization_spaces=True)
    if tokenizer.chat_template is None:
        print("chat_template is None, replacing chat template with SIMPLE_CHAT_TEMPLATE")
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

    model = AutoModelForSequenceClassification.from_pretrained(
        model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, torch_dtype=torch_dtype, **model_kwargs
    )

    if model_config.model_name_or_path.startswith("google/gemma-2") or model_config.model_name_or_path.startswith("meta-llama/Llama-3.2-1B-Instruct"):
        print("Attention implementation:", model.config._attn_implementation)

    if tokenizer.pad_token is None:
        print("pad_token is None, replacing pad token with eos token")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = model.config.eos_token_id

    if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
        warnings.warn(
            "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
            " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
        )

    ################
    # Dataset
    ################
    raw_datasets = get_uf_rew_dataset(
        tokenizer, config.max_length, config.dataset_num_proc, version=dr_config.dataset_version, subset_to_remove=dr_config.subset_to_remove, remove_columns=True
    )

    train_dataset = raw_datasets["train"].remove_columns("margin")  # TODO: parametrize margin
    eval_dataset = raw_datasets["val"].remove_columns("margin")  # TODO: parametrize margin

    ################
    # Training
    ################
    if dr_config.eps != 0:
        RewardTrainerClass = DRRewardTrainer
        extra_kwargs = {"eps": dr_config.eps, "dist_fn": dr_config.dist_fn}
        print(f"Using DR RLHF, epsilon={dr_config.eps}")
    else:
        RewardTrainerClass = RewardTrainer
        extra_kwargs = {}

    if model_config.use_peft and model_config.model_name_or_path.startswith("google/gemma-2"):
        model_config.lora_target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
    elif model_config.use_peft and model_config.model_name_or_path.startswith("meta-llama/Llama-3.2-1B-Instruct"):
        model_config.lora_target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]

    trainer = RewardTrainerClass(
        model=model,
        processing_class=tokenizer,
        args=config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=get_peft_config(model_config),  # type: ignore
        **extra_kwargs,
    )
    trainer.train()
    trainer.save_model(output_dir)
    if config.push_to_hub:
        trainer.push_to_hub()  # type: ignore
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    print(metrics)
