import random

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import wandb
from jaxtyping import Float, Int, jaxtyped 
from lightning_utilities.core.rank_zero import rank_zero_only
from torchvision.utils import save_image

from typing import List

from conf.dataset import ValueRange, CelebAParams
from src.callbacks.id_callback import IDCallback
from utils.Logging.LogStrategy import LogStrategy
from utils.utils import normalize_value_range, undersample_list, mask2rgb, augmentWithBackground, get_hack_mode
from beartype import beartype as typechecker
from utils.utils import display_tensor, display_mask


class LogCelebA1(LogStrategy):
    """
    Used when generating sample 1 domain at a time.
    """
    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_train(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        prediction: Float[torch.Tensor, 'batch c h w'],
        input_to_model: Float[torch.Tensor, 'batch c h w'],
        batch: Float[torch.Tensor, 'batch c h w'],
        x0_hat: Float[torch.Tensor, 'batch c h w'],
        ts: Int[torch.Tensor, 'batch'],
        batch_idx: int,
    ) -> None:
        c = prediction.shape[1]
        if c == 1:  # for the sketch
            prediction = prediction.repeat(1, 3, 1, 1)
            input_to_model = input_to_model.repeat(1, 3, 1, 1)
            batch = batch.repeat(1, 3, 1, 1)
            x0_hat = x0_hat.repeat(1, 3, 1, 1)
        elif c > 3:  # for the segmentation
            if not self.params_data.data_params.return_background:
                prediction = augmentWithBackground(prediction)
                input_to_model = augmentWithBackground(input_to_model)
                batch = augmentWithBackground(batch)
                x0_hat = augmentWithBackground(x0_hat)

            prediction = prediction.argmax(dim=1)
            input_to_model = input_to_model.argmax(dim=1)
            batch = batch.argmax(dim=1)
            x0_hat = x0_hat.argmax(dim=1)

            prediction = mask2rgb(prediction)
            input_to_model = mask2rgb(input_to_model)
            batch = mask2rgb(batch)
            x0_hat = mask2rgb(x0_hat)

        images = torch.cat([batch, input_to_model, prediction, x0_hat], dim=-1)

        remaining = self.already_logged(plMod, batch_idx, batch_size=images.shape[0])
        if remaining <= 0:
            return
        images = images[:remaining]
        images = normalize_value_range(images, plMod.params.logging.value_range, clip=True)

        plMod.logger.experiment.log({f'{stage_prefix}/image': [
            wandb.Image(image, caption=f'[batch, input_to_model, prediction, x0_hat]')
            for image in images
        ]})

    @rank_zero_only
    def log_generate(self, plMod: pl.LightningModule, stage_prefix: str, batch_idx: int, batch, idx=None):
        remaining = self.already_logged(plMod, batch_idx, batch_size=batch.shape[0])

        if remaining <= 0:
            return

        with torch.no_grad():
            _, generated_data, generated_x0 = plMod.generate_samples(remaining, return_samples=True, return_pred_x0=True)

        generated_data, generated_x0 = zip(
            *undersample_list(
                list(zip(generated_data, generated_x0)),
                plMod.params.logging.time_step_in_process,
                plMod.params.logging.strategy,
                plMod.params.logging.quad_factor,
            )
        )
        full_tensor    = torch.cat(generated_data, dim=-1)
        full_tensor_x0 = torch.cat(generated_x0  , dim=-1)
        # batch_size, channels, height, width * time_steps

        full_tensor = torch.cat([full_tensor, full_tensor_x0], dim=-2)  # concat them in height dimension

        c = full_tensor.shape[1]
        if c == 1:  # for the sketch
            full_tensor = full_tensor.repeat(1, 3, 1, 1)
        elif c > 3:  # for the segmentation
            if not self.params_data.data_params.return_background:
                full_tensor = augmentWithBackground(full_tensor)
            full_tensor = full_tensor.argmax(dim=1)
            full_tensor = mask2rgb(full_tensor)

        full_tensor = normalize_value_range(full_tensor, plMod.params.logging.value_range, clip=True)

        plMod.logger.experiment.log({f'{stage_prefix}/image': [
            wandb.Image(image, caption=f'xt and x0_hat')
            for image in full_tensor
        ]})


