import torch as th
from torch.utils.data import Dataset, DataLoader
from torch import nn

import h5py
import os
from utils.configuration import Configuration
from utils.io import model_path
from model.loci import Loci
from nn.background import ViTDepthUncertantyBackground
from utils.utils import LambdaModule, Gaus2D, Prioritize, MaskCenter, RadomSimilarityBasedMaskDrop
from utils.optimizers import SDAdam, SDAMSGrad
from utils.io import Timer
import numpy as np
import cv2
from pathlib import Path
import shutil
import pickle
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
from model.scripts.training import eval_net
from model.pretrainer import LociPretrainer
import pytorch_lightning as pl
from data.lightning_objects import LociPretrainerDataModule
from data.lightning_background_v2 import LociBackgroundDataModule
from data.lightning_loci import LociDataModule
from data.lightning_uncertainty import LociUncertaintyPretrainerDataModule
from model.lightning.pretrainer import LociPretrainerModule
from model.lightning.background_v2 import LociBackgroundModule
from model.lightning.loci import LociModule
from model.lightning.uncertainty import LociUncertaintyPretrainerModule
from utils.loss import SSIM

def preprocess(tensor, scale=1, normalize=False, mean_std_normalize=False, size=None, add_text=False, text="", position=(10, 30), font_scale=1, font_color=(255,255,255), outline_color=(0,0,0), font_thickness=2):

    if normalize:
        min_ = th.min(tensor)
        max_ = th.max(tensor)
        tensor = (tensor - min_) / (max_ - min_)

    if mean_std_normalize:
        mean = th.mean(tensor)
        std = th.std(tensor)
        tensor = th.clip((tensor - mean) / (2 * std), -1, 1) * 0.5 + 0.5

    if scale > 1:
        upsample = nn.Upsample(scale_factor=scale).to(tensor[0].device)
        tensor = upsample(tensor)

    if size is not None:
        tensor = F.interpolate(tensor, size=size, mode='bicubic', align_corners=True)

    if add_text:
        font = cv2.FONT_HERSHEY_SIMPLEX
        img = (tensor[0].cpu().numpy().transpose(1,2,0) * 255).astype(np.uint8)
        img = cv2.UMat(img).get()
        img = cv2.putText(img, text, position, font, font_scale, outline_color, font_thickness+1, cv2.LINE_AA)
        img = cv2.putText(img, text, position, font, font_scale, font_color, font_thickness, cv2.LINE_AA)
        if len(img.shape) == 2:
            img = np.expand_dims(img, axis=2)
        tensor = th.tensor(img.transpose(2,0,1), device=tensor.device).unsqueeze(0) / 255.0

    return tensor

