from typing import Dict, Tuple

import torch
import torch.nn as nn

from CRDR.src.utils.registry import TRAINER_REGISTRY
from CRDR.src.trainer.rgan_rate_distortion_trainer import RGANRateDistortionTrainer

@TRAINER_REGISTRY.register()
class BetaCondRGANRateDistortionTrainer(RGANRateDistortionTrainer):
    def optimize_parameters(self, current_iter: int, data_dict: Dict):
        log_dict = {}

        ###################################################################
        #                             Train G                             
        ###################################################################
        self.discriminator.requires_grad_(False)
        self.g_optimizer.zero_grad()
        if self.aux_optimizer:
            self.aux_optimizer.zero_grad()

        # run model
        real_images, fake_images, bpp, other_outputs = self.run_comp_model(data_dict)
        log_dict['qbpp'] = other_outputs.get('qbpp', -1)
        beta = other_outputs.pop('beta')
        
        # calculate losses
        g_loss_dict = {}
        dist_loss = self.distortion_loss(real_images, fake_images, **other_outputs)
        g_loss_dict['distortion'] = dist_loss

        rate_loss = self.rate_loss(bpp, **other_outputs, current_iter=current_iter)
        g_loss_dict['rate'] = rate_loss

        assert self.perceptual_loss is not None
        percep_loss = self.perceptual_loss(real_images, fake_images)
        g_loss_dict['perceptual'] = percep_loss

        # RGAN adv loss
        real_d_pred = self.discriminator(real_images, **other_outputs).detach()
        fake_g_pred = self.discriminator(fake_images, **other_outputs)

        l_g_real = self.gan_loss(real_d_pred - fake_g_pred, is_real=False, is_disc=False)
        l_g_fake = self.gan_loss(fake_g_pred - real_d_pred, is_real=True, is_disc=False)
        adv_loss = (l_g_real + l_g_fake) / 2.
        g_loss_dict['adv'] = adv_loss

        l_total = dist_loss + rate_loss + beta * (adv_loss + percep_loss)

        # For stability
        if (loss_anomaly := self.check_loss_nan_inf(l_total)): # False if no anomaly
            self.logger.warning(f'iter{current_iter}: skipped because loss is {loss_anomaly}')
            return # skip back-propagation part
        
        # back prop & update parameters
        l_total.backward()
        if self.opt.optim.get('clip_max_norm'):
            nn.utils.clip_grad_norm_(self.comp_model.parameters(), self.opt.optim.clip_max_norm)
        self.g_optimizer.step()

        log_dict.update(g_loss_dict)

        if self.g_scheduler:
            self.g_scheduler.step()

        if self.aux_optimizer:
            log_dict['aux'] = self.optimize_aux_parameters()

        ###################################################################
        #                             Train D                             
        ###################################################################
        self.discriminator.requires_grad_(True)
        self.d_optimizer.zero_grad()

        # real
        fake_d_pred = self.discriminator(fake_images, **other_outputs).detach()
        real_d_pred = self.discriminator(real_images, **other_outputs)
        l_d_real = self.gan_loss(real_d_pred - fake_d_pred, is_real=True, is_disc=True) * 0.5
        l_d_real.backward()

        # fake
        fake_d_pred = self.discriminator(fake_images.detach(), **other_outputs)
        l_d_fake = self.gan_loss(fake_d_pred - real_d_pred.detach(), is_real=False, is_disc=True) * 0.5
        l_d_fake.backward()

        log_dict.update({
            'd_real': l_d_real,
            'd_fake': l_d_fake,
            'd_total': l_d_real + l_d_fake,
            'out_d_real': torch.mean(real_d_pred.detach()),
            'out_d_fake': torch.mean(fake_d_pred.detach()),
        })

        self.d_optimizer.step()
        if self.d_scheduler:
            self.d_scheduler.step()

        return log_dict