# Copyright 2024-2025
# [ANONYMIZED_INSTITUTION],
# [ANONYMIZED_FACULTY],
# [ANONYMIZED_DEPARTMENT]
#
# Authors:
# AUTHOR_1 (author1@example.com)
# AUTHOR_2 (author2@example.com)
#
# Code generation tools and workflows:
# First versions of this code were potentially generated
# with the help of AI writing assistants including
# GitHub Copilot, ChatGPT, Microsoft Copilot, Google Gemini.
# Afterwards, the generated segments were manually reviewed and edited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Iterate over a dataset and compute the perplexity for each sentence."""

import logging

import datasets
import numpy as np
import torch
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast

from topollm.config_classes.tokenizer.tokenizer_config import TokenizerConfig
from topollm.model_handling.loaded_model_container import LoadedModelContainer
from topollm.model_inference.perplexity.repeat_tensor_input_and_apply_diagonal_mask import (
    repeat_tensor_input_and_apply_diagonal_mask,
)
from topollm.model_inference.perplexity.saving.sentence_perplexity_container import SentencePerplexityContainer
from topollm.typing.enums import LMmode, MLMPseudoperplexityGranularity, Verbosity
from topollm.typing.types import PerplexityResultsList

default_device = torch.device(
    device="cpu",
)
default_logger: logging.Logger = logging.getLogger(
    name=__name__,
)


def pseudoperplexity_per_token_of_sentence(
    sentence: str,
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    tokenizer_config: TokenizerConfig,
    model: PreTrainedModel,
    mlm_pseudoperplexity_granularity: MLMPseudoperplexityGranularity = MLMPseudoperplexityGranularity.SENTENCE,  # type: ignore - Problems with StrEnum
    device: torch.device = default_device,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> SentencePerplexityContainer:
    """Compute the pseudo-perplexity of a masked language model on a given sentence."""
    mask_token_id = tokenizer.mask_token_id
    if not isinstance(
        mask_token_id,
        int,
    ):
        msg = "Expected an integer."
        raise TypeError(msg)

    # Make sure that `padding=False`, otherwise the repeated input will be duplicated many times.
    tensor_input = tokenizer.encode(
        text=sentence,
        return_tensors="pt",
        max_length=tokenizer_config.max_length,
        padding=False,
        truncation="longest_first",
    )

    if not isinstance(
        tensor_input,
        torch.Tensor,
    ):
        msg = "Expected a torch.Tensor."
        raise TypeError(msg)

    token_id_list = tensor_input[0].tolist()  # type: ignore - tensor_input can be subscripted
    tensor_input_decoded: list[str] = [
        tokenizer.convert_ids_to_tokens(int(single_token_id))
        for single_token_id in tensor_input[0]  # type: ignore - tensor_input can be subscripted
    ]

    (
        masked_input,
        labels,
    ) = repeat_tensor_input_and_apply_diagonal_mask(
        tensor_input=tensor_input,
        mask_token_id=mask_token_id,
    )

    # Move inputs and labels to the correct device.
    masked_input = masked_input.to(
        device=device,
    )
    labels = labels.to(
        device=device,
    )

    results_loss_list: list[float] = []

    if mlm_pseudoperplexity_granularity == MLMPseudoperplexityGranularity.SENTENCE:
        # We send the entire batch at once through the model to get a sentence-level loss.
        with torch.inference_mode():
            output = model(
                masked_input,
                labels=labels,
            )

            loss = output.loss

            results_loss_list.append(
                loss.cpu().item(),
            )
    elif mlm_pseudoperplexity_granularity == MLMPseudoperplexityGranularity.TOKEN:
        for masked_input_row, labels_row in zip(
            masked_input,
            labels,
            strict=True,
        ):
            masked_input_row_unsqueezed = masked_input_row.unsqueeze(0)
            labels_row_unsqueezed = labels_row.unsqueeze(0)

            with torch.inference_mode():
                output = model(
                    masked_input_row_unsqueezed,
                    labels=labels_row_unsqueezed,
                )

                loss = output.loss

                results_loss_list.append(
                    loss.cpu().item(),
                )
    else:
        msg = "Invalid value for `mlm_pseudoperplexity_mode`."
        raise ValueError(msg)

    # Concatenate 0.0 for the start and end token to the results list.
    results_loss_list_with_start_and_end = [
        0.0,
        *results_loss_list,
        0.0,
    ]

    sentence_perplexity_container = SentencePerplexityContainer(
        token_ids=token_id_list,
        token_strings=tensor_input_decoded,
        token_perplexities=results_loss_list_with_start_and_end,
    )

    return sentence_perplexity_container


def token_level_to_sentence_level_pseudoperplexity(
    loss: torch.Tensor,
) -> float:
    """Convert token-level loss to sentence-level loss."""
    return np.exp(loss.item())


def compute_perplexity_over_dataset(
    loaded_model_container: LoadedModelContainer,
    dataset: datasets.Dataset,
    column_name: str,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> PerplexityResultsList:
    """Compute the perplexity for each sentence in a dataset."""
    if loaded_model_container.lm_mode == LMmode.CLM:
        msg = "Perplexity computation not implemented for CLM yet."
        raise NotImplementedError(msg)

    results_list: PerplexityResultsList = []

    for index, single_entry in enumerate(
        tqdm(
            dataset,
            desc="Iterating over dataset",
        ),
    ):
        if not isinstance(
            single_entry,
            dict,
        ):
            msg = "Expected a dictionary."
            raise TypeError(msg)

        # Extract the sentence we want to compute the perplexity for.
        sentence = single_entry[column_name]

        result: SentencePerplexityContainer = pseudoperplexity_per_token_of_sentence(
            sentence=sentence,
            tokenizer=loaded_model_container.tokenizer,
            tokenizer_config=loaded_model_container.tokenizer_config,
            model=loaded_model_container.model,
            mlm_pseudoperplexity_granularity=MLMPseudoperplexityGranularity.TOKEN,  # type: ignore - Problems with StrEnum
            device=loaded_model_container.device,
            verbosity=verbosity,
            logger=logger,
        )

        results_list.append(
            (index, result),
        )

    if verbosity >= Verbosity.NORMAL:
        logger.info(
            "len(results_list):\n%s",
            len(results_list),
        )

    return results_list
