from __future__ import annotations

import os
import random
from functools import partial
from pathlib import Path
from typing import *

import datasets
import numpy as np
import torch
import wandb
from accelerate import Accelerator
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    set_seed,
)
from trl import ModelConfig
from trl.trainer.online_dpo_config import OnlineDPOConfig

from utils.argument import H4ArgumentParser, ScriptArguments
from utils.data import (
    DPODataCollatorWithPaddingAndOracle,
    PreferenceCollatorWithOracle,
)
from utils.trainer import OnlineSoftDPOTrainer, compute_p_oracle


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # Set for all GPUs
    set_seed(seed)  # Hugging Face's Trainer consistency

    # Ensure deterministic behavior across multi-GPU environments
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def compute_p_oracle(
    ds: datasets.Dataset,
    label_type: Literal["oracle", "binary", "3level", "5level", "conditioned"],
    margin_scale: float = 1.0,
    soft_threshold: float = 0.1,
    num_proc: int = 16,
):
    def assign_p_oracle(example):
        chosen_score = torch.tensor(example["chosen_score"] * margin_scale)
        rejected_score = torch.tensor(example["rejected_score"] * margin_scale)

        p_oracle = torch.nn.functional.sigmoid(chosen_score - rejected_score)

        if label_type == "oracle":
            p_sampled = p_oracle.item()
        elif label_type == "conditioned":
            diff = chosen_score - rejected_score
            mask = diff > soft_threshold
            p_sampled = torch.where(
                mask, torch.tensor(1.0, device=chosen_score.device), p_oracle
            )
        elif label_type == "binary":
            p_sampled = torch.bernoulli(p_oracle).item()
        elif label_type == "3level":
            if p_oracle.item() < 0.5:
                p_sampled = torch.bernoulli(p_oracle * 2)
                p_sampled = p_sampled.item() * 0.5
            elif p_oracle.item() > 0.5:
                p_sampled = torch.bernoulli(p_oracle * 2 - 1)
                p_sampled = (p_sampled.item() + 1) * 0.5
            else:
                p_sampled = 0.5
        elif label_type == "5level":
            p_intervals = {
                "significantly_worse": (0, 0.2),
                "slightly_worse": (0.2, 0.5),
                "slightly_better": (0.5, 0.8),
                "significantly_better": (0.8, 1),
            }
            for k, v in p_intervals.items():
                if v[0] <= p_oracle.item() < v[1]:
                    p_case = k
                    break
            lower, upper = p_intervals[p_case]
            p_sampled = (
                torch.bernoulli((p_oracle - lower) / (upper - lower))
                * (upper - lower)
                + lower
            ).item()
        return {"p_oracle": p_sampled}

    ds = ds.map(assign_p_oracle, num_proc=num_proc)
    return ds