class LogCelebA3(LogStrategy):
    @jaxtyped
    @typechecker
    def from_data_get_seg(self, data: Float[torch.Tensor, 'b c_n_dom h w']) -> Int[torch.Tensor, 'b h w']:
        """
        get the categorical segmentation from the full data
        """
        b, _, h, w = data.shape
        n_dom_notscan = 4
        seg = data[:, n_dom_notscan:]
        return_background = self.params_data.data_params.return_background
        if return_background:
            return seg.argmax(dim=1)
        else:
            seg = augmentWithBackground(segmentations_maps=seg)
            return seg.argmax(dim=1)

    @jaxtyped
    @typechecker
    def data_to_rgb(self, data: Float[torch.Tensor, 'b c h w']) -> Float[torch.Tensor, 'b 3*3 h w']:
        data_rgb = torch.cat([
            data[:, :3],
            data[:, 3].unsqueeze(1).repeat(1, 3, 1, 1),
            mask2rgb(self.from_data_get_seg(data)),
        ], dim=1)

        return data_rgb

    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_train(
        self,
        stage_prefix  : str,
        plMod         : pl.LightningModule,
        ts            : Int[torch.Tensor  , 'b n_dom'],
        modes         : Float[torch.Tensor, 'b n_dom'],
        batch         : Float[torch.Tensor, 'b c_n_dom h w'],
        input_to_model: Float[torch.Tensor, 'b c_n_dom h w'],
        prediction    : Float[torch.Tensor, 'b c_n_dom h w'],
        x0_hat        : Float[torch.Tensor, 'b c_n_dom h w'],
        batch_idx     : int,
    ) -> None:
        b, _, h, w = batch.shape
        n_dom = 3
        rgb_dim = 3

        remaining = self.already_logged(plMod, batch_idx, batch_size=b)
        if remaining <= 0:
            return

        n_dom_notseg = 4
        images = torch.cat([batch, input_to_model, prediction, x0_hat], dim=-1)
        images[:, :n_dom_notseg] = normalize_value_range(
            images[:, :n_dom_notseg], ValueRange.OneUnbound, clip=True
            )  # normalize but the segmentation mask

        # transpose each domain into their RGB channels
        images = self.data_to_rgb(images)
        current_quantity = min(remaining, images.shape[0])
        images = images[:remaining]

        # the domains are cat on the channel dims, they should be cat on the h dim
        images = images.reshape([-1, n_dom, rgb_dim, images.shape[-2], images.shape[-1]])
        images = images.transpose(1, 2).reshape([-1, rgb_dim, n_dom * images.shape[-2], images.shape[-1]])
        images = images.clamp(0, 1)

        plMod.logger.experiment.log(
            {
                f'{stage_prefix}/image': [
                    wandb.Image(
                        image,
                        caption=f'data, D(data), model(D(data)) rinsed, x0_hat',
                    )
                    for image in images
                ]
            }
        )

    @jaxtyped
    @typechecker
    def log_generate(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch_idx: int,
        batch: List[Float[torch.Tensor, 'b c h w']],
        modes: Float[torch.Tensor, 'b n_dom'],
        idx=None,
    ):
        # decide of the mode for the generation current batch
        if plMod.params.logging.hack_mode is not None:
            modes = get_hack_mode(hack_mode=plMod.params.logging.hack_mode, modes_init=modes)

        # flip full supervised sign to unsupervised
        if plMod.params.logging.if_all_here_generate_none:
            equilizing = torch.full_like(modes, fill_value=-1, dtype=modes.dtype, device=modes.device)
            full_supervision = (modes.sum(dim=1) == modes.shape[1]).int().unsqueeze(dim=1)
            modes = modes.clone()
            modes = modes + equilizing * full_supervision

        if self.params.early_leave:
            remaining = self.already_logged(plMod, batch_idx, batch_size=batch[0].shape[0])
            if remaining <= 0:
                return

        # generate samples
        with torch.no_grad():
            _, generated_data, generated_x0 = plMod.generate_samples(
                modes=modes,
                condition=[i for i in batch],

                # under sample in the time steps
                undersampling=plMod.params.logging.time_step_in_process,
                strategy=plMod.params.logging.strategy,
                quad_factor=plMod.params.logging.quad_factor,
            )

        if hasattr(plMod, 'id_callback'):
            stage = plMod.get_stage()
            id_cb: IDCallback = plMod.id_callback
            if id_cb.p.compute_running and stage in id_cb.p.stages:
                is_good_ema_mode = (id_cb.p.compute_on_ema and '/ema' in stage_prefix) or (not id_cb.p.compute_on_ema and '/ema' not in stage_prefix)
                is_good_frequency = id_cb.is_time_to_compute_valid(plMod.current_epoch) if stage == 'valid' else stage == 'test'
                if is_good_frequency and is_good_ema_mode:
                    # only use the sample where the mode is unsupervised for the face
                    mode_bool_face = modes.bool()[:, 0]
                    face_x0 = generated_x0[-1].split(self.params_data.data_params.dimension_per_domain, dim=1)[0]
                    face_x0_unsupervised = face_x0[~mode_bool_face]

                    id_cb.process_batch(batch_fakes=face_x0_unsupervised)
                    id_cb.check_compute_running(epoch=plMod.current_epoch, logger=plMod, trainer=plMod.trainer)

        # log metrics
        self.log_generate_metrics(
            stage_prefix=stage_prefix,
            plMod=plMod,
            batch=batch,
            modes=modes,
            generated_x0=generated_x0,
        )

        # log images
        images = self.log_generate_log_images(
            plMod=plMod,
            stage_prefix=stage_prefix,
            batch_idx=batch_idx,
            batch=batch,
            modes=modes,
            generated_data=generated_data,
            generated_x0=generated_x0,
        )
        self.log_generate_to_wandb(
            stage_prefix=stage_prefix,
            batch_idx=batch_idx,
            img_cat=images,
            plMod=plMod,
            idx=idx,
        )
        self.log_image_to_disk(
            img_cat=images,
            plMod=plMod,
            idx=idx,
        )

    @jaxtyped
    @typechecker
    def log_generate_diversity(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch_idx: int,
        batch: List[Float[torch.Tensor, 'b c h w']],
        modes: Float[torch.Tensor, 'b n_dom'],
        idx=None,
    ):
        original_batch_size = batch[0].shape[0]

        # decide of the mode for the generation current batch
        if plMod.params.logging.hack_mode is not None:
            modes = get_hack_mode(hack_mode=plMod.params.logging.hack_mode, modes_init=modes)

        # assert that there is no unsupervised samples
        unsupervised = modes.sum(dim=1) == 0
        modes[unsupervised, 2] = 1

        if self.params.early_leave:
            remaining = self.already_logged(plMod, batch_idx, batch_size=batch[0].shape[0])
            if remaining <= 0:
                return
            batch = [i[:remaining] for i in batch]
            modes = modes[:remaining]
            if idx is not None:
                idx = idx[:remaining]
        
        # region preprocessing for diversity logging
        variation_quantity = self.params.log_generate_diversity.variation_quantity or original_batch_size
        generate_all_in_batch = self.params.log_generate_diversity.generate_all_in_batch
        # endregion

        if not generate_all_in_batch:  # reduce the batch to one element
            random_sample = random.randint(0, batch[0].shape[0] - 1)

            batch = [i[random_sample:random_sample + 1] for i in batch]
            modes = modes[random_sample:random_sample + 1]
            if idx is not None:
                idx = idx[random_sample:random_sample + 1]

        # generate samples
        batch = [i.unsqueeze(1).repeat(1, variation_quantity, 1, 1, 1).flatten(end_dim=1) for i in batch]
        modes = modes.unsqueeze(1).repeat(1, variation_quantity, 1).flatten(end_dim=1)
        if idx is not None:
            idx = idx.unsqueeze(1).repeat(1, variation_quantity).flatten()

        assert plMod.params.logging.time_step_in_process == 1, 'diversity logging only support time_step_in_process=1'
        generated_datas = []
        generated_x0s = []
        for i in range(0, batch[0].shape[0], original_batch_size):
            _, generated_data, generated_x0 = plMod.generate_samples(
                modes=modes[i:i + original_batch_size],
                condition=[j[i:i + original_batch_size] for j in batch],

                # under sample in the time steps
                undersampling=plMod.params.logging.time_step_in_process,
                strategy=plMod.params.logging.strategy,
                quad_factor=plMod.params.logging.quad_factor,
            )
            generated_datas += generated_data
            generated_x0s += generated_x0

        # put batch the variation dimension
        batch = [i.unsqueeze(1).reshape(-1, variation_quantity, *i.shape[1:]) for i in batch]
        modes = modes.unsqueeze(1).reshape(-1, variation_quantity, modes.shape[1])
        if idx is not None:
            idx = idx.unsqueeze(1).reshape(-1, variation_quantity)

        generated_datas = torch.cat(generated_datas, dim=0).unsqueeze(1).reshape(-1, variation_quantity, *generated_datas[0].shape[1:])
        generated_x0s = torch.cat(generated_x0s, dim=0).unsqueeze(1).reshape(-1, variation_quantity, *generated_x0s[0].shape[1:])

        # log metrics
        self.log_generate_diversity_metrics(
            stage_prefix=stage_prefix,
            plMod=plMod,
            batch=batch,
            modes=modes,
            generated_x0=generated_x0s,
        )

        # log images
        images = self.log_generate_diversity_log_images(
            plMod=plMod,
            stage_prefix=stage_prefix,
            batch_idx=batch_idx,
            batch=batch,
            modes=modes,
            generated_x0=generated_x0s,
            idx=idx,
        )
        self.log_generate_to_wandb_diversity(
            stage_prefix=stage_prefix,
            batch_idx=batch_idx,
            img_cat=images,
            plMod=plMod,
            idx=idx,
        )
        self.log_image_to_disk_diversity(
            img_cat=images,
            plMod=plMod,
            idx=idx,
        )

    @jaxtyped
    @typechecker
    def log_generate_diversity_metrics(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch: List[Float[torch.Tensor, 'b diversity ci h w']],
        modes: Float[torch.Tensor, 'b diversity n_dom'],
        generated_x0: Float[torch.Tensor, 'b diversity n_dom_c h w'],
    ) -> None:
        # split the generation in a list of domain
        generated_prediction = list(generated_x0.split(self.params_data.data_params.dimension_per_domain, dim=2))

        metrics_obs = plMod.get_metric_object()
        metric_dict = metrics_obs.get_dict_generation_diversity(
            batch=batch,
            prediction=generated_prediction,
            modes=modes,
        )
        for metric_name, value in metric_dict.items():
            plMod.log_g(stage_prefix, metric_name, value)

    @jaxtyped
    @typechecker
    def log_generate_diversity_log_images(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch_idx: int,
        batch: List[Float[torch.Tensor, 'b diversity ci h w']],  # length of the list is the number of domains
        modes: Float[torch.Tensor, 'b diversity n_dom'],
        generated_x0  : Float[torch.Tensor, 'b diversity n_dom_c h w'],  # the length of the list is the number of time steps
        idx: Int[torch.Tensor, 'b diversity'] = None,
    ) -> Float[torch.Tensor, 's 3 hg wg']:
        n_dom_other = 4
        n_dom = 3
        dim_rgb = 3
        remaining = batch[0].shape[0]

        # remove useless data and cat along the time steps when needed
        batch = torch.cat(batch, dim=2)[:remaining]  # cat along channels -> [b diversity C h w]
        modes = modes[:remaining]
        x0s = generated_x0[:remaining]

        # remove diversity for batch modes and idx
        batch = batch[:, 0]
        modes = modes[:, 0]
        idx = idx[:, 0] if idx is not None else None

        # region normalize the value range -> only normalize the photo and sketch, not the segmentation, clip both to [0, 1]
        batch[:, :n_dom_other] = normalize_value_range(batch[:, :n_dom_other], plMod.params.logging.value_range, clip=True)
        x0s[:, :, :n_dom_other] = normalize_value_range(x0s[:, :, :n_dom_other], plMod.params.logging.value_range, clip=True)
        # endregion

        # region put data as rgb
        batch = self.data_to_rgb(batch)  # [b, 3*3, h, w]
        x0_hat = self.data_to_rgb(x0s.flatten(end_dim=1))  # [b*diversity, 3*3, h, w]
        # endregion

        # concat them in height dimension x_s and x0_hat
        x0_hat = x0_hat.reshape(batch.shape[0], -1, *x0_hat.shape[1:])  # [b, diversity, 3*3, h, w]
        x0_hat = x0_hat.reshape(batch.shape[0], -1, *x0_hat.shape[2:])  # [b, diversity, 3*3, h, w]
        x0_hat = x0_hat.permute(0, 2, 3, 1, 4).flatten(start_dim=3)  # [b, 3*3, h, w * diversity]
        img_cat = torch.cat([
            batch,
            x0_hat,
        ], dim=-1)  # [b, 3*3, h, w + (1 * diversity)]

        # region add supervision strip on the left
        strip_color_width = 10
        strip_space = 10
        strip_width = strip_color_width + strip_space
        strip_shape = [remaining, n_dom * dim_rgb, img_cat.shape[2], strip_width]
        strip = torch.zeros(strip_shape, device=img_cat.device)
        for bi in range(remaining):
            for dom_i in range(modes.shape[1]):
                is_supervised = modes[bi, dom_i] == 1
                color = torch.tensor([0, 1, 0]) if is_supervised else torch.tensor([1, 0, 0])
                strip_i = color.reshape(dim_rgb, 1, 1).repeat(1, img_cat.shape[2], strip_color_width)
                strip[bi, dom_i * dim_rgb:dom_i * dim_rgb + dim_rgb, :, :strip_color_width] = strip_i

        # concat the strip with the initial datas
        img_cat = torch.cat([strip, img_cat], dim=-1)
        # endregion

        # the domains are cat on the channel dims, they should be cat on the h dim
        img_cat = img_cat.reshape([remaining, n_dom, dim_rgb, img_cat.shape[-2], img_cat.shape[-1]])
        img_cat = img_cat.transpose(1, 2).reshape([-1, dim_rgb, n_dom * img_cat.shape[-2], img_cat.shape[-1]])
        
        return img_cat

    @jaxtyped
    @typechecker
    def log_generate_metrics(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch: List[Float[torch.Tensor, 'b ci h w']],
        modes: Float[torch.Tensor, 'b n_dom'],

        generated_x0: List[Float[torch.Tensor, 'b n_dom_c h w']],
    ) -> None:
        batch_size, _, h, w = batch[0].shape

        # split the generation in a list of domain
        generated_prediction = list(generated_x0[-1].split(self.params_data.data_params.dimension_per_domain, dim=1))

        metrics_obs = plMod.get_metric_object()
        metric_dict = metrics_obs.get_dict_generation(
            data=batch,
            prediction=generated_prediction,
            mode=modes,
        )
        for metric_name, value in metric_dict.items():
            plMod.log_g(stage_prefix, metric_name, value)

    @jaxtyped
    @typechecker
    def log_generate_log_images(
        self,
        stage_prefix: str,
        plMod: pl.LightningModule,
        batch_idx: int,
        batch: List[Float[torch.Tensor, 'b ci h w']],  # length of the list is the number of domains
        modes: Float[torch.Tensor, 'b n_dom'],

        generated_data: List[Float[torch.Tensor, 'b n_dom_c h w']],  # the length of the list is the number of time steps
        generated_x0  : List[Float[torch.Tensor, 'b n_dom_c h w']],  # the length of the list is the number of time steps
    ) -> Float[torch.Tensor, 's 3 hg wg']:
        params_data: CelebAParams = self.params_data.data_params
        n_dom_other = 4
        n_dom = 3
        dim_rgb = 3
        h, w = batch[0].shape[-2:]
        time_steps = len(generated_data)
        remaining = batch[0].shape[0]
        batch_size = remaining

        # remove useless data and cat along the time steps when needed
        batch  = torch.cat(batch, dim=1)[:remaining]  # cat along channels
        modes  = modes[:remaining]
        x_s    = torch.cat(generated_data, dim=-1)[:remaining]  # cat along the time steps
        x0_hat = torch.cat(generated_x0, dim=-1)[:remaining]    # cat along the time steps

        # region normalize the value range -> only normalize the photo and sketch, not the segmentation, clip both to [0, 1]
        batch[:, :n_dom_other]  = normalize_value_range(batch[:, :n_dom_other] , plMod.params.logging.value_range, clip=True)
        x_s[:, :n_dom_other]    = normalize_value_range(x_s[:, :n_dom_other]   , plMod.params.logging.value_range, clip=True)
        x0_hat[:, :n_dom_other] = normalize_value_range(x0_hat[:, :n_dom_other], plMod.params.logging.value_range, clip=True)
        # endregion

        # region FETCH THE L1 MAP

        # to have a good visualization of the l1 for the segmentation, we will put the segmentation in category then back into one hot
        x0_hat_for_l1 = torch.cat([
            x0_hat[:, :n_dom_other],
            F.one_hot(self.from_data_get_seg(x0_hat), num_classes=params_data.n_class).permute(0, 3, 1, 2),  # encode back into one hot and remove the background
        ], dim=1)

        # we do the same for the batch segmentation
        batch_for_l1 = torch.cat([
            batch[:, :n_dom_other],
            F.one_hot(self.from_data_get_seg(batch), num_classes=params_data.n_class).permute(0, 3, 1, 2),  # encode back into one hot and remove the background
        ], dim=1)

        l1_map = F.l1_loss(batch_for_l1.repeat(1, 1, 1, time_steps), x0_hat_for_l1, reduction='none')
        # l1 map is [batch, c_ndom, h, w * time_steps], we have the error for the segmentation for each class on the 3 dim of C
        # we will just put the sum clamped at one
        l1_map = torch.cat([
            l1_map[:, :n_dom_other],
            l1_map[:, n_dom_other:].sum(dim=1, keepdim=True).clamp(0., 1.),
        ], dim=1)

        # put in rgb
        l1_map = torch.cat([
            l1_map[:, :3].unsqueeze(1),
            l1_map[:, 3:4].unsqueeze(2).repeat(1, 1, dim_rgb, 1, 1),
            l1_map[:, 4:].unsqueeze(2).repeat(1, 1, dim_rgb, 1, 1),
        ], dim=1)
        l1_map[:, :, 1:] = 0  # l1 map is red
        # l1 map is [batch, c_ndom, rgb, h, w * time_steps]
        l1_map = l1_map.reshape(batch_size, -1, h, w * time_steps)
        # endregion

        # region put data as rgb
        batch  = self.data_to_rgb(batch)
        x_s    = self.data_to_rgb(x_s)
        x0_hat = self.data_to_rgb(x0_hat)
        # endregion

        # region at GT at the end of the line
        supervision_black_strip_width = 10
        black_strip = torch.zeros([remaining, x_s.shape[1], x_s.shape[2], supervision_black_strip_width],
                                  device=x_s.device)
        x_s = torch.cat([x_s, black_strip, batch], dim=-1)
        x0_hat = torch.cat([x0_hat, black_strip, batch], dim=-1)
        l1_map = torch.cat([l1_map, black_strip, batch], dim=-1)
        # endregion

        # concat them in height dimension x_s and x0_hat
        img_cat = torch.cat([x_s, x0_hat, l1_map], dim=-2)

        # region add supervision strip on the left
        strip_color_width = 10
        strip_space = 10
        strip_width = strip_color_width + strip_space
        strip_shape = [remaining, n_dom * dim_rgb, img_cat.shape[2], strip_width]
        strip = torch.zeros(strip_shape, device=img_cat.device)
        for bi in range(remaining):
            for dom_i in range(modes.shape[1]):
                is_supervised = modes[bi, dom_i] == 1
                color = torch.tensor([0, 1, 0]) if is_supervised else torch.tensor([1, 0, 0])
                strip_i = color.reshape(dim_rgb, 1, 1).repeat(1, img_cat.shape[2], strip_color_width)
                strip[bi, dom_i * dim_rgb:dom_i * dim_rgb + dim_rgb, :, :strip_color_width] = strip_i

        # concat the strip with the initial datas
        img_cat = torch.cat([strip, img_cat], dim=-1)
        # endregion

        # the domains are cat on the channel dims, they should be cat on the h dim
        img_cat = img_cat.reshape([remaining, n_dom, dim_rgb, img_cat.shape[-2], img_cat.shape[-1]])
        img_cat = img_cat.transpose(1, 2).reshape([-1, dim_rgb, n_dom * img_cat.shape[-2], img_cat.shape[-1]])
        
        return img_cat

    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_generate_to_wandb(
        self,
        stage_prefix: str,
        batch_idx: int,
        img_cat: Float[torch.Tensor, 'b 3 hg wg'],
        plMod,
        idx,
    ) -> None:
        batch_size = img_cat.shape[0]
        remaining = self.already_logged(plMod, batch_idx, batch_size=batch_size)
        if remaining <= 0:
            return
        remaining = min(remaining, batch_size)

        wandb_images = [wandb.Image(
            img_cat[i],
            caption=f'id:{idx[i]} xt and x0_hat',
        ) for i in range(remaining)]
        plMod.logger.experiment.log({f'{stage_prefix}/image': wandb_images})

    @jaxtyped
    @typechecker
    def log_image_to_disk(
        self,
        img_cat: Float[torch.Tensor, 'b 3 hg wg'],
        plMod,
        idx,
    ) -> None:
        # save image to disk if needed
        if plMod.get_stage() == 'test' and plMod.params.logging.save_image_to_disk:
            assert idx is not None, 'idx should be provided when saving image to disk'
            for i in range(img_cat.shape[0]):
                img_i = img_cat[i]
                save_image(img_i, fp=f'_results/{idx[i]}.png')

    @jaxtyped
    @typechecker
    @rank_zero_only
    def log_generate_to_wandb_diversity(
        self,
        stage_prefix: str,
        batch_idx: int,
        img_cat: Float[torch.Tensor, 'b 3 hg wg'],
        plMod,
        idx,
    ) -> None:
        batch_size = img_cat.shape[0]
        remaining = self.already_logged(plMod, batch_idx, batch_size=batch_size)
        if remaining <= 0:
            return
        remaining = min(remaining, batch_size)

        wandb_images = [wandb.Image(
            img_cat[i],
            caption=f'id:{idx[i]} gt and variations',
        ) for i in range(remaining)]
        plMod.logger.experiment.log({f'{stage_prefix}/image': wandb_images})

    @jaxtyped
    @typechecker
    def log_image_to_disk_diversity(
        self,
        img_cat: Float[torch.Tensor, 'b 3 hg wg'],
        plMod,
        idx: Int[torch.Tensor, 'b diversity'],
    ) -> None:
        # save image to disk if needed
        if plMod.get_stage() == 'test' and plMod.params.logging.save_image_to_disk:
            assert idx is not None, 'idx should be provided when saving image to disk'
            for i in range(img_cat.shape[0]):
                img_i = img_cat[i]
                save_image(img_i, fp=f'_results/{idx[i, 0].item()}_diversity.png')
