"""Functions returning hellaswag datasets.


Looks like the number of endings is always 4.

Sizes:
    train:      39905
    validation: 10042
    test:       10003
"""
import re
from typing import Any, Dict, Optional

import datasets
from lighteval.tasks.requests import Doc
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common

###############################################################################


def load_default_lm_suffix_mc_task(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    if task != 'hellaswag':
        raise ValueError(f'The task must be "hellaswag". Instead received: {task}')

    base_ds = datasets.load_dataset("Rowan/hellaswag", subtask)
    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_lm_suffix_mc_example(tokenizer, ex, sequence_length),
        batched=False,
    )

    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels', 'context_length'])

    return ds

###############################################################################


def encode_lm_suffix_mc_example(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int
):
    doc = _prompt_hellaswag(example)
    return preprocessing_common.encode_lm_suffix_mc_example(doc, tokenizer=tokenizer, sequence_length=sequence_length)


def _prompt_hellaswag(line, task_name: str = None):
    # Taken from https://github.com/huggingface/smollm/blob/main/evaluation/tasks.py
    def preprocess(text):
        """Comes from AiHarness"""
        # text = text.strip()
        # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
        text = text.replace(" [title]", ". ")
        text = re.sub("\\[.*?\\]", "", text)
        text = text.replace("  ", " ")
        return text

    ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
    return Doc(
        task_name=task_name,
        query=preprocess(line["activity_label"] + ": " + ctx).rstrip(),
        choices=[" " + preprocess(ending) for ending in line["endings"]],
        gold_index=int(line["label"]) if line["label"] != "" else -1,  # -1 for test
    )
