from __future__ import annotations

import os
import random
from pathlib import Path
from typing import *

import datasets
import numpy as np
import torch
import wandb
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl import DPOConfig, DPOTrainer, ModelConfig

from utils.argument import H4ArgumentParser, ScriptArguments
from utils.data import (
    DPODataCollatorWithPaddingAndOracle,
    PreferenceCollatorWithOracle,
)
from utils.trainer_no_compatible import SoftDPOTrainer


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # Set for all GPUs
    set_seed(seed)  # Hugging Face's Trainer consistency

    # Ensure deterministic behavior across multi-GPU environments
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def compute_p_oracle(
    ds: datasets.Dataset,
    label_type: Literal["oracle", "binary", "3level", "5level", "conditioned"],
    margin_scale: float = 1.0,
    soft_threshold: float = 0.1,
    num_proc: int = 16,
):
    def assign_p_oracle(example):
        chosen_score = torch.tensor(example["chosen_score"] * margin_scale)
        rejected_score = torch.tensor(example["rejected_score"] * margin_scale)

        p_oracle = torch.nn.functional.sigmoid(chosen_score - rejected_score)

        if label_type == "oracle":
            p_sampled = p_oracle.item()
        elif label_type == "conditioned":
            diff = chosen_score - rejected_score
            mask = diff > soft_threshold
            p_sampled = torch.where(
                mask, torch.tensor(1.0, device=chosen_score.device), p_oracle
            )

            # mask = chosen_score > soft_threshold
            # p_sampled = torch.where(
            #     mask, torch.tensor(1.0, device=chosen_score.device), p_oracle
            # )
        elif label_type == "binary":
            p_sampled = torch.bernoulli(p_oracle).item()
        elif label_type == "3level":
            if p_oracle.item() < 0.5:
                p_sampled = torch.bernoulli(p_oracle * 2)
                p_sampled = p_sampled.item() * 0.5
            elif p_oracle.item() > 0.5:
                p_sampled = torch.bernoulli(p_oracle * 2 - 1)
                p_sampled = (p_sampled.item() + 1) * 0.5
            else:
                p_sampled = 0.5
        elif label_type == "5level":
            p_intervals = {
                "significantly_worse": (0, 0.2),
                "slightly_worse": (0.2, 0.5),
                "slightly_better": (0.5, 0.8),
                "significantly_better": (0.8, 1),
            }
            for k, v in p_intervals.items():
                if v[0] <= p_oracle.item() < v[1]:
                    p_case = k
                    break
            lower, upper = p_intervals[p_case]
            p_sampled = (
                torch.bernoulli((p_oracle - lower) / (upper - lower))
                * (upper - lower)
                + lower
            ).item()
        return {"p_oracle": p_sampled}

    ds = ds.map(assign_p_oracle, num_proc=num_proc)
    return ds


