import re
from typing import Any, Optional, Union
import datasets
import torch
from torch import Tensor
from transformers import AutoTokenizer

from _utils import pad_seq
from _abstract_task.data import HFDataModule


class GLUEHFDataModule(HFDataModule):
    # Reference:
    # google-research/electra (https://github.com/google-research/electra/blob/master/finetune/classification/classification_tasks.py)
    # microsoft/DeBERTa (https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/apps/tasks/superglue_tasks.py)
    # nyu-mll/jiant (https://github.com/nyu-mll/jiant/tree/235f646bd292e2f5bde444379f487d1d50f035a4/jiant/tasks/lib)

    BENCHMARK: str
    TASK: str
    SEGMENT_NAMES: list[str]
    FEATURE_NAMES: list[str] = ["idx", "input_ids", "non_pad", "segment_ids", "label"]

    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer)
        self.max_sequence_length = self.config.max_sequence_length
        seq_len = f"_{self.config.max_sequence_length}"
        task = "wsc.fixed" if self.TASK == "wsc" else self.TASK
        super().__init__(
            dataset={
                "path": self.BENCHMARK,
                "name": task,
            },
            cache_name_template=f"processed{seq_len}-{{split}}.arrow",
            cache_dir=self.config.datasets_cache_dir,
            using_features=self.FEATURE_NAMES,
        )

    # An overwritten method activated when no cache found
    def _preprocess(self, dataset, split, cache_file_path):
        # We use batched `Dataset.map` on default, becuase many tasks need to
        # skip or add some examples on the fly (i.e. not one-to-one processing)
        dataset.map(
            self._batch_preprocess_fn,
            cache_file_name=cache_file_path,
            batched=True,
            remove_columns=dataset.column_names,
            fn_kwargs={"split": split},
        )

    # Abstract batch preprocessing logic to provide a simple api for `_preprocess_fn`
    def _batch_preprocess_fn(
        self, new_features: dict[str, list], split: str
    ) -> dict[str, list]:

        # batched `Dataset.map` receives data "columns" (features) but not "rows" (examples),
        # so we need to collect items from features to preprocess example at a time
        preprocessing_batch_size = len(new_features["idx"])
        new_examples = []
        for i in range(preprocessing_batch_size):
            new_example = {k: v[i] for k, v in new_features.items()}
            new_example = self._preprocess_fn(new_example, split=split)
            if new_example is None:
                _msg = "Skipping examples in validation/test set is not allowed."
                assert split == "train", _msg
            elif isinstance(new_example, list):
                new_examples.extend(new_example)
            else:
                new_examples.append(new_example)

        # batched `Dataset.map` expects data "columns" but not "rows"
        # so we need to reshape processed rows (examples) to columns (features)
        new_features = {name: [] for name in self.FEATURE_NAMES}
        for new_example in new_examples:
            for feature_name in self.FEATURE_NAMES:
                new_features[feature_name].append(new_example[feature_name])

        return new_features

    # Receive an example, return None(skip)/ an example/ multiple examples
    def _preprocess_fn(
        self, example: dict, split: str
    ) -> Union[None, dict, list[dict]]:
        nameA = self.SEGMENT_NAMES[0]
        nameB = self.SEGMENT_NAMES[1] if len(self.SEGMENT_NAMES) == 2 else None
        return self._generate_example(
            idx=example["idx"],
            textA=example[nameA],
            textB=example[nameB] if nameB else None,
            label=example["label"],
        )

    # A helper function for tokenization and create an data instance
    def _generate_example(
        self,
        idx: int,
        label: Any,
        textA: str,
        textB: Optional[str] = None,
        **other_features,
    ):
        output = self.tokenizer(
            textA,
            textB,
            max_length=self.config.max_sequence_length,
            truncation="longest_first",
        )
        return {
            "idx": idx,
            "input_ids": output.input_ids,
            "non_pad": [True] * len(output.input_ids),
            "segment_ids": output.token_type_ids,
            "label": label,
            **other_features,
        }

    # Used as dataloader's collate_fn
    def _collate(self, samples: list[dict[str, Tensor]]) -> dict[str, Tensor]:
        b = {key: [sample[key] for sample in samples] for key in samples[0]}  # batch
        pad_id = self.tokenizer.pad_token_id
        b["idx"] = torch.stack(b["idx"])  # <int>(B)
        b["input_ids"] = pad_seq(b["input_ids"], pad=pad_id)  # <int>(B,L)
        b["non_pad"] = pad_seq(b["non_pad"], pad=False)  # <bool>(B,L)
        b["segment_ids"] = pad_seq(b["segment_ids"], pad=0)  # <int>(B,L)
        b["label"] = torch.stack(b["label"])  # <int>(B,)
        return b


