"""Function returning SNLI datasets compatible with some other places."""
import string
from typing import Dict, Optional

import datasets
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common


###############################################################################
###############################################################################


# Following some stuff for the RTE template from https://arxiv.org/pdf/2102.09690
# 
# For some reason, they had "question" and "answer" lower-cased, so I'm following that.
BINARIZED_LM_MCQA_TEMPLATE = "{premise}\nquestion: {hypothesis}. True or False?\nanswer:"


# Labels will be (' True', ' False')
def load_default_binarized_lm_mcqa_task(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    template: str = BINARIZED_LM_MCQA_TEMPLATE,
):
    if task != 'snli':
        raise ValueError(f'The task must be "snli". Instead received: {task}')
    
    base_ds = datasets.load_dataset("stanfordnlp/snli", subtask)

    ds = base_ds[split].to_iterable_dataset()

    # Remove examples with no labels.
    # NOTE: I think somewhere between 700-800 examples in the train set with no labels.
    ds = ds.filter(lambda x: x['label'] != -1)

    ds = ds.map(
        lambda ex: encode_binarized_lm_mcqa_example(tokenizer, ex['premise'], ex['hypothesis'], sequence_length, template=template),
        batched=False,
    )
    ds = ds.map(
        lambda ex: {'label': encode_binarized_lm_mcqa_label(ex['label'])},
        batched=False,
    )
    ds = ds.rename_column('label', 'labels')
    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels'])

    return ds


###############################################################################


# TODO: Make this better? [see what lighteval did]
def _remove_trailing_punctuation(s: str) -> str:
    return s.rstrip(string.punctuation)


def encode_binarized_lm_mcqa_example(
    tokenizer: PreTrainedTokenizer,
    premise: str,
    hypothesis: str,
    max_length: int,
    *,
    template: str = BINARIZED_LM_MCQA_TEMPLATE,
) -> Dict[str, torch.Tensor]:
    hypothesis = _remove_trailing_punctuation(hypothesis)
    context = template.format(premise=premise, hypothesis=hypothesis)

    ex = tokenizer(
        context, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=max_length,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, max_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, max_length),
    }


_OG_LABELS = ('entailment', 'neutral', 'contradiction')
_OG_TO_BINARIZED_LABEL = {
    'entailment': 'True',
    'neutral': 'False',
    'contradiction': 'False',
}
_BINARIZED_LABEL = ('True', 'False')


def encode_binarized_lm_mcqa_label(label: int) -> int:
    og_label = _OG_LABELS[label]
    binarized_label = _OG_TO_BINARIZED_LABEL[og_label]
    return _BINARIZED_LABEL.index(binarized_label)
