from __future__ import annotations

import os
import random
from functools import partial
from pathlib import Path
from typing import *

import datasets
import numpy as np
import torch
import wandb
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl import (
    DataCollatorForCompletionOnlyLM,
    ModelConfig,
    SFTConfig,
    SFTTrainer,
)

from utils.argument import H4ArgumentParser, ScriptArguments


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # Set for all GPUs
    set_seed(seed)  # Hugging Face's Trainer consistency

    # Ensure deterministic behavior across multi-GPU environments
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


FORMAT_TEMPLATE = (
    "<|user|>\n{user_prompt}\n<|assistant|>\n{assistant_response}{eos_token}"
)


def formatting_prompt_dataset(ds: datasets.Dataset, tokenizer: AutoTokenizer):
    formatted_texts: List[str] = []
    for i in range(len(ds["chosen"])):
        formatted_texts.append(
            tokenizer.apply_chat_template(
                ds["chosen"][i], tokenize=False
            ).replace(tokenizer.bos_token, "")
        )
    return formatted_texts


def main():
    # ! note that H4ArgumentParser is at early stage of development
    # ! later we might need to use other parsers.
    parser = H4ArgumentParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse()

    set_random_seed(script_args.manual_seed)

    # 1. Load the model
    model = AutoModelForCausalLM.from_pretrained(
        model_config.model_name_or_path,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    model.config.use_cache = False

    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
    )
    if script_args.use_eos_padding:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            model.config.vocab_size += 1
            model.config.pad_token_id = tokenizer.pad_token_id
            model.resize_token_embeddings(len(tokenizer))

    # 2. Prepare the dataset under `rm_annotated_data`,
    # i.e., containing chosen/rejected and corresponding scores
    try:
        train_ds = datasets.load_from_disk(script_args.train_path)
    except Exception:
        train_ds = datasets.load_dataset(script_args.train_path, split="train")
    if script_args.split_eval_from_train:
        eval_ds = train_ds.select(
            range(
                min(script_args.split_eval_from_train, int(0.5 * len(train_ds)))
            )
        )
        train_ds = train_ds.select(range(len(eval_ds), len(train_ds)))
    else:
        eval_ds = None

    if script_args.max_training_samples is not None:
        train_ds = train_ds.select(
            range(min(script_args.max_training_samples, len(train_ds)))
        )
    if script_args.eval_path is not None:
        if script_args.split_eval_from_train:
            print(
                "[WARNING] `eval_path` will be ignored since split_eval_from_train is set."
            )
        else:
            try:
                eval_ds = datasets.load_from_disk(script_args.eval_path)
            except Exception:
                eval_ds = datasets.load_dataset(
                    script_args.eval_path, split="validation"
                )

    train_ds = train_ds.shuffle(seed=script_args.manual_seed)
    if eval_ds:
        eval_ds = eval_ds.shuffle(seed=script_args.manual_seed)
    else:
        print("[WARNING] No evaluation dataset is provided.")

    # ! in trl >= 0.12.0, the safest dataset format is only providing `chosen` and `rejected` dialogs in List[Dict] format.
    # ! we do not modify the run_annotation for compatibility with older versions, but handle the conversion here.

    if isinstance(train_ds[0]["chosen"], str):
        # 1. consider the case where the ds is from run_annotation.py
        if "prompt_dialg" in train_ds[0]:
            train_ds = train_ds.map(
                lambda x: {
                    "chosen": x["prompt_dialg"]
                    + [{"role": "assistant", "content": x["chosen"]}],
                    "rejected": x["prompt_dialg"]
                    + [{"role": "assistant", "content": x["rejected"]}],
                },
                remove_columns=["prompt", "prompt_dialg"],
                num_proc=max(8, os.cpu_count() // 8),
            )
        # 2. consider the case where prompt, chosen and rejected are all plain strings
        elif "prompt" in train_ds[0] and isinstance(train_ds[0]["prompt"], str):
            # make sure the prompt is plain text, instead of a chat-template-added prompt, which starts with a bos_token
            assert not train_ds[0]["prompt"].startswith(
                tokenizer.bos_token
            ), "the prompt should not start with a bos_token"
            train_ds = train_ds.map(
                lambda x: {
                    "chosen": [
                        {"role": "user", "content": x["prompt"]},
                        {"role": "assistant", "content": x["chosen"]},
                    ],
                    "rejected": [
                        {"role": "user", "content": x["prompt"]},
                        {"role": "assistant", "content": x["rejected"]},
                    ],
                },
                remove_columns=["prompt"],
                num_proc=max(8, os.cpu_count() // 8),
            )
        if eval_ds:
            if "prompt_dialg" in eval_ds[0]:
                eval_ds = eval_ds.map(
                    lambda x: {
                        "chosen": x["prompt_dialg"]
                        + [{"role": "assistant", "content": x["chosen"]}],
                        "rejected": x["prompt_dialg"]
                        + [{"role": "assistant", "content": x["rejected"]}],
                    },
                    remove_columns=["prompt", "prompt_dialg"],
                    num_proc=max(8, os.cpu_count() // 8),
                )
            elif "prompt" in eval_ds[0] and isinstance(
                eval_ds[0]["prompt"], str
            ):
                eval_ds = eval_ds.map(
                    lambda x: {
                        "chosen": [
                            {"role": "user", "content": x["prompt"]},
                            {"role": "assistant", "content": x["chosen"]},
                        ],
                        "rejected": [
                            {"role": "user", "content": x["prompt"]},
                            {"role": "assistant", "content": x["rejected"]},
                        ],
                    },
                    remove_columns=["prompt"],
                    num_proc=max(8, os.cpu_count() // 8),
                )
    # ! after the above step, we got the ds with only chosen/rejected columns.
    # ! for sft training, we adopt the tulu3 template

    instruction_template = "<|user|>"
    response_template = "<|assistant|>\n"
    collator = DataCollatorForCompletionOnlyLM(
        response_template=response_template,
        instruction_template=instruction_template,
        tokenizer=tokenizer,
    )

    # 3. Set up training arguments
    print(script_args)
    print(training_args)

    if script_args.output_model_name is not None:
        output_name = script_args.output_model_name
    else:
        output_name = f"SFT_{model_config.model_name_or_path}_lr{training_args.learning_rate}_localbs{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}_ds{script_args.train_path.split('/')[-1]}_seed{script_args.manual_seed}"

    if training_args.output_dir is not None:
        if not Path(training_args.output_dir).exists():
            Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

    training_args.output_dir = os.path.join(
        training_args.output_dir, output_name
    )
    print(f"Output directory: {training_args.output_dir}")

    # training_args.remove_unused_columns = False
    training_args.run_name = output_name
    training_args.save_strategy = "no"

    formatter = partial(formatting_prompt_dataset, tokenizer=tokenizer)

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        formatting_func=formatter,
        data_collator=collator,
    )

    accelerator = Accelerator()

    if accelerator.is_main_process:
        wandb.init(
            project="SoftDPO",
            name=output_name,
        )

    # 4. Kick off the training
    trainer.train()
    trainer.save_model(training_args.output_dir)


if __name__ == "__main__":
    main()
