from typing import Callable, Tuple, Dict
from tqdm import tqdm
import math
from jax import numpy as jnp
from torch.utils.data import DataLoader
import jax
from flax import linen as nn
from .base import Evaluator
from .losses import cross_entropy_loss_lm


def pred_acc_lm(output, labels):
    loss = output["loss"]
    metrics = {
        "loss": loss,
        "bpc": loss / math.log(2),
        "ppl": jnp.exp(loss),
    }
    return metrics


class LavaPretrainEvaluator(Evaluator):
    def __init__(
        self,
        model,
        val_data,
        data_collator,
        config,
    ) -> None:
        super().__init__(val_data, data_collator, config)

        self._model = model
        self._batchnorm = config.batchnorm
        self._batch_size = config.batch_size
        self._config = config
        self._total_samples = (
            len(self._val_loader)
            if self._config.eval_samples is None
            else self._config.eval_samples
        )
        assert self._total_samples <= len(self._val_loader)

    def evaluate(
        self,
        trainer_eval_fn: Callable[[str, jax.Array], Tuple[jax.Array]],
        prefix="eval_",
        **kwargs,
    ) -> Dict[str, jax.Array]:
        """Iterate over validation data, get outputs from trainer eval
        and compute metrics.
        Decouple from trainer to add data-specific evaluation logic:
            - squad split in overlapping windows
            - language do generation from promts
        Args:
            trainer_eval_fn: Callable[Dict[str, np.array]] -> Tuple
                Function which places data to devices by trainer sharding.
                Contains platform specific model call. Outputs "labels", "model_output"
            prefix: str - used to rename metrics depending on eval/test data
            state: Trainer State - used for post eval
        """
        scores = {}

        progress_bar = tqdm(
            range(self._total_samples), position=0, leave=True, initial=0
        )
        it = 0
        val_iter = iter(self._val_loader)
        while it < self._total_samples:
            it += 1
            batch = next(val_iter)
            labels, output = trainer_eval_fn(batch)
            metrics = self.compute_metrics(output, labels)
            for k in metrics.keys():
                scores[k] = scores.get(k, 0) + metrics[k]
            progress_bar.update(1)

        scores = {prefix + k: v / self._total_samples for k, v in scores.items()}

        return scores

    def compute_metrics(self, output, labels):
        return pred_acc_lm(output, labels)
