import torch
from torch import nn as nn
from torch.distributions import Bernoulli, Categorical


class HLoss(nn.Module):
    """ Entropy Loss """

    def __init__(self, mode="max"):
        """ mode is either max or min depending on what you would like to achieve. """
        super(HLoss, self).__init__()
        self.mode = mode

    def forward(self, x, t=None):
        h = torch.softmax(x, dim=1) * torch.log_softmax(x, dim=1)
        h = -h.sum(dim=1).mean()

        if self.mode == "max":
            return -h
        elif self.mode == "min":
            return h
        else:
            raise ValueError("HLoss optimization mode %s not supported!" % self.mode)


class SoftTargetLoss(torch.nn.Module):
    def __init__(self):
        super(SoftTargetLoss, self).__init__()

    def forward(self, input, target):
        y_pred, y_true = input, target
        return -torch.sum(y_true * torch.log_softmax(y_pred, dim=-1), dim=-1).mean()


class CamSoftmaxCELoss(torch.nn.Module):

    def __init__(self):
        super(CamSoftmaxCELoss, self).__init__()

    def forward(self, input, target):
        # split into u/d and l/r
        c = input.shape[1] // 2
        input_ud = input[:, :c]
        input_lr = input[:, c:]
        target_ud = target[:, :c]
        target_lr = target[:, c:]

        loss_ud = -torch.sum(target_ud * torch.log_softmax(input_ud, dim=-1), dim=-1).mean()
        loss_lr = -torch.sum(target_lr * torch.log_softmax(input_lr, dim=-1), dim=-1).mean()

        return loss_ud + loss_lr


class CamEntropyLoss(torch.nn.Module):

    def __init__(self):
        super(CamEntropyLoss, self).__init__()

    def forward(self, input, target):
        # split into u/d and l/r
        c = input.shape[1] // 2
        input_ud = input[:, :c]
        input_lr = input[:, c:]

        loss_ud = Categorical(probs=torch.log_softmax(input_ud, dim=-1)).entropy().mean()
        loss_lr = Categorical(probs=torch.log_softmax(input_lr, dim=-1)).entropy().mean()

        return -(loss_ud + loss_lr)


class MultiStepCamSoftmaxCELoss(CamSoftmaxCELoss):

    def __init__(self):
        super(MultiStepCamSoftmaxCELoss, self).__init__()

    def forward(self, input, target):

        # reshape step dimension into batch
        input = input.reshape((-1, input.shape[-1]))
        target = target.reshape((-1, target.shape[-1]))

        return super().forward(input, target)
