"""Function returning SST2 datasets compatible with some other places.

https://huggingface.co/datasets/stanfordnlp/sst2

Sizes:
    train   validation  test
    67349   872         1821

"""
import json
import os
import string
from typing import Any, Dict, List, Optional, Union

import datasets
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common
from npeff_torch.icl import icl_datasets_common


###############################################################################
###############################################################################

# Following some stuff for the SST2 template from https://arxiv.org/pdf/2102.09690
LM_MCQA_TEMPLATE = "Review: {sentence}\nSentiment:"


# Labels will be (' Negative', ' Positive')
def load_default_lm_mcqa_task(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    template: str = LM_MCQA_TEMPLATE,
):
    if task != 'sst2':
        raise ValueError(f'The task must be "sst2". Instead received: {task}')
    
    base_ds = datasets.load_dataset("stanfordnlp/sst2", subtask)

    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_lm_mcqa_example(tokenizer, ex, sequence_length, template=template),
        batched=False,
    )
    ds = ds.map(
        lambda ex: {'label': encode_lm_mcqa_label(ex['label'])},
        batched=False,
    )
    ds = ds.rename_column('label', 'labels')
    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels'])

    return ds


###############################################################################

def encode_lm_mcqa_example(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int,
    *,
    template: str = LM_MCQA_TEMPLATE,
) -> Dict[str, torch.Tensor]:
    sentence = example['sentence'].strip()
    context = template.format(sentence=sentence)

    ex = tokenizer(
        context, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=sequence_length,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, sequence_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, sequence_length),
    }


_OG_LABELS = ('negative', 'positive')
_OG_TO_BINARIZED_LABEL = {
    'negative': 'Negative',
    'positive': 'Positive',
}
_LM_MCQA_LABEL = ('Negative', 'Positive')


def encode_lm_mcqa_label(label: int) -> int:
    og_label = _OG_LABELS[label]
    binarized_label = _OG_TO_BINARIZED_LABEL[og_label]
    return _LM_MCQA_LABEL.index(binarized_label)


###############################################################################
###############################################################################

# Labels will be (' Negative', ' Positive')
def load_default_lm_mcqa_task_with_icl(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    template: str = LM_MCQA_TEMPLATE,
    example_separator: str = '\n\n',
    # Will use the icl context with the highest score from this file(s). This file(s) should
    # have been created by the lm_mcqa_icl_example_selection.py script.
    icl_context_score_filepath: Union[str, List[str]],
):
    if task != 'sst2':
        raise ValueError(f'The task must be "sst2". Instead received: {task}')

    if isinstance(icl_context_score_filepath, str):
        icl_context_score_filepath = [icl_context_score_filepath]

    score_results = []
    for filepath in icl_context_score_filepath:
        with open(os.path.expanduser(filepath), 'rt') as f:
            score_results.extend(json.load(f))
    
    best_result = max(score_results, key=lambda r: r['score'])
    print(best_result['score'])
    icl_context = best_result['icl_context']

    base_ds = datasets.load_dataset("stanfordnlp/sst2", subtask)

    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_lm_mcqa_task_with_icl(
            tokenizer, ex, sequence_length,
            template=template, example_separator=example_separator, icl_context=icl_context),
        batched=False,
    )
    ds = ds.map(
        lambda ex: {'label': encode_lm_mcqa_label(ex['label'])},
        batched=False,
    )
    ds = ds.rename_column('label', 'labels')
    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels'])

    return ds


def encode_lm_mcqa_task_with_icl(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int,
    *,
    template: str = LM_MCQA_TEMPLATE,
    example_separator: str = '\n\n',
    icl_context: str,
) -> Dict[str, torch.Tensor]:
    sentence = example['sentence'].strip()
    context = template.format(sentence=sentence)

    context = f'{icl_context}{example_separator}{context}'

    ex = tokenizer(
        context, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=sequence_length,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, sequence_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, sequence_length),
    }

###############################################################################
###############################################################################


class LmMcqaIclExampleHelper(icl_datasets_common.IclExampleHelperAbc):

    def __init__(
        self, *,
        tokenizer: PreTrainedTokenizer,
        sequence_length: int,

        template: Optional[str] = None,
    ):
        super().__init__(tokenizer=tokenizer, sequence_length=sequence_length)
        self.template = LM_MCQA_TEMPLATE if template is None else template

    def make_text_for_individual_example(self, example: Dict[str, Any], *, include_label: bool) -> str:
        sentence = example['sentence'].strip()
        context = self.template.format(sentence=sentence)
        if include_label:
            label_str = _LM_MCQA_LABEL[encode_lm_mcqa_label(example['label'])]
            context = f'{context} {label_str}'
        return context

    def join_examples_text(self, texts: List[str]) -> str:
        return '\n\n'.join(texts)

    def encode_text(self, text: str) -> Dict[str, torch.Tensor]:
        ex = self.tokenizer(
            text, 
            return_tensors="pt",
            # Do this to ensure that no EOS token gets appended to the end.
            add_special_tokens=False,
            max_length=self.sequence_length,
            truncation=True,
            padding=True,
        )
        return {
            'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], self.tokenizer.pad_token_id, self.sequence_length),
            'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, self.sequence_length),
        }
