import pytorch_lightning as pl
import torch as th
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from utils.optimizers import Ranger
from utils.configuration import Configuration
from nn.object_discovery import PositionProposalVit
from utils.io import UEMA, Timer
from utils.utils import PositionInMask, Gaus2D
import torch.distributed as dist
from einops import rearrange, repeat, reduce

class LociUncertaintyPretrainerModule(pl.LightningModule):
    def __init__(self, cfg: Configuration, state_dict=None):
        super().__init__()
        self.cfg = cfg

        print(f"RANDOM SEED: {cfg.seed}")
        np.random.seed(cfg.seed)
        th.manual_seed(cfg.seed)

        crop_size = cfg.model.crop_size
        if not isinstance(crop_size, tuple) and not isinstance(crop_size, list):
            crop_size = (crop_size, crop_size)

        self.net = PositionProposalVit(
            input_channels  = 6 if cfg.model.input_depth else 5,
            latent_channels = cfg.model.position_proposal.channels,
            num_layers      = cfg.model.position_proposal.num_layers,
        )

        self.gaus2d = Gaus2D(cfg.model.crop_size)

        self.compute_position = PositionInMask(cfg.model.crop_size)
        self.min_std  = 1.0 / min(crop_size)

        self.bceloss = nn.BCEWithLogitsLoss()

        self.lr = self.cfg.learning_rate
        self.own_loggers = {}
        self.timer = Timer()
        self.val_metrics = {}

        self.num_updates = -1

    def forward(self, input_rgb, input_depth, fg_mask, positions2d):
        input = th.cat((fg_mask, input_rgb, input_depth), dim=1) if self.cfg.model.input_depth else th.cat((fg_mask, input_rgb), dim=1)
        return self.net(input, positions2d)

    def log(self, name, value, on_step=True, on_epoch=True, prog_bar=False, logger=True):
        super().log(name, value, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar, logger=logger, sync_dist=True)

        if name.startswith("val_"):
            if name not in self.val_metrics:
                self.val_metrics[name] = 0
                print("Adding metric: ", name)

            self.val_metrics[name] += value.item() if isinstance(value, th.Tensor) else value
        else:
            if name not in self.own_loggers:
                self.own_loggers[name] = UEMA(10000)

            self.own_loggers[name].update(value.item() if isinstance(value, th.Tensor) else value)

    def compute_step(self, batch, batch_idx):
        rgb, depth, instance_masks = batch

        with th.no_grad():
            # binarize instance masks and compute foground mask as maximum of all instance masks
            instance_masks = (instance_masks > 0.75).float()
            fg_mask = reduce(instance_masks, 'b n h w -> b 1 h w', 'max')
            
            # computed target z as mean depth per instance mask
            target_z = th.sum(depth * instance_masks, dim=(2,3)) / (th.sum(instance_masks, dim=(2,3)) + 1e-8)
            target_z = rearrange(target_z, 'b n -> (b n) 1')

            max_num_masks = instance_masks.shape[1]
            valid_masks   = reduce(instance_masks, 'b n h w -> b n', 'max')
            num_masks     = reduce(valid_masks, 'b n -> b 1', 'sum')

            # enumerate valid masks and select a random number of them
            arange  = th.cumsum(valid_masks, dim=1) * valid_masks + (1 - valid_masks) * (max_num_masks + 1) - 0.5
            rnd_val = th.rand_like(num_masks)

            # use all masks 25% of the time, use no masks 10% of the time, use a random number of masks else
            num_input_masks = (rnd_val > 0.75).float() * num_masks + (rnd_val > 0.1).float() * (rnd_val < 0.75).float() * num_masks * th.rand_like(num_masks)
            input_selection = (arange < num_input_masks).float().unsqueeze(-1).unsqueeze(-1) 

            input_masks  = instance_masks * input_selection
            target_masks = instance_masks * (1 - input_selection) * valid_masks.unsqueeze(-1).unsqueeze(-1)

            input_masks  = rearrange(input_masks, 'b n h w -> (b n) 1 h w')
            target_masks = rearrange(target_masks, 'b n h w -> (b n) 1 h w')

            # rarrange target_selection into (b n 1)
            target_selection = (1 - input_selection).squeeze(-1) * valid_masks.unsqueeze(-1)
            input_selection  = rearrange(input_selection, 'b n 1 1 -> (b n) 1')

            # compute offset of null positions if we have valid target positions
            more_than_zero_targets = reduce(target_masks, '(b n) 1 h w -> b 1', 'max', n = max_num_masks)
            target_error_offset = 100 * more_than_zero_targets * (1 - target_selection.squeeze(-1))

            # compute null positions 
            zero_xy        = self.compute_position(1 - fg_mask)[1]
            zero_positions = th.cat((zero_xy, zero_xy * 0), dim=1).unsqueeze(1) 

            target_xy, target_std = self.compute_position(target_masks)[1:]
            target_positions = th.cat((target_xy, target_z, target_std), dim=1) 
            target_positions = rearrange(target_positions, '(b n) c -> b n c', n = max_num_masks)
            target_positions = target_positions * target_selection + zero_positions * (1 - target_selection)

            rand2_selection = (th.rand_like(target_std) < 0.3).float()
            rand3_selection = (th.rand_like(target_std) < 0.1).float()

            rand1_xy, center_xy, std = self.compute_position(input_masks)
            rand2_xy = self.compute_position(input_masks)[0] * rand2_selection
            rand3_xy = self.compute_position(input_masks)[0] * rand3_selection
            
            # rand std can be slitly bigger than std
            rand_std1 =  std * 0.2 + th.rand_like(std) * std
            rand_std2 = (std * 0.2 + th.rand_like(std) * std) * rand2_selection
            rand_std3 = (std * 0.2 + th.rand_like(std) * std) * rand3_selection

            rand_positions1 = th.cat((rand1_xy, target_z, rand_std1), dim=1) * input_selection
            rand_positions2 = th.cat((rand2_xy, target_z, rand_std2), dim=1) * input_selection
            rand_positions3 = th.cat((rand3_xy, target_z, rand_std3), dim=1) * input_selection

            center_positions = th.cat((center_xy, target_z, std), dim=1) * input_selection

            rnd_val = th.rand_like(center_positions[:,:1]) 

            # in 50% of cases use the "correct" positions as input, in the other 50% use the random positions
            rand_positions1 = rand_positions1 * (rnd_val < 0.5).float() + center_positions * (rnd_val >= 0.5).float()
            rand_positions2 = rand_positions2 * (rnd_val < 0.5).float()
            rand_positions3 = rand_positions3 * (rnd_val < 0.5).float()

            center_positions = rearrange(center_positions, '(b n) c -> b (n c)', n = max_num_masks)
            rand_positions1  = rearrange(rand_positions1, '(b n) c -> b (n c)', n = max_num_masks)
            rand_positions2  = rearrange(rand_positions2, '(b n) c -> b (n c)', n = max_num_masks)
            rand_positions3  = rearrange(rand_positions3, '(b n) c -> b (n c)', n = max_num_masks)

            rnd_val = th.rand_like(center_positions[:,:1])

            # in 50% of cases use the "correct" positions everywhere as input, in the other 50% use the random positions
            rand_positions1 = rand_positions1 * (rnd_val < 0.5).float() + center_positions * (rnd_val >= 0.5).float()
            rand_positions2 = rand_positions2 * (rnd_val < 0.5).float()
            rand_positions3 = rand_positions3 * (rnd_val < 0.5).float()

            input_positions = th.cat((rand_positions1, rand_positions2, rand_positions3), dim=1)
            input_positions = rearrange(input_positions, 'b (n c) -> (b n) c', n = 3 * max_num_masks)

            input_positions2d = self.gaus2d(input_positions, compute_std=False) * (input_positions[:,-1:] > self.gaus2d.min_std).float().unsqueeze(-1).unsqueeze(-1)
            input_positions2d = reduce(input_positions2d, '(b n) 1 h w -> b 1 h w', 'max', n = 3 * max_num_masks)

            return rgb, depth, fg_mask, input_positions2d, target_positions, target_error_offset, more_than_zero_targets

    def training_step(self, batch, batch_idx):
        rgb, depth, fg_mask, input_positions2d, target_positions, target_error_offset, more_than_zero_targets = self.compute_step(batch, batch_idx)

        position_out, position_valid = self(rgb, depth, fg_mask, input_positions2d)
        position_err = reduce((position_out.unsqueeze(1) - target_positions)**2, 'b n c -> b n', 'sum')

        # TODO directly predict if a mask is valid or not, so that no additional loss is imposed on emply slots !!!

        position_loss = th.mean(reduce(position_err + target_error_offset, 'b n -> b', 'min') * more_than_zero_targets)
        valid_loss    = self.bceloss(position_valid, more_than_zero_targets) 

        loss = position_loss + valid_loss

        self.log("loss", loss, prog_bar=True)
        self.log("position_loss", position_loss, prog_bar=True)
        self.log("valid_loss", valid_loss, prog_bar=True)

        if self.num_updates < self.trainer.global_step:
            self.num_updates = self.trainer.global_step
            print("Epoch[{}|{}|{}|{:.2f}%]: {}, Loss: {:.2e} possition-loss: {:.2e}, valid-loss: {:.2e}".format(
                self.trainer.local_rank,
                self.trainer.global_step,
                self.trainer.current_epoch,
                (batch_idx + 1) / len(self.trainer.train_dataloader) * 100,
                str(self.timer),
                float(self.own_loggers['loss']),
                float(self.own_loggers['position_loss']),
                float(self.own_loggers['valid_loss'])
            ), flush=True)

        self.val_metrics = {}

        return loss

    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        return Ranger(self.net.parameters(), lr=self.lr, weight_decay=self.cfg.weight_decay)

