from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def get_jacobian(outputs, inputs, max_pool=False):
    gradients = []
    for i in range(outputs.size(1)):
        gradient = grad(
            [outputs[:, i].sum()], [inputs], retain_graph=True,
            create_graph=True)[0]
        if max_pool:
            gradient = F.max_pool2d(gradient, kernel_size=2, stride=2)
        gradients.append(gradient)
    true_jacobian = torch.stack(gradients, dim=-1)

    return true_jacobian


def bce_loss(labels, outputs, inputs=None, true_masks=None, l=0.5):
    bce_loss = nn.BCELoss()
    return bce_loss(outputs.reshape(-1), labels.to(torch.float32))


def mask_loss_binary(labels, outputs, inputs, true_masks=None, l=0.5, use_grad=True):
    loss = bce_loss(labels, outputs)
    if true_masks is None:
        return loss

    jacobian = get_jacobian(outputs, inputs)
    if len(jacobian.shape) > 3:
        # Input is multi-dimensional
        jacobian = jacobian.view(jacobian.shape[0], -1)

    assert jacobian.shape == true_masks.shape

    jacobian_softmax = F.log_softmax(jacobian, dim=1)
    mask_loss = torch.norm(jacobian_softmax[~true_masks], p=2).sum()
    if use_grad:
        loss = loss + l * mask_loss
    return [loss, mask_loss]


def feature_difference_loss_kl(net, inputs, masks):
    # pdb.set_trace()
    if masks.shape != inputs.shape:
        masks = masks.view(masks.shape[0], -1, inputs.shape[2], inputs.shape[3])
    if masks is None:
        return (net.features(inputs) * 0).sum()
    idx = masks.sum(dim=(1,2,3)) > 0
    inputs = inputs[idx]
    masks = masks[idx]
    curr_feat = net.endo_map(net.features(inputs)).flatten(start_dim=1)
    mask_feat = net.features(inputs * masks).flatten(start_dim=1)
    kl_loss = nn.KLDivLoss(reduction='batchmean', log_target=True)
    feat_loss = kl_loss(F.log_softmax(curr_feat, dim=1), F.log_softmax(mask_feat, dim=1))

    return feat_loss
