from pathlib import Path

from accelerate import PartialState
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from trl.extras.dataset_formatting import conversations_formatting_function

from datasets import Dataset, DatasetDict, load_dataset, load_from_disk

from .utils import DatasetVersion

DATASET_UF_REW_BASE_DIR = "datasets/llm-blender_Unified-Feedback_full_processed"
DATASET_UF_REW_BASE_DIR_400K = "datasets/llm-blender_Unified-Feedback_400k_processed"
DATASET_UF_REW_BASE_DIR_40K = "datasets/llm-blender_Unified-Feedback_40k_processed"


DATASET_UF_PPO_BASE_DIR = "datasets/llm-blender_Unified-Feedback_ppo_full_processed"
DATASET_UF_PPO_BASE_DIR_20K = "datasets/llm-blender_Unified-Feedback_ppo_20k_processed"
DATASET_UF_PPO_BASE_DIR_5K = "datasets/llm-blender_Unified-Feedback_ppo_5k_processed"

DATASET_UF_DPO_BASE_DIR = "datasets/llm-blender_Unified-Feedback_dpo_full_processed"
DATASET_UF_DPO_BASE_DIR_20K = "datasets/llm-blender_Unified-Feedback_dpo_20k_processed"
DATASET_UF_DPO_BASE_DIR_400K = "datasets/llm-blender_Unified-Feedback_dpo_400k_processed"
DATASET_UF_DPO_BASE_DIR_5K = "datasets/llm-blender_Unified-Feedback_dpo_5k_processed"

DATASET_HHH_BASE_DIR = "datasets/HuggingFaceH4_hhh_alignment_processed"


def local_dataset_exists(dataset_dir: str) -> bool:
    return (Path(dataset_dir) / "dataset_dict.json").is_file()


def binarize_with_margin(example):
    chosen, score_chosen = example["conv_A"], example["conv_A_rating"]
    rejected, score_rejected = example["conv_B"], example["conv_B_rating"]
    if score_rejected > score_chosen:
        chosen, rejected = rejected, chosen
        score_chosen, score_rejected = score_rejected, score_chosen
    margin = abs(example["conv_A_rating"] - example["conv_B_rating"])

    if "summarize" in example["source"]:  # Yang et al. (2024)
        chosen[0]["content"] = "Generate one-sentence summary for the following post: " + chosen[0]["content"].strip()
        rejected[0]["content"] = "Generate one-sentence summary for the following post: " + rejected[0]["content"].strip()

    return {"chosen": chosen, "rejected": rejected, "score_chosen": score_chosen, "score_rejected": score_rejected, "margin": margin}


def elicit_prompt(example):
    chosen, rejected = example["chosen"], example["rejected"]

    prompt = chosen[-1:]
    chosen = chosen[:-1]
    rejected = rejected[:-1]

    return {"chosen": chosen, "rejected": rejected, "prompt": prompt}


