from __future__ import division

from ignite.metrics import Metric, MetricsLambda, Average

import torch

from experiments.utils.ignite_output import IgniteOutput


def PCorrect():
    def output_transform(output: IgniteOutput):
        probs = output.y_pred.softmax(dim=-1)
        correct_probs = torch.gather(probs, dim=1, index=output.y[:, None])
        return correct_probs

    metric = Average(output_transform=output_transform)
    return metric
