"""Function returning SNLI datasets compatible with some other places."""
import string
from typing import Any, Dict, List, Optional

import datasets
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common
from npeff_torch.icl import icl_datasets_common


###############################################################################
###############################################################################


def load_default_sequence_classification_task(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    if task != 'snli':
        raise ValueError(f'The task must be "snli". Instead received: {task}')
    
    base_ds = datasets.load_dataset("stanfordnlp/snli", subtask)

    ds = base_ds[split].to_iterable_dataset()

    # Remove examples with no labels.
    # NOTE: I think somewhere between 700-800 examples in the train set with no labels.
    ds = ds.filter(lambda x: x['label'] != -1)

    ds = ds.map(
        lambda ex: encode_example(tokenizer, ex['premise'], ex['hypothesis'], sequence_length),
        batched=False,
    )
    ds = ds.rename_column('label', 'labels')
    ds = ds.select_columns(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

    return ds


###############################################################################


def encode_example(tokenizer: PreTrainedTokenizer, premise: str, hypothesis: str, max_length: int) -> Dict[str, torch.Tensor]:
    ex = tokenizer.encode_plus(
        premise,
        hypothesis,
        return_tensors="pt",
        add_special_tokens=True,
        max_length=max_length,
        return_token_type_ids=True,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, max_length),
        'token_type_ids': preprocessing_common.single_truncate_and_pad(ex['token_type_ids'], tokenizer.pad_token_type_id, max_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, max_length),
    }


###############################################################################
###############################################################################


_LM_MCQA_TEMPLATE = "{premise}\nQuestion: {hypothesis}. True, False or Neither?\nAnswer:"


# Labels will be (' True', ' False', ' Neither')
def load_default_lm_mcqa_task(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    template: str = _LM_MCQA_TEMPLATE,
):
    if task != 'snli':
        raise ValueError(f'The task must be "snli". Instead received: {task}')
    
    base_ds = datasets.load_dataset("stanfordnlp/snli", subtask)

    ds = base_ds[split].to_iterable_dataset()

    # Remove examples with no labels.
    # NOTE: I think somewhere between 700-800 examples in the train set with no labels.
    ds = ds.filter(lambda x: x['label'] != -1)

    ds = ds.map(
        lambda ex: encode_lm_mcqa_example(tokenizer, ex['premise'], ex['hypothesis'], sequence_length, template=template),
        batched=False,
    )
    ds = ds.map(
        lambda ex: {'label': encode_lm_mcqa_label(ex['label'])},
        batched=False,
    )
    ds = ds.rename_column('label', 'labels')
    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels'])

    return ds


###############################################################################


# TODO: Make this better? [see what lighteval did]
def _remove_trailing_punctuation(s: str) -> str:
    return s.rstrip(string.punctuation)


def encode_lm_mcqa_example(
    tokenizer: PreTrainedTokenizer,
    premise: str,
    hypothesis: str,
    max_length: int,
    *,
    template: str = _LM_MCQA_TEMPLATE,
) -> Dict[str, torch.Tensor]:
    hypothesis = _remove_trailing_punctuation(hypothesis)
    context = template.format(premise=premise, hypothesis=hypothesis)

    ex = tokenizer(
        context, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=max_length,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, max_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, max_length),
    }


# TODO: Name this stuff and make it cleaner.
LABELS = ('entailment', 'neutral', 'contradiction')
LABELS_MAP = {
    'entailment': 'True',
    'neutral': 'Neither',
    'contradiction': 'False',
}
LABELS_MAP2 = {
    'True': 'A',
    'False': 'B',
    'Neither': 'C',
}
LABELS2 = ('A', 'B', 'C')

_MCQA_LABELS = ('True', 'False', 'Neither')

_OG_TO_LM_MCQA_LABEL = tuple(LABELS2.index(LABELS_MAP2[LABELS_MAP[LABELS[i]]]) for i in range(len(LABELS)))


def encode_lm_mcqa_label(label: int) -> int:
    return _OG_TO_LM_MCQA_LABEL[label]


###############################################################################


class LmMcqaIclExampleHelper(icl_datasets_common.IclExampleHelperAbc):

    def __init__(
        self, *,
        tokenizer: PreTrainedTokenizer,
        sequence_length: int,

        template: Optional[str] = None,
    ):
        super().__init__(tokenizer=tokenizer, sequence_length=sequence_length)
        self.template = _LM_MCQA_TEMPLATE if template is None else template

    def make_text_for_individual_example(self, example: Dict[str, Any], *, include_label: bool) -> str:
        premise = example['premise']
        hypothesis = _remove_trailing_punctuation(example['hypothesis'])

        context = self.template.format(premise=premise, hypothesis=hypothesis)
        if include_label:
            label_str = _MCQA_LABELS[encode_lm_mcqa_label(example['label'])]
            context = f'{context} {label_str}'
        return context

    def join_examples_text(self, texts: List[str]) -> str:
        return '\n\n'.join(texts)

    def encode_text(self, text: str) -> Dict[str, torch.Tensor]:
        ex = self.tokenizer(
            text, 
            return_tensors="pt",
            # Do this to ensure that no EOS token gets appended to the end.
            add_special_tokens=False,
            max_length=self.sequence_length,
            truncation=True,
            padding=True,
        )
        return {
            'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], self.tokenizer.pad_token_id, self.sequence_length),
            'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, self.sequence_length),
        }