def get_uf_rew_dataset(
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    max_length: int,
    dataset_num_proc: int | None = None,
    version: DatasetVersion = "full",
    subset_to_remove: str = "",
    remove_columns: bool = True,
) -> DatasetDict:
    tokenizer_name = tokenizer.name_or_path.replace("/", "_")
    dataset_dir = ""
    if version == "400k":
        dataset_dir += DATASET_UF_REW_BASE_DIR_400K
    elif version == "40k":
        dataset_dir += DATASET_UF_REW_BASE_DIR_40K
    elif version == "full":
        dataset_dir += DATASET_UF_REW_BASE_DIR
    else:
        raise NotImplementedError(f"Dataset version '{version}' not supported")

    dataset_dir += f"_{tokenizer_name}"
    dataset_dir += f"_no-{subset_to_remove}" if subset_to_remove != "" else ""
    dataset_dir += "_extr" if remove_columns else ""

    if local_dataset_exists(dataset_dir):
        print(f"Loading dataset '{dataset_dir}'.")
        return load_from_disk(dataset_dir)  # type: ignore

    dataset: DatasetDict | Dataset = load_dataset("llm-blender/Unified-Feedback", "all")  # type: ignore

    def preprocess_function(examples):
        new_examples = {
            "input_ids_chosen": [],
            "attention_mask_chosen": [],
            "input_ids_rejected": [],
            "attention_mask_rejected": [],
            "margin": [],
        }
        for chosen, rejected, margin in zip(examples["chosen"], examples["rejected"], examples["margin"]):
            tokenized_chosen = tokenizer(chosen)
            tokenized_rejected = tokenizer(rejected)
            new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
            new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
            new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
            new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
            new_examples["margin"].append(margin)

        return new_examples

    # Preprocess the dataset and filter out examples that are longer than args.max_length
    # Compute that only on the main process for faster data processing.
    # see: https://github.com/huggingface/trl/pull/1255
    with PartialState().local_main_process_first():
        dataset = dataset.filter(lambda example: example["conv_A_rating"] != example["conv_B_rating"], batched=False, num_proc=dataset_num_proc)

        if subset_to_remove != "":
            print(f"Filtering out {subset_to_remove}...")
            dataset = dataset.filter(lambda example: not example["source"].endswith(subset_to_remove), batched=False, num_proc=dataset_num_proc)

        if version == "400k":
            selector = range(0, len(dataset["train"]), 2)  # Yang et al. (2024)
        elif version == "40k":
            selector = range(0, len(dataset["train"]), 20)  # Yang et al. (2024)
        else:
            selector = range(0, len(dataset["train"]))
        dataset = DatasetDict({"train": dataset["train"].select(selector), "val": dataset["val"]})  # type: ignore

        dataset = dataset.map(binarize_with_margin, batched=False, num_proc=dataset_num_proc)

        # Wrap inputs with chat template.
        # This assumes the chosen/rejected columns are in the OpenAI messages format.
        chosen_fn = conversations_formatting_function(tokenizer, "chosen")  # type: ignore
        rejected_fn = conversations_formatting_function(tokenizer, "rejected")  # type: ignore
        # OURS:
        dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x), "margin": x["margin"]}, num_proc=dataset_num_proc)
        dataset = dataset.map(preprocess_function, batched=True, num_proc=dataset_num_proc)

        # # THEIRS:
        # def format_func2(example):
        #     chosen, rejected, margin = example["chosen"], example["rejected"], example["margin"]
        #     kwargs = {"return_tensors": "pt"}
        #     prompt_plus_chosen_response = tokenizer.apply_chat_template(chosen, tokenize=False)
        #     prompt_plus_rejected_response = tokenizer.apply_chat_template(rejected, tokenize=False)
        #     tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
        #     tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)
        #     return {
        #         "input_ids_chosen": tokens_chosen["input_ids"][0],
        #         "attention_mask_chosen": tokens_chosen["attention_mask"][0],
        #         "input_ids_rejected": tokens_rejected["input_ids"][0],
        #         "attention_mask_rejected": tokens_rejected["attention_mask"][0],
        #         "margin": margin,
        #     }

        # dataset = dataset.map(format_func2, num_proc=dataset_num_proc)

        dataset = dataset.filter(lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length, num_proc=dataset_num_proc)

        if remove_columns:
            dataset = dataset.remove_columns(
                ["id", "conv_A", "conv_B", "conv_A_rating", "conv_B_rating", "num_turns", "source", "chosen", "rejected", "score_chosen", "score_rejected"]
            )

        dataset.set_format(type="torch")
        dataset.save_to_disk(dataset_dir)
        print(f"Saved dataset to '{dataset_dir}'.")
        return dataset


