import torch

from src.metrics.metric import Metric
from src.settings.metric import ElementWiseScores, MetricResults, MetricType


@Metric.register(MetricType.LOGP_PERPLEXITY)
class LogpPerplexityMetric(Metric):
    def compute(self, **kwargs) -> list[MetricResults]:

        # logprob[0] shape [1, seq_len]
        logprobs: list[torch.Tensor] = kwargs.get('logprobs', None)
        dataset_name: str = kwargs.get('dataset_name', '')

        if logprobs is None:
            raise ValueError('logprobs should not be None')

        perplexity_values = []
        for answer_logp in logprobs:
            perplexity_values.extend(self.calculate_logp_perplexity(answer_logp))

        return [
            MetricResults(
                element_wise_scores=[
                    ElementWiseScores(
                        label=dataset_name + '@@' + 'perplexity',
                        values=perplexity_values,
                    )
                ],
                need_average=need_average,
            )
            for need_average in self._settings.need_average
        ]

    @staticmethod
    def calculate_logp_perplexity(logprobs: torch.Tensor) -> list[float]:
        ppl = -logprobs.mean(-1)
        return torch.exp(ppl).detach().cpu().tolist()
