import torch
from datasets import Dataset, DatasetDict
from datasets.utils.logging import disable_progress_bar
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizer

from utils.encoding import encode_data_characterwise


def test_model_with_string(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    text: str,
):
    disable_progress_bar()
    text_dataset = DatasetDict(
        {
            "test": Dataset.from_dict(
                {
                    "text": [text],
                }
            ),
        }
    )
    data = encode_data_characterwise(tokenizer, text_dataset)["test"]
    # Use the last token as target token
    target_token_id = data["input_ids"][0][-1]

    model.config.output_attentions = True
    output = model(
        input_ids=data["input_ids"],
        attention_mask=data["attention_mask"],
    )

    token_logits = output["logits"][:, -2]
    token_probs = torch.softmax(token_logits, dim=-1)
    max_prob_token_id = token_logits.argmax(dim=-1)
    predictions_correct = max_prob_token_id == target_token_id
    target_token_prob = token_probs[0, target_token_id]

    return predictions_correct, target_token_prob, output


def sample_random_replacement(
    alphabet: str,
    length: int,
    seed: int,
) -> str:
    rng = torch.Generator().manual_seed(seed)
    replacement_token_indices = torch.randint(
        len(alphabet),
        size=(length,),
        generator=rng,
    )
    replacement_string = "".join(
        [alphabet[i] for i in replacement_token_indices]
    )
    return replacement_string
