from typing import Callable, Tuple, Dict
import math
import torch
import numpy as np
from torch.utils.data import DataLoader
from .base_pt import TorchEvaluator
from accelerate.utils import tqdm


def pred_acc_lm(output, labels):
    logits = output["logits"][:, :-1, :]
    labels = labels[:, 1:]
    n_valid = (labels != -100).sum()
    loss = output["loss"].mean()
    accuracy = np.sum(np.argmax(logits, -1) == labels) / n_valid
    metrics = {
        "loss": loss,
        "bpc": loss / math.log(2),
        "ppl": np.exp(loss),
        "accuracy": accuracy,
    }
    return metrics


def pred_fast(output, labels):
    loss = output["loss"].mean()
    metrics = {
        "loss": loss,
        "bpc": loss / math.log(2),
        "ppl": np.exp(loss),
    }
    return metrics


class LanguageEvaluator(TorchEvaluator):
    def __init__(
        self,
        model,
        tokenizer,
        val_data,
        data_collator,
        config,
    ) -> None:
        super().__init__(val_data, data_collator, config)

        self._tokenizer = tokenizer
        self._model = model
        self._batchnorm = config.batchnorm
        self._batch_size = config.batch_size
        self._config = config
        self._total_samples = (
            len(self._val_loader)
            if self._config.eval_samples is None
            else self._config.eval_samples
        )

    def evaluate(
        self,
        trainer_eval_fn: Callable[[str, torch.tensor], Tuple[torch.tensor]],
        prefix="eval_",
        accelerator=None,
        **kwargs,
    ) -> Dict[str, torch.tensor]:
        """Iterate over validation data, get outputs from trainer eval
        and compute metrics.
        Decouple from trainer to add data-specific evaluation logic:
            - squad split in overlapping windows
            - language do generation from promts
        Args:
            trainer_eval_fn: Callable[Dict[str, np.array]] -> Tuple
                Function which places data to devices by trainer sharding.
                Contains platform specific model call. Outputs "labels", "model_output"
            prefix: str - used to rename metrics depending on eval/test data
            state: Trainer State - used for post eval
        """
        progress_bar = tqdm(
            range(self._total_samples), position=0, leave=True, initial=0
        )
        scores = {}
        it = 0
        step = 1 if accelerator is None else accelerator.num_processes
        for batch in self._val_loader:
            labels, output = trainer_eval_fn(batch)
            metrics = self.compute_metrics(output, labels)
            for k in metrics.keys():
                scores[k] = scores.get(k, 0) + metrics[k]
            progress_bar.update(step)
            it += step
            if it >= self._total_samples:
                break

        scores = {prefix + k: v / (it // step) for k, v in scores.items()}

        return scores

    def compute_metrics(self, output, labels):
        with torch.no_grad():
            return pred_fast(output, labels)  # pred_acc_lm(output, labels)
