import torch
from torch import nn
import torch.nn.functional as F

def adopt_weight(weight, global_step, threshold=0, value=0.0):
    """For GAN training, discriminator is only used after threshold iterations.

    Args:
        weight (float): The weight to be updated.
        global_step (int): The current global step.
        threshold (int): The threshold to determine when to update the weight.
        value (float): The value to update the weight to.

    Returns:
        float: The updated weight.
    """
    if global_step < threshold:
        weight = value
    return weight


class LeCAM_EMA(object):
    def __init__(self, init=0.0, decay=0.999):
        self.logits_real_ema = init
        self.logits_fake_ema = init
        self.decay = decay

    def update(self, logits_real, logits_fake):
        self.logits_real_ema = self.logits_real_ema * self.decay + torch.mean(
            logits_real
        ).item() * (1 - self.decay)
        self.logits_fake_ema = self.logits_fake_ema * self.decay + torch.mean(
            logits_fake
        ).item() * (1 - self.decay)


def lecam_reg(real_pred, fake_pred, lecam_ema):
    reg = torch.mean(F.relu(real_pred - lecam_ema.logits_fake_ema).pow(2)) + torch.mean(
        F.relu(lecam_ema.logits_real_ema - fake_pred).pow(2)
    )
    return reg


def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1.0 - logits_real))
    loss_fake = torch.mean(F.relu(1.0 + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def vanilla_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.softplus(-logits_real))
    loss_fake = torch.mean(F.softplus(logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def non_saturating_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(
        F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real)
    )
    loss_fake = torch.mean(
        F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake)
    )
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def hinge_gen_loss(logit_fake):
    return -torch.mean(logit_fake)


def non_saturating_gen_loss(logit_fake):
    return torch.mean(
        F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake)
    )