class RTEData(GLUEHFDataModule):  # Recognizing Textual Entailment
    BENCHMARK = "glue"
    TASK = "rte"
    SEGMENT_NAMES = ["sentence1", "sentence2"]


class CBData(GLUEHFDataModule):  # CommitmentBank
    BENCHMARK = "super_glue"
    TASK = "cb"
    SEGMENT_NAMES = ["premise", "hypothesis"]


class COPAData(GLUEHFDataModule):  # Choice of Plausible Alternatives
    BENCHMARK = "super_glue"
    TASK = "copa"
    FEATURE_NAMES = FEATURE_NAMES = [
        "idx",
        "input_ids",
        "non_pad",
        "segment_ids",
        "label",
        "choice2_input_ids",
        "choice2_non_pad",
        "choice2_segment_ids",
    ]

    def _preprocess_fn(self, example: dict, split: str) -> list[dict]:
        question_dict = {
            "cause": "What was the cause of this?",
            "effect": "What happened as a result?",
        }
        prompt = example["premise"] + " " + question_dict[example["question"]]
        choice2 = self._generate_example(
            idx=example["idx"], textA=prompt, textB=example["choice2"], label=None,
        )
        return self._generate_example(
            idx=example["idx"],
            textA=prompt,
            textB=example["choice1"],
            label=example["label"],
            choice2_input_ids=choice2["input_ids"],
            choice2_non_pad=choice2["non_pad"],
            choice2_segment_ids=choice2["segment_ids"],
        )

    def _collate(self, samples: list[dict[str, Tensor]]) -> dict[str, Tensor]:
        b = {key: [sample[key] for sample in samples] for key in samples[0]}  # batch
        pad_id = self.tokenizer.pad_token_id
        b["idx"] = torch.stack(b["idx"])  # <int>(B)
        b["input_ids"] = pad_seq(
            b["input_ids"] + b.pop("choice2_input_ids"), pad=pad_id
        )  # <int>(2B,L)
        b["non_pad"] = pad_seq(
            b["non_pad"] + b.pop("choice2_non_pad"), pad=False
        )  # <bool>(2B,L)
        b["segment_ids"] = pad_seq(
            b["segment_ids"] + b.pop("choice2_segment_ids"), pad=0
        )  # <int>(2B,L)
        b["label"] = torch.stack(b["label"])  # <int>(B)
        return b


class MultiRCData(GLUEHFDataModule):  # Multi-Sentence Reading Comprehension
    BENCHMARK = "super_glue"
    TASK = "multirc"

    def _preprocess_fn(self, example: dict, split: str):
        idxs = example["idx"]
        return self._generate_example(
            idx=[idxs["paragraph"], idxs["question"], idxs["answer"]],
            textA=example["paragraph"],
            textB=example["question"] + " " + example["answer"],
            label=example["label"],
        )


