"""Function returning yahoo_answers_topics datasets compatible with some other places.

https://huggingface.co/datasets/community-datasets/yahoo_answers_topics
"""
from typing import Any, Dict, List, Optional, Union

import datasets
import torch
from transformers import PreTrainedTokenizer

from npeff_torch.datasets import preprocessing_common


###############################################################################

# The index of the label here is the integral representation of each label.
LABELS = ('Society & Culture', 'Science & Mathematics', 'Health', 'Education & Reference', 'Computers & Internet', 'Sports', 'Business & Finance', 'Entertainment & Music', 'Family & Relationships', 'Politics & Government')

_NL_LABELS = "\n".join(LABELS)

_SUBTASK_TO_TEMPLATE = {
    'prompt_1': f'Question: {{question}}\nWhat broad topic is this question about? Choose from:\n{_NL_LABELS}.\nTopic:',
}


def load_default_lm_mcqa_task(
    task: str,
    # The subtask can be used to specify a pre-set template. It must be None if template is provided.
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    template: Optional[str] = None,
):
    if task != 'yahoo_answers_topics':
        raise ValueError(f'The task must be "yahoo_answers_topics". Instead received: {task}')

    if template is None:
        template = _SUBTASK_TO_TEMPLATE[subtask]
    else:
        assert subtask is None

    base_ds = datasets.load_dataset("community-datasets/yahoo_answers_topics")
    ds = base_ds[split].to_iterable_dataset()

    ds = ds.map(
        lambda ex: encode_lm_mcqa_example(tokenizer, ex, sequence_length, template=template),
        batched=False,
    )
    ds = ds.map(
        lambda ex: {'labels': ex['topic']},
        batched=False,
    )
    ds = ds.select_columns(['input_ids', 'attention_mask', 'labels'])

    return ds


###############################################################################


def encode_lm_mcqa_example(
    tokenizer: PreTrainedTokenizer,
    example: Dict[str, Any],
    sequence_length: int,
    *,
    template: str,
) -> Dict[str, torch.Tensor]:
    question = example['question_title'].strip()
    context = template.format(question=question)

    ex = tokenizer(
        context, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=sequence_length,
        truncation=True,
        padding=True,
    )

    return {
        'input_ids': preprocessing_common.single_truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, sequence_length),
        'attention_mask': preprocessing_common.single_truncate_and_pad(ex['attention_mask'], 0, sequence_length),
    }