def main():
    # ! note that H4ArgumentParser is at early stage of development
    # ! later we might need to use other parsers.
    parser = H4ArgumentParser((ScriptArguments, OnlineDPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse()

    set_random_seed(script_args.manual_seed)

    # 1. Load the policy model and reference model
    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    model.config.use_cache = False

    if script_args.ref_model:
        ref_name = script_args.ref_model
    else:
        ref_name = model_config.model_name_or_path
    ref_model = AutoModelForCausalLM.from_pretrained(
        ref_name,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
    )
    if script_args.use_eos_padding:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        model.config.vocab_size += 1
        ref_model.config.vocab_size += 1
        model.config.pad_token_id = tokenizer.pad_token_id
        ref_model.config.pad_token_id = tokenizer.pad_token_id
        model.resize_token_embeddings(len(tokenizer))
        ref_model.resize_token_embeddings(len(tokenizer))

    # ! we have to manually handle the reward model
    rm = AutoModelForSequenceClassification.from_pretrained(
        training_args.reward_model_path,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    rm_tokenizer = AutoTokenizer.from_pretrained(
        training_args.reward_model_path
    )
    if not rm_tokenizer.pad_token:
        rm_tokenizer.pad_token = rm_tokenizer.eos_token
        rm.config.pad_token_id = rm_tokenizer.pad_token_id

    # 2. Prepare the dataset under `rm_annotated_data`,
    # i.e., containing chosen/rejected and corresponding scores
    try:
        train_ds = datasets.load_from_disk(script_args.train_path)
    except Exception:
        train_ds = datasets.load_dataset(script_args.train_path, split="train")
    if script_args.split_eval_from_train:
        eval_ds = train_ds.select(
            range(
                min(script_args.split_eval_from_train, int(0.5 * len(train_ds)))
            )
        )
        train_ds = train_ds.select(range(len(eval_ds), len(train_ds)))
    if script_args.max_training_samples is not None:
        train_ds = train_ds.select(
            range(min(script_args.max_training_samples, len(train_ds)))
        )
    if script_args.eval_path is not None:
        if script_args.split_eval_from_train:
            print(
                "[WARNING] `eval_path` will be ignored since split_eval_from_train is set."
            )
        else:
            eval_ds = datasets.load_from_disk(script_args.eval_path)
            if script_args.label_type is not None:
                eval_ds = compute_p_oracle(
                    eval_ds, script_args.label_type, script_args.margin_scale
                )

    train_ds = train_ds.shuffle(seed=script_args.manual_seed)
    if script_args.eval_path is not None or script_args.split_eval_from_train:
        eval_ds = eval_ds.shuffle(seed=script_args.manual_seed)
    else:
        print("[WARNING] No evaluation dataset is provided.")

    # ! we have to handle prompt for old version TRL.
    PROMPT_KEYS = ["context_messages", "chosen"]
    if "prompt" not in train_ds.column_names:
        for _ in PROMPT_KEYS:
            if _ in train_ds.column_names:
                key = _
                break
        train_ds = train_ds.map(
            lambda x: {
                "prompt": tokenizer.apply_chat_template(
                    [x[key][0]], tokenize=False, add_generation_prompt=True
                )
            }
        )
    if "prompt" not in eval_ds.column_names:
        for _ in PROMPT_KEYS:
            if _ in train_ds.column_names:
                key = _
                break
        eval_ds = eval_ds.map(
            lambda x: {
                "prompt": tokenizer.apply_chat_template(
                    [x[key][0]], tokenize=False, add_generation_prompt=True
                )
            }
        )

    # ! chosen / rejected should be string
    if "chosen" in train_ds.column_names and isinstance(
        train_ds["chosen"][0], list
    ):
        train_ds = train_ds.map(
            lambda x: {
                "chosen": x["chosen"][1]["content"],
                "rejected": x["rejected"][1]["content"],
            }
        )
        eval_ds = eval_ds.map(
            lambda x: {
                "chosen": x["chosen"][1]["content"],
                "rejected": x["rejected"][1]["content"],
            }
        )

    # 3. Set up training arguments
    print(script_args)
    print(training_args)

    # ! we fix the usage of OnlineDPOTrainer for now.
    trainer_cls = OnlineSoftDPOTrainer

    if script_args.output_model_name is not None:
        output_name = script_args.output_model_name

    else:
        output_name = f"{model_config.model_name_or_path}_lr{training_args.learning_rate}_localbs{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}_trainer{script_args.trainer_type}_label{script_args.label_type}_margin{script_args.margin_scale}_loss{training_args.loss_type}_seed{script_args.manual_seed}"

    if training_args.output_dir is not None:
        if not Path(training_args.output_dir).exists():
            Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

    training_args.output_dir = os.path.join(
        training_args.output_dir, output_name
    )
    print(f"Output directory: {training_args.output_dir}")

    training_args.remove_unused_columns = False
    training_args.run_name = output_name

    trainer = trainer_cls(
        model=model,
        ref_model=ref_model,
        reward_model=rm,
        reward_processing_class=rm_tokenizer,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=DPODataCollatorWithPaddingAndOracle(
            pad_token_id=tokenizer.pad_token_id,
        ),
        compute_soft_label=partial(
            compute_p_oracle,
            label_type=script_args.label_type,
            margin_scale=script_args.margin_scale,
            soft_threshold=script_args.soft_threshold,
        ),
    )

    accelerator = Accelerator()

    if accelerator.is_main_process:
        wandb.init(
            project="SoftDPO",
            name=output_name,
        )

    # 4. Kick off the training
    trainer.train()
    trainer.save_model(training_args.output_dir)
    trainer.model.save_pretrained(
        str(Path(training_args.output_dir) / "final_checkpoint")
    )  # ! strange, need to check the saving objects


if __name__ == "__main__":
    main()
