import torch
from torch import nn

from .utils.gan_utils import LeCAM_EMA, adopt_weight, hinge_d_loss, hinge_gen_loss, lecam_reg, non_saturating_d_loss, non_saturating_gen_loss, vanilla_d_loss

from mmengine import ConfigDict
from .modules.lpips import LPIPS
from .modules.discriminators import (
    PatchGANDiscriminator,
    PatchGANMaskBitDiscriminator,
    StyleGANDiscriminator,
    DinoDiscriminator,
)

from mmhug.registry import HF_MODELS
from .utils.diff_aug import DiffAugment

class MaetokLoss(nn.Module):
    def __init__(
        self,
        gan_loss: ConfigDict
    ):
        super().__init__()
        self.gan_loss = HF_MODELS.build(gan_loss)
    def forward(
        self,
        pred_img: torch.Tensor,
        gt_img: torch.Tensor,
        cur_step: int,
        is_discriminator: bool,
        last_layer: nn.Module = None,
        pred_dino: torch.Tensor = None,
        pred_clip: torch.Tensor = None,
        pred_hog: torch.Tensor = None,
    ) -> dict:
        """
        Args:
            pred_img: predicted image. Shape: (B, C, H, W). In range 
            gt_img: ground truth image. Shape: (B, C, H, W)
            cur_step: current training step
            is_discriminator: whether to compute discriminator loss
            last_layer: last layer of generator
        """
        if is_discriminator:
            return self.compute_loss_discriminator(
                pred_img=pred_img,
                gt_img=gt_img,
                cur_step=cur_step,
            )
        else:
            loss_gan: dict = self.compute_loss_generator(
                pred_img=pred_img,
                gt_img=gt_img,
                cur_step=cur_step,
                last_layer=last_layer
            )