class WiCData(GLUEHFDataModule):  # Words in Context
    BENCHMARK = "super_glue"
    TASK = "wic"
    FEATURE_NAMES = [
        "idx",
        "input_ids",
        "non_pad",
        "segment_ids",
        "label",
        "word_ranges",
    ]

    def _preprocess_fn(self, example: dict, split: str):
        word: str = example["sentence1"][example["start1"] : example["end1"]]
        output = self.tokenizer(
            f"{word} {self.tokenizer.sep_token} {example['sentence1']}",
            example["sentence2"],
            max_length=self.config.max_sequence_length,
            truncation="longest_first",
            return_offsets_mapping=True,
        )

        # Find tokens that are associated with the span
        offset = len(f"{word} {self.tokenizer.sep_token} ")
        span1_token_range = character_range_to_token_range(
            target_char_range=(example["start1"] + offset, example["end1"] + offset),
            token_char_ranges=output.offset_mapping,
            candidate_mask=[i == 0 for i in output.token_type_ids],  # from setence1
        )  # (token_start, token_end)
        span2_token_range = character_range_to_token_range(
            target_char_range=(example["start2"], example["end2"]),
            token_char_ranges=output.offset_mapping,
            candidate_mask=[i == 1 for i in output.token_type_ids],  # from setence2
        )  # (token_start, token_end)

        return {
            "idx": example["idx"],
            "input_ids": output.input_ids,
            "non_pad": [True] * len(output.input_ids),
            "segment_ids": output.token_type_ids,
            "label": example["label"],
            "word_ranges": span1_token_range + span2_token_range,  # <int>(4)
        }

    def _collate(self, samples):
        b = {key: [sample[key] for sample in samples] for key in samples[0]}  # batch
        pad_id = self.tokenizer.pad_token_id
        b["idx"] = torch.stack(b["idx"])  # <int>(B)
        b["input_ids"] = pad_seq(b["input_ids"], pad=pad_id)  # <int>(B,L)
        b["non_pad"] = pad_seq(b["non_pad"], pad=False)  # <bool>(B,L)
        b["segment_ids"] = pad_seq(b["segment_ids"], pad=0)  # <int>(B,L)
        b["label"] = torch.stack(b["label"])  # <int>(B)
        b["word_ranges"] = torch.stack(b["word_ranges"]).view(-1, 2, 2)  # <int>(B,2,2)
        return b


class WSCData(GLUEHFDataModule):  # The Winograd Schema Challenge
    BENCHMARK = "super_glue"
    TASK = "wsc"
    FEATURE_NAMES = [
        "idx",
        "input_ids",
        "non_pad",
        "segment_ids",
        "label",
        "word_ranges",
    ]

    def _preprocess_fn(self, example: dict, split: str) -> Optional[dict]:
        text = example["text"]
        span1_char_start, span1_text = example["span1_index"], example["span1_text"]
        span2_char_start, span2_text = example["span2_index"], example["span2_text"]
        if span1_text not in text or span2_text not in text:
            if split == "test":  # Fix two test examples whose span texts are not span
                if span1_text == "Kamenev and Zinoviev":
                    span1_char_start = 21
                    span1_text = "Lev Kamenev, Stalin's old Pravda co-editor, and Grigory Zinoviev"
                elif span1_text == "Kamenev, Zinoviev, and Stalin":
                    span1_char_start = 0
                    span1_text = "Kotkin"
            else:
                return None  # unrecoverable incorrectly labeled example

        # Recover incorrectly labled example
        span1_char_end = span1_char_start + len(span1_text)
        if text[span1_char_start:span1_char_end] != span1_text:
            span1_char_start, span1_char_end = self._fix_span(
                text, span1_text, span1_char_start,
            )
        assert text[span1_char_start:span1_char_end] == span1_text
        span2_char_end = span2_char_start + len(span2_text)
        if text[span2_char_start:span2_char_end] != span2_text:
            span2_char_start, span2_char_end = self._fix_span(
                text, span2_text, span2_char_start,
            )
        assert text[span2_char_start:span2_char_end] == span2_text

        # Tokenize
        output = self.tokenizer(
            text,
            max_length=self.config.max_sequence_length,
            truncation="longest_first",
            return_offsets_mapping=True,
        )

        # Find tokens that are associated with the span
        span1_token_range = character_range_to_token_range(
            target_char_range=(span1_char_start, span1_char_end),
            token_char_ranges=output.offset_mapping,
        )
        span2_token_range = character_range_to_token_range(
            target_char_range=(span2_char_start, span2_char_end),
            token_char_ranges=output.offset_mapping,
        )

        return {
            "idx": example["idx"],
            "input_ids": output.input_ids,
            "non_pad": [True] * len(output.input_ids),
            "segment_ids": output.token_type_ids,
            "label": example["label"],
            "word_ranges": span1_token_range + span2_token_range,  # <int>(4)
        }

    def _fix_span(self, text, span_text, label_start_char_idx):
        # When labeled character positions are not correct, we'll find all
        # occurences and pick up the one closest to the original index
        start_idxs = [m.start(2) for m in re.finditer(f"(^|\W)({span_text})\W", text)]
        dist = (torch.tensor(start_idxs) - label_start_char_idx).abs()
        pick = torch.argmin(dist).item()
        start_char_idx = start_idxs[pick]
        end_char_idx = start_char_idx + len(span_text)
        assert text[start_char_idx:end_char_idx] == span_text
        return start_char_idx, end_char_idx

    def _collate(self, samples):
        b = {key: [sample[key] for sample in samples] for key in samples[0]}  # batch
        pad_id = self.tokenizer.pad_token_id
        b["idx"] = torch.stack(b["idx"])  # <int>(B)
        b["input_ids"] = pad_seq(b["input_ids"], pad=pad_id)  # <int>(B,L)
        b["non_pad"] = pad_seq(b["non_pad"], pad=False)  # <bool>(B,L)
        b["segment_ids"] = pad_seq(b["segment_ids"], pad=0)  # <int>(B,L)
        b["label"] = torch.stack(b["label"])  # <int>(B)
        b["word_ranges"] = torch.stack(b["word_ranges"]).view(-1, 2, 2)  # <int>(B,2,2)
        return b