def get_uf_ppo_dataset(
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    max_length: int,
    dataset_num_proc: int | None = None,
    version: DatasetVersion = "full",
    remove_columns: bool = True,
) -> DatasetDict:
    tokenizer_name = tokenizer.name_or_path.replace("/", "_")
    dataset_dir = ""

    if version == "5k":
        dataset_dir += DATASET_UF_PPO_BASE_DIR_5K
    elif version == "20k":
        dataset_dir += DATASET_UF_PPO_BASE_DIR_20K
    elif version == "full":
        dataset_dir += DATASET_UF_PPO_BASE_DIR
    else:
        raise NotImplementedError(f"Dataset version '{version}' not supported")

    dataset_dir += f"_{tokenizer_name}"
    dataset_dir += "_extr" if remove_columns else ""

    if local_dataset_exists(dataset_dir):
        print(f"Loading dataset '{dataset_dir}'.")
        return load_from_disk(dataset_dir)  # type: ignore

    dataset: DatasetDict = load_dataset("llm-blender/Unified-Feedback", "all")  # type: ignore
    dataset = dataset.shuffle(seed=42)  # following Yang et al. (2024)

    kwargs = {"padding": "max_length", "truncation": True, "max_length": max_length, "return_tensors": "pt"}

    def preprocess_function(examples):
        new_examples = {
            "margin": [],
            "input_ids": [],
            "attention_mask": [],
        }
        for chosen, margin in zip(examples["chosen"], examples["margin"]):
            # assert chosen[:-1] == rejected[:-1]  # should have same prompt for both
            prompt_messages = chosen[:-1]
            prompt_formatted: str = tokenizer.apply_chat_template(prompt_messages, tokenize=False)  # type: ignore
            tokenized_prompt_only = tokenizer(prompt_formatted, **kwargs)
            new_examples["input_ids"].append(tokenized_prompt_only["input_ids"].squeeze(0))  # type: ignore
            new_examples["attention_mask"].append(tokenized_prompt_only["attention_mask"].squeeze(0))  # type: ignore
            new_examples["margin"].append(margin)

        return new_examples

    with PartialState().local_main_process_first():
        if version == "5k":
            dataset = DatasetDict(
                {
                    "train": dataset["train"].select(range(5000)),
                    "val": dataset["train"].select(range(20000, 21000)),  # following Yang et al. (2024)
                }
            )
        elif version == "20k":
            dataset = DatasetDict(
                {
                    "train": dataset["train"].select(range(20000)),
                    "val": dataset["train"].select(range(20000, 21000)),  # odd choice, but following Yang et al. (2024)
                }
            )

        # binarize_with_margin adds the summarize prompt
        dataset = dataset.map(binarize_with_margin, batched=False, num_proc=dataset_num_proc)
        dataset = dataset.map(preprocess_function, batched=True, num_proc=dataset_num_proc, remove_columns=dataset.column_names["train"] if remove_columns else None)
        dataset = dataset.filter(lambda x: len(x["input_ids"]) <= max_length, num_proc=dataset_num_proc)
        dataset.set_format(type="torch")
        dataset.save_to_disk(dataset_dir)
        print(f"Saved dataset to '{dataset_dir}'.")
        return dataset


def get_uf_dpo_dataset(
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    max_length: int,
    dataset_num_proc: int | None = None,
    version: DatasetVersion = "full",
    subset_to_remove: str = "",
    remove_columns: bool = True,
) -> DatasetDict:
    remove_columns = True  # force for now

    tokenizer_name = tokenizer.name_or_path.replace("/", "_")
    dataset_dir = ""
    if version == "5k":
        dataset_dir += DATASET_UF_DPO_BASE_DIR_5K
    elif version == "20k":
        dataset_dir += DATASET_UF_DPO_BASE_DIR_20K
    elif version == "400k":
        dataset_dir += DATASET_UF_DPO_BASE_DIR_400K
    elif version == "full":
        dataset_dir += DATASET_UF_DPO_BASE_DIR
    else:
        raise NotImplementedError(f"Dataset version '{version}' not supported")

    dataset_dir += f"_{tokenizer_name}"
    dataset_dir += f"_no-{subset_to_remove}" if subset_to_remove != "" else ""
    dataset_dir += "_extr" if remove_columns else ""

    if local_dataset_exists(dataset_dir):
        print(f"Loading dataset '{dataset_dir}'.")
        return load_from_disk(dataset_dir)  # type: ignore

    dataset: DatasetDict = load_dataset("llm-blender/Unified-Feedback", "all")  # type: ignore
    dataset = dataset.shuffle(seed=42)  # following Yang et al. (2024)

    # kwargs = {"padding": "max_length", "truncation": True, "max_length": max_length, "return_tensors": "pt"}

    # def preprocess_function(examples):
    #     new_examples = {"prompt": [], "chosen": [], "rejected": []}

    #     for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
    #         prompt_messages = chosen[:-1]
    #         chosen_trunc = chosen[-1]
    #         rejected_trunc = rejected[-1]

    #         new_examples["prompt"].append(prompt_messages)
    #         new_examples["chosen"].append(chosen_trunc)
    #         new_examples["rejected"].append(rejected_trunc)

    #     return new_examples

    with PartialState().local_main_process_first():
        dataset = dataset.filter(lambda example: example["conv_A_rating"] != example["conv_B_rating"], batched=False, num_proc=dataset_num_proc)

        if subset_to_remove != "":
            print(f"Filtering out {subset_to_remove}...")
            dataset = dataset.filter(lambda example: not example["source"].endswith(subset_to_remove), batched=False, num_proc=dataset_num_proc)

        if version == "5k":
            selector_train = range(5000)
            selector_val = range(20000, 21000)  # following Yang et al. (2024)
        elif version == "20k":
            selector_train = range(20000)
            selector_val = range(20000, 21000)  # following Yang et al. (2024)
        elif version == "400k":
            selector_train = range(0, len(dataset["train"]), 2)
            selector_val = range(0, len(dataset["val"]))
        else:
            selector_train = range(0, len(dataset["train"]))
            selector_val = range(0, len(dataset["val"]))

        if version in ["5k", "20k"]:
            dataset = DatasetDict({"train": dataset["train"].select(selector_train), "val": dataset["train"].select(selector_val)})  # type: ignore  # following Yang et al. (2024)
        else:
            dataset = DatasetDict({"train": dataset["train"].select(selector_train), "val": dataset["val"].select(selector_val)})  # type: ignore

        # binarize_with_margin adds the summarize prompt
        dataset = dataset.map(binarize_with_margin, batched=False, num_proc=dataset_num_proc)

        # chosen_fn = conversations_formatting_function(tokenizer, "chosen")  # type: ignore
        # rejected_fn = conversations_formatting_function(tokenizer, "rejected")  # type: ignore

        # dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=dataset_num_proc)
        # dataset = dataset.map(preprocess_function, batched=True, num_proc=dataset_num_proc)

        if remove_columns:
            dataset = dataset.remove_columns(["id", "conv_A", "conv_B", "conv_A_rating", "conv_B_rating", "num_turns", "source", "margin"])

        dataset.set_format(type="torch")
        dataset.save_to_disk(dataset_dir)
        print(f"Saved dataset to '{dataset_dir}'.")
        return dataset


def get_hhh_dataset(tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, dataset_num_proc: int | None = None, remove_columns: bool = True) -> DatasetDict:
    dataset_dir = DATASET_HHH_BASE_DIR
    if local_dataset_exists(dataset_dir):
        print(f"Loading dataset '{dataset_dir}'.")
        return load_from_disk(dataset_dir)  # type: ignore

    raw_datasets = {}
    for subset in ["harmless", "helpful", "honest", "other"]:
        raw_datasets[subset] = load_dataset("HuggingFaceH4/hhh_alignment", subset)["test"]  # type: ignore
    dataset = DatasetDict(raw_datasets)

    def to_conversation_format(examples):
        new_examples = {"chosen": [], "rejected": []}

        for input, targets_data in zip(examples["input"], examples["targets"]):
            assert len(targets_data["choices"]) == 2
            chosen = targets_data["choices"][0]
            rejected = targets_data["choices"][1]
            if targets_data["labels"][0] == 0 and targets_data["labels"][1] == 1:
                chosen, rejected = rejected, chosen
            new_examples["chosen"].append([{"content": input, "role": "user"}, {"content": chosen, "role": "assistant"}])
            new_examples["rejected"].append([{"content": input, "role": "user"}, {"content": rejected, "role": "assistant"}])

        return new_examples

    dataset = dataset.map(to_conversation_format, num_proc=dataset_num_proc, batched=True)

    chosen_fn = conversations_formatting_function(tokenizer, "chosen")  # type: ignore
    rejected_fn = conversations_formatting_function(tokenizer, "rejected")  # type: ignore
    dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=dataset_num_proc)

    dataset = dataset.map(lambda x: {"chosen": tokenizer(x["chosen"]), "rejected": tokenizer(x["rejected"])}, num_proc=dataset_num_proc)

    if remove_columns:
        dataset = dataset.remove_columns(["input", "targets"])

    dataset.set_format(type="torch")
    dataset.save_to_disk(dataset_dir)
    print(f"Saved dataset to '{dataset_dir}'.")

    return dataset
