"""Functions returning trivia-qa datasets.


From the HuggingFace dataset page:
    https://huggingface.co/datasets/mandarjoshi/trivia_qa

Sizes:
    name                 train   validation  test
    rc                   138384  18669       17210
    rc.nocontext         138384  18669       17210
    unfiltered           87622   11313       10832
    unfiltered.nocontext 87622   11313       10832
    
"""
from typing import Any, Dict, Optional

import datasets
from lighteval.tasks import default_prompts
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common


###############################################################################


def load_default_as_open_qa_task(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    if task != 'trivia_qa':
        raise ValueError(f'The task must be "trivia_qa". Instead received: {task}')

    base_ds = datasets.load_dataset("mandarjoshi/trivia_qa", subtask)
    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_as_open_qa_example(tokenizer, ex, sequence_length),
        batched=False,
    )
    ds = ds.select_columns(['input_ids', 'attention_mask'])

    return ds


def encode_as_open_qa_example(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int
):
    doc = default_prompts.triviaqa(example)

    question = doc.query.rstrip()

    ex = tokenizer(
        question, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=sequence_length,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, sequence_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, sequence_length),
    }
