from typing import Callable, Tuple, Dict
import math
import torch
import numpy as np
from torch.utils.data import DataLoader
from .base_pt import TorchEvaluator
from accelerate.utils import tqdm


class LavaPretrainEvaluator(TorchEvaluator):
    def __init__(
        self,
        model,
        val_data,
        data_collator,
        config,
    ) -> None:
        super().__init__(val_data, data_collator, config)

        self._model = model
        self._batchnorm = config.batchnorm
        self._batch_size = config.batch_size
        self._config = config
        self._total_samples = (
            len(self._val_loader)
            if self._config.eval_samples is None
            else self._config.eval_samples
        )

    def evaluate(
        self,
        trainer_eval_fn: Callable[[str, torch.tensor], Tuple[torch.tensor]],
        prefix="eval_",
        accelerator=None,
        **kwargs,
    ) -> Dict[str, torch.tensor]:
        """Iterate over validation data, get outputs from trainer eval
        and compute metrics.
        Decouple from trainer to add data-specific evaluation logic:
            - squad split in overlapping windows
            - language do generation from promts
        Args:
            trainer_eval_fn: Callable[Dict[str, np.array]] -> Tuple
                Function which places data to devices by trainer sharding.
                Contains platform specific model call. Outputs "labels", "model_output"
            prefix: str - used to rename metrics depending on eval/test data
            state: Trainer State - used for post eval
        """
        progress_bar = tqdm(
            range(self._total_samples), position=0, leave=True, initial=0
        )
        scores = {}
        it = 0
        step = 1 if accelerator is None else accelerator.num_processes
        for batch in self._val_loader:
            labels, output = trainer_eval_fn(batch)
            metrics = self.compute_metrics(output, labels)
            for k in metrics.keys():
                scores[k] = scores.get(k, 0) + metrics[k]
            progress_bar.update(step)
            it += step
            if it >= self._total_samples:
                break

        scores = {prefix + k: v / (it // step) for k, v in scores.items()}

        return scores

    def compute_metrics(self, output, labels):
        with torch.no_grad():
            loss = output["loss"].mean()
            metrics = {
                "loss": loss,
                "bpc": loss / math.log(2),
                "ppl": np.exp(loss),
            }
            return metrics
