class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, dim=-1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(1. / batch_size))
    if len(res) == 1:
        return res[0]
    return res


def equal_to(x1, x2, eta=1e-9):
    return x1 > x2 - eta and x1 < x2 + eta


if __name__ == "__main__":
    import numpy as np
    import torch
    pred = np.array([[0.1, 0.3, 0.6],
            [0.4, 0.45, 0.15],
            [0.2, 0.6, 0.1]])
    label = np.array([0, 1, 2])
    pred = torch.from_numpy(pred)
    label = torch.from_numpy(label)
    acc = accuracy(pred, label)
    print(acc)