"""Functions returning Winogrande datasets.


From the HuggingFace dataset page:
    https://huggingface.co/datasets/allenai/winogrande

Sizes:
    name                train   validation  test
    winogrande_debiased 9248    1267        1767
    winogrande_l        10234   1267        1767
    winogrande_m        2558    1267        1767
    winogrande_s        640     1267        1767
    winogrande_xl       40398   1267        1767
    winogrande_xs       160     1267        1767

I think that all of the subtasks above are subsets of the "winogrande_xl" split.

"""
from typing import Any, Dict, Optional

import datasets
from lighteval.tasks import default_prompts
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common


###############################################################################

_VALID_SUBTASKS = (
    'winogrande_debiased',
    'winogrande_l',
    'winogrande_m',
    'winogrande_s',
    'winogrande_xl',
    'winogrande_xs',
)

###############################################################################


def load_default_lm_suffix_mc_task(
    task: str,
    subtask: str,
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    if task != 'winogrande':
        raise ValueError(f'The task must be "winogrande". Instead received: {task}')
    if subtask not in _VALID_SUBTASKS:
        raise ValueError(f'Invalid subtask: {subtask}')

    base_ds = datasets.load_dataset("allenai/winogrande", subtask)
    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_lm_suffix_mc_example(tokenizer, ex, sequence_length),
        batched=False,
    )

    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels', 'context_length'])

    return ds


def encode_lm_suffix_mc_example(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int
):
    doc = default_prompts.winogrande(example)
    return preprocessing_common.encode_lm_suffix_mc_example(doc, tokenizer=tokenizer, sequence_length=sequence_length)