class BoolQData(GLUEHFDataModule):  # BoolQ
    BENCHMARK = "super_glue"
    TASK = "boolq"
    SEGMENT_NAMES = ["question", "passage"]


class ReCoRDData(GLUEHFDataModule):  # Reading Comprehension with Commonsense Reasoning
    BENCHMARK: str = "super_glue"
    TASK = "record"
    FEATURE_NAMES = [
        "idx",
        "input_ids",
        "non_pad",
        "segment_ids",
        "masked_idx",
        "entity_ranges",
        "entity_labels",
    ]

    # load not only preprocessed dataset but also raw dataset for evaluation and testing
    def setup(self, stage):
        assert len(self.raw_datasets_kwargs) == 1
        self.raw_datasets = datasets.load_dataset(**self.raw_datasets_kwargs[0])
        super().setup(stage)

    def _preprocess_fn(self, example: dict, split):

        # Tokenize
        try:
            output = self.tokenizer(
                example["passage"],
                example["query"].replace("@placeholder", self.tokenizer.mask_token),
                max_length=self.config.max_sequence_length,
                truncation="only_first",
                return_offsets_mapping=True,
            )
        except:  # the query is longer than max length
            output = self.tokenizer(
                example["passage"],
                example["query"].replace("@placeholder", self.tokenizer.mask_token),
                max_length=self.config.max_sequence_length,
                truncation="longest_first",
                return_offsets_mapping=True,
            )

        # Find the masked position
        masked_idx = output.input_ids.index(self.tokenizer.mask_token_id)
        assert output.token_type_ids[masked_idx] == 1  # mask is in query text

        # Find the truncated range of passage
        for char_span, sid in zip(output.offset_mapping, output.token_type_ids):
            if sid == 0 and char_span != (0, 0):  # in first sentence and not sentinel
                last_char_end = char_span[1]

        # Convert character indices to token indices and label entities
        entity_ranges: list[int] = []
        entity_labels: list[int] = []
        for entity_text, entity_start, entity_end in zip(
            example["entity_spans"]["text"],
            example["entity_spans"]["start"],
            example["entity_spans"]["end"],
        ):
            # Fix imperfect labeling which includes spaces or unrecognized characters by tokenizer in entity text. Since character
            # ranges returned by tokenizer won't include spaces and unrecognized characteres, the start/end character index of entity
            # that includes those at the beggining/end won't be included in any token range, and this will trigger error when finding
            # associated tokens of the entity.
            char_ranges: list[tuple[int, int]] = self.tokenizer(
                entity_text, return_offsets_mapping=True, add_special_tokens=False
            ).offset_mapping
            if not char_ranges:
                continue  # skip this entity which is all invalid characters
            entity_char_start: int = (
                entity_start + char_ranges[0][0]
            )  # start of first token
            entity_char_end: int = (
                entity_start + char_ranges[-1][1]
            )  # end of last token
            # If the entity is truncated, skip this entity
            if entity_end > last_char_end:
                continue

            entity_token_start, entity_token_end = character_range_to_token_range(
                target_char_range=(entity_char_start, entity_char_end),
                token_char_ranges=output.offset_mapping,
                candidate_mask=[i == 0 for i in output.token_type_ids],
            )
            label = int(entity_text in example["answers"])
            entity_ranges.append(entity_token_start)
            entity_ranges.append(entity_token_end)
            entity_labels.append(label)

        if not entity_ranges:  # if all entity mentions in passage are truncated ...
            if split == "train":
                return None
            else:  # fake entity
                entity_ranges.append(0)
                entity_ranges.append(1)
                entity_labels.append(0)

        return {
            "idx": example["idx"]["query"],
            "input_ids": output.input_ids,
            "non_pad": [True] * len(output.input_ids),
            "segment_ids": output.token_type_ids,
            "masked_idx": masked_idx,  # <int>
            "entity_ranges": entity_ranges,  # <int>(#entities * 2), token ranges of entities
            "entity_labels": entity_labels,  # <int>(#entities)
        }

    def _collate(self, samples):
        b = {key: [sample[key] for sample in samples] for key in samples[0]}  # batch
        pad_id, B = self.tokenizer.pad_token_id, len(b["idx"])
        b["idx"] = torch.stack(b["idx"])  # <int>(B)
        b["input_ids"] = pad_seq(b["input_ids"], pad=pad_id)  # <int>(B,L)
        b["non_pad"] = pad_seq(b["non_pad"], pad=False)  # <bool>(B,L)
        b["segment_ids"] = pad_seq(b["segment_ids"], pad=0)  # <int>(B,L)
        b["masked_idx"] = torch.stack(b["masked_idx"])  # <int>(B)
        b["entity_labels"] = pad_seq(
            b["entity_labels"], pad=-1
        )  # <int>(B, max #entities)
        ranges = pad_seq(b["entity_ranges"], pad=-1)  # <int>(B, max #entities * 2)
        ranges = ranges.view(B, -1, 2)  # <int>(B, max #entities, 2)
        is_pad = b["entity_labels"] == -1  # <bool>(B, max #entities)
        is_pad_start = torch.stack(
            [is_pad, torch.zeros_like(is_pad)], dim=2
        )  # <bool>(B,max #entities,2)
        ranges[
            is_pad_start
        ] = 0  # make pad entity range (0,-1) to avoid 0 entity length
        b["entity_ranges"] = ranges
        return b


