from dataclasses import dataclass

import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

from utils.context_probing import (
    compute_majority_token,
    compute_target_token_probs,
    is_prediction_correct,
    sample_random_string_replacement,
)
from utils.encoding import encode_strings_characterwise


@dataclass
class ProbingResult:
    correct: torch.Tensor
    majority_prediction_correct: bool
    target_token_prob: torch.Tensor
    output: dict


def test_model_with_strings(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    text: list[str],
) -> ProbingResult:
    data = encode_strings_characterwise(tokenizer, text)
    # Use the last token as target token
    target_token_ids = torch.tensor(
        [encoding[-1] for encoding in data["input_ids"]]
    )
    assert all(tt == target_token_ids[0] for tt in target_token_ids)
    target_token_id = target_token_ids[0]

    model.config.output_attentions = True
    output = model(
        input_ids=data["input_ids"],
        attention_mask=data["attention_mask"],
    )

    token_logits = output["logits"][:, -2]
    predictions_correct = is_prediction_correct(token_logits, target_token_id)
    majority_correct_token = compute_majority_token(token_logits)
    print("Target token:", target_token_id)
    print("Majority correct token:", majority_correct_token)
    majority_prediction_correct = bool(
        (majority_correct_token == target_token_id).item()
    )
    target_token_prob = compute_target_token_probs(
        token_logits, target_token_id
    )
    return ProbingResult(
        correct=predictions_correct,
        majority_prediction_correct=majority_prediction_correct,
        target_token_prob=target_token_prob,
        output=output,
    )
