from torch.autograd import Function
import torch.nn as nn
import torch
from torch.nn import functional as F

class EMLoss(nn.Module):

	def __init__(self):

		super(EMLoss, self).__init__()

	def forward(self, real, fake):
		b = F.softmax(real, dim=1) * F.log_softmax(fake, dim=1)
		b = -1.0 * b.sum(dim=1).mean()

		return b

class HLoss(nn.Module):

	def __init__(self):

		super(HLoss, self).__init__()

	def forward(self, x):
		b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
		b = -1.0 * b.sum(dim=1).mean()

		return b

class MSE(nn.Module):

	def __init__(self):

		super(MSE, self).__init__()

	def forward(self, pred, real):

		diffs = torch.add(real, -pred)
		n = torch.numel(diffs.data)
		mse = torch.sum(diffs.pow(2)) / n

		return mse

class SIMSE(nn.Module):

	def __init__(self):

		super(SIMSE, self).__init__()

	def forward(self, pred, real):
		diffs = torch.add(real, - pred)
		n = torch.numel(diffs.data)
		simse = torch.sum(diffs).pow(2) / (n ** 2)
		return simse


class SparseParam(nn.Module):
	def __init__(self):
		super(SparseParam, self).__init__()

	def forward(self, maskI, maskS):
		sparse_pen = torch.sum(nn.Sigmoid()(maskI)) + torch.sum(nn.Sigmoid()(maskS))

		return sparse_pen

class OverlapMask(nn.Module):
	def __init__(self):
		super(OverlapMask, self).__init__()

	def forward(self, featI, featS):
		featI = nn.Sigmoid()(featI)
		featS = nn.Sigmoid()(featS)
		intersection = torch.sum(featI * featS)
		union = torch.sum(featI + featS - featI * featS)
		score = (intersection + 1e-6) / (union + 1e-6)

		return score

