# encoding: utf-8
import copy
import itertools
import logging
from collections import OrderedDict

import torch

from fastreid.utils import comm
from .evaluator import DatasetEvaluator

logger = logging.getLogger(__name__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class ClasEvaluator(DatasetEvaluator):
    def __init__(self, cfg, output_dir=None):
        self.cfg = cfg
        self._output_dir = output_dir
        self._cpu_device = torch.device('cpu')

        self._predictions = []

    def reset(self):
        self._predictions = []

    def process(self, inputs, outputs):
        pred_logits = outputs.to(self._cpu_device, torch.float32)
        labels = inputs["targets"].to(self._cpu_device)

        # measure accuracy
        acc1, = accuracy(pred_logits, labels, topk=(1,))
        num_correct_acc1 = acc1 * labels.size(0) / 100

        self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)})

    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            predictions = comm.gather(self._predictions, dst=0)
            predictions = list(itertools.chain(*predictions))

            if not comm.is_main_process(): return {}

        else:
            predictions = self._predictions

        total_correct_num = 0
        total_samples = 0
        for prediction in predictions:
            total_correct_num += prediction["num_correct"]
            total_samples += prediction["num_samples"]

        acc1 = total_correct_num / total_samples * 100

        self._results = OrderedDict()
        self._results["Acc@1"] = acc1
        self._results["metric"] = acc1

        return copy.deepcopy(self._results)
