import pytorch_lightning as pl
import torch as th
import numpy as np
from torch.utils.data import DataLoader
from utils.io import UEMA, Timer
from utils.optimizers import Ranger
from nn.background import ViTDepthUncertantyBackground
from nn.uncertainty_loss import UncertaintyGANLoss

class LociBackgroundModule(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.own_loggers = {}
        self.timer = Timer()

        np.random.seed(cfg.seed)
        th.manual_seed(cfg.seed)

        self.net = ViTDepthUncertantyBackground( #TODO use mermory effiecient CovNext
            latent_size               = self.cfg.model.latent_size,
            reg_lambda                = self.cfg.model.background.reg_lambda,
            batch_size                = self.cfg.model.batch_size,
            hidden_channels           = self.cfg.model.background.channels,
            num_embedding_layers      = self.cfg.model.background.num_embedding_layers,
            num_attention_layers      = self.cfg.model.background.num_attention_layers,
            num_hyper_layers          = self.cfg.model.background.num_hyper_layers,
            hyper_channels            = self.cfg.model.background.num_hyper_channels,
            num_heads                 = self.cfg.model.background.num_heads,
            uncertainty_base_channels = self.cfg.model.background.uncertainty_base_channels,
            uncertainty_blocks        = self.cfg.model.background.uncertainty_blocks,
            uncertainty_threshold     = self.cfg.model.background.uncertainty_threshold,
            depth_input               = self.cfg.model.input_depth,
            rgbd_decoder              = self.cfg.model.background.rgbd_decoder
        )

        self.gan_loss = UncertaintyGANLoss(
            output_size          = self.cfg.model.crop_size,
            batch_size           = self.cfg.model.batch_size,
            discriminator_start  = cfg.model.background.gan_loss_pretraining,
            discriminator_weight = cfg.model.background.gan_loss_factor
        )

        self.last_input = None
        self.last_rgb = None
        self.last_depth = None
        self.last_fg_mask = None
        self.last_uncertainty = None

        self.num_updates = -1

    def forward(self, x):
        return self.net(x)

    def log(self, name, value, on_step=True, on_epoch=True, prog_bar=False, logger=True):
        super().log(name, value, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar, logger=logger, sync_dist=True)

        if name.startswith("val_"):
            if name not in self.val_metrics:
                self.val_metrics[name] = 0
                print("Adding metric: ", name)

            self.val_metrics[name] += value.item() if isinstance(value, th.Tensor) else value
        else:
            if name not in self.own_loggers:
                self.own_loggers[name] = UEMA(10000)

            self.own_loggers[name].update(value.item() if isinstance(value, th.Tensor) else value)

    def training_step(self, batch, batch_idx):
        cur_rgb, cur_depth, cur_fg_mask, time_step, use_depth, use_fg_masks = batch
        last_input, last_rgb, last_depth, last_fg_mask, last_uncertainty = self.last_input, self.last_rgb, self.last_depth, self.last_fg_mask, self.last_uncertainty
        
        # reshape for easy broadcasting
        use_depth = use_depth.view(-1, 1, 1, 1).float()
        use_fg_masks = use_fg_masks.view(-1, 1, 1, 1).float()

        last_uncertainty = last_uncertainty * (1 - use_fg_masks) + last_fg_mask * use_fg_masks if last_uncertainty is not None else None

        if time_step[0].item() == 0:
            self.net.reset_state()

        cur_input  = th.cat((cur_rgb, cur_depth), dim=1) if self.cfg.model.input_depth else cur_rgb

        confidence   = (cur_fg_mask > 0.5).float() * (cur_fg_mask - 0.5) * 2 + (cur_fg_mask <= 0.5).float() * (0.5 - cur_fg_mask) * 2
        depth_weight = th.clip(cur_depth, 0.1, 1) * use_depth + (1 - use_depth) # FIXME shoul be a dataset parameter!!! Also FIXME for no depth available!!
        warmup       = self.trainer.global_step < self.cfg.model.background.uncertainty_warmup_steps

        # run uncertainty estimation
        cur_uncertainty, cur_uncertainty_noised = self.net.uncertainty_estimation(cur_input)

        supervised_uncertainty_loss = th.mean(((cur_uncertainty - (cur_fg_mask > 0.5).float())**2) * confidence * use_fg_masks)

        loss = supervised_uncertainty_loss.clone()
        if time_step[0].item() > 0:

            self.net.detach()
            output_rgb, output_depth = self.net(last_input, last_uncertainty, compute_rbg = not warmup)
            prediction_mask = (cur_uncertainty < self.cfg.model.background.uncertainty_threshold).float().detach()
            prediction_mask = prediction_mask * (1 - use_fg_masks) + use_fg_masks * (cur_fg_mask < self.cfg.model.background.uncertainty_threshold).float().detach()

            unsupervised_uncertainty_loss = ( 
                self.cfg.model.background.uncertainty_regularizer * th.mean((1 - use_fg_masks) * cur_uncertainty_noised**2) +
                th.mean((1 - use_fg_masks) * depth_weight * th.abs(output_depth - cur_depth).detach() * (1 - cur_uncertainty_noised)) +
                self.cfg.model.background.rgb_loss_factor * th.mean((1 - use_fg_masks) * 0.3333 * th.abs(output_rgb - cur_rgb).detach() * (1 - cur_uncertainty_noised)) if not warmup else 0
            )

            uncertainty_loss = supervised_uncertainty_loss + unsupervised_uncertainty_loss

            loss = loss + unsupervised_uncertainty_loss

            depth_loss = rgb_loss = 0

            if warmup:
                depth_loss = th.mean(th.abs(output_depth - cur_depth))
            else:

                depth_loss    = th.mean(th.abs(output_depth - cur_depth) * prediction_mask)
                rgb_loss, log = self.gan_loss(cur_rgb, output_rgb, 1 - prediction_mask, self.net.get_last_layer(), self.trainer.global_step)

                self.log('train_rgb_loss', rgb_loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_rgb_loss_rec', log['rec_loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_rgb_l1', log['rec_loss_l1'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_rgb_ssim', log['rec_loss_ssim'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_acc_fake_certain', log['acc_fake_certain'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_acc_fake_uncertain', log['acc_fake_uncertain'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_d_loss', log['d_loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_g_loss', log['g_loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
                self.log('train_d_weight', log['d_weight'], on_step=True, on_epoch=True, prog_bar=True, logger=True)

            loss = loss + depth_loss + rgb_loss * self.cfg.model.background.rgb_loss_factor

            self.log('train_depth_loss', depth_loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('train_loss', loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('train_uncertainty', th.mean(cur_uncertainty).item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('train_prediction_mask', th.mean(prediction_mask).item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('train_openings', self.net.openings(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log('train_uncertainty_loss', uncertainty_loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
            if self.num_updates < self.trainer.global_step:
                self.num_updates = self.trainer.global_step
                print("Epoch[{}|{}|{}|{:.2f}%]: {}, Loss: {:.2e}, U: {:.2e}|{:.2e}, M: {:.2e}, depth: {:.2e}, rgb: {:.2e}|{:.2e}|{:.2e}, GAN: {:.2e}|{:.2e}||{:.2e}|{:.2e}, D: {:.2e}, O: {:.2e}".format(
                    self.trainer.local_rank,
                    self.trainer.global_step,
                    self.trainer.current_epoch,
                    (batch_idx + 1) / len(self.trainer.train_dataloader) * 100,
                    str(self.timer),
                    float(self.own_loggers['train_loss']),
                    float(self.own_loggers['train_uncertainty_loss']),
                    float(self.own_loggers['train_uncertainty']),
                    float(self.own_loggers['train_prediction_mask']),
                    float(self.own_loggers['train_depth_loss']),
                    float(self.own_loggers['train_rgb_loss']) if not warmup else 0,
                    float(self.own_loggers['train_rgb_l1']) if not warmup else 0,
                    float(self.own_loggers['train_rgb_ssim']) if not warmup else 0,
                    float(self.own_loggers['train_g_loss']) if not warmup else 0,
                    float(self.own_loggers['train_d_loss']) if not warmup else 0, 
                    1 - float(self.own_loggers['train_acc_fake_certain']) if not warmup else 0,
                    1 - float(self.own_loggers['train_acc_fake_uncertain']) if not warmup else 0,
                    float(self.own_loggers['train_d_weight']) if not warmup else 0,
                    float(self.own_loggers['train_openings']),
                ), flush=True)

        self.last_input       = cur_input
        self.last_rgb         = cur_rgb
        self.last_depth       = cur_depth
        self.last_fg_mask     = cur_fg_mask
        self.last_uncertainty = cur_uncertainty.detach()
        
        return loss

    def validation_step(self, batch, batch_idx):
        # Optional: Implement the validation step
        pass

    def test_step(self, batch, batch_idx):
        # Optional: Implement the test step
        pass

    # TODO hintergunrd masked predictor lernt nur maskhḱierte bereiche zu ersetzen, aber nicht wirkilich vodergunrd zu ignorieren / zu erkennen, da der ja weg maskiert ist!!!
    # TODO das bedeutet das das uncertainty netzt eigentlcih weggeworfen werden kann und die object masken den input zum bg maskieren solten!!!

    def configure_optimizers(self):
        optimizer = Ranger([
            {'params': self.net.parameters(), 'lr': self.cfg.learning_rate, "weight_decay": self.cfg.weight_decay},
            {'params': self.gan_loss.parameters(), 'lr': self.cfg.learning_rate * 4, "weight_decay": 0.001},
        ])
        return optimizer

    # FIXME
    """
    def on_train_epoch_end(self, trainer, pl_module):
        logged_metrics = trainer.logger_connector.logged_metrics
        num_updates = self.global_step
        epoch = self.current_epoch
        batch_index = self.trainer.batch_idx
        trainloader_length = len(self.train_dataloader())

        print("Epoch[{}/{}/{:.2f}%]: Loss: {:.2e}|{:.2e}, U: {:.2e}, M: {:.2e}, depth: {:.2e}, rgb: {:.2e}|{:.2e}|{:.2e}, GAN: {:.2e}|{:.2e}||{:.2e}|{:.2e}, D: {:.2e}, O: {:.2e}".format(
            num_updates,
            batch_index,
            batch_index / trainloader_length * 100,
            epoch + 1,
            logged_metrics['train_loss_epoch'],
            logged_metrics['train_uncertainty_loss_epoch'],
            logged_metrics['train_uncertainty_epoch'],
            logged_metrics['train_prediction_mask_epoch'],
            logged_metrics['train_depth_loss_epoch'],
            logged_metrics['train_rgb_loss_rec_epoch'],
            logged_metrics['train_rgb_loss_l1_epoch'],
            logged_metrics['train_rgb_loss_ssim_epoch'],
            logged_metrics['train_g_loss_epoch'],
            logged_metrics['train_d_loss_epoch'],
            1 - logged_metrics['train_acc_fake_certain_epoch'],
            1 - logged_metrics['train_acc_fake_uncertain_epoch'],
            logged_metrics['train_d_weight_epoch'],
            logged_metrics['train_openings_epoch'],
        ), flush=True)
    """

