import torch

from src.metrics.metric import Metric
from src.metrics.registry import LogpKLSettings
from src.metrics.utils import logprobs_from_logits
from src.settings.metric import ElementWiseScores, MetricResults, MetricType


@Metric.register(MetricType.LOGPKL)
class LOGPKLMetric(Metric):
    def __init__(self, settings: LogpKLSettings) -> None:
        super().__init__(settings=settings)
        self._settings: LogpKLSettings = settings

    def compute(self, **kwargs) -> list[MetricResults]:

        # logprob[0] shape [1, seq_len]
        logprobs: list[torch.Tensor] = kwargs.get('logprobs', None)
        metrics_kwargs = kwargs.get('metrics_kwargs', {})
        ref_logprobs: list[torch.Tensor] = metrics_kwargs.get(self._settings.ref_logprobs_type, None)
        dataset_name: str = kwargs.get('dataset_name', '')

        if logprobs is None:
            raise ValueError('logprobs should not be None')
        if ref_logprobs is None:
            raise ValueError(f'{self._settings.ref_logprobs_type} should not be None')

        element_wise_diversity_scores = [
            ElementWiseScores(
                label=dataset_name + '@@' + f'kl_with_{self._settings.ref_logprobs_type}',
                values=[(logprob - ref_logprob).mean().item() for logprob, ref_logprob in zip(logprobs, ref_logprobs)],
            )
        ]

        return [
            MetricResults(element_wise_scores=element_wise_diversity_scores, need_average=need_average)
            for need_average in self._settings.need_average
        ]
