import torch as th
import torchvision as tv
from torch import nn
from utils.utils import BatchToSharedObjects, SharedObjectsToBatch, LambdaModule, MultiArgSequential, Binarize
from einops import rearrange, repeat, reduce
from utils.optimizers import Ranger
from nn.spectral_convnext import SPConvNeXtBlock, SPConvNeXtUnet 
import torch.nn.functional as F
from utils.loss import MaskedL1SSIMLoss

class UncertaintyGANLoss(nn.Module):
    def __init__(
        self, 
        output_size,
        batch_size,
        discriminator_start = 50001, 
        in_channels = 3, 
        base_channels = 12,
        blocks = [1,2,3,4],
        discriminator_weight = 0.5,
    ):
        super(UncertaintyGANLoss, self).__init__()
        self.discriminator = nn.Sequential(
            SPConvNeXtUnet(
                in_channels    = in_channels,
                out_channels   = 2,
                base_channels  = base_channels,
                blocks         = blocks,
            ),
            nn.Softmax(dim=1),
        )

        self.l1ssim = MaskedL1SSIMLoss()

        self.discriminator_start  = discriminator_start
        self.discriminator_weight = discriminator_weight

        self.register_buffer('fake_target_certain', th.zeros((batch_size, *output_size), dtype=th.long) + 0, persistent=False)
        self.register_buffer('fake_target_uncertain', th.zeros((batch_size, *output_size), dtype=th.long) + 1, persistent=False)

    def calculate_adaptive_weight(self, rec_loss, g_loss, last_layer=None):
        rec_grads = th.autograd.grad(rec_loss, last_layer, retain_graph=True)[0]
        g_grads   = th.autograd.grad(g_loss, last_layer, retain_graph=True)[0]

        d_weight = th.linalg.norm(rec_grads) / (th.linalg.norm(g_grads) + 1e-6)
        d_weight = th.clamp(d_weight, 0.0, 1e6).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

    def masked_cross_entropy(self, prediction, target, mask):
        return th.sum(
            F.cross_entropy(prediction, target, reduction='none') * mask.squeeze(dim=1).detach()
        ) / (th.sum(mask.detach()) + 1e-6)

    def masked_accuracy(self, prediction, target, mask):
        return th.sum(
            (th.argmax(prediction, dim=1) == target).float() * mask.squeeze(dim=1)
        ) / (th.sum(mask) + 1e-6)

    def forward(self, inputs, reconstructions, uncertainty, last_layer, num_updates):

        # check wether reconstructions has a grad function
        if not reconstructions.requires_grad:
            raise ValueError("Reconstructions must have a grad function")

        # discriminator update
        if num_updates >= self.discriminator_start:
            
            logits_fake = self.discriminator(reconstructions.detach())

            acc_fake_certain    = self.masked_accuracy(logits_fake,      self.fake_target_certain,   (1 - uncertainty))
            acc_fake_uncertain  = self.masked_accuracy(logits_fake,      self.fake_target_uncertain, uncertainty)
            loss_fake_certain   = self.masked_cross_entropy(logits_fake, self.fake_target_certain,   (1 - uncertainty))
            loss_fake_uncertain = self.masked_cross_entropy(logits_fake, self.fake_target_uncertain, uncertainty)

            d_loss = loss_fake_certain + loss_fake_uncertain

        rec_loss, rec_loss_l1, rec_loss_ssim = self.l1ssim(inputs, reconstructions, (1 - uncertainty).detach())
        loss = rec_loss

        if num_updates >= self.discriminator_start:

            # Freeze discriminator
            for param in self.discriminator.parameters():
                param.requires_grad = False

            # generator update
            logits_fake = self.discriminator(reconstructions)
            g_loss      = self.masked_cross_entropy(logits_fake, self.fake_target_certain, uncertainty)

            d_weight = self.calculate_adaptive_weight(rec_loss, g_loss, last_layer=last_layer)

            # unfreeze discriminator
            for param in self.discriminator.parameters():
                param.requires_grad = True

            loss = rec_loss + d_weight * g_loss + d_loss

        log = {
            "total_loss"          : loss.item(), 
            "rec_loss"            : rec_loss.item(),
            "rec_loss_l1"         : rec_loss_l1.item(),
            "rec_loss_ssim"       : rec_loss_ssim.item(),
            "d_weight"            : d_weight.item()            if num_updates >= self.discriminator_start else 0,
            "g_loss"              : g_loss.item()              if num_updates >= self.discriminator_start else 0,
            "d_loss"              : d_loss.item()              if num_updates >= self.discriminator_start else 0,
            "acc_fake_certain"    : acc_fake_certain.item()    if num_updates >= self.discriminator_start else 0,
            "acc_fake_uncertain"  : acc_fake_uncertain.item()  if num_updates >= self.discriminator_start else 0,
            "loss_fake_certain"   : loss_fake_certain.item()   if num_updates >= self.discriminator_start else 0,
            "loss_fake_uncertain" : loss_fake_uncertain.item() if num_updates >= self.discriminator_start else 0,
        }

        return loss, log
