import os

os.environ["HF_DATASETS_OFFLINE"] = "1"

import typing as tp
from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.utils.data
import transformers
from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)]
            for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch


config_type = tp.Optional[tp.Mapping[str, tp.Mapping[str, tp.List[str]]]]

_MODE_FILE = {"train": "train.csv", "validation": "valid.csv", "test": "test.csv"}

_TASK_TO_KEYS = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
    "boolq": ("passage", "question"),
    "cb": ("premise", "hypothesis"),
}


def get_dataloaders(
    dataset_name,
    batch_size,
    max_seq_len,
    tokenizer_name,
    seed,
    task_type,
    task="glue",
    cache_dir=None,
):
    tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
    dataloaders = {}
    datasets = _get_datasets(
        dataset_name, max_seq_len, tokenizer, seed, task, cache_dir
    )
    for k, dataset in datasets.items():
        if "test" not in k:
            dataloaders[k] = torch.utils.data.DataLoader(
                dataset,
                batch_size=batch_size if "train" in k else 256,
                collate_fn=(
                    None
                    if dataset_name != "copa"
                    else DataCollatorForMultipleChoice(
                        tokenizer, max_length=max_seq_len
                    )
                ),
            )
            print(f"loader {k}: {len(dataloaders[k])}")

    if task_type != "regression":

        num_labels = (
            len(datasets["train"].features["label"].names)
            if dataset_name != "record"
            else 2
        )
    else:
        num_labels = 1

    return dataloaders, num_labels


def _get_datasets(name, max_seq_len, tokenizer, seed=42, task="glue", cache_dir=None):
    datasets = load_dataset(
        task,
        name,
        cache_dir=cache_dir,
    )
    for k, v in datasets.items():
        if "test" not in k:
            if name == "copa":
                datasets[k] = _copa_preprocess(
                    v, max_seq_len, tokenizer, seed, task=name
                )
            elif name == "record":
                datasets[k] = _record_preprocess(
                    v, max_seq_len, tokenizer, seed, task=name
                )
            else:
                datasets[k] = _preprocess_dataset(
                    v, max_seq_len, tokenizer, seed, task=name
                )
    return datasets


def _preprocess_dataset(dataset, max_seq_len, tokenizer, seed, task):
    if task == "multirc":
        encoded_dataset = dataset.map(
            lambda item: tokenizer(
                item["paragraph"],
                item["question"] + tokenizer.pad_token + item["answer"],
                truncation="only_first",
                max_length=max_seq_len,
                padding="max_length",
            )
        )
    elif task == "wic":

        def tokenize(item):
            return tokenizer(
                item["word"] + tokenizer.sep_token + item["sentence1"],
                item["sentence2"],
                truncation=False,
                max_length=max_seq_len,
                padding="max_length",
            )

        encoded_dataset = dataset.map(tokenize)
    elif task == "wsc":

        def tokenize(item):
            return tokenizer(
                item["span1_text"] + tokenizer.sep_token + item["span2_text"],
                item["text"],
                max_length=max_seq_len,
                truncation="longest_first",
                padding="max_length",
            )

        encoded_dataset = dataset.map(tokenize)
    else:
        encoded_dataset = dataset.map(
            lambda examples: tokenizer(
                *_get_input_ids(examples, task),
                max_length=max_seq_len,
                truncation="longest_first",
                padding="max_length",
            ),
            batched=True,
        )
    encoded_dataset = encoded_dataset.map(lambda x: {"labels": x["label"]})
    columns = ["input_ids", "attention_mask", "labels"]
    if "token_type_ids" in encoded_dataset.column_names:
        # bert, albert case
        columns += ["token_type_ids"]
    encoded_dataset.set_format(type="torch", columns=columns)
    return encoded_dataset.shuffle(seed=seed)


def _copa_preprocess(dataset, max_seq_len, tokenizer, seed, task) -> Dataset:
    _QUESTION_DICT = {
        "cause": "What was the cause of this?",
        "effect": "What happened as a result?",
    }

    def preprocess_item(item):
        len_question = len(
            tokenizer.encode(
                f" {_QUESTION_DICT[item['question']]}", add_special_tokens=False
            )
        )
        len_choices = max(
            (
                len(tokenizer.encode(item["choice1"], add_special_tokens=False)),
                len(tokenizer.encode(item["choice2"], add_special_tokens=False)),
            )
        )
        encoded_premise = tokenizer.encode(
            item["premise"],
            max_length=max_seq_len - len_question - len_choices,
            truncation=True,
        )
        truncated_premise = tokenizer.decode(encoded_premise)
        return {
            "premise": truncated_premise + f" {_QUESTION_DICT[item['question']]}",
            "choices": [item["choice1"], item["choice2"]],
            "labels": item["label"],
        }

    def tokenize_choices(item):
        premises = [item["premise"]] * 2

        return tokenizer(
            premises, item["choices"], truncation=True, max_length=max_seq_len
        )

    dataset = dataset.map(preprocess_item)
    dataset = dataset.map(tokenize_choices)
    columns = ["input_ids", "attention_mask", "labels"]
    if "token_type_ids" in dataset.column_names:
        columns.append("token_type_ids")
    dataset.set_format("pt", columns=columns)
    return dataset


def _record_preprocess(dataset, max_seq_len, tokenizer, seed, task):
    passages = []
    queries = []
    entities = []
    labels = []
    idxs = []
    for item in dataset:
        for e_id, entity in enumerate(item["entities"]):
            label = 0
            if entity in item["answers"]:
                label = 1
            passages += [item["passage"].replace("@placeholder", entity)]
            queries += [item["query"]]
            entities += [entity]
            labels += [label]
            c_idx = item["idx"]
            c_idx["entity"] = e_id
            idxs += [c_idx]

    record_dataset = Dataset.from_dict(
        {
            "passage": passages,
            "query": queries,
            "entity": entities,
            "label": labels,
            "idx": idxs,
        }
    )
    record_dataset = record_dataset.map(
        lambda item: tokenizer(
            item["passage"],
            item["query"],
            max_length=max_seq_len,
            padding="max_length",
            truncation="only_first",
        ),
        batched=True,
        batch_size=2000,
    )
    encoded_dataset = record_dataset.map(lambda x: {"labels": x["label"]})
    columns = ["input_ids", "attention_mask", "labels"]
    if "token_type_ids" in encoded_dataset.column_names:
        # bert, albert case
        columns += ["token_type_ids"]
    encoded_dataset.set_format(type="torch", columns=columns)
    return encoded_dataset.shuffle(seed=seed)


def _get_input_ids(examples, task):
    keys = _TASK_TO_KEYS[task]
    return tuple([examples[k] for k in keys if k is not None])