def main():
    # ! note that H4ArgumentParser is at early stage of development
    # ! later we might need to use other parsers.
    parser = H4ArgumentParser((ScriptArguments, DPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse()

    set_random_seed(script_args.manual_seed)

    # 1. Load the policy model and reference model
    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    model.config.use_cache = False

    if script_args.ref_model:
        ref_name = script_args.ref_model
    else:
        ref_name = model_config.model_name_or_path
    ref_model = AutoModelForCausalLM.from_pretrained(
        ref_name,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
    )
    if script_args.use_eos_padding:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            model.config.vocab_size += 1
            ref_model.config.vocab_size += 1
            model.config.pad_token_id = tokenizer.pad_token_id
            ref_model.config.pad_token_id = tokenizer.pad_token_id
            model.resize_token_embeddings(len(tokenizer))
            ref_model.resize_token_embeddings(len(tokenizer))

    # 2. Prepare the dataset under `rm_annotated_data`,
    # i.e., containing chosen/rejected and corresponding scores
    try:
        train_ds = datasets.load_from_disk(script_args.train_path)
    except Exception:
        train_ds = datasets.load_dataset(script_args.train_path, split="train")
    if script_args.label_type is not None:
        train_ds = compute_p_oracle(
            train_ds,
            script_args.label_type,
            script_args.margin_scale,
            script_args.soft_threshold,
        )
    if script_args.split_eval_from_train:
        eval_ds = train_ds.select(
            range(
                min(script_args.split_eval_from_train, int(0.5 * len(train_ds)))
            )
        )
        train_ds = train_ds.select(range(len(eval_ds), len(train_ds)))
    else:
        eval_ds = None

    if script_args.max_training_samples is not None:
        train_ds = train_ds.select(
            range(min(script_args.max_training_samples, len(train_ds)))
        )
    if script_args.eval_path is not None:
        if script_args.split_eval_from_train:
            print(
                "[WARNING] `eval_path` will be ignored since split_eval_from_train is set."
            )
        else:
            try:
                eval_ds = datasets.load_from_disk(script_args.eval_path)
            except Exception:
                eval_ds = datasets.load_dataset(
                    script_args.eval_path, split="validation"
                )
            if script_args.label_type is not None:
                eval_ds = compute_p_oracle(
                    eval_ds, script_args.label_type, script_args.margin_scale
                )

    train_ds = train_ds.shuffle(seed=script_args.manual_seed)
    if eval_ds:
        eval_ds = eval_ds.shuffle(seed=script_args.manual_seed)
    else:
        print("[WARNING] No evaluation dataset is provided.")

    # ! in trl >= 0.12.0, the safest dataset format is only providing `chosen` and `rejected` dialogs in List[Dict] format.
    # ! we do not modify the run_annotation for compatibility with older versions, but handle the conversion here.

    if isinstance(train_ds[0]["chosen"], str):
        # 1. consider the case where the ds is from run_annotation.py
        if "prompt_dialg" in train_ds[0]:
            train_ds = train_ds.map(
                lambda x: {
                    "chosen": x["prompt_dialg"]
                    + [{"role": "assistant", "content": x["chosen"]}],
                    "rejected": x["prompt_dialg"]
                    + [{"role": "assistant", "content": x["rejected"]}],
                },
                remove_columns=["prompt", "prompt_dialg"],
                num_proc=max(8, os.cpu_count() // 8),
            )
        # 2. consider the case where prompt, chosen and rejected are all plain strings
        elif "prompt" in train_ds[0] and isinstance(train_ds[0]["prompt"], str):
            # make sure the prompt is plain text, instead of a chat-template-added prompt, which starts with a bos_token
            assert not train_ds[0]["prompt"].startswith(
                tokenizer.bos_token
            ), "the prompt should not start with a bos_token"
            train_ds = train_ds.map(
                lambda x: {
                    "chosen": [
                        {"role": "user", "content": x["prompt"]},
                        {"role": "assistant", "content": x["chosen"]},
                    ],
                    "rejected": [
                        {"role": "user", "content": x["prompt"]},
                        {"role": "assistant", "content": x["rejected"]},
                    ],
                },
                remove_columns=["prompt"],
                num_proc=max(8, os.cpu_count() // 8),
            )
        if eval_ds:
            if "prompt_dialg" in eval_ds[0]:
                eval_ds = eval_ds.map(
                    lambda x: {
                        "chosen": x["prompt_dialg"]
                        + [{"role": "assistant", "content": x["chosen"]}],
                        "rejected": x["prompt_dialg"]
                        + [{"role": "assistant", "content": x["rejected"]}],
                    },
                    remove_columns=["prompt", "prompt_dialg"],
                    num_proc=max(8, os.cpu_count() // 8),
                )
            elif "prompt" in eval_ds[0] and isinstance(
                eval_ds[0]["prompt"], str
            ):
                eval_ds = eval_ds.map(
                    lambda x: {
                        "chosen": [
                            {"role": "user", "content": x["prompt"]},
                            {"role": "assistant", "content": x["chosen"]},
                        ],
                        "rejected": [
                            {"role": "user", "content": x["prompt"]},
                            {"role": "assistant", "content": x["rejected"]},
                        ],
                    },
                    remove_columns=["prompt"],
                    num_proc=max(8, os.cpu_count() // 8),
                )
    # ! after the above step, we got the ds with only chosen/rejected columns.

    # 3. Set up training arguments
    print(script_args)
    print(training_args)

    # if script_args.trainer_type == "vanilla":  # ! do not use this
    #     trainer_cls = DPOTrainer
    # elif script_args.trainer_type == "soft_dpo":
    #     trainer_cls = SoftDPOTrainer
    # else:
    #     raise ValueError(
    #         f"Unsupported trainer type: {script_args.trainer_type}"
    #     )
    trainer_cls = SoftDPOTrainer

    if script_args.output_model_name is not None:

        output_name = script_args.output_model_name

    else:

        output_name = f"{model_config.model_name_or_path}_lr{training_args.learning_rate}_localbs{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}_trainer{script_args.trainer_type}_label{script_args.label_type}_margin{script_args.margin_scale}_loss{training_args.loss_type}_ds{script_args.train_path.split('/')[-1]}_seed{script_args.manual_seed}"

        assert script_args.n_iter is not None and isinstance(
            script_args.n_iter, int
        ), "n_iter should be manually set for clearer logging"

        output_name += f"_iter{script_args.n_iter}"

    if training_args.output_dir is not None:
        if not Path(training_args.output_dir).exists():
            Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

    training_args.output_dir = os.path.join(
        training_args.output_dir, output_name
    )
    print(f"Output directory: {training_args.output_dir}")

    training_args.remove_unused_columns = False
    training_args.run_name = output_name
    training_args.save_strategy = "no"

    trainer = trainer_cls(
        model=model,
        ref_model=ref_model,
        processing_class=tokenizer,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=PreferenceCollatorWithOracle(
            pad_token_id=tokenizer.pad_token_id
        ),
    )

    accelerator = Accelerator()

    if accelerator.is_main_process:
        wandb.init(
            project="SoftDPO",
            name=output_name,
        )

    # 4. Kick off the training
    trainer.train()
    trainer.save_model(training_args.output_dir)


if __name__ == "__main__":
    main()
