# Max-Heinrich Laves
# Institute of Mechatronic Systems
# Leibniz Universität Hannover, Germany
# 2019

import torch
from PIL import ImageFilter

__all__ = ['accuracy', 'kl_loss', 'nentr', 'xavier_normal_init', 'gaussian_dropout', 'Sobel']


def accuracy(input, target):
    _, max_indices = torch.max(input.data, 1)
    acc = (max_indices == target).sum().float() / max_indices.size(0)
    return acc.item()


def kl_loss(logits):
    return -torch.nn.functional.log_softmax(logits, dim=1).mean()


def nentr(p, base=None):
    """
    Calculates entropy of p to the base b. If base is None, the natural logarithm is used.
    :param p: batches of class label probability distributions (softmax output)
    :param base: base b
    :return:
    """
    eps = torch.tensor([1e-16], device=p.device)
    if base:
        base = torch.tensor([base], device=p.device, dtype=torch.float32)
        return (p.mul(p.add(eps).log().div(base.log()))).sum(dim=1).abs()
    else:
        return (p.mul(p.add(eps).log())).sum(dim=1).abs()


def xavier_normal_init(m):
    if isinstance(m, torch.nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight.data)


def gaussian_dropout(x, mu, p, layer):
    if type(layer) == torch.nn.modules.linear.Linear:
        return gaussian_dropout_linear(x, mu, p, layer)
    elif type(layer) == torch.nn.modules.conv.Conv2d:
        return gaussian_dropout_conv2d(x, mu, p, layer)
    else:
        assert False


def gaussian_dropout_linear(x, mu, p, layer):
    """
    sigma^2 = p/(1-p) + sum(W^2*x^2)
    see: Appendix B in Kingma-Variational Dropout and the Local Reparameterization Trick-2015
    """
    assert type(layer) == torch.nn.modules.linear.Linear
    sigma = torch.nn.functional.linear(x ** 2, layer.weight.data ** 2, bias=None)
    sigma = (p / (1 - p) * sigma).sqrt()
    eps = torch.randn_like(mu)
    return mu + sigma * eps


def gaussian_dropout_conv2d(x, mu, p, prev_layer):
    """
    sigma^2 = p/(1-p) + sum(W^2*x^2)
    see: Appendix B in Kingma-Variational Dropout and the Local Reparameterization Trick-2015
    """
    assert type(prev_layer) == torch.nn.modules.conv.Conv2d
    sigma = torch.nn.functional.conv2d(x ** 2, prev_layer.weight.data ** 2,
                                       bias=None, padding=prev_layer.padding, stride=prev_layer.stride)
    sigma = (p / (1 - p) * sigma).sqrt()
    eps = torch.randn_like(mu)
    return mu + sigma * eps


class Sobel(object):
    def __call__(self, img):
        return img.filter(ImageFilter.FIND_EDGES).crop((1, 1, 33, 33))


class DataSetWithIndex(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx], torch.tensor([idx])
