"""Common things for preprocessing datasets."""
from lighteval.tasks.requests import Doc
import torch
from transformers import PreTrainedTokenizer


###############################################################################


def encode_lm_suffix_mc_example(
    doc: Doc,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    """Encodes a given doc as a multiple choice via LM suffix."""
    
    # At least for the tokenizers that I've f-ed with, the space is prepended to the
    # token. Hence make sure the query tokens are a prefix for all of the full sentences
    # for each choice.
    query = doc.query.rstrip()
    sentences = [query]
    for choice in doc.choices:
        # TODO: Maybe make this configurable with a newline or something if some tasks use that?
        if choice.startswith(' '):
            sentences.append(f'{query}{choice}')
        else:
            sentences.append(f'{query} {choice}')

    ex = tokenizer(
        sentences, 
        return_tensors="pt",
        # Do this to ensure that no EOS token gets appended to the end.
        add_special_tokens=False,
        max_length=sequence_length,
        truncation=True,
        padding=True,
    )

    input_ids = ex['input_ids'][1:]
    attention_mask = ex['attention_mask'][1:]

    context_length = torch.sum((ex['attention_mask'][0] != 0).type(input_ids.dtype))

    return {
        'input_ids': _lm_suffix_mc_truncate_and_pad(input_ids, tokenizer.pad_token_id, sequence_length),
        'attention_mask': _lm_suffix_mc_truncate_and_pad(attention_mask, 0, sequence_length),
        'context_length': context_length,
        'labels': doc.gold_index,
    }


def _lm_suffix_mc_truncate_and_pad(x: torch.Tensor, pad_token_id: int, sequence_length: int) -> torch.Tensor:
    # x.shape = [n_options, sequence]
    assert len(x.shape) == 2
    x = x[:, :sequence_length]
    if x.shape[-1] < sequence_length:
        n_padding = sequence_length - x.shape[-1]
        padding = pad_token_id * torch.ones([x.shape[0], n_padding], dtype=x.dtype, device=x.device)
        x = torch.cat([x, padding], dim=-1)
    return x


###############################################################################


def single_truncate_and_pad(x: torch.Tensor, pad_token_id: int, sequence_length: int) -> torch.Tensor:
    # x.shape = [1, sequence]
    # ret.shape = [sequence_length]
    # For whatever reason, a dummy batch dimension of 1 gets added. Remove it.
    x = torch.squeeze(x, dim=0)
    x = x[:sequence_length]
    if len(x) < sequence_length:
        n_padding = sequence_length - len(x)
        padding = pad_token_id * torch.ones([n_padding], dtype=x.dtype, device=x.device)
        x = torch.cat([x, padding], dim=-1)
    return x
