import torch
import torch.nn.functional as F


def accuracy(output, target, topk=(1,)):

	maxk = max(topk)
	batch_size = target.size(0)
	_, pred = output.topk(maxk, 1, True, True)
	pred = pred.t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))
	res = []
	for k in topk:
		correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
		res.append(correct_k.mul_(100.0 / batch_size))
	return res


def cross_entropy(output, target, n_classes):
	target = F.one_hot(target, n_classes).to(dtype=torch.float) if \
        len(target.shape)==1 else target
	ce = -torch.sum(target * torch.log(output), axis=1).mean()
	return ce
	
