from typing import Any, Optional

import lpips as lpips_lib
import torch
import torch.nn as nn
import torch.nn.functional as F

from dae.utils.generic_utils import ModulesRegister, TaskState
from dae.utils.torch_utils import Frozen, Ref, ensure_1d_tokens, unwrap
from dae.utils.train_utils import aggregate_losses

from ..blocks.discriminator import NLayerDiscriminator
from ..encoders.pretrained import PreTrainedEncoder

AUX_LOSSES = ModulesRegister("AUX_LOSSES", lower=True)


class REPALoss(nn.Module):
    # https://arxiv.org/pdf/2410.06940#page=23

    def __init__(self, unet, model="dinov2_base", n_layers=2, i_extract=4, cache_dir=None, accelerator=None):
        super().__init__()
        if hasattr(unet, "mid_dim"):
            features_dim = unet.mid_dim
        else:
            features_dim = unet.mid_block.attention.inner_dim

        self.features_extractor = Frozen(
            PreTrainedEncoder(model, cache_dir=cache_dir, freeze=True, drop_cls=True),
            accelerator,
            allow_grad=False,
        )

        # Create features MLP
        self.repa_mlp = nn.Sequential()
        self.repa_loss = nn.CosineSimilarity(dim=2, eps=1e-5)
        for i in range(n_layers):
            in_dim = features_dim
            out_dim = in_dim
            if i == n_layers - 1:
                out_dim = self.features_extractor.module.out_dim
            self.repa_mlp.append(nn.Linear(in_dim, out_dim))
            if i != n_layers - 1:
                self.repa_mlp.append(nn.SiLU())

        # Register hook to get miiddle features
        if hasattr(unet.mid_block, "transformer_blocks"):
            tformer_blocks = unet.mid_block.transformer_blocks
        else:
            tformer_blocks = unet.mid_block.attention.transformer_blocks

        i_extract = min(i_extract, len(tformer_blocks) // 2)
        tformer_blocks[i_extract].register_forward_hook(self._hook_repa)

    def _hook_repa(self, module, input, output):
        if self.training:
            self._repa_layer_output = output

    def forward(self, x_gt):
        # Extract and project repa extracted features
        repa_val = self._repa_layer_output
        repa_val = ensure_1d_tokens(repa_val)
        repa_val = self.repa_mlp(repa_val)

        # Compute loss with a reference model
        with torch.no_grad():
            repa_ref = self.features_extractor(x_gt, target_n_tokens=repa_val.shape[1])

        assert repa_val.shape == repa_ref.shape, f"Invalid shape {repa_val.shape} != {repa_ref.shape}"

        self._repa_layer_output = None
        with torch.autocast("cuda", enabled=False):
            return 1 - self.repa_loss(repa_val.to(torch.float32), repa_ref.to(torch.float32)).mean()


@AUX_LOSSES.register("ae_aux_losses")
class DAELosses(nn.Module):
    def __init__(
        self,
        ae: nn.Module,
        repa: Optional[dict] = None,
        lpips: bool = True,
        accelerator: Optional[Any] = None,
    ):
        super().__init__()
        self.accelerator = accelerator

        # REPA loss
        self.repa_loss = None
        if repa is not None:
            ae = unwrap(ae, unw_ema=True)
            self.repa_loss = REPALoss(ae.decoder, accelerator=accelerator, **repa)

        # LPIPS
        self.lpips_loss = None
        if lpips:
            self.lpips_loss = Frozen(lpips_lib.LPIPS(net="vgg"), accelerator=accelerator)

    def forward(self, x_BCWH, x0_pred, target_x=None):
        losses = {}
        if target_x is None:
            target_x = x_BCWH

        # REPA loss
        if self.repa_loss is not None:
            losses["repa"] = self.repa_loss(x_BCWH)

        # LPIPS loss
        if self.lpips_loss is not None:
            losses["lpips"] = self.lpips_loss(target_x, x0_pred).mean()

        return losses

    def __deepcopy__(self, memo):
        return None


class GanLoss(nn.Module):
    def __init__(self, model_last_layer=None, start_iter=0, adaptive_weight=False, patch_gan=False, hinge_loss=True, discriminator=None):
        super().__init__()

        self.gan_model = NLayerDiscriminator(
            input_nc=3,
            flatten=not patch_gan,
            **(discriminator or {}),
        )
        self.patch_gan = patch_gan
        self.adaptive_gan_weight = adaptive_weight
        self._model_last_layer = Ref(model_last_layer)
        self.hinge_loss = hinge_loss
        self.start_iter = start_iter

    def forward(self, x_gt, x_pred, xt, t, n_train_steps, existing_losses=None, step=None):
        with torch.autocast("cuda", enabled=False):
            losses = {}
            if step == "disc_loss":
                assert existing_losses is not None, "Existing losses must be provided for GAN discriminator loss calculation"
                # Add adversarial loss
                if n_train_steps >= self.start_iter:
                    logits_fake = self.gan_model(x_pred, xt, t)

                    if not self.patch_gan:
                        gan_g_loss = F.binary_cross_entropy(logits_fake, torch.ones_like(logits_fake))
                    else:
                        gan_g_loss = -torch.mean(logits_fake)

                    if self.adaptive_gan_weight:
                        nll_loss, _ = aggregate_losses(self.cfg, existing_losses)
                        adaptive_gan_w = self.calculate_adaptive_gan_w(nll_loss, gan_g_loss, self._model_last_layer.value)
                        gan_g_loss = gan_g_loss * adaptive_gan_w
                    losses["gan_disc"] = gan_g_loss
            elif step == "train":
                # Train GAN model
                logits_real = self.gan_model(x_gt.detach(), xt.detach(), t)
                logits_fake = self.gan_model(x_pred.detach(), xt.detach(), t)

                if not self.patch_gan:
                    loss_real = F.binary_cross_entropy(logits_real, torch.ones_like(logits_real))
                    loss_fake = F.binary_cross_entropy(logits_fake, torch.zeros_like(logits_fake))
                elif self.hinge_loss:
                    loss_real = torch.mean(F.relu(1.0 - logits_real))
                    loss_fake = torch.mean(F.relu(1.0 + logits_fake))
                else:
                    loss_real = torch.mean(F.softplus(-logits_real))
                    loss_fake = torch.mean(F.softplus(logits_fake))

                losses["gan_train"] = (loss_real + loss_fake) / 2

                if n_train_steps < self.start_iter:
                    losses["gan_train"] = 0 * losses["gan_train"]
            else:
                raise ValueError(f"Unknown GAN step: {step}")
        return losses

    def calculate_adaptive_gan_w(self, nll_loss, g_loss, last_layer):
        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True, allow_unused=True)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True, allow_unused=True)[0]
        acc = TaskState().accelerator
        nll_grads = acc.reduce(nll_grads, reduction="mean")
        g_grads = acc.reduce(g_grads, reduction="mean")
        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        return d_weight
