import torch
from torch import nn
import torch.nn.functional as F
from .perceptual_loss import PerceptualLoss
from mmhug.models.custom_transformers.maetok.vit_maetok_decoder import ViTMaetokDecoder
from mmhug.registry import HF_MODELS

from .utils.gan_utils import LeCAM_EMA, adopt_weight, hinge_d_loss, hinge_gen_loss, lecam_reg, non_saturating_d_loss, non_saturating_gen_loss, vanilla_d_loss


from .modules.lpips import LPIPS
from .modules.discriminators import (
    PatchGANDiscriminator,
    PatchGANMaskBitDiscriminator,
    StyleGANDiscriminator,
    DinoDiscriminator,
)

from .utils.diff_aug import DiffAugment

@HF_MODELS.register_module()
class GANLoss(nn.Module):
    """
    A composite loss module for GAN training that combines pixel-wise, perceptual, and
    discriminator-based adversarial losses.

    Attributes:
        pixel_loss_cfg (dict): Configuration for the pixel-wise loss (type and weight).
        perceptual_loss_cfg (dict): Configuration for the perceptual (VGG) loss,
            including a warmup period during which it is disabled.
        disc_loss_cfg (dict): Configuration for the discriminator loss, including
            when to start using it, augmentation options, and various weighting factors.
        pixel_loss (nn.Module): Instantiated pixel-wise loss module.
        perceptual_loss (nn.Module): Instantiated perceptual loss module (e.g., VGG-based).
        disc_loss (nn.Module): Instantiated adversarial loss module (e.g., hinge loss).
    """

    def __init__(
        self,
        pixel_loss=dict(
            type="l2",
            weight=1.0,
        ),
        perceptual_loss=dict(
            warmup=10000,
            type="PerceptualLoss",
            norm_img=True,
            use_input_norm=True,
            layer_weights={
                '2': 0.1,
                '7': 0.1,
                '16': 1.0,
                '25': 1.0,
                '34': 1.0,
            },
            vgg_type='vgg19',
            perceptual_weight=1.0,
            style_weight=0,
        ),
        disc_loss=dict(
            disc_start=30000,            # iteration to start using discriminator loss
            use_diff_aug=True,           # apply differentiable data augmentation
            discriminator="dino",        # which architecture to use for D
            disc_loss_type="hinge",      # type of adversarial loss (hinge or vanilla)
            disc_adaptive_weight=True,   # adaptively scale adversarial weight
            disc_weight=0.4,             # base weight for discriminator loss
            disc_cr_loss_weight=4.0,     # weight for consistency regularization
            disc_in_channels=3,          # number of input channels to discriminator
            image_size=(256, 256),       # expected spatial size
            disc_num_layers=3,           # how many conv layers in discriminator
            disc_dim=64,                 # base channel dimension for discriminator
            gen_adv_loss_type="hinge",   # adversarial loss type used by generator
            lecam_weight=0.001           # weight for LeCam regularization if used
        ),
    ):
        """
        Initialize the GANLoss module, constructing sub-losses per configuration.

        Args:
            pixel_loss (dict): Configuration dict for pixel-wise loss:
                - type (str): "l1" or "l2".
                - weight (float): scaling factor.
            perceptual_loss (dict): Configuration dict for perceptual loss:
                - type (str): e.g. "vgg".
                - weight (float): scaling factor after warmup.
                - warmup (int): number of iterations to skip perceptual loss.
            disc_loss (dict): Configuration dict for adversarial loss:
                - disc_start (int): iteration index to enable adversarial loss.
                - use_diff_aug (bool): whether to apply diff. augmentation.
                - discriminator (str): name of discriminator backbone.
                - disc_loss_type (str): "hinge" or "bce".
                - disc_adaptive_weight (bool): adaptively adjust loss weight.
                - disc_weight (float): base weight for adversarial term.
                - disc_cr_loss_weight (float): weight for consistency regularization.
                - disc_in_channels (int): channels for D input.
                - image_size (tuple): H×W for D input.
                - disc_num_layers (int): depth of discriminator.
                - disc_dim (int): base channels for D.
                - gen_adv_loss_type (str): adv loss type for generator.
                - lecam_weight (float): LeCam regularization weight.
        """
        super().__init__()

        # Store the raw configs so they can be logged or modified later
        self.pixel_loss_cfg = pixel_loss
        self.perceptual_loss_cfg = perceptual_loss
        self.disc_loss_cfg = disc_loss

        # Build each individual loss module
        self.build_pixel_loss()
        self.build_perceptual_loss()
        self.build_disc_loss()

    def train(self, mode: bool = True):
        """
        Override the default train() to ensure that the perceptual network
        always stays in evaluation mode (no running-stat updates, no dropout)
        when in training or evaluation.

        Args:
            mode (bool): Ignored for perceptual sub-module; only applied to this class.
        """
        # Set this module to train or eval (will affect pixel & disc losses if they use BatchNorm/Dropout)
        super().train(mode)

        if self.use_perceptual_loss:
            # Freeze and set perceptual net to eval to avoid BatchNorm updates / dropout
            self.perceptual_loss.eval()
            # Ensure no gradients flow into perceptual net parameters
            self.perceptual_loss.requires_grad_(False)

    def calculate_adaptive_weight(self, other_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor):
        """ Calculate adaptive weight for discriminator loss w.r.t other_loss(pixel loss and perceptual loss)
        Args:
            other_loss: other loss terms(pixel loss and perceptual loss)
            g_loss: generator loss
            last_layer: last layer of generator
        """
        other_grads = torch.autograd.grad(other_loss, last_layer, retain_graph=True, allow_unused=True)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True, allow_unused=True)[0]

        d_weight = torch.norm(other_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()

        return d_weight.detach()

    def compute_loss_discriminator(
        self,
        pred_img: torch.Tensor,
        gt_img: torch.Tensor,
        cur_step: int,
    ) -> dict:
        """
        Compute the discriminator loss including adversarial, LeCam, and consistency losses.

        Args:
            pred_img (torch.Tensor): Fake/generated images (B, C, H, W).
            gt_img (torch.Tensor): Ground-truth/real images (B, C, H, W).
            cur_step (int): Current iteration for warmup scheduling.

        Returns:
            dict: {'discriminator_loss': scalar tensor}.
        """

        # Detach real and fake images to avoid gradient tracking
        real_input = gt_img.detach()
        fake_input = pred_img.detach()

        # Apply differentiable augmentation if enabled
        if self.use_diff_aug:
            real_input = DiffAugment(real_input, policy="color,translation,cutout_0.2", prob=0.5)
            fake_input = DiffAugment(fake_input, policy="color,translation,cutout_0.2", prob=0.5)

        # Compute logits for real and fake images
        logits_real = self.discriminator(real_input)
        logits_fake = self.discriminator(fake_input)

        # Apply dynamic discriminator loss weight (e.g. ramp-up)
        disc_weight = adopt_weight(
            self.disc_weight, cur_step, self.discriminator_iter_start
        )

        # LeCam regularization term (optional)
        if self.use_lecam:
            self.lecam_ema.update(logits_real, logits_fake)
            lecam_loss = lecam_reg(logits_real, logits_fake, self.lecam_ema)
            d_loss = self.disc_loss(logits_real, logits_fake) + self.lecam_loss_weight * lecam_loss
        else:
            d_loss = self.disc_loss(logits_real, logits_fake)

        # Total discriminator loss with scaling
        discriminator_loss = disc_weight * d_loss

        # Consistency Regularization Loss (CR Loss)
        if self.disc_cr_loss_weight > 0.0 and cur_step >= self.discriminator_iter_start:
            # Stronger augmentation for CR
            real_aug = DiffAugment(gt_img.detach(), policy='color,translation,cutout_0.5', prob=1.0)
            fake_aug = DiffAugment(pred_img.detach(), policy='color,translation,cutout_0.5', prob=1.0)

            # Forward through D
            logits_real_s = self.discriminator(real_aug)
            logits_fake_s = self.discriminator(fake_aug)

            # MSE between original logits and augmented logits
            cr_loss = F.mse_loss(
                torch.cat([logits_real, logits_fake], dim=0),
                torch.cat([logits_real_s, logits_fake_s], dim=0)
            )

            discriminator_loss += self.disc_cr_loss_weight * cr_loss

        return {
            "discriminator_loss": discriminator_loss,
        }

    def compute_loss_generator(
        self,
        pred_img: torch.Tensor,
        gt_img: torch.Tensor,
        cur_step: int,
        last_layer: nn.Module = None,
    ) -> dict:
        """
        Compute the generator’s total training loss, consisting of:
        1. Pixel reconstruction (L1/L2) loss
        2. (Optional) Perceptual (LPIPS) loss with warmup schedule
        3. (Optional) Adversarial loss, adaptively weighted and with a start threshold

        Args:
            pred_img (Tensor): Generated images, shape (B, C, H, W).
                Must be normalized to [-1, 1] (i.e. mean=[0.5], std=[0.5]).
            gt_img (Tensor): Ground-truth images, same shape and normalization.
            cur_step (int): Current training iteration index.
            last_layer (nn.Module, optional): The final layer of your generator,
                used to compute the adaptive weight between perception/pixel and adversarial
                losses. If `disc_adaptive_weight=False`, this can be left as None.

        Returns:
            Dict[str, Tensor]:  
            - 'pixel_loss': reconstruction term  
            - 'perceptual_loss': LPIPS term (0 before warmup or if disabled)  
            - 'generator_adv_loss': adversarial term (0 if disabled)
        """

        # 1. Pixel Reconstruction Loss
        #    Multiply base pixel_loss (L1/L2) by its configured scalar.
        pixel_recon = self.pixel_loss(pred_img, gt_img)
        pixel_loss = pixel_recon * self.pixel_weights

        # 2. Perceptual (LPIPS) Loss with Warmup
        if self.use_perceptual_loss:
            # Raw LPIPS value (mean across batch)
            p_loss, sytle_loss = self.perceptual_loss(pred_img, gt_img)

            # Compute warmup coefficient: ramps from 0→1 over self.perceptual_warmup steps
            if self.perceptual_warmup > 0:
                alpha = min(1.0, cur_step / float(self.perceptual_warmup))
            else:
                alpha = 1.0
            if p_loss is not None:
                p_loss = alpha * p_loss
            else:
                p_loss = None
            if sytle_loss is not None:
                sytle_loss = alpha * sytle_loss
            else:
                sytle_loss = None
        else:
            # Disabled → zero
            p_loss = None
            sytle_loss = None

        # 3. Generator Adversarial Loss
        if self.use_adv_loss:
            # Optionally apply DiffAugment (no detach here: we want grad back to generator)
            fake_for_d = pred_img
            if self.use_diff_aug:
                fake_for_d = DiffAugment(fake_for_d, policy="color,translation,cutout_0.2", prob=0.5)

            # Forward through discriminator (frozen during generator update)
            logits_fake = self.discriminator(fake_for_d)

            # Base adversarial loss (e.g. hinge on fake)
            adv_base = self.gen_adv_loss(logits_fake)

            # Adaptive weight between (pixel+perc) and adversarial terms
            if self.disc_adaptive_weight and last_layer is not None:
                # Estimate gradient magnitudes to balance losses
                other = pixel_loss
                if p_loss is not None:
                    other += p_loss
                if sytle_loss is not None:
                    other += sytle_loss
                adapt_w = self.calculate_adaptive_weight(other, adv_base, last_layer=last_layer)
            else:
                adapt_w = 1.0

            # Ramp-up adversarial weight after discriminator warmup
            adv_weight = adopt_weight(self.disc_weight, cur_step, self.discriminator_iter_start)

            generator_adv_loss = adapt_w * adv_weight * adv_base
        else:
            generator_adv_loss = None

        loss_dict = {
            "pixel_loss":        pixel_loss,
        }
        if p_loss is not None:
            loss_dict["perceptual_loss"] = p_loss
        if sytle_loss is not None:
            loss_dict["style_loss"] = sytle_loss
        if generator_adv_loss is not None:
            loss_dict["generator_adv_loss"] = generator_adv_loss
        return loss_dict
    def build_pixel_loss(self):
        pixel_loss_type = self.pixel_loss_cfg['type']
        self.pixel_weights = self.pixel_loss_cfg.get('weight', 1)
        if pixel_loss_type == "l2":
            self.pixel_loss = F.mse_loss
        elif pixel_loss_type == "l1":
            self.pixel_loss = F.l1_loss
        else:
            raise ValueError(f"pixel_loss_type '{pixel_loss_type}' not supported")
    
    def build_perceptual_loss(self):
        self.use_perceptual_loss = True
        if self.perceptual_loss_cfg is None:
            self.use_perceptual_loss = False
            return
        self.perceptual_warmup = self.perceptual_loss_cfg.pop('warmup', 0)
        self.perceptual_loss: PerceptualLoss = HF_MODELS.build(self.perceptual_loss_cfg)
        self.perceptual_loss.eval()
        self.perceptual_loss.requires_grad_(False)
    
    def build_disc_loss(self):
        self.use_adv_loss = True
        # discriminator loss
        if self.disc_loss_cfg is None:
            self.use_adv_loss = False
            return
        
        disc_weight = self.disc_loss_cfg.get('disc_weight', 0)
        if disc_weight <= 0:
            self.use_adv_loss = False
            return
        
        disc_start = self.disc_loss_cfg['disc_start']
        disc_adaptive_weight = self.disc_loss_cfg['disc_adaptive_weight']
        gen_adv_loss = self.disc_loss_cfg['gen_adv_loss_type']
        lecam_loss_weight = self.disc_loss_cfg['lecam_weight']
        use_diff_aug = self.disc_loss_cfg['use_diff_aug']
        
        discriminator_type = self.disc_loss_cfg['discriminator']
        disc_loss_type = self.disc_loss_cfg['disc_loss_type']
        disc_in_channels = self.disc_loss_cfg['disc_in_channels']
        disc_num_layers = self.disc_loss_cfg['disc_num_layers']
        disc_dim = self.disc_loss_cfg['disc_dim']
        image_size = self.disc_loss_cfg['image_size']

        self.use_diff_aug = use_diff_aug
        
        assert discriminator_type in ["patchgan", "stylegan", "maskbit", "dino"]
        assert disc_loss_type in ["hinge", "vanilla", "non-saturating"]
        if discriminator_type == "patchgan":
            self.discriminator = PatchGANDiscriminator(
                input_nc=disc_in_channels, 
                n_layers=disc_num_layers,
                ndf=disc_dim,
            )
        elif discriminator_type == "stylegan":
            self.discriminator = StyleGANDiscriminator(
                input_nc=disc_in_channels, 
                image_size=image_size,
            )
        elif discriminator_type == "maskbit":
            self.discriminator = PatchGANMaskBitDiscriminator(
                input_nc=disc_in_channels, 
                n_layers=disc_num_layers,
                ndf=disc_dim,
            )
        elif discriminator_type == "dino":
            self.discriminator = DinoDiscriminator()
        else:
            raise ValueError(f"Unknown GAN discriminator type '{discriminator_type}'.")
        
        disc_loss_type = self.disc_loss_cfg['disc_loss_type']
        if disc_loss_type == "hinge":
            self.disc_loss = hinge_d_loss
        elif disc_loss_type == "vanilla":
            self.disc_loss = vanilla_d_loss
        elif disc_loss_type == "non-saturating":
            self.disc_loss = non_saturating_d_loss
        else:
            raise ValueError(f"Unknown GAN discriminator loss '{disc_loss_type}'.")
        self.discriminator_iter_start = disc_start
        self.disc_weight = disc_weight
        self.disc_adaptive_weight = disc_adaptive_weight

        assert gen_adv_loss in ["hinge", "non-saturating"]
        # gen_adv_loss
        if gen_adv_loss == "hinge":
            self.gen_adv_loss = hinge_gen_loss
        elif gen_adv_loss == "non-saturating":
            self.gen_adv_loss = non_saturating_gen_loss
        else:
            raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
        
        self.lecam_loss_weight = lecam_loss_weight
        self.use_lecam = lecam_loss_weight is not None and lecam_loss_weight > 0
        if self.use_lecam:
            self.lecam_ema = LeCAM_EMA()

        self.disc_cr_loss_weight = self.disc_loss_cfg.get('disc_cr_loss_weight', 0.)

    def forward(
        self,
        pred_img: torch.Tensor,
        gt_img: torch.Tensor,
        cur_step: int,
        is_discriminator: bool = False,
        last_layer: nn.Module = None,
    ) -> dict:
        """
        Args:
            pred_img: predicted image. Shape: (B, C, H, W). In range 
            gt_img: ground truth image. Shape: (B, C, H, W)
            cur_step: current training step
            is_discriminator: whether to compute discriminator loss
            last_layer: last layer of generator
        """
        if is_discriminator:
            return self.compute_loss_discriminator(
                pred_img=pred_img,
                gt_img=gt_img,
                cur_step=cur_step,
            )
        else:
            loss_gan: dict = self.compute_loss_generator(
                pred_img=pred_img,
                gt_img=gt_img,
                cur_step=cur_step,
                last_layer=last_layer
            )
            return loss_gan

if __name__ == "__main__":
    from mmengine.device import get_device
    # quick sanity check
    device = get_device()
    latents = torch.randn((2, 129, 1024)).to(device)
    generator = ViTMaetokDecoder(to_pixel="linear").to(device).train()
    loss_module = GANLoss().to(device)
    fake = generator(latents)
    real = torch.randn((2, 3, 256, 256)).to(device)
    cur_step = 100000
    out_gen = loss_module(fake, real, cur_step, False, last_layer=generator.last_layer)
    print(out_gen)
    out_disc = loss_module(fake, real, cur_step, True)
    print(out_disc)
