import torch
from datasets import DatasetDict
from datasets.utils.logging import disable_progress_bar
from transformers import BatchEncoding, PreTrainedTokenizer

from lib_llm.inference import tokenize


def encode_data_naturally(
    tokenizer: PreTrainedTokenizer,
    dataset: DatasetDict,
) -> DatasetDict:
    def encode(example: dict):
        sequences = example["text"]
        max_length = max(len(s) for s in sequences)
        return tokenize(
            tokenizer,
            sequences,
            # Pad to the longest possible tokenization
            max_length=max_length,
        )

    return dataset.map(
        encode,
        batched=True,
    )


def encode_strings_characterwise(
    tokenizer: PreTrainedTokenizer,
    sequences: list[str],
) -> BatchEncoding:
    max_length = max(len(s) for s in sequences)
    sequence_token_ids = []
    sequence_token_masks = []
    for sequence in sequences:
        sequence_chars = list(sequence)
        encoded_chars = tokenize(
            tokenizer,
            sequence_chars,
            max_length=1,
        )
        # add padding
        num_padding = max_length - len(sequence)
        padded_input_ids = torch.cat(
            (
                torch.tensor(
                    [tokenizer.pad_token_id] * num_padding, dtype=torch.long
                ),
                encoded_chars.input_ids.squeeze(1),
            )
        )
        padded_attention_mask = torch.cat(
            (
                torch.tensor([0] * num_padding, dtype=torch.long),
                encoded_chars.attention_mask.squeeze(1),
            )
        )
        sequence_token_ids.append(padded_input_ids)
        sequence_token_masks.append(padded_attention_mask)
    return BatchEncoding(
        {
            "input_ids": torch.stack(sequence_token_ids),
            "attention_mask": torch.stack(sequence_token_masks),
        }
    )


def encode_data_characterwise(
    tokenizer: PreTrainedTokenizer,
    dataset: DatasetDict,
) -> DatasetDict:
    def characterwise_encoding(example: dict):
        return encode_strings_characterwise(tokenizer, example["text"])

    disable_progress_bar()
    return dataset.map(
        characterwise_encoding,
        batched=True,
        # remove_columns=["text"],
    ).with_format("torch")
