import torch
import torch.nn as nn
import torch.nn.functional as F

# DCGAN loss
def loss_dcgan_dis(dis_fake, dis_real):
  L1 = torch.mean(F.softplus(-dis_real))
  L2 = torch.mean(F.softplus(dis_fake))
  return L1, L2


def loss_dcgan_gen(dis_fake):
  loss = torch.mean(F.softplus(-dis_fake))
  return loss


# Hinge Loss
def loss_hinge_dis(dis_fake, dis_real):
  loss_real = torch.mean(F.relu(1. - dis_real))
  loss_fake = torch.mean(F.relu(1. + dis_fake))
  return loss_real, loss_fake
# def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss
  # loss = torch.mean(F.relu(1. - dis_real))
  # loss += torch.mean(F.relu(1. + dis_fake))
  # return loss


def loss_hinge_gen(dis_fake):
  loss = -torch.mean(dis_fake)
  return loss

# Loss functions used in the SAN objective

def loss_san_hinge_dis(dis_fake_list, dis_real_list):
    dis_fake_fun, dis_fake_dir = dis_fake_list
    dis_real_fun, dis_real_dir = dis_real_list
    loss_real_fun = F.relu(1. - dis_real_fun)
    loss_fake_fun = F.relu(1. + dis_fake_fun)
    loss_real_dir = - dis_real_dir
    loss_fake_dir = dis_fake_dir
    loss_real = torch.mean(loss_real_fun + loss_real_dir)
    loss_fake = torch.mean(loss_fake_fun + loss_fake_dir)
    return loss_real, loss_fake

def loss_san_hinge_gen(dis_fake):
    loss = -torch.mean(dis_fake)
    return loss


# Default to hinge loss
# generator_loss = loss_hinge_gen
# discriminator_loss = loss_hinge_dis