from tqdm import tqdm
import torch.nn.functional as F
import torch
from . import metrics


class Evaluator(object):
    def __init__(self, metric, dataloader, device):
        self.dataloader = dataloader
        self.metric = metric
        self.device = device

    def eval(self, model, progress=True):
        self.metric.reset()
        with torch.no_grad():
            for i, (inputs, targets) in enumerate(
                tqdm(self.dataloader, disable=not progress)
            ):
                inputs, targets = inputs.cuda(self.device), targets.cuda(self.device)
                outputs = model(inputs)
                self.metric.update(outputs, targets)
        return self.metric.get_results()

    def __call__(self, *args, **kwargs):
        return self.eval(*args, **kwargs)


def classification_evaluator(dataloader, device):
    metric = metrics.MetricCompose(
        {
            "Acc": metrics.TopkAccuracy(topk=(1, 5)),
            "Loss": metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction="sum")),
        }
    )
    return Evaluator(metric, dataloader, device)


def segmentation_evaluator(dataloader, num_classes, ignore_idx=255):
    cm = metrics.ConfusionMatrix(num_classes, ignore_idx=ignore_idx)
    metric = metrics.MetricCompose(
        {
            "mIoU": metrics.mIoU(cm),
            "Acc": metrics.Accuracy(),
            "Loss": metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction="sum")),
        }
    )
    return Evaluator(metric, dataloader=dataloader)
