import copy

import torch


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
    return res


def load_weights(model, params, grads=None):
    if grads is None:
        for mp, p in zip(model.parameters(), params):
            mp.data = copy.deepcopy(p.data)
    else:
        for mp, p, g in zip(model.parameters(), params, grads):
            mp.data = copy.deepcopy(p.data)
            mp.grad.data = copy.deepcopy(g)


def get_param_dim(model):
    dim = 0  # get parameter dimension
    for p in model.parameters():
        dim += p.numel()
    return dim


def zero_grad(model):
    """
    Zeros out the gradient of each parameter in the model
    """
    for p in model.parameters():
        if p.grad is not None:
            p.grad.data.zero_()


def get_device(model):
    return list(model.parameters())[0].device
