import re
from functools import partial
import numpy as np
from collator import MaskedDataCollatorForLM

INSTRUCTION_TEMPLATE = "### Human:\n"
RESPONSE_TEMPLATE = "### Response:\n"

dict_of_lists_to_list_of_dicts = lambda d: [
    dict(zip(d.keys(), vals)) for vals in zip(*d.values())
]
list_of_dicts_to_dict_of_lists = lambda l: {k: [d[k] for d in l] for k in l[0]}
flatten_nested_list = lambda nested_list: (
    [item for sublist in nested_list for item in sublist],
    [len(sublist) for sublist in nested_list],
)
pack_nested_list = lambda flattened_list, lengths: [
    list(sublist) for sublist in np.split(flattened_list, np.cumsum(lengths)[:-1])
]


def sciq_formatter_func(example):
    support = example["support"].lstrip()
    return {
        "question": f"{support}\n{example['question']}",
        "choices": [
            example["distractor1"],
            example["distractor2"],
            example["distractor3"],
            example["correct_answer"],
        ],
        "target": 3,
    }


def hellaswag_formatter_func(example):
    preprocess = lambda text: re.sub(
        "\\[.*?\\]", "", text.replace(" [title]", ". ").strip()
    ).replace("  ", " ")
    ctx = example["ctx_a"] + " " + example["ctx_b"].capitalize()
    return {
        "question": preprocess(example["activity_label"] + ": " + ctx),
        "choices": [preprocess(ending) for ending in example["endings"]],
        "target": int(example["label"]),
    }


def winogrande_formatter_func(example):
    idx = example["sentence"].index("_")
    return {
        "question": example["sentence"][:idx],
        "choices": [
            example["option1"] + example["sentence"][idx + 1 :],
            example["option2"] + example["sentence"][idx + 1 :],
        ],
        "target": {"1": 0, "2": 1}[example["answer"]],
    }


def arc_formatter_func(example):
    return {
        "question": f"{example['question']}",
        "choices": example["choices"]["text"],
        "target": example["choices"]["label"].index(example["answerKey"]),
    }


def quartz_formatter_func(example):
    return {
        "question": f"{example['para']}\n{example['question']}",
        "choices": example["choices"]["text"],
        "target": example["choices"]["label"].index(example["answerKey"]),
    }


def quail_formatter_func(example):
    return {
        "question": f"{example['context']}\n{example['question']}",
        "choices": example["answers"],
        "target": example["correct_answer_id"],
    }


def format_dataset(dataset, dataset_name, num_proc):
    formatter_func = {
        "sciq": sciq_formatter_func,
        "hellaswag": hellaswag_formatter_func,
        "winogrande": winogrande_formatter_func,
        "arc": arc_formatter_func,
        "quartz": quartz_formatter_func,
        "quail": quail_formatter_func,
    }[dataset_name]
    dataset = dataset.map(
        formatter_func,
        num_proc=num_proc,
        # Otherwise will be missused by Trainers
        remove_columns="label" if dataset_name == "hellaswag" else None,
        desc="Formatting dataset",
    )
    dataset = dataset.map(
        lambda example: {
            "prompt": INSTRUCTION_TEMPLATE + example["question"],
            "response": RESPONSE_TEMPLATE + example["choices"][example["target"]],
            "text": INSTRUCTION_TEMPLATE
            + example["question"]
            + "\n"
            + RESPONSE_TEMPLATE
            + example["choices"][example["target"]],
            "response_list": [
                RESPONSE_TEMPLATE + choice for choice in example["choices"]
            ],
            "text_list": [
                INSTRUCTION_TEMPLATE
                + example["question"]
                + "\n"
                + RESPONSE_TEMPLATE
                + choice
                for choice in example["choices"]
            ],
        },
        num_proc=num_proc,
        desc="Formatting dataset",
    )
    return dataset


def tokenize_func(examples, tokenizer, is_weight_by_token, is_completion_only):
    text_flattened_list, lengths = flatten_nested_list(examples["text_list"])
    text_examples = dict_of_lists_to_list_of_dicts(tokenizer(text_flattened_list))
    collator = MaskedDataCollatorForLM(
        tokenizer,
        is_weight_by_token,
        is_completion_only,
    )
    batch = collator.super_touch_call(text_examples)
    start_indices, end_indices = collator.get_indices(batch)
    batch = dict_of_lists_to_list_of_dicts(batch)
    start_indices_nested_list, end_indices_nested_list = pack_nested_list(
        start_indices, lengths
    ), pack_nested_list(end_indices, lengths)
    batch_nested_list = pack_nested_list(batch, lengths)
    updates = [
        {
            "input_ids": batch_samples[target]["input_ids"],
            "attention_mask": batch_samples[target]["attention_mask"],
            "start_index": start_indices[target],
            "end_index": end_indices[target],
            "sft_ids": batch_samples[target]["labels"][start_indices[target] :],
            "sft_labels": batch_samples[target]["labels"][
                (end_indices[target] if is_completion_only else start_indices[target]) :
            ],
            "sft_options": batch_samples[target]["labels"][end_indices[target] :],
            "input_ids_list": [bs["input_ids"] for bs in batch_samples],
            "attention_mask_list": [bs["attention_mask"] for bs in batch_samples],
            "start_index_list": start_indices,
            "end_index_list": end_indices,
            "sft_ids_list": [
                bs["labels"][si:] for bs, si in zip(batch_samples, start_indices)
            ],
            "sft_labels_list": [
                bs["labels"][(ei if is_completion_only else si) :]
                for bs, si, ei in zip(batch_samples, start_indices, end_indices)
            ],
            "sft_options_list": [
                bs["labels"][ei:] for bs, ei in zip(batch_samples, end_indices)
            ],
        }
        for batch_samples, start_indices, end_indices, target in zip(
            batch_nested_list,
            start_indices_nested_list,
            end_indices_nested_list,
            examples["target"],
        )
    ]
    return list_of_dicts_to_dict_of_lists(updates)


def tokenize_dataset(
    dataset, tokenizer, is_weight_by_token, is_completion_only, batch_size
):
    dataset = dataset.map(
        partial(
            tokenize_func,
            tokenizer=tokenizer,
            is_weight_by_token=is_weight_by_token,
            is_completion_only=is_completion_only,
        ),
        batched=True,
        batch_size=batch_size,
        # Important, do not change
        num_proc=1,
        desc="Tokenizing dataset",
    )
    return dataset


def init_adaboost_dataset(dataset, is_weight_by_token, num_proc):
    if not is_weight_by_token:
        dataset = dataset.add_column(
            "adaboost_weight_0", [1 / len(dataset)] * len(dataset)
        )
    else:
        dataset = dataset.map(
            lambda example: {
                "adaboost_weight_0": [1 / len(dataset) / len(example["sft_labels"])]
                * len(example["sft_labels"]),
            },
            num_proc=num_proc,
            desc="Initializing adaboost weights",
        )
    return dataset


def preprocess_dataset(dataset, tokenizer, sargs, is_train=True):
    dataset = format_dataset(dataset, sargs.dataset_name, sargs.num_proc)
    dataset = tokenize_dataset(
        dataset,
        tokenizer,
        sargs.is_weight_by_token,
        sargs.is_completion_only,
        sargs.pred_batch_size,
    )
    if is_train:
        dataset = init_adaboost_dataset(
            dataset, sargs.is_weight_by_token, sargs.num_proc
        )
    return dataset