class AXbData(GLUEHFDataModule):  # Broadcoverage Diagnostics
    BENCHMARK = "super_glue"
    TASK = "axb"
    SEGMENT_NAMES = ["sentence1", "sentence2"]


class AXgData(GLUEHFDataModule):  # Winogender Schema Diagnostics
    BENCHMARK = "super_glue"
    TASK = "axg"
    SEGMENT_NAMES = ["premise", "hypothesis"]


def character_range_to_token_range(
    target_char_range: tuple[int, int],
    token_char_ranges: list[tuple[int, int]],
    candidate_mask: list[bool] = None,
) -> tuple[int, int]:
    tgt_char_start, tgt_char_end = target_char_range
    tgt_token_range = [None, None]
    for token_idx, (token_char_start, token_char_end) in enumerate(token_char_ranges):
        if token_char_start == token_char_end == 0:
            continue  # sentinel tokens (e.g. [CLS], [SEP])
        if candidate_mask and candidate_mask[token_idx] is False:
            continue

        # Note at most times tgt_char_idx will equal to a token's start or end char idx,
        # but sometimes label is not very precise and we'll find the most recent token to
        # correct it.
        if token_char_start <= tgt_char_start < token_char_end:
            tgt_token_range[0] = token_idx  # inclusive start
        if token_char_start < tgt_char_end <= token_char_end:
            tgt_token_range[1] = token_idx + 1  # exclusive end

    assert tgt_token_range[0] <= tgt_token_range[1]
    return tgt_token_range


DATA_MODULES = {
    "cola": CoLAData,
    "sst2": SST2Data,
    "mrpc": MRPCData,
    "stsb": STSBData,
    "qqp": QQPData,
    "mnli": MNLIData,
    "qnli": QNLIData,
    "rte": RTEData,
    "wnli": WNLIData,
    "ax": AXData,
    "cb": CBData,
    "copa": COPAData,
    "multirc": MultiRCData,
    "wic": WiCData,
    "wsc": WSCData,
    "boolq": BoolQData,
    "record": ReCoRDData,
    "axb": AXbData,
    "axg": AXgData,
}
