"""Function returning PILE datasets compatible with some other places."""
from typing import Dict, Optional

import datasets
import torch
from transformers import PreTrainedTokenizer

###############################################################################

_TASK_TO_DATASET_NAME = {
    'pile-uncopyrighted': "monology/pile-uncopyrighted"
}

###############################################################################


def load_default_language_modeling_task(
    task: str,
    subtask: Optional[str],
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
    *,
    skip: Optional[int] = None,
    # Intended for debugging.
    select_relevant_columns: bool = True,
):
    if task not in _TASK_TO_DATASET_NAME:
        raise ValueError(f'Invalid task name: {task}')

    base_ds = datasets.load_dataset(_TASK_TO_DATASET_NAME[task], subtask, streaming=True)

    ds = base_ds[split]
    ds = ds.filter(lambda ex: len(ex['text'].strip()) > 0)
    if skip is not None:
        ds = ds.skip(skip)
    ds = ds.map(
        lambda ex: encode_example(tokenizer, ex['text'], sequence_length),
        batched=False,
    )

    if select_relevant_columns:
        ds = ds.select_columns(['input_ids', 'attention_mask', 'labels'])

    return ds


###############################################################################


def _truncate_and_pad(x: torch.Tensor, pad_token_id: int, sequence_length: int) -> torch.Tensor:
    # For whatever reason, a dummy batch dimension of 1 gets added. Remove it.
    x = torch.squeeze(x, dim=0)
    x = x[:sequence_length]
    if len(x) < sequence_length:
        n_padding = sequence_length - len(x)
        padding = pad_token_id * torch.ones([n_padding], dtype=x.dtype, device=x.device)
        x = torch.cat([x, padding], dim=-1)
    return x


def encode_example(tokenizer: PreTrainedTokenizer, text: str, max_length: int) -> Dict[str, torch.Tensor]:
    ex = tokenizer(text, return_tensors="pt", max_length=max_length, padding=True, truncation=True)
    ex = {
        'input_ids': _truncate_and_pad(ex['input_ids'], tokenizer.pad_token_id, max_length).type(torch.int32),
        'attention_mask': _truncate_and_pad(ex['attention_mask'], 0, max_length).type(torch.int32),
    }

    labels = torch.zeros_like(ex['input_ids'])
    labels[:-1] = ex['input_ids'][1:]
    labels[-1] = tokenizer.pad_token_id
    ex['labels'] = labels

    return ex