def color_mask(mask):

    colors = th.tensor([
	[ 255,   0,   0 ],
	[   0,   0, 255 ],
	[ 255, 255,   0 ],
	[ 255,   0, 255 ],
	[   0, 255, 255 ],
	[   0, 255,   0 ],
	[ 255, 128,   0 ],
	[ 128, 255,   0 ],
	[ 128,   0, 255 ],
	[ 255,   0, 128 ],
	[   0, 255, 128 ],
	[   0, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 255 ],
	[ 128, 255, 255 ],
	[ 128, 255, 255 ],
	[ 255, 255, 128 ],
	[ 255, 255, 128 ],
	[ 255, 128, 255 ],
	[ 128,   0,   0 ],
	[   0,   0, 128 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
    ], device = mask.device) / 255.0

    colors = colors.view(1, -1, 3, 1, 1)
    mask = mask.unsqueeze(dim=2)

    return th.sum(colors[:,:mask.shape[1]] * mask, dim=1)


def priority_to_img(priority, h, w):

    imgs = []

    for p in range(priority.shape[2]):

        img = np.zeros((h,w,3), np.uint8)

        font                   = cv2.FONT_HERSHEY_SIMPLEX
        text_position          = (h // 6, w //2)
        font_scale             = w / 256
        font_color             = (255,255,255)
        thickness              = 2
        lineType               = 2

        cv2.putText(img,f'{priority[0,0,p].item():.2e}',
            text_position,
            font,
            font_scale,
            font_color,
            thickness,
            lineType)

        imgs.append(rearrange(th.tensor(img, device=priority.device), 'h w c -> 1 1 c h w'))

    return imgs

def to_rgb_object(tensor, o):
    colors = th.tensor([
	[ 255,   0,   0 ],
	[   0,   0, 255 ],
	[ 255, 255,   0 ],
	[ 255,   0, 255 ],
	[   0, 255, 255 ],
	[   0, 255,   0 ],
	[ 255, 128,   0 ],
	[ 128, 255,   0 ],
	[ 128,   0, 255 ],
	[ 255,   0, 128 ],
	[   0, 255, 128 ],
	[   0, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 255 ],
	[ 128, 255, 255 ],
	[ 128, 255, 255 ],
	[ 255, 255, 128 ],
	[ 255, 255, 128 ],
	[ 255, 128, 255 ],
	[ 128,   0,   0 ],
	[   0,   0, 128 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
    ], device = tensor.device) / 255.0

    colors = colors.view(48,3,1,1)
    return colors[o] * tensor

def to_rgb(tensor: th.Tensor):
    return th.cat((
        tensor * 0.6 + 0.4,
        tensor, 
        tensor
    ), dim=1)

"""
def distance_weights(positions, sigma = 0.5):
    # positions is of shape (B, N, 3)
    B, N, _ = positions.shape

    # give more weight to z distance
    positions = th.cat((positions, positions[:,:,2:3]), dim=2)
    
    # Expand dims to compute pairwise differences
    p1 = positions[:, :, None, :]
    p2 = positions[:, None, :, :]
    
    # Compute pairwise differences and squared Euclidean distance
    diff = p1 - p2
    squared_diff = diff ** 2
    squared_distances = th.sum(squared_diff, dim=-1)
    
    # Compute the actual distances
    distances = th.sqrt(squared_distances)
    weights = th.exp(-distances / (2 * sigma ** 2))
    
    return weights
"""

def distance_weights(positions, sigma_scale = 25): # TODO tune this and the reactivater correlations !!
    # positions is of shape (B, N, 3)
    B, N, _ = positions.shape

    sigma = positions[:,:,-1]
    positions = positions[:,:,:-1]

    # give more weight to z distance
    positions = th.cat((positions, positions[:,:,2:3]), dim=2)
    
    # Expand dims to compute pairwise differences
    p1 = positions[:, :, None, :]
    p2 = positions[:, None, :, :]

    # expand sigma
    sigma1 = sigma[:,:,None]
    sigma2 = sigma[:,None,:]
    
    # Compute pairwise differences and squared Euclidean distance
    diff = p1 - p2
    squared_diff = diff ** 2
    squared_distances = th.sum(squared_diff, dim=-1)

    var = sigma1 * sigma2 * sigma_scale
    
    # Compute the actual distances
    distances = th.sqrt(squared_distances)
    weights = th.exp(-distances / (2 * var + 1e-5))
    
    return weights

def batch_covariance(slots):
    mean_slots = th.mean(slots, dim=1, keepdim=True)
    centered_slots = slots - mean_slots
    cov_matrix = th.bmm(centered_slots.transpose(1, 2), centered_slots) / (slots.size(1) - 1)
    return cov_matrix

def batch_correlation(slots):
    cov_matrix = batch_covariance(slots)
    variances = th.diagonal(cov_matrix, dim1=-2, dim2=-1)
    std_matrix = th.sqrt(variances[:, :, None] * variances[:, None, :])
    corr_matrix = cov_matrix / std_matrix
    return corr_matrix

def get_topk_indices(correlation_matrix, k):
    diag_indices = th.arange(0, correlation_matrix.size(-1))
    correlation_matrix[:, diag_indices, diag_indices] = -1.0
    
    # Sort the values
    sorted_values, sorted_indices = th.sort(correlation_matrix, dim=-1, descending=True)
    
    # Select the top-k indices. Since we've zeroed the diagonal, the top-k will now be correct
    topk_indices = sorted_indices[:, :, :k]
    
    return topk_indices

def get_drop_mask(similarity_matrix):
    similarity_matrix = th.relu(similarity_matrix)
    mean_similarity   = th.mean(similarity_matrix)
    similarity_matrix = th.relu(similarity_matrix - mean_similarity) / (1 - mean_similarity)
    similarity_matrix = th.triu(similarity_matrix) * (1 - th.eye(similarity_matrix.shape[-1], device=similarity_matrix.device))
    drop_propability  = reduce(similarity_matrix, 'b n m -> b n', 'max')
    return (drop_propability < th.rand_like(drop_propability)).float()
    #return 1 - (drop_propability > th.rand_like(drop_propability)).float()
    #return 1 - reduce(similarity_matrix, 'b n m -> b n', 'max')

slots = th.randn(1, 10, 5)

def save(cfg: Configuration, dataset: Dataset, checkpoint_path, active_layer, size, object_view, nice_view, individual_views, add_text):

    np.random.seed(1234)
    th.manual_seed(1234)

    #assert(cfg.sequence_len == 2)
    cfg_net = cfg.model
    device = th.device(cfg.device)
    cfg_net.batch_size = 1

    data_module = LociDataModule(cfg)
    model       = LociModule(cfg)

    # Load the model from the checkpoint if provided, otherwise create a new model
    if checkpoint_path is not None and os.path.exists(checkpoint_path):

        # ends with ckpt
        if checkpoint_path[-4:] == 'ckpt':
            model = LociModule.load_from_checkpoint(checkpoint_path, cfg=cfg, strict=False)
        else:
            model.load_state_dict(th.load(checkpoint_path, map_location=device))

    net = model.net.to(device=device)
    net.eval()

    dataloader = data_module.val_dataloader()
    
    gaus2d = Gaus2D(size, position_limit=3.5).to(device)

    init = 1#net.get_init_status()
    print(f"Init status: {init}")
    mseloss = nn.MSELoss()

    prioritize = Prioritize(cfg_net.num_objects).to(device)

    last_input = None
    last_rgb = None
    last_depth = None
    last_time_steps = None
    output = None

    i = 0
    teacher_forcing = ((cfg.teacher_forcing // cfg.backprop_steps) * cfg.backprop_steps + 1) if cfg.sequence_len > 1 else cfg.teacher_forcing

    with th.no_grad():
        for input in dataloader:

            tensor_rgb        = input[0].to(device)
            tensor_depth      = input[1].to(device) 
            time_steps        = input[2].to(device)
            use_depth         = input[3].to(device)
            instance_masks    = input[5].to(device)

            timestep = time_steps[0,0].item()

            if timestep == -teacher_forcing:
                if i > 100:
                    exit()
                net.reset_state()
                i += 1

                output = { 
                    'reconstruction' : {
                        'object': None,
                        'depth_raw': None,
                        'mask': None,
                        'mask_raw': None,
                        'occlusion': None,
                        'position': None,
                        'gestalt': None,
                        'priority': None,
                        'output_depth': None,
                    },
                    'prediction' : {
                        'bg_rgb': None,
                        'bg_depth': None,
                        'output_depth': None,
                    }
                }

                bg_input        = th.cat((tensor_rgb[:,0], tensor_depth[:,0]), dim=1) if cfg.model.input_depth else tensor_rgb[:,0]
                uncertainty_cur = net.background.uncertainty_estimation(bg_input)[0]
                fg_mask         = (uncertainty_cur > 0.8).float()
            
                results = net.proposal(instance_masks[:,0,:16], tensor_depth[:,0], tensor_rgb[:,0], fg_mask = fg_mask)
            
                seg_position = results['position']
                seg_mask     = results['mask']
                
                # sort by mask size
                seg_mask_sum = reduce(seg_mask, 'b o h w -> b o', 'sum')
                sorted_values, sorted_indices = th.sort(seg_mask_sum, dim=1, descending=True)

                # Using advanced indexing to sort the masks and positions
                sorted_seg_mask     = seg_mask[th.arange(seg_mask.size(0)).unsqueeze(1), sorted_indices]
                sorted_seg_position = seg_position[th.arange(seg_position.size(0)).unsqueeze(1), sorted_indices]

                sorted_seg_position = sorted_seg_position[:,:cfg.model.num_objects]
                sorted_seg_mask     = sorted_seg_mask[:,:cfg.model.num_objects]

                output['reconstruction']['position'] = rearrange(sorted_seg_position, 'b n c -> b (n c)')
                output['reconstruction']['mask']     = th.cat((sorted_seg_mask, 1 - reduce(results['mask'], 'b n h w -> b 1 h w', 'max')), dim=1)


            else:
                tensor_rgb   = th.cat((last_rgb, tensor_rgb), dim=1)
                tensor_depth = th.cat((last_depth, tensor_depth), dim=1)
                time_steps   = th.cat((last_time_step, time_steps), dim=1)

            last_rgb       = tensor_rgb[:, -1:]
            last_depth     = tensor_depth[:, -1:]
            last_time_step = time_steps[:, -1:]

            for time_step in range(len(time_steps[0])-1):
                timestep = time_steps[0,time_step].item()
                t = time_step if time_steps[0,time_step].item() >= 0 and cfg.sequence_len > 1 else 0

                input_rgb   = tensor_rgb[:, t]
                input_depth = tensor_depth[:, t]

                target_rgb   = tensor_rgb[:, t+1 if timestep >= 0 and cfg.sequence_len > 1 else 0]
                target_depth = tensor_depth[:, t+1 if timestep >= 0 and cfg.sequence_len > 1 else 0]

                output_last = output['prediction'] if time_steps[0,time_step].item() > 0 else output['reconstruction']
                output = net(
                    input_rgb       = input_rgb,
                    input_depth     = input_depth if cfg.model.input_depth else output_last['output_depth'],
                    bg_rgb_last     = output['prediction']['bg_rgb'],
                    bg_depth_last   = output['prediction']['bg_depth'],
                    object_last     = output_last['object'],
                    depth_raw_last  = output_last['depth_raw'],
                    mask_last       = output_last['mask'],
                    mask_raw_last   = output_last['mask_raw'],
                    occlusion_last  = output_last['occlusion'],
                    position_last   = output_last['position'],
                    gestalt_last    = output_last['gestalt'],
                    priority_last   = output_last['priority'],
                    teacher_forcing = timestep < 0,
                    reset           = False,
                    detach          = False,
                    evaluate        = True, 
                    test            = False,
                )

                output_next = output['prediction'] if time_steps[0,time_step].item() >= 0 else output['reconstruction']

                bg_depth_next        = output['prediction']['bg_depth']
                background_next      = output['prediction']['bg_rgb']
                output_rgb_next      = output_next['output_rgb']
                output_depth_next    = output_next['output_depth']
                mask_next            = output_next['mask']
                object_next          = output_next['object']
                depth_next           = output_next['depth']
                position_next        = output_next['position']
                gestalt_next         = output_next['gestalt']
                priority_next        = output_next['priority']
                uncertainty_cur      = output['reconstruction']['uncertainty']

                if timestep == 5:
                    _gestalt  = rearrange(gestalt_next,  'b (n c) -> b n c', n = cfg_net.num_objects)
                    _position = rearrange(position_next, 'b (n c) -> b n c', n = cfg_net.num_objects)
                    for n in range(0, cfg_net.num_objects):
                        with open(f'latent-{i:04d}-{n:02d}.pickle', "wb") as outfile:
                            state = {
                                "gestalt":  th.round(th.clip(_gestalt[0:1,n], 0, 1)),
                                "position": _position[0:1,n],
                            }
                            pickle.dump(state, outfile)

                print(f'Saving[{timestep+teacher_forcing:3d}/{i+1}/{len(dataloader)}]: {(i*100) / len(dataloader):.3f}%')

                gestalt  = rearrange(output_next['gestalt'],  'b (o c) -> b c o', o = cfg_net.num_objects)
                position = rearrange(output_next['position'], 'b (o c) -> b c o', o = cfg_net.num_objects)
                visible  = (reduce(output_next['mask'][:,:-1], 'b c h w -> b 1 c', 'max') > 0.75).float()

                gestalt  = gestalt * visible  + 0.5 * (1 - visible)
                position = position * visible

                #slot_corr    = batch_correlation(th.cat((gestalt, position[:,-1:]), dim=1))
                weights      = distance_weights(rearrange(position, 'b c o -> b o c'))
                slot_corr    = batch_correlation(th.cat((gestalt, position[:,-1:]), dim=1)) * weights 
                topk_indices = get_topk_indices(slot_corr, 5)
                drop_mask    = get_drop_mask(slot_corr).unsqueeze(-1).unsqueeze(-1)

                #drop_mask, slot_corr = RadomSimilarityBasedMaskDrop()(output_next['position'], output_next['gestalt'], output_next['mask'][:,:-1])
                #topk_indices = get_topk_indices(slot_corr, 5)
                drop_mask = th.ones_like(drop_mask)


                #print(((th.triu(slot_corr) * (1 - th.eye(cfg_net.num_objects, device=slot_corr.device))).detach().cpu().numpy() * 100).astype(int))

                #if timestep % 2 == 0:
                #    #mask_next = mask_next * th.cat((drop_mask, th.ones_like(drop_mask[:,:1])), dim=1).unsqueeze(-1).unsqueeze(-1)
                #    drop_mask = th.ones_like(drop_mask)

                highlited_target_rgb = target_rgb
                if mask_next is not None:
                    grayscale                 = target_rgb[:,0:1] * 0.299 + target_rgb[:,1:2] * 0.587 + target_rgb[:,2:3] * 0.114
                    object_mask_next          = th.sum(mask_next[:,:-1]*drop_mask, dim=1).unsqueeze(dim=1)
                    highlited_target_rgb  = grayscale * (1 - object_mask_next) 
                    highlited_target_rgb += grayscale * object_mask_next * 0.3333333 
                    highlited_target_rgb  = highlited_target_rgb + color_mask(mask_next[:,:-1]*drop_mask) * 0.6666666

                highlited_target_depth = target_depth if target_depth is not None else th.zeros_like(target_rgb)
                if mask_next is not None:
                    grayscale                 = target_depth if target_depth is not None else th.zeros_like(target_rgb)
                    object_mask_next          = th.sum(mask_next[:,:-1], dim=1).unsqueeze(dim=1)
                    highlited_target_depth  = grayscale * (1 - object_mask_next) 
                    highlited_target_depth += grayscale * object_mask_next * 0.3333333 
                    highlited_target_depth  = highlited_target_depth + color_mask(mask_next[:,:-1]) * 0.6666666

                xy_next         = rearrange(position_next, 'b (o c) -> (b o) c', o = cfg_net.num_objects)[:,:2]
                position_next   = th.cat((position_next, th.zeros_like(position_next[:,:3]), th.ones_like(position_next[:,:1])), dim=1)
                position_next2d = rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects+1)
                position_next2d = th.cat((position_next2d[:,:3], th.clamp(position_next2d[:,3:], 1 / min(cfg_net.latent_size), 1)), dim=1)
                position_next2d = gaus2d(position_next2d)
                position_next2d = rearrange(position_next2d, '(b o) 1 h w -> b o 1 h w', o=cfg_net.num_objects+1)
                                 
                object_next = th.cat((object_next, background_next), dim=1)
                object_next = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)

                depth_next = th.cat((depth_next, bg_depth_next), dim=1)
                depth_next = rearrange(depth_next, 'b (o 1) h w -> b o 1 h w')
                mask_next  = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w')

                object_mask_next = reduce(output_next['mask_raw'][:,:-1], 'b c h w -> b 1 h w', 'sum')

                output_dir = 'individual_images'
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)


                if object_view or individual_views:
                    num_objects   = cfg_net.num_objects + 1
                    object_width  = int(np.ceil(((size[1] * 3 + 18) - (num_objects - 1) * 6) / num_objects))
                    object_height = int(np.ceil((object_width / size[1]) * size[0]))
                    padding = (object_width * num_objects + (num_objects - 1) * 6) - size[1] * 3
                    obj_size = (object_height, object_width)
                    
                    width  = size[1] * 3 + 18*3 + padding
                    height = size[0] * 2 + 18*4 + 9*3 + object_height*9

                    img = th.ones((3, height, width), device = object_next.device) * 0.2

                    # Process the main images
                    input_rgb_img = preprocess(input_rgb, size=size, add_text=add_text, text="RGB Input")[0]
                    input_depth_img = preprocess(input_depth if use_depth[0].item() == 1 else th.zeros_like(input_rgb), size=size, add_text=add_text, text="Depth Input")[0]
                    highlited_target_rgb_img = preprocess(highlited_target_rgb, size=size, add_text=add_text, text="Highlighted RGB target")[0]
                    output_rgb_img = preprocess(output_rgb_next, size=size, add_text=add_text, text="RGB Output")[0]
                    output_depth_img = preprocess(output_depth_next, size=size, add_text=add_text, text="Depth Output")[0]
                    uncertainty_img = preprocess(uncertainty_cur, size=size, add_text=add_text, text="Input Uncertainty")[0]
    
                    # Save the main images individually
                    if individual_views:
                        cv2.imwrite(os.path.join(output_dir, f'input_rgb-{i:04d}-{timestep+teacher_forcing:03d}.jpg'), input_rgb_img.cpu().numpy().transpose(1,2,0) * 255)
                        cv2.imwrite(os.path.join(output_dir, f'input_depth-{i:04d}-{timestep+teacher_forcing:03d}.jpg'), input_depth_img.cpu().numpy().transpose(1,2,0) * 255)
                        cv2.imwrite(os.path.join(output_dir, f'highlited_target_rgb-{i:04d}-{timestep+teacher_forcing:03d}.jpg'), highlited_target_rgb_img.cpu().numpy().transpose(1,2,0) * 255)
                        cv2.imwrite(os.path.join(output_dir, f'output_rgb-{i:04d}-{timestep+teacher_forcing:03d}.jpg'), output_rgb_img.cpu().numpy().transpose(1,2,0) * 255)
                        cv2.imwrite(os.path.join(output_dir, f'output_depth-{i:04d}-{timestep+teacher_forcing:03d}.jpg'), output_depth_img.cpu().numpy().transpose(1,2,0) * 255)
                        cv2.imwrite(os.path.join(output_dir, f'uncertainty-{i:04d}-{timestep+teacher_forcing:03d}.jpg'), uncertainty_img.cpu().numpy().transpose(1,2,0) * 255)
    
                    # Add the images to the main img as before
                    img[:, 18:size[0]+18, 18:size[1]+18] = input_rgb_img
                    img[:, 18:size[0]+18, 18*2+size[1]:18*2+size[1]*2] = input_depth_img
                    img[:, 18:size[0]+18, 18*3+size[1]*2:18*3+size[1]*3] = highlited_target_rgb_img
                    img[:, -size[0]-18:-18, 18:size[1]+18] = output_rgb_img
                    img[:, -size[0]-18:-18, 18*2+size[1]:18*2+size[1]*2] = output_depth_img
                    img[:, -size[0]-18:-18, 18*3+size[1]*2:18*3+size[1]*3] = uncertainty_img

                    for o in range(num_objects):
                        # Process the images
                        object_img = preprocess(object_next[:,o])
                        depth_img = preprocess(depth_next[:,o])
                        mask_img = to_rgb_object(preprocess(mask_next[:,o]), o)
                        position_img = to_rgb_object(preprocess(position_next2d[:,o]), o)

                        mask_img_top1 = to_rgb_object(preprocess(mask_next[:,topk_indices[0,o,0]]), topk_indices[0,o,0]) if o < cfg_net.num_objects else th.zeros_like(mask_img)
                        mask_img_top2 = to_rgb_object(preprocess(mask_next[:,topk_indices[0,o,1]]), topk_indices[0,o,1]) if o < cfg_net.num_objects else th.zeros_like(mask_img)
                        mask_img_top3 = to_rgb_object(preprocess(mask_next[:,topk_indices[0,o,2]]), topk_indices[0,o,2]) if o < cfg_net.num_objects else th.zeros_like(mask_img)
                        mask_img_top4 = to_rgb_object(preprocess(mask_next[:,topk_indices[0,o,3]]), topk_indices[0,o,3]) if o < cfg_net.num_objects else th.zeros_like(mask_img)
                        mask_img_top5 = to_rgb_object(preprocess(mask_next[:,topk_indices[0,o,4]]), topk_indices[0,o,4]) if o < cfg_net.num_objects else th.zeros_like(mask_img)
    
                        # Save the individual images
                        if individual_views:
                            cv2.imwrite(os.path.join(output_dir, f'object-{i:04d}-{timestep+teacher_forcing:03d}-obj{o:02d}.jpg'), object_img[0].cpu().numpy().transpose(1,2,0) * 255)
                            cv2.imwrite(os.path.join(output_dir, f'depth-{i:04d}-{timestep+teacher_forcing:03d}-obj{o:02d}.jpg'), depth_img[0].cpu().numpy().transpose(1,2,0) * 255)
                            cv2.imwrite(os.path.join(output_dir, f'mask-{i:04d}-{timestep+teacher_forcing:03d}-obj{o:02d}.jpg'), mask_img[0].cpu().numpy().transpose(1,2,0) * 255)
                            cv2.imwrite(os.path.join(output_dir, f'position-{i:04d}-{timestep+teacher_forcing:03d}-obj{o:02d}.jpg'), position_img[0].cpu().numpy().transpose(1,2,0) * 255)

                        # resize the images
                        object_img = preprocess(object_img, size=obj_size)[0]
                        depth_img = preprocess(depth_img, size=obj_size)[0]
                        mask_img = preprocess(mask_img, size=obj_size)[0]
                        mask_img_top1 = preprocess(mask_img_top1, size=obj_size)[0] * max(0, slot_corr[0, o, topk_indices[0,o,0]] if o < cfg_net.num_objects else 0)
                        mask_img_top2 = preprocess(mask_img_top2, size=obj_size)[0] * max(0, slot_corr[0, o, topk_indices[0,o,1]] if o < cfg_net.num_objects else 0)
                        mask_img_top3 = preprocess(mask_img_top3, size=obj_size)[0] * max(0, slot_corr[0, o, topk_indices[0,o,2]] if o < cfg_net.num_objects else 0)
                        mask_img_top4 = preprocess(mask_img_top4, size=obj_size)[0] * max(0, slot_corr[0, o, topk_indices[0,o,3]] if o < cfg_net.num_objects else 0)
                        mask_img_top5 = preprocess(mask_img_top5, size=obj_size)[0] * max(0, slot_corr[0, o, topk_indices[0,o,4]] if o < cfg_net.num_objects else 0)
                        position_img = preprocess(position_img, size=obj_size)[0]
    
                        # Add the images to the main img as before
                        img[:,size[0]+18*2                    :size[0]+18*2+      object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = object_img
                        img[:,size[0]+18*2+   6+object_height :size[0]+18*2+  6+2*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = depth_img
                        img[:,size[0]+18*2+2*(6+object_height):size[0]+18*2+2*6+3*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = mask_img
                        img[:,size[0]+18*2+3*(6+object_height):size[0]+18*2+3*6+4*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = mask_img_top1 
                        img[:,size[0]+18*2+4*(6+object_height):size[0]+18*2+4*6+5*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = mask_img_top2
                        img[:,size[0]+18*2+5*(6+object_height):size[0]+18*2+5*6+6*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = mask_img_top3
                        img[:,size[0]+18*2+6*(6+object_height):size[0]+18*2+6*6+7*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = mask_img_top4
                        img[:,size[0]+18*2+7*(6+object_height):size[0]+18*2+7*6+8*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = mask_img_top5
                        img[:,size[0]+18*2+8*(6+object_height):size[0]+18*2+8*6+9*object_height,18+(object_width+6)*o:18+(object_width+6)*o+object_width] = position_img

                    img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
                    cv2.imwrite(f'loci-grid-{i:04d}-{timestep+teacher_forcing:03d}.jpg', img)



def save_bg(cfg: Configuration, dataset: Dataset, file, active_layer, size, vertical_images, add_text, individual_views):

    np.random.seed(1234)
    th.manual_seed(1234)

    #assert(cfg.sequence_len == 2)
    cfg_net = cfg.model
    device = th.device(cfg.device)
    cfg_net.batch_size = 1

    os.makedirs(f"out/{cfg.model_path}", exist_ok=True)

    data_module = LociBackgroundDataModule(cfg)
    dataloader  = data_module.val_dataloader()

    if file != '':
        model = LociBackgroundModule.load_from_checkpoint(file, cfg=cfg, strict=False).to(device)
    else:
        model = LociBackgroundModule(cfg).to(device)

    # create model 
    net = model.net
    net.eval()

    ssim = SSIM()
    
    with th.no_grad():
        for i, batch in enumerate(dataloader):

            batch = [b.to(device) for b in batch]

            source_rgb, source_depth, source_fg_mask, target_rgb, target_depth, target_fg_mask, use_depth, use_fg_masks, delta_t, _, _, _, _, _, input_mode = batch

            source = th.cat([source_rgb, source_depth], dim=1) if cfg.model.input_depth else source_rgb
            target = th.cat([target_rgb, target_depth], dim=1) if cfg.model.input_depth else target_rgb

            source_uncertainty, source_uncertainty_noised = net.uncertainty_estimation(source)
            target_uncertainty, target_uncertainty_noised = net.uncertainty_estimation(target)

            output_rgb, output_depth, motion_context, depth_context = net(
                #source_rgb, target_rgb, th.rand_like(source_uncertainty.detach())*0.01, th.rand_like(target_uncertainty.detach())*0.01, delta_t
                source, target, source_uncertainty.detach(), target_uncertainty.detach(), delta_t, input_mode.view(-1, 1, 1, 1)
            )

            print(f'Saving[{i+1}/{len(dataloader)}|{i+1/len(dataloader)*100:.2f}%]')

            grayscale             = target_rgb[:,0:1] * 0.299 + target_rgb[:,1:2] * 0.587 + target_rgb[:,2:3] * 0.114
            target_rgb_highlited  = grayscale * (1 - target_uncertainty) 
            target_rgb_highlited += grayscale * target_uncertainty * 0.3333333 
            target_rgb_highlited  = target_rgb_highlited + to_rgb_object(target_uncertainty, 5) * 0.6666666

            depth_error      = th.abs(target_depth - output_depth)
            depth_error_mean = th.mean(depth_error, dim=(2,3), keepdim=True)
            depth_error_std  = th.std(depth_error, dim=(2,3), keepdim=True)
            depth_error_mask = (depth_error > depth_error_mean + depth_error_std * 2).float()

            rgb_error      = th.mean(th.abs(target_rgb - output_rgb), dim=1, keepdim=True)
            rgb_error_mean = th.mean(rgb_error, dim=(2,3), keepdim=True)
            rgb_error_std  = th.std(rgb_error, dim=(2,3), keepdim=True)
            rgb_error_mask = (rgb_error > rgb_error_mean + rgb_error_std * 2).float()


            if vertical_images:
                width  = size[1] * 2 + 18 * 3
                height = size[0] * 5 + 18 * 6

                img = th.ones((3, height, width), device=device) * 0.2
                img[:, 18:size[0]+18, 18:size[1]+18] = preprocess(source_rgb, size=size, add_text=add_text, text="RGB Input")[0]
                img[:, 18*2+size[0]:18*2+size[0]*2, 18:size[1]+18] = preprocess(source_depth, size=size, add_text=add_text, text="Depth Input")[0]
                img[:, 18*3+size[0]*2:18*3+size[0]*3, 18:size[1]+18] = preprocess(target_rgb_highlited, size=size, add_text=add_text, text="Uncertainty masked Input")[0]
                img[:, 18*4+size[0]*3:18*4+size[0]*4, 18:size[1]+18] = preprocess(depth_error_mask, size=size, add_text=add_text, text="Depth Error Mask")[0]
                img[:, 18*5+size[0]*4:18*5+size[0]*5, 18:size[1]+18] = preprocess(rgb_error_mask, size=size, add_text=add_text, text="RGB Error Mask")[0]

                img[:, 18:size[0]+18, -size[1]-18:-18] = preprocess(output_rgb, size=size, add_text=add_text, text="RGB Output")[0]
                img[:, 18*2+size[0]:18*2+size[0]*2, -size[1]-18:-18] = preprocess(output_depth, size=size, add_text=add_text, text="Depth Output")[0]
                img[:, 18*3+size[0]*2:18*3+size[0]*3, -size[1]-18:-18] = preprocess(source_uncertainty, size=size, add_text=add_text, text="Uncertainty Output")[0]
                img[:, 18*4+size[0]*3:18*4+size[0]*4, -size[1]-18:-18] = preprocess(depth_error, mean_std_normalize=True, size=size, add_text=add_text, text="Depth Error")[0]
                img[:, 18*5+size[0]*4:18*5+size[0]*5, -size[1]-18:-18] = preprocess(rgb_error, mean_std_normalize=True, size=size, add_text=add_text, text="RGB Error")[0]

                img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
                cv2.imwrite(f'background-grid-{i:04d}.jpg', img)
            else:
                width  = size[1] * 5 + 18*6
                height = size[0] * 2 + 18*3

                img = th.ones((3, height, width), device = device) * 0.2
                img[:,18:size[0]+18, 18:size[1]+18]                 = preprocess(source_rgb, size=size, add_text=add_text, text="RGB Input")[0]
                img[:,18:size[0]+18, 18*2+size[1]:18*2+size[1]*2]   = preprocess(source_depth, size = size, add_text=add_text, text="Depth Input")[0]
                img[:,18:size[0]+18, 18*3+size[1]*2:18*3+size[1]*3] = preprocess(target_rgb_highlited, size = size, add_text=add_text, text="Uncertainty masked Input")[0]
                img[:,18:size[0]+18, 18*4+size[1]*3:18*4+size[1]*4] = preprocess(depth_error_mask, size = size, add_text=add_text, text="Depth Error Mask")[0]
                img[:,18:size[0]+18, 18*5+size[1]*4:18*5+size[1]*5] = preprocess(rgb_error_mask, size = size, add_text=add_text, text="RGB Error Mask")[0]


                img[:,-size[0]-18:-18, 18:size[1]+18]                 = preprocess(output_rgb, size=size, add_text=add_text, text="RGB Output")[0]
                img[:,-size[0]-18:-18, 18*2+size[1]:18*2+size[1]*2]   = preprocess(output_depth, size=size, add_text=add_text, text="Depth Output")[0]
                img[:,-size[0]-18:-18, 18*3+size[1]*2:18*3+size[1]*3] = preprocess(source_uncertainty, size = size, add_text=add_text, text="Uncertainty Output")[0]
                img[:,-size[0]-18:-18, 18*4+size[1]*3:18*4+size[1]*4] = preprocess(depth_error, mean_std_normalize=True, size = size, add_text=add_text, text="Depth Error")[0]
                img[:,-size[0]-18:-18, 18*5+size[1]*4:18*5+size[1]*5] = preprocess(rgb_error, mean_std_normalize=True, size = size, add_text=add_text, text="RGB Error")[0]


                img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
                cv2.imwrite(f'background-grid-{i:04d}.jpg', img)

            if individual_views:
                cv2.imwrite(f'background-input-rgb-{i:04d}-rgb.jpg', rearrange(preprocess(source_rgb, size=size, add_text=add_text, text="RGB Input")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'background-input-depth-{i:04d}-{t+3:03d}-depth.jpg', rearrange(preprocess(source_depth, size=size, add_text=add_text, text="Depth Input")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'background-input-rgb-masked-{i:04d}-{t+3:03d}-rgb.jpg', rearrange(preprocess(target_rgb_highlited, size=size, add_text=add_text, text="Uncertainty masked Input")[0] * 255, 'c h w -> h w c').cpu().numpy())

                cv2.imwrite(f'background-output-rgb-{i:04d}-{t+3:03d}-rgb.jpg', rearrange(preprocess(output_rgb, size=size, add_text=add_text, text="RGB Output")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'background-output-depth-{i:04d}-{t+3:03d}-depth.jpg', rearrange(preprocess(output_depth, size=size, add_text=add_text, text="Depth Output")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'background-output-uncertainty-{i:04d}-{t+3:03d}-uncertainty.jpg', rearrange(preprocess(source_uncertainty, size=size, add_text=add_text, text="Uncertainty Output")[0] * 255, 'c h w -> h w c').cpu().numpy())





def save_objects(cfg: Configuration, dataset: Dataset, file, active_layer, size, vertical_images, add_text, individual_views, mask = False, export_latent = False, input_mask = False):

    np.random.seed(1234)
    th.manual_seed(1234)

    #assert(cfg.sequence_len == 2)
    cfg_net = cfg.model
    device = th.device(cfg.device)
    cfg_net.batch_size = 1

    os.makedirs(f"out/{cfg.model_path}", exist_ok=True)

    data_module = LociPretrainerDataModule(cfg)
    dataloader  = data_module.val_dataloader()

    if file != '':
        model = LociPretrainerModule.load_from_checkpoint(file, cfg=cfg, strict=False).to(device)
    else:
        model = LociPretrainerModule(cfg).to(device)

    # create model 
    net = model.net
    net.eval()

    mask_center = MaskCenter(cfg_net.crop_size).to(device)
    gaus2d      = Gaus2D(cfg_net.crop_size).to(device)

    with th.no_grad():
        for i, input in enumerate(dataloader):

            input_rgb   = input[0].to(device)
            input_depth = input[1].to(device)
            input_instance_mask = input[2].to(device)

            results = net(input_rgb, input_depth, input_instance_mask, iterations=cfg.pretrainer_iterations, mode=cfg.pretraining_mode)

            if export_latent:
                gestalt  = results['gestalt'].cpu().numpy()
                position = results['position'].cpu().numpy()

                # save using pickle
                with open(f'out/{cfg.model_path}/latent-states-{i:04d}.pkl', 'wb') as f:
                    pickle.dump({'gestalt': gestalt, 'position': position}, f)

            print(f'Saving[{(i+1)*100/len(dataloader):.2f}%/{i+1}/{len(dataloader)}]')

            xy_std = th.cat(mask_center(input_instance_mask), dim=1)
            pos2d  = gaus2d(xy_std)

            grayscale            = input_rgb[:,0:1] * 0.299 + input_rgb[:,1:2] * 0.587 + input_rgb[:,2:3] * 0.114
            highlited_input_rgb  = grayscale * (1 - input_instance_mask) 
            highlited_input_rgb += grayscale * input_instance_mask * 0.3333333 
            highlited_input_rgb  = highlited_input_rgb + to_rgb_object(input_instance_mask, 5) * 0.333333 + to_rgb_object(pos2d, 1) * 0.333333

            norm_depth = th.sigmoid(results['depth'])

            
            if vertical_images:
                width  = size[1] * 3 + 18 * 4
                height = size[0] * 2 + 18 * 3

                img = th.ones((3, height, width), device=device) * 0.2
                img[:, 18:size[0]+18, 18:size[1]+18] = preprocess(input_rgb, size=size, add_text=True, text="RGB Input")[0]
                img[:, 18*2+size[0]:18*2+size[0]*2, 18:size[1]+18] = preprocess(input_depth, size=size, add_text=True, text="Depth Input")[0]
                img[:, 18*3+size[0]*2:18*3+size[0]*3, 18:size[1]+18] = preprocess(highlited_input_rgb, size=size, add_text=True, text="GT Masked Input")[0]

                if mask:
                    img[:, 18:size[0]+18, -size[1]-18:-18] = preprocess(results['object'] * results['mask'], size=size, add_text=True, text="RGB Output")[0]
                    img[:, 18*2+size[0]:18*2+size[0]*2, -size[1]-18:-18] = preprocess(norm_depth * results['mask'], size=size, add_text=True, text="Depth Output")[0]
                elif input_mask:
                    img[:, 18:size[0]+18, -size[1]-18:-18] = preprocess(results['object'] * input_instance_mask, size=size, add_text=True, text="RGB Output")[0]
                    img[:, 18*2+size[0]:18*2+size[0]*2, -size[1]-18:-18] = preprocess(norm_depth * input_instance_mask, size=size, add_text=True, text="Depth Output")[0]
                else:
                    img[:, 18:size[0]+18, -size[1]-18:-18] = preprocess(results['object'], size=size, add_text=True, text="RGB Output")[0]
                    img[:, 18*2+size[0]:18*2+size[0]*2, -size[1]-18:-18] = preprocess(norm_depth, size=size, add_text=True, text="Depth Output")[0]

                img[:, 18*3+size[0]*2:18*3+size[0]*3, -size[1]-18:-18] = preprocess(results['mask'], size=size, add_text=True, text="Mask Output")[0]

                img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
                cv2.imwrite(f'object-grid-{i:04d}-{t+3:03d}.jpg', img)
            else:
                width  = size[1] * 3 + 18*4
                #height = size[0] * 3 + 18*4
                height = size[0] * 2 + 18*3

                img = th.ones((3, height, width), device = device) * 0.2
                img[:,18:size[0]+18, 18:size[1]+18]                 = preprocess(input_rgb, size=size, add_text=True, text="RGB Input")[0]
                img[:,18:size[0]+18, 18*2+size[1]:18*2+size[1]*2]   = preprocess(input_depth, size = size, add_text=True, text="Depth Input")[0]
                img[:,18:size[0]+18, 18*3+size[1]*2:18*3+size[1]*3] = preprocess(highlited_input_rgb, size = size, add_text=True, text="GT Masked Input")[0]

                if mask:
                    img[:,size[0]+36:size[0]*2+36, 18:size[1]+18]                 = preprocess(results['object'] * results['mask'], size=size, add_text=True, text="RGB Output")[0]
                    img[:,size[0]+36:size[0]*2+36, 18*2+size[1]:18*2+size[1]*2]   = preprocess(norm_depth * results['mask'], size=size, add_text=True, text="Depth Output")[0]
                elif input_mask:
                    img[:,size[0]+36:size[0]*2+36, 18:size[1]+18]                 = preprocess(results['object'] * input_instance_mask, size=size, add_text=True, text="RGB Output")[0]
                    img[:,size[0]+36:size[0]*2+36, 18*2+size[1]:18*2+size[1]*2]   = preprocess(norm_depth * input_instance_mask, size=size, add_text=True, text="Depth Output")[0]
                else:
                    img[:,size[0]+36:size[0]*2+36, 18:size[1]+18]                 = preprocess(results['object'], size=size, add_text=True, text="RGB Output")[0]
                    img[:,size[0]+36:size[0]*2+36, 18*2+size[1]:18*2+size[1]*2]   = preprocess(norm_depth, size=size, add_text=True, text="Depth Output")[0]

                img[:,size[0]+36:size[0]*2+36, 18*3+size[1]*2:18*3+size[1]*3] = preprocess(results['mask'], size = size, add_text=True, text="Mask Output")[0]
                #img[:,-size[0]-18:-18, 18:size[1]+18]                 = preprocess(results['masks'][:,0:1], size=size, add_text=True, text="Mask Output 1")[0]
                #img[:,-size[0]-18:-18, 18*2+size[1]:18*2+size[1]*2]   = preprocess(results['masks'][:,1:2], size=size, add_text=True, text="Mask Output 2")[0]
                #img[:,-size[0]-18:-18, 18*3+size[1]*2:18*3+size[1]*3] = preprocess(results['masks'][:,2:3], size = size, add_text=True, text="Mask Output 3")[0]

                img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
                cv2.imwrite(f'object-grid-{i:04d}.jpg', img)

            if individual_views:
                cv2.imwrite(f'object-input-rgb-{i:04d}-rgb.jpg', rearrange(preprocess(input_rgb, size=size, add_text=True, text="RGB Input")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'object-input-depth-{i:04d}-depth.jpg', rearrange(preprocess(input_depth, size=size, add_text=True, text="Depth Input")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'object-input-masked-{i:04d}-masked.jpg', rearrange(preprocess(highlited_input_rgb, size=size, add_text=True, text="GT Masked Input")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'object-output-rgb-{i:04d}-rgb.jpg', rearrange(preprocess(results['object'] * results['mask'], size=size, add_text=True, text="RGB Output")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'object-output-depth-{i:04d}-depth.jpg', rearrange(preprocess(results['depth'] * results['mask'], size=size, add_text=True, text="Depth Output")[0] * 255, 'c h w -> h w c').cpu().numpy())
                cv2.imwrite(f'object-output-mask-{i:04d}-mask.jpg', rearrange(preprocess(results['mask'], size=size, add_text=True, text="Mask Output")[0] * 255, 'c h w -> h w c').cpu().numpy())


def save_masks(cfg: Configuration, dataset: Dataset, file, active_layer, size, vertical_images, add_text, individual_views, mask = False, export_latent = False):

    np.random.seed(1234)
    th.manual_seed(1234)

    #assert(cfg.sequence_len == 2)
    cfg.pretraining_mode = "mask"
    cfg_net = cfg.model
    device = th.device(cfg.device)
    cfg_net.batch_size = 1

    os.makedirs(f"out/{cfg.model_path}", exist_ok=True)

    data_module = LociPretrainerDataModule(cfg)
    dataloader  = data_module.val_dataloader()

    if file != '':
        model = LociPretrainerModule.load_from_checkpoint(file, cfg=cfg, strict=False).to(device)
    else:
        model = LociPretrainerModule(cfg).to(device)

    # create model 
    model.net.eval()

    mask_center = MaskCenter(cfg_net.crop_size).to(device)
    gaus2d      = Gaus2D(cfg_net.crop_size).to(device)

    with th.no_grad():
        for i, input in enumerate(dataloader):

            input_rgb   = input[0].to(device)
            input_depth = input[1].to(device)
            input_instance_mask = input[2].to(device)

            results = model(input_rgb, input_depth, input_instance_mask)

            if export_latent:
                gestalt  = results['gestalt'].cpu().numpy()
                position = results['position'].cpu().numpy()

                # save using pickle
                with open(f'out/{cfg.model_path}/latent-states-{i:04d}.pkl', 'wb') as f:
                    pickle.dump({'gestalt': gestalt, 'position': position}, f)

            print(f'Saving[{(i+1)*100/len(dataloader):.2f}%/{i+1}/{len(dataloader)}]')

            xy_std = model.net.mask_pretrainer.mask_center(input_instance_mask)
            pos2d  = gaus2d(xy_std, compute_std=False)
            
            width  = size[1] * 2 + 18 * 3
            height = size[0] * 2 + 18 * 3

            img = th.ones((3, height, width), device=device) * 0.2
            img[:, 18:size[0]+18, 18:size[1]+18] = preprocess(input_instance_mask, size=size, add_text=True, text="GT Mask")[0]
            img[:, 18*2+size[0]:18*2+size[0]*2, 18:size[1]+18] = preprocess(results['mask'][:,0:1], size=size, add_text=True, text="Mask Output")[0]

            img[:, 18:size[0]+18, -size[1]-18:-18] = preprocess(th.abs(results['mask'][:,0:1] - input_instance_mask), normalize=True, size=size, add_text=True, text="Error")[0]
            img[:, 18*2+size[0]:18*2+size[0]*2, -size[1]-18:-18] = preprocess(pos2d, size=size, add_text=True, text="Position")[0]

            img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
            cv2.imwrite(f'mask-grid-{i:04d}.jpg', img)


def save_depth(cfg: Configuration, dataset: Dataset, file, active_layer, size, vertical_images, add_text, individual_views, mask = False, export_latent = False):

    np.random.seed(1234)
    th.manual_seed(1234)

    #assert(cfg.sequence_len == 2)
    cfg.pretraining_mode = "depth"
    cfg_net = cfg.model
    device = th.device(cfg.device)
    cfg_net.batch_size = 1

    os.makedirs(f"out/{cfg.model_path}", exist_ok=True)

    data_module = LociPretrainerDataModule(cfg)
    dataloader  = data_module.val_dataloader()

    if file != '':
        model = LociPretrainerModule.load_from_checkpoint(file, cfg=cfg, strict=False).to(device)
    else:
        model = LociPretrainerModule(cfg).to(device)

    # create model 
    model.net.eval()

    mask_center = MaskCenter(cfg_net.crop_size).to(device)
    gaus2d      = Gaus2D(cfg_net.crop_size).to(device)

    with th.no_grad():
        for i, input in enumerate(dataloader):

            input_rgb   = input[0].to(device)
            input_depth = input[1].to(device)
            input_instance_mask = input[2].to(device)

            results = model(input_rgb, input_depth, input_instance_mask)

            input_depth_mean = th.sum(input_depth * input_instance_mask, dim=(1,2,3), keepdim=True) 
            input_depth_mean = input_depth_mean / (th.sum(input_instance_mask, dim=(1,2,3), keepdim=True) + 1e-6)
            input_depth_std  = th.sqrt(
                th.sum((input_depth - input_depth_mean)**2 * input_instance_mask, dim=(1,2,3), keepdim=True) / 
                (th.sum(input_instance_mask, dim=(1,2,3), keepdim=True) + 1e-6)
            )

            input_depth = th.sigmoid(((input_depth - input_depth_mean) / (input_depth_std + 1e-6)) * input_instance_mask) * input_instance_mask

            if export_latent:
                gestalt  = results['gestalt'].cpu().numpy()
                position = results['position'].cpu().numpy()

                # save using pickle
                with open(f'out/{cfg.model_path}/latent-states-{i:04d}.pkl', 'wb') as f:
                    pickle.dump({'gestalt': gestalt, 'position': position}, f)

            print(f'Saving[{(i+1)*100/len(dataloader):.2f}%/{i+1}/{len(dataloader)}]')

            xy_std = model.net.mask_pretrainer.mask_center(input_instance_mask)
            pos2d  = gaus2d(xy_std, compute_std=False)
            
            width  = size[1] * 2 + 18 * 3
            height = size[0] * 2 + 18 * 3

            img = th.ones((3, height, width), device=device) * 0.2
            img[:, 18:size[0]+18, 18:size[1]+18] = preprocess(input_depth, size=size, add_text=True, text="Input Depth")[0]
            img[:, 18*2+size[0]:18*2+size[0]*2, 18:size[1]+18] = preprocess(th.sigmoid(results['depth']) * input_instance_mask, size=size, add_text=True, text="Depth Output")[0]

            img[:, 18:size[0]+18, -size[1]-18:-18] = preprocess(th.abs(th.sigmoid(results['depth']) * input_instance_mask - input_depth), normalize=True, size=size, add_text=True, text="Error")[0]
            img[:, 18*2+size[0]:18*2+size[0]*2, -size[1]-18:-18] = preprocess(pos2d, size=size, add_text=True, text="Position")[0]

            img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
            cv2.imwrite(f'depth-grid-{i:04d}.jpg', img)

def save_rgb(cfg: Configuration, dataset: Dataset, file, active_layer, size, vertical_images, add_text, individual_views, mask = False, export_latent = False):

    np.random.seed(1234)
    th.manual_seed(1234)

    #assert(cfg.sequence_len == 2)
    cfg.pretraining_mode = "rgb"
    cfg_net = cfg.model
    device = th.device(cfg.device)
    cfg_net.batch_size = 1

    os.makedirs(f"out/{cfg.model_path}", exist_ok=True)

    data_module = LociPretrainerDataModule(cfg)
    dataloader  = data_module.val_dataloader()

    if file != '':
        model = LociPretrainerModule.load_from_checkpoint(file, cfg=cfg, strict=False).to(device)
    else:
        model = LociPretrainerModule(cfg).to(device)

    # create model 
    model.net.eval()

    mask_center = MaskCenter(cfg_net.crop_size).to(device)
    gaus2d      = Gaus2D(cfg_net.crop_size).to(device)

    with th.no_grad():
        for i, input in enumerate(dataloader):

            input_rgb   = input[0].to(device)
            input_depth = input[1].to(device)
            input_instance_mask = input[2].to(device)

            results = model(input_rgb, input_depth, input_instance_mask)

            input_depth_mean = th.sum(input_depth * input_instance_mask, dim=(1,2,3), keepdim=True) 
            input_depth_mean = input_depth_mean / (th.sum(input_instance_mask, dim=(1,2,3), keepdim=True) + 1e-6)
            input_depth_std  = th.sqrt(
                th.sum((input_depth - input_depth_mean)**2 * input_instance_mask, dim=(1,2,3), keepdim=True) / 
                (th.sum(input_instance_mask, dim=(1,2,3), keepdim=True) + 1e-6)
            )

            input_depth = th.sigmoid(((input_depth - input_depth_mean) / (input_depth_std + 1e-6)) * input_instance_mask) * input_instance_mask

            if export_latent:
                gestalt  = results['gestalt'].cpu().numpy()
                position = results['position'].cpu().numpy()

                # save using pickle
                with open(f'out/{cfg.model_path}/latent-states-{i:04d}.pkl', 'wb') as f:
                    pickle.dump({'gestalt': gestalt, 'position': position}, f)

            print(f'Saving[{(i+1)*100/len(dataloader):.2f}%/{i+1}/{len(dataloader)}]')

            xy_std = model.net.mask_pretrainer.mask_center(input_instance_mask)
            pos2d  = gaus2d(xy_std, compute_std=False)
            
            width  = size[1] * 2 + 18 * 3
            height = size[0] * 2 + 18 * 3

            img = th.ones((3, height, width), device=device) * 0.2
            img[:, 18:size[0]+18, 18:size[1]+18] = preprocess(input_rgb * input_instance_mask, size=size, add_text=True, text="Input RGB")[0]
            img[:, 18*2+size[0]:18*2+size[0]*2, 18:size[1]+18] = preprocess(input_depth, size=size, add_text=True, text="GT Depth")[0]

            img[:, 18:size[0]+18, -size[1]-18:-18] = preprocess(th.abs(results['object'] - input_rgb) * input_instance_mask, normalize=True, size=size, add_text=True, text="Error")[0]
            img[:, 18*2+size[0]:18*2+size[0]*2, -size[1]-18:-18] = preprocess(results['object'] * input_instance_mask, size=size, add_text=True, text="RGB Output")[0]

            img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
            cv2.imwrite(f'rgb-grid-{i:04d}.jpg', img)

def save_uncertainty(cfg: Configuration, file, size, vertical_images, add_text, individual_views, mask = False, export_latent = False):

    np.random.seed(1234)
    th.manual_seed(1234)

    #assert(cfg.sequence_len == 2)
    cfg_net = cfg.model
    device = th.device(cfg.device)
    cfg_net.batch_size = 1

    os.makedirs(f"out/{cfg.model_path}", exist_ok=True)

    data_module = LociUncertaintyPretrainerDataModule(cfg)
    dataloader  = data_module.val_dataloader()

    if file != '':
        model = LociUncertaintyPretrainerModule.load_from_checkpoint(file, cfg=cfg, strict=False).to(device)
    else:
        model = LociUncertaintyPretrainerModule(cfg).to(device)

    # create model 
    model.net.eval()

    gaus2d = Gaus2D(cfg_net.crop_size).to(device)

    with th.no_grad():
        for i, input in enumerate(dataloader):

            input_rgb      = input[0].to(device)
            input_depth    = input[1].to(device)
            instance_masks = input[2].to(device)

            """
            input_depth, input_depth, fg_mask, input_positions2d, target_positions, target_error_offset = model.compute_step(
                (input_rgb, input_depth, instance_masks), i
            )

            width  = size[1] * 4 + 18 * 5
            height = size[0] * 1 + 18 * 2

            target_positions   = target_positions * (target_error_offset == 0).float().unsqueeze(-1)
            target_positions2d = gaus2d(rearrange(target_positions, ' 1 n c -> n c'), compute_std=False)
            target_positions2d = target_positions2d * (target_error_offset == 0).float().squeeze(0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            target_positions2d = reduce(target_positions2d, 'b 1 h w -> 1 1 h w', 'max')
            input_positions2d  = th.cat((fg_mask * 0.2 + input_positions2d * 0.8, fg_mask * 0.2, fg_mask * 0.2 + target_positions2d * 0.8), dim=1)

            img = th.ones((3, height, width), device=device) * 0.2
            img[:, 18:size[1]+18, 18:size[0]+18]                 = preprocess(input_rgb, size=size, add_text=True, text="RGB Input")[0]
            img[:, 18:size[1]+18, 18*2+size[0]:18*2+size[0]*2]   = preprocess(input_depth, size=size, add_text=True, text="GT Depth")[0]
            img[:, 18:size[1]+18, 18*3+size[0]*2:18*3+size[0]*3] = preprocess(fg_mask, size=size, add_text=True, text="FG Mask")[0]
            img[:, 18:size[1]+18, 18*4+size[0]*3:18*4+size[0]*4] = preprocess(input_positions2d, size=size, add_text=True, text="Position")[0]

            img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
            cv2.imwrite(f'uncertainty-grid-{i:04d}.jpg', img)
            
            print(f'Saving[{(i+1)*100/len(dataloader):.2f}%/{i+1}/{len(dataloader)}]')
            """

            #"""

            fg_mask = reduce(instance_masks, 'b c h w -> b 1 h w', 'max')

            positions2d = [th.zeros_like(fg_mask)]
            valid_gaus_mask = th.zeros_like(input_rgb)
            invalid_gaus_mask = th.zeros_like(input_rgb)

            
            for n in range(instance_masks.shape[1]):
                gaus_mask = reduce(th.cat(positions2d, dim=1), 'b c h w -> b 1 h w', 'max')
                position, valid = model(input_rgb, input_depth, fg_mask, gaus_mask)
                valid = th.sigmoid(valid)

                positions2d.append(gaus2d(position, compute_std=False))

                width  = size[1] * 4 + 18 * 5
                height = size[0] * 1 + 18 * 2

                gaus_mask = gaus2d(position, compute_std=False)
                gaus_mask = th.cat((fg_mask * 0.2, th.maximum(fg_mask * 0.2, gaus_mask * valid), th.maximum(fg_mask * 0.2, gaus_mask * (1-valid))), dim=1)

                invalid_gaus_mask = th.maximum(invalid_gaus_mask, gaus_mask * (1-valid))
                valid_gaus_mask   = th.maximum(valid_gaus_mask, gaus_mask * valid)

                gaus_mask = th.maximum(invalid_gaus_mask, valid_gaus_mask)

                img = th.ones((3, height, width), device=device) * 0.2
                img[:, 18:size[1]+18, 18:size[0]+18]                 = preprocess(input_rgb, size=size, add_text=True, text="RGB Input")[0]
                img[:, 18:size[1]+18, 18*2+size[0]:18*2+size[0]*2]   = preprocess(input_depth, size=size, add_text=True, text="GT Depth")[0]
                img[:, 18:size[1]+18, 18*3+size[0]*2:18*3+size[0]*3] = preprocess(fg_mask, size=size, add_text=True, text="FG Mask")[0]
                img[:, 18:size[1]+18, 18*4+size[0]*3:18*4+size[0]*4] = preprocess(gaus_mask, size=size, add_text=True, text="Position")[0]

                img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
                cv2.imwrite(f'uncertainty-grid-{i:04d}-{n:03d}.jpg', img)
                
                print(f'Saving[{(i+1)*100/len(dataloader):.2f}%/{i+1}/{len(dataloader)}|{n+1}/{instance_masks.shape[1]}%]')
            #"""
