import logging
from pathlib import Path

import numpy as np
import torch
from datasets import Dataset
from transformers import BatchEncoding

from ...metrics import (
    CorrectnessMetric,
    LossMetric,
    SequenceProbMetric,
    TokenEvaluationTask,
    TokenMetric,
)


logger = logging.getLogger(__name__)
FILE_DIR = Path(__file__).parent
ROOT_PATH = FILE_DIR.parent.parent.parent


def memorization_dynamics_metrics(
    # data: RandomStringData,
    alphabet_tokens: list[str],
    alphabet_token_ids: np.ndarray,
    encoded_dataset: Dataset,
) -> TokenEvaluationTask:
    """Construct a metric per character that tracks the charcter's probability
    distribution over the sequence.
    """
    strings_encoding = BatchEncoding(
        dict(
            input_ids=encoded_dataset["input_ids"],
            attention_mask=encoded_dataset["attention_mask"],
        )
    )
    metrics: dict[str, TokenMetric] = {
        "loss": LossMetric(strings_encoding),
        "correct": CorrectnessMetric(strings_encoding),
    }
    assert all(token not in metrics for token in alphabet_tokens)

    # Add the character metrics that track the probabilty distribution
    # for each character
    input_ids = encoded_dataset["input_ids"]
    assert isinstance(input_ids, torch.Tensor)
    data_shape = input_ids.shape
    for token, token_encoding in zip(alphabet_tokens, alphabet_token_ids):
        # Sequence probability estimation is done by feeding token sequences
        # with the same length as the sequence of token probabilities
        # produced by the model for the input sequence.
        # Therefore, for each cahracter whose token's probability we want to
        # compute, we create a sequence of the same length as the
        # input/training sequence consisting only of that character's token.
        # E.g. for the 26 latin chracters, we estimate the probabilities
        # of the sequences {"a": ["a", "a", ...], "b": ["b", "b", ...], ...}.
        token_ids = torch.tensor(
            [token_encoding] * data_shape[1], dtype=torch.long
        )
        token_mask = torch.ones_like(token_ids)
        token_encoding = BatchEncoding(
            dict(
                input_ids=token_ids,
                attention_mask=token_mask,
            )
        )
        metrics[token] = SequenceProbMetric(token_encoding)

    eval_task = TokenEvaluationTask(
        metrics,
        data=encoded_dataset,
        index_names=["string", "character"],
    )
    return eval_task
