import random
from bisect import bisect_right
from copy import copy
from typing import Any, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from beartype import beartype as typechecker
from jaxtyping import Float, Int, Shaped, jaxtyped
from lightning_utilities.core.rank_zero import rank_zero_only
from torch.optim.lr_scheduler import (CosineAnnealingLR, LinearLR,
                                      ReduceLROnPlateau, SequentialLR)
from tqdm import tqdm

import wandb
from conf.dataset import DatasetParams
from conf.model import (BackboneParams, DiffusionParams_DDIM,
                        DiffusionParams_DDPM, LoggingParams, ModelParams)
from src.Backbones.utils import get_model
from utils.Logging.LoggingImport import get_log_strategy
from utils.Losses import get_loss
from utils.Metrics import get_metrics
from utils.utils import (broadcast_modes_to_pixels,
                         broadcast_modes_to_pixels_shape, display_mask,
                         display_tensor, freeze, get_undersample_indices,
                         is_logging_time, norm_max, norm_minmax)


class SequentialLR2(SequentialLR):
    def step(self, metric=None):
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch)
        scheduler = self._schedulers[idx]
        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
            scheduler.step(0)
        else:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(metric)
            else:
                scheduler.step()

        if isinstance(scheduler, ReduceLROnPlateau):
            self._last_lr = self._schedulers[0].get_last_lr()
        else:
            self._last_lr = scheduler.get_last_lr()


@jaxtyped
@typechecker
def extract(
    data: Float[torch.Tensor, 'X'],
    t: Int[torch.Tensor, 'b n_dom'],
    c_per_dom: List[int],
) -> Float[torch.Tensor, 'b c_n_dom 1 1']:
    data = data.to(t.device)

    b, n_dom = t.shape

    # extract the data that should be the same for one domain
    out = data.gather(-1, t.flatten()).reshape([b, n_dom])

    # repeat the data for each domain, and cat accros channels
    out = torch.cat([out[:, i].unsqueeze(1).repeat(1, c_per_dom[i]) for i in range(n_dom)], dim=1)
    # add the spatial dimensions
    out = out.unsqueeze(-1).unsqueeze(-1)

    return out


class DiffusionVariables_DDPM(nn.Module):
    def __init__(self, params: DiffusionParams_DDPM):
        super().__init__()
        self.p = params

        self.n_steps_training = params.n_steps_training
        self.beta_min = params.beta_min
        self.beta_max = params.beta_max
        self.clamp_generation = params.clamp_generation
        self.clamp_end = params.clamp_end
        self.clamp_min_max = params.clamp_min_max
        self.jump = params.jump
        self.jump_len = params.jump_len
        self.jump_n_sample = params.jump_n_sample

        self.beta             = torch.linspace(self.beta_min, self.beta_max, self.n_steps_training)  # linearly increasing variance schedule
        self.sqrt_beta        = self.beta.sqrt()
        self.alpha            = (1. - self.beta)
        self.sqrt_alpha       = self.alpha.sqrt()
        self.alpha_bar        = self.alpha.cumprod(dim=0)
        self.one_minus_alpha  = self.beta
        self.sqrt_recip_alpha = (1. / self.alpha).sqrt()

        self.sqrt_alpha_bar            = self.alpha_bar.sqrt()
        self.sqrt_one_minus_alphas_bar = (1. - self.alpha_bar).sqrt()
        self.eps_coef                  = self.one_minus_alpha / self.sqrt_one_minus_alphas_bar

        self.sigma2      = self.beta
        self.sigma2_sqrt = self.beta.sqrt()

        # for x0
        self.sqrt_recip_alpha_bar    = (1. / self.alpha_bar).sqrt()
        self.sqrt_recip_m1_alpha_bar = (1. / self.alpha_bar - 1.).sqrt()


class DiffusionVariables_DDIM(nn.Module):
    def __init__(self, params: DiffusionParams_DDIM):
        super().__init__()
        self.p = params

        self.skip_steps   = params.skip_steps
        self.repeat_noise = params.repeat_noise
        self.temperature  = params.temperature
        self.clamp_generation = params.clamp_generation
        self.clamp_end = params.clamp_end
        self.clamp_min_max = params.clamp_min_max
        self.jump = params.jump
        self.jump_len = params.jump_len
        self.jump_n_sample = params.jump_n_sample

        self.n_steps_training = params.n_steps_training
        self.n_steps_generation = params.n_steps_generation

        if params.ddim_discretize == 'uniform':
            c = self.n_steps_training // self.n_steps_generation
            self.time_steps = torch.arange(0, self.n_steps_training - 1, c).long() + 1
        elif params.ddim_discretize == 'quad':
            self.time_steps = torch.linspace(0, np.sqrt(self.n_steps_training * params.time_step_factor), self.n_steps_generation).pow(2).long() + 1
        else:
            raise ValueError(f'{params.ddim_discretize=} is not an available discretization method.')

        self.beta_min = params.beta_min
        self.beta_max = params.beta_max

        self.beta           = torch.linspace(self.beta_min, self.beta_max, self.n_steps_training)  # linearly increasing variance schedule
        self.sqrt_beta      = self.beta.sqrt()
        self.alpha          = (1. - self.beta)
        self.sqrt_alpha     = self.alpha.sqrt()
        self.alpha_bar      = self.alpha.cumprod(dim=0)
        self.sqrt_alpha_bar = self.alpha_bar.sqrt()
        self.sqrt_one_minus_alphas_bar = (1. - self.alpha_bar).sqrt()

        self.ddim_alpha           = self.alpha_bar[self.time_steps].clone()
        self.ddim_alpha_sqrt      = self.ddim_alpha.sqrt()
        self.ddim_alpha_prev      = torch.cat([self.alpha_bar[:1], self.alpha_bar[self.time_steps[:-1]]])
        self.ddim_alpha_prev_sqrt = self.ddim_alpha_prev.sqrt()
        self.ddim_sigma = (
            params.ddim_eta *
            (
                    (1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
                    (1 - self.ddim_alpha / self.ddim_alpha_prev)
            ).sqrt()
        )

        self.ddim_sqrt_one_minus_alpha = (1 - self.ddim_alpha).sqrt()

        self.dir_xt_coef = (1. - self.ddim_alpha_prev - self.ddim_sigma.pow(2)).sqrt()

    @jaxtyped
    @typechecker
    def from_indexes_get_tau(self, indexes: Int[torch.Tensor, '*shape']) -> Int[torch.Tensor, '*shape']:  # todo check this one (maybe not even needed)
        shape = indexes.shape
        time_steps = self.time_steps.to(indexes.device)
        taus = time_steps[indexes.flatten()]
        taus = taus.reshape(shape)
        return taus


class DiffusionAlltoAll(pl.LightningModule):
    def __init__(
        self,
        params: ModelParams,
        params_data: DatasetParams,
    ):
        super().__init__()
        self.params = params
        self.params_data = params_data
        self.model = self._get_epsilons_model(params.backbone)

        if params.diffusion.diffusion_name == 'ddim':
            self.diffusion_variables: DiffusionVariables_DDIM = DiffusionVariables_DDIM(params=params.diffusion)
        elif params.diffusion.diffusion_name == 'ddpm':
            self.diffusion_variables: DiffusionVariables_DDPM = DiffusionVariables_DDPM(params=params.diffusion)
        else:
            raise ValueError(f'{params.diffusion.diffusion_name=} is not an available diffusion method.')

        self.loss = get_loss(params.loss)
        assert self.params.loss.reduction == 'none'

        self.train_metrics = get_metrics(params.metrics)(params, params_data)
        self.valid_metrics = get_metrics(params.metrics)(params, params_data)
        self.test_metrics  = get_metrics(params.metrics)(params, params_data)
        freeze(self.train_metrics)
        freeze(self.valid_metrics)
        freeze(self.test_metrics)

        self.log_strategy = get_log_strategy(params.logging, params_data)

        # self.log_generation_time_steps(params.logging)
        # self.test_log_generation_time_steps(params.logging)

    @property
    def c_per_domain(self):
        shapes = self.params_data.data_params.dimension_per_domain
        return shapes

    def _get_epsilons_model(self, params: BackboneParams) -> nn.Module:
        return get_model(params)

    def configure_optimizers(self):
        optimizers_params = self.params.optimizer

        model_params = [{'params': self.model.parameters()}]
        if self.params.backbone.name == 'celeba_model3':
            # ddp not happy if there is parameter that we do not train in the opti
            will_train = (p for n, p in self.model.named_parameters() if 'encoders.encoders.0' not in n and 'decoders.decoders.0' not in n)
            model_params = [{'params': will_train}]

        if optimizers_params.optimizer == 'adam':
            opti = optim.Adam(
                model_params,
                lr=optimizers_params.learning_rate,
                betas=optimizers_params.betas,
                weight_decay=optimizers_params.weight_decay,
            )
        elif optimizers_params.optimizer == 'adamw':
            opti = optim.AdamW(
                model_params,
                lr=optimizers_params.learning_rate,
                betas=optimizers_params.betas,
                weight_decay=optimizers_params.weight_decay,
            )
        elif optimizers_params.optimizer == 'sgd':
            opti = optim.SGD(
                model_params,
                lr=optimizers_params.learning_rate,
                momentum=optimizers_params.momentum,
                weight_decay=optimizers_params.weight_decay,
            )
        else:
            raise ValueError(f'{optimizers_params.optimizer=} is not an available optimizer.')

        res = {'optimizer': opti}

        warmup_params = optimizers_params.learning_rate_warmup
        rop_params = optimizers_params.reduce_on_plateau
        cos_params = optimizers_params.cosines_scheduler
        assert sum([rop_params.use_scheduler, cos_params.use_scheduler]) <= 1, 'Only one scheduler can be used at a time.'

        warmup_scheduler = None
        if warmup_params.use_scheduler:
            warmup_scheduler = LinearLR(
                optimizer=opti,
                start_factor=warmup_params.start_factor,
                end_factor=warmup_params.end_factor,
                total_iters=warmup_params.total_iters,
                last_epoch=warmup_params.last_epoch,
                verbose=warmup_params.verbose,
            )

        main_scheduler = None
        if rop_params.use_scheduler:
            main_scheduler = ReduceLROnPlateau(
                optimizer=opti,
                mode=rop_params.mode,
                factor=rop_params.factor,
                patience=rop_params.patience,
                threshold=rop_params.threshold,
                threshold_mode=rop_params.threshold_mode,
                cooldown=rop_params.cooldown,
                min_lr=rop_params.min_lr,
                eps=rop_params.eps,
                verbose=rop_params.verbose,
            )
        if cos_params.use_scheduler:
            main_scheduler = CosineAnnealingLR(
                optimizer=opti,
                T_max=cos_params.T_max,
                eta_min=0,
                verbose=True,
            )

        scheduler = None
        if warmup_scheduler is not None and main_scheduler is not None:
            sequential_scheduler = SequentialLR2(
                optimizer=opti,
                schedulers=[warmup_scheduler, main_scheduler],
                milestones=[warmup_scheduler.total_iters + 1],
                last_epoch=-1,
                verbose=True,
            )
            scheduler = sequential_scheduler
        elif warmup_scheduler is not None or main_scheduler is not None:
            scheduler = warmup_scheduler or main_scheduler

        if scheduler is not None:
            res['lr_scheduler'] = {
                'scheduler': scheduler,
            }
            if rop_params.use_scheduler:
                res['lr_scheduler']['interval'] = rop_params.interval
                res['lr_scheduler']['monitor'] = rop_params.monitor
                res['lr_scheduler']['reduce_on_plateau'] = True

        return res

    def get_number_iterations(self) -> int:
        max_epochs = self.params.optimizer.max_epochs
        max_steps = self.params.optimizer.max_steps
        assert max_epochs == -1 or max_steps == -1, 'At least one of max_epochs and max_steps must be -1 to remove ambiguity.'

        if max_epochs != -1:
            estimated_stepping_batches = self.trainer.estimated_stepping_batches  # use the epochs
            t_max = estimated_stepping_batches
        elif max_steps != -1:
            t_max = max_steps
        else:
            raise ValueError('At least one of max_epochs and max_steps must be -1 to remove ambiguity.')

        return t_max

    def log_g(self, train_stage: str, logged: str, value: Any, **kwargs):
        is_train = 'train' in train_stage
        self.log(f'{train_stage}/{logged}', value, **kwargs, on_epoch=True, on_step=is_train, sync_dist=True)

    def training_step(self, batch, batch_idx):
        train_stage = 'train'
        loss = self._step(tuple(batch), train_stage, batch_idx=batch_idx)
        self.log_g(train_stage, 'lr', self.trainer.optimizers[0].param_groups[0]['lr'])
        return loss

    @jaxtyped
    @typechecker
    def _step(
        self,
        _batch,  #: Tuple[List[Float[torch.Tensor, 'b _ci h w']], Float[torch.Tensor, 'b n_dom']],
        stage_prefix: str,
        batch_idx: int,
    ) -> torch.Tensor:
        stage = self.get_stage()
        if self.params_data.data_params.return_indice:
            dataidx, mode = _batch
            data = dataidx[:-1]
            idx = dataidx[-1]
            _batch = data, mode
        else:
            data, mode = _batch
            idx = None
        n_dom = mode.shape[1]

        # broadcast the mode to every pixel, for each domain
        mode_per_pixel = broadcast_modes_to_pixels(data, mode)

        data_cat = torch.cat(data, dim=1)
        batch_size, c, h, w = data_cat.shape

        # region get t
        if not self.params.approach_spe.one_t_per_dom:  # one T for all the domains
            ts = torch.randint(0, self.diffusion_variables.n_steps_training, (batch_size,), device=self.device, dtype=torch.long)
            ts = ts.unsqueeze(1).repeat(1, n_dom)
        else:  # one T per domains
            ts = torch.randint(0, self.diffusion_variables.n_steps_training, (batch_size, n_dom), device=self.device, dtype=torch.long)

            if self.params.approach_spe.train_condition:  # have condition on line with more than one supervision
                condition_mask = torch.zeros_like(mode, device=self.device, dtype=mode.dtype)
                supervision_amound = mode.sum(dim=1)
                more_than_one_supervision = supervision_amound > 1  # where we can apply the condition mode
                for i in range(batch_size):

                    if more_than_one_supervision[i]:  # chose the condition, and make them clean
                        available_supervision = mode[i].nonzero().flatten()
                        number_of_condition = random.randint(0, available_supervision.shape[0] - 1)  # can be 0 conditions, but all cannot be condition
                        conditions = available_supervision[
                            torch.randperm(available_supervision.size(0))[:number_of_condition]
                        ]  # those are the condition in the current line, to mark them as condition, should put them to t0
                        ts[i, conditions] = 0
                        condition_mask[i, conditions] = 1
        # endregion

        if self.params.approach_spe.proportion_t0 is not None:
            # replace t with t0 with a certain probability
            t0 = torch.zeros_like(ts, device=self.device, dtype=torch.long)
            ts = torch.where(torch.rand_like(ts, dtype=torch.float32) < self.params.approach_spe.proportion_t0, t0, ts)

        if self.params.approach_spe.replace_missing_t_with_T:
            # mask ts according to mode: replace t where there is a missing input with T
            Ts = torch.full_like(ts, self.diffusion_variables.n_steps_training - 1, device=self.device, dtype=torch.long)
            ts = torch.where(mode == 1, ts, Ts).long()
            # endregion

        # region deterioration
        noise = self.get_noise(shapes=[i.shape for i in data], is_train=True)

        sqrt_alpha_bar            = extract(self.diffusion_variables.sqrt_alpha_bar           , ts, self.c_per_domain)
        sqrt_one_minus_alphas_bar = extract(self.diffusion_variables.sqrt_one_minus_alphas_bar, ts, self.c_per_domain)

        batch_mixed = sqrt_alpha_bar * data_cat + sqrt_one_minus_alphas_bar * noise
        # endregion

        # region mask input according to mode
        if self.params.approach_spe.empty_handling == 'noise':
            # replace missing inputs with noise
            batch_mixed = torch.where(mode_per_pixel == 1, batch_mixed, noise)
        elif self.params.approach_spe.empty_handling == 'pred':
            # replace missing inputs with the prediction from the steps before, by replacing with noise
            raise NotImplementedError
        elif self.params.approach_spe.empty_handling == 'minusone':
            # replace missing inputs with -1
            minusone = torch.full_like(batch_mixed, -1, device=self.device, dtype=batch_mixed.dtype)
            batch_mixed = torch.where(mode_per_pixel == 1, batch_mixed, minusone)
        else:
            raise ValueError(f'{self.params.approach_spe.empty_handling=} is not an available empty_handling.')
        # endregion

        batch_recon = self.model(batch_mixed, ts)

        losses = self.loss(batch=data_cat, batch_mixed=batch_mixed, noise_infos=noise, batch_recon=batch_recon, t=ts)
        # mask the loss according to mode
        loss_mask = mode_per_pixel
        if (not self.params.approach_spe.train_condition_learn_condition) and self.params.approach_spe.train_condition:  # if True, do not
            loss_mask = loss_mask * (1 - broadcast_modes_to_pixels(data, condition_mask))

        if self.params_data.ignore_0_index_sunrgbd:
            # found the pixels where there is the 0 class in the segmentation
            zero_layer = data[2][:, :1]  # [batch, 1, H, W] with 1 on misc
            
            # since it's a misc voxel, the whole pixel slide reconstruction error
            # should be ignored
            loss_mask[:, 4:] = loss_mask[:, 4:] * (1 - zero_layer)  # it is 1 where there is not the 0 class

        loss = (losses * loss_mask).sum() / loss_mask.sum()

        with torch.no_grad():
            x0_hat = self.get_x0_hat_for_logging(x_t=batch_mixed, noise_predicted=batch_recon, t=ts)  # TODO check this function, this should not be from DDIM since it's training step

        # Log
        self.log_g(f'{stage_prefix}/step', 'loss', loss.item(), prog_bar=True)

        # Metrics
        if stage in self.params.metrics.metrics_logging_stage:
            freqs = self.params.metrics.metrics_logging_freq
            freq = {'train': freqs[0], 'valid': freqs[1], 'test': freqs[2]}[stage]
            if self.current_epoch % freq == 0:
                metrics = self.get_metric_object().get_dict(
                    data=_batch[0],
                    mode=_batch[1],
                    batch_mixed=batch_mixed,
                    noise=noise,
                    batch_recon=batch_recon,
                    data0=x0_hat,
                )
                for metric_name, value in metrics.items():
                    self.log_g(f'{stage_prefix}/step', metric_name, value)

        # Log images
        if is_logging_time(self.params.logging.log_steps, current_epoch=self.current_epoch, batch_idx=batch_idx, stage=stage):
            self.log_strategy.log_train(
                stage_prefix=f'{stage_prefix}/step',
                prediction=batch_recon, input_to_model=batch_mixed, batch=data_cat, ts=ts,
                x0_hat    =x0_hat,
                plMod     =self,
                batch_idx =batch_idx,
                modes     =_batch[1],
            )
        if is_logging_time(self.params.logging.log_generate, current_epoch=self.current_epoch, batch_idx=batch_idx, stage=stage):
            self.log_strategy.log_generate(
                stage_prefix=f'{stage_prefix}/generate',
                plMod     =self,
                batch     =_batch[0],
                modes     =_batch[1],
                batch_idx =batch_idx,
                idx       = idx,
            )
        if is_logging_time(self.params.logging.log_generate_diversity, current_epoch=self.current_epoch, batch_idx=batch_idx, stage=stage):
            self.log_strategy.log_generate_diversity(
                stage_prefix=f'{stage_prefix}/generate_diversity',
                plMod     =self,
                batch     =_batch[0],
                modes     =_batch[1],
                batch_idx =batch_idx,
                idx       = idx,
            )
        return loss

    def validation_step(self, batch, batch_idx):
        ema_params = self.params.optimizer.ema
        is_ema = ema_params.use and not ema_params.validate_original_weights
        perform_double_pass = ema_params.use and ema_params.validate_original_weights and ema_params.perform_double_validation

        if not perform_double_pass:
            addition = ("/ema" if is_ema else "/van") if ema_params.use else ""
            self._step(tuple(batch), f'valid{addition}', batch_idx=batch_idx)
        else:
            self._step(tuple(batch), 'valid/van', batch_idx=batch_idx)

            self.ema.swap_model_weights(self.trainer)
            self._step(tuple(batch), 'valid/ema', batch_idx=batch_idx)
            self.ema.swap_model_weights(self.trainer)

    def test_step(self, batch, batch_idx):
        ema_params = self.params.optimizer.ema
        is_ema = ema_params.use and not ema_params.validate_original_weights
        perform_double_pass = ema_params.use and ema_params.validate_original_weights and ema_params.perform_double_validation

        if not perform_double_pass:
            addition = ("/ema" if is_ema else "/van") if ema_params.use else ""
            self._step(tuple(batch), f'test{addition}', batch_idx=batch_idx)
        else:
            self._step(tuple(batch), 'test/van', batch_idx=batch_idx)

            self.ema.swap_model_weights(self.trainer)
            self._step(tuple(batch), 'test/ema', batch_idx=batch_idx)
            self.ema.swap_model_weights(self.trainer)

    @jaxtyped
    @typechecker
    @torch.no_grad()
    def get_x0_hat_for_logging(  # TODO ask this question on discord
        self,
        x_t: Float[torch.Tensor, 'b c_ndom h w'],
        noise_predicted: Float[torch.Tensor, 'b c_ndom h w'],
        t: Int[torch.Tensor, 'b ndom'],
    ):
        """
        For logging only in the training loop.
        """
        df: DiffusionVariables_DDIM = self.diffusion_variables

        # keep in mind that alpha from ddim is alpha_bar from ddpm
        alpha_sqrt           = extract(df.sqrt_alpha_bar, t, self.c_per_domain)
        sqrt_one_minus_alpha = extract(df.sqrt_one_minus_alphas_bar, t, self.c_per_domain)

        pred_x0 = (x_t - sqrt_one_minus_alpha * noise_predicted) / alpha_sqrt

        return pred_x0

########################################################################################################################

    @jaxtyped
    @typechecker
    def get_noise(
        self,
        shapes: List,
        is_train: bool = False,
    ) -> Float[torch.Tensor, 'b c_ndom h w']:
        """
        Return the noise to combine (or initial) for each of the domains concatenated in the channel dimension.
        """
        same_noise_per_domain = self.params.approach_spe.noise_train if is_train else self.params.approach_spe.noise_gen

        if same_noise_per_domain == 'constant_noise':
            shape = max(shapes, key=lambda s: s[1])  # generate noise for the domain with more channels
            noise = torch.randn(*shape, device=self.device)
            noises = [noise[:, :c] for _, c, _, _ in shapes]  # crop the noise to the correct number of channels
            noises = torch.cat(noises, dim=1)
        elif same_noise_per_domain == 'vanilla_noise':
            nb_c = sum([c for _, c, _, _ in shapes])
            batch_noise = shapes[0][0]
            h_noise = shapes[0][2]
            w_noise = shapes[0][3]
            shapes_noise = [batch_noise, nb_c, h_noise, w_noise]
            noises = torch.randn(shapes_noise, device=self.device)
        else:
            raise NotImplementedError

        return noises

    @jaxtyped
    @typechecker
    def deteriorate_condition(
        self,
        condition: List[Float[torch.Tensor, 'b _ci h w']],
        t        : Int[torch.Tensor, 'b ndom'],
    ) -> Float[torch.Tensor, 'b c_ndom h w']:
        """
        apply the noise to the condition
        Should be used to inject the condition to create x_t
        """
        condition_cat = torch.cat(condition, dim=1)

        noise = self.get_noise(shapes=[xi.shape for xi in condition])

        sqrt_alpha_bar            = extract(self.diffusion_variables.sqrt_alpha_bar           , t, self.c_per_domain)
        sqrt_one_minus_alphas_bar = extract(self.diffusion_variables.sqrt_one_minus_alphas_bar, t, self.c_per_domain)

        x_tau_s = sqrt_alpha_bar * condition_cat + sqrt_one_minus_alphas_bar * noise

        return x_tau_s

    @jaxtyped
    @typechecker
    def from_ndom_to_cndom(
        self,
        tensor: Shaped[torch.Tensor, 'b ndom'],
    ) -> Shaped[torch.Tensor, 'b c_ndom']:
        b, ndom = tensor.shape
        shapes = [[b, ci] for ci in self.params_data.data_params.dimension_per_domain]
        shaped = torch.cat([tensor[:, i:i+1].repeat(1, ci) for i, (_, ci) in enumerate(shapes)], dim=1)
        return shaped

    # region UTILS GENERATION
    @jaxtyped
    @typechecker
    def generate_samples(
        self,
        condition: List[Float[torch.Tensor, 'b _ci h w']],
        modes: Float[torch.Tensor, 'b n_dom'],

        undersampling: int = 1,
        strategy: str = 'uniform',
        quad_factor: float = 0.8,
    ):
        nb_domain = modes.shape[-1]
        condition = copy(condition)
        # remove index if there is an index from that dataset
        if len(condition) == nb_domain + 1:
            condition = condition[:-1]
        else:
            assert len(condition) == nb_domain, f'{len(condition)=} should be {nb_domain=}'

        modes_per_channel = self.from_ndom_to_cndom(modes)

        # region remove data where there is no condition once for all
        for domain in range(nb_domain):
            modes_per_pixel_dom = broadcast_modes_to_pixels([condition[domain]], modes[:, domain].unsqueeze(1))
            noise = torch.randn_like(condition[domain])
            condition[domain] = torch.where(modes_per_pixel_dom == 1, condition[domain], noise)
        # endregion
        condition_cat = torch.cat(condition, dim=1)
        b, c_ndom, h, w = condition_cat.shape

        # region variables definition
        samples = []
        x0s     = []
        n_dom = len(condition)

        all_time_steps = self.get_time_steps()
        initial_time_steps_length = len(all_time_steps)
        all_indexes = torch.arange(len(all_time_steps)).flip(0)  # in decreasing order
        if self.diffusion_variables.jump:
            all_time_steps, all_indexes = self.augmente_time_steps_with_jumps(all_time_steps)
        # endregion

        # region obtain x_T the initial sample
        indexes, t = self.combine_time_steps(
            index=all_indexes[0],
            step=all_time_steps[0],
            modes=modes,
        )
        x_T = self.deteriorate_condition(condition=condition, t=t)
        # we already have replaced with noise the domains with no supervision
        x_T = torch.where(modes_per_channel.reshape([b, c_ndom, 1, 1]) == 1, x_T, condition_cat)
        # endregion

        # region define logging variables: used to know if we should save the current samples or not
        if undersampling is not None:
            initial_sampled_steps = get_undersample_indices(
                src_length=initial_time_steps_length,
                nb_samples=undersampling,
                strategy=strategy,
                quad_factor=quad_factor,
            )
            # construct the list of indice in the loop where the samples will be saved
            # this has to be done in case we are using jump, and we see multiple time the same indices
            seen_steps = dict()
            for indice_loop, index in enumerate(all_indexes):
                if index in initial_sampled_steps:
                        seen_steps[index.item()] = indice_loop
            sampled_steps = sorted([seen_steps[index] for index in initial_sampled_steps])
        else:
            sampled_steps = list(range(len(all_time_steps)))
        # we remove the last time step bcs we add it at the end anyway, no duplicate
        sampled_steps = sampled_steps[:-1]

        # endregion

        x_t = x_T
        pred_x_0 = None
        for indice_loop, (i, step) in enumerate(zip(tqdm(all_indexes, position=0, leave=True, desc="Generation Loop"), all_time_steps)):
            # compute the indices
            indexes, t = self.combine_time_steps(index=i, step=step, modes=modes)

            next_i = all_indexes[indice_loop + 1] if indice_loop + 1 < len(all_indexes) else None
            if next_i is not None and i < next_i:  # jump operation, forward diffusion
                x_t = self.forward_diffusion_step(
                    x_t=x_t,
                    t=t,
                )
                x_t_minus_1 = x_t  # just for the end of the loop affectation
            else:  # normal operation, backward diffusion
                # use supervised context and augment input with condition
                condition_x_t = self.deteriorate_condition(condition=condition, t=t)
                # change input only on the condition
                x_t = torch.where(modes_per_channel.reshape([b, c_ndom, 1, 1]) == 1, condition_x_t, x_t)

                if not self.params.approach_spe.one_t_per_dom:
                    # if we use do not use one T per domain, we have to make sure that t given are adjusted here, so we recompute them
                    indexes = torch.full([b, n_dom], fill_value=i.item(), device=self.device).long()
                    t = torch.full([b, n_dom], fill_value=step.item(), device=self.device).long()

                x_t_minus_1, pred_x_0, e_t = self.backward_diffusion_step(
                    x_t=x_t,
                    t=t,
                    indexes=indexes,
                )

            # region save samples
            if indice_loop in sampled_steps:
                samples.append(x_t)
                x0s.append(pred_x_0)

            x_t = x_t_minus_1

        if self.params.approach_spe.post_norm is None:
            pass
        elif self.params.approach_spe.post_norm == 'norm_minmax':
            x_t = norm_minmax(x_t, self.params_data.data_params.dimension_per_domain)
            pred_x_0 = norm_minmax(pred_x_0, self.params_data.data_params.dimension_per_domain)
        elif self.params.approach_spe.post_norm == 'norm_max':
            x_t = norm_max(x_t, self.params_data.data_params.dimension_per_domain)
            pred_x_0 = norm_max(pred_x_0, self.params_data.data_params.dimension_per_domain)
        else:
            raise ValueError(f'{self.params.approach_spe.post_norm=} is not an available post_norm.')

        if self.diffusion_variables.clamp_end:
            clamp_min, clamp_max = self.diffusion_variables.clamp_min_max
            x_t = torch.clamp(x_t, clamp_min, clamp_max)
            pred_x_0 = torch.clamp(pred_x_0, clamp_min, clamp_max)

        samples.append(x_t)
        x0s.append(pred_x_0)
        # return x_t, inputs to models, x0 prediction
        return x_t, samples, x0s

    @jaxtyped
    @typechecker
    def get_condition_time_steps(
        self,
        index: int,  # current index in the diffusion process  [high=noisy, low=clean]
        step: int,   # current t in the diffusion process  [high=noisy, low=clean]
    ) -> Tuple[
        int,  # index condition
        int,  # time_step condition
    ]:
        """
        Given index and step in the generation process, return the index and time step for the condition

        code for the variables: NC => low = noisy, high = clean
        code for the variables: CN => low = clean, high = noisy

        for list, low means the first elements, high means the last elements
        """
        ap = self.params.approach_spe

        index_CN = index

        cond_prog_NC = ap.noisy_condition_progression
        shift_from_gen_NC = ap.shift_from_gen

        # steps and indexes without the jump augmentation
        all_time_steps_CN = self.get_time_steps().flip(0).tolist()   # from clean to noisy
        all_indexes_CN = torch.arange(len(all_time_steps_CN)).tolist()  # indexes from clean to noisy

        # apply shift on the generation index
        index_CN = max(min(index_CN - shift_from_gen_NC, len(all_indexes_CN) - 1), 0)
        # cond_prod is an indices in the progression, meaning that the lower, the noisier, the higher, the cleaner
        # we need to transform it to a denoising indices, meaning that the lower, the cleaner, the higher, the noisier

        if ap.condition_mode == 'noisy':
            return index_CN, all_time_steps_CN[index_CN]

        elif ap.condition_mode == 'clean':
            index_cond = all_indexes_CN[0]
            t_cond     = all_time_steps_CN[0]

        elif ap.condition_mode == 'noisy_constant':
            index_cond = all_indexes_CN[::-1][cond_prog_NC]
            t_cond     = all_time_steps_CN[::-1][cond_prog_NC]

        elif ap.condition_mode == 'noisy_constant_fade':
            index_cond = all_indexes_CN[::-1][cond_prog_NC]
            t_cond     = all_time_steps_CN[::-1][cond_prog_NC]
            # if the generation has caught up with the condition, we follow the generation
            if index_CN <= index_cond:  # if the condition would be less clean than the generation, we follow the generation
                return index_CN, all_time_steps_CN[index_CN]
            # else fall back to noisy_constant

        elif ap.condition_mode == 'noisy_skip':
            current_progression_generation_NC = all_indexes_CN[-1] - index_CN  # how many step deep we are currently in for generation
            current_progression_condition_NC  = current_progression_generation_NC + cond_prog_NC  # how many step the condition should be deep in
            current_progression_condition_NC = min(current_progression_condition_NC, all_indexes_CN[-1])  # condition cannot be more noisy that the noisiest step
            computed_cond_prog_CN = max(all_indexes_CN[0], all_indexes_CN[::-1][current_progression_condition_NC])  # condition cannot be more clean that the cleanest step
            # we have a level of noise

            index_cond = all_indexes_CN[computed_cond_prog_CN]
            t_cond     = all_time_steps_CN[computed_cond_prog_CN]

        else:
            raise ValueError(f"Unknown condition mode {ap.condition_mode}")

        return index_cond, t_cond

    @jaxtyped
    @typechecker
    def combine_time_steps(
        self,
        index: Int[torch.Tensor, ''],  # current index in the diffusion process
        step: Int[torch.Tensor, ''],   # current t in the diffusion process

        modes: Float[torch.Tensor, 'b n_dom'],
    ) -> Tuple[
        Int[torch.Tensor, 'b n_dom'],  # indexes
        Int[torch.Tensor, 'b n_dom'],  # time_steps
    ]:
        """
        Create the t and the indexes for the current step on the condition
        """
        b, n_dom = modes.shape

        i = torch.full([b, n_dom], fill_value=index.item(), device=self.device).long()
        t = torch.full([b, n_dom], fill_value=step.item(), device=self.device).long()

        index_cond, t_cond = self.get_condition_time_steps(index=index.item(), step=step.item())

        index_cond = torch.full([b, n_dom], fill_value=index_cond, device=self.device).long()
        t_cond     = torch.full([b, n_dom], fill_value=t_cond    , device=self.device).long()
        i = torch.where(modes == 1, index_cond, i)
        t = torch.where(modes == 1, t_cond    , t)
        return i, t

    @jaxtyped
    @typechecker
    def augmente_time_steps_with_jumps(
        self,
        original_time_steps: Int[torch.Tensor, 'nb_diffusion_steps'],
        # jump_len: int,
        # jump_n_sample: int,
    ) -> Tuple[Int[torch.Tensor, 'augmented_steps'], Int[torch.Tensor, 'augmented_steps']]:
        df = self.diffusion_variables

        jump_len = df.jump_len
        jump_n_sample = df.jump_n_sample

        if isinstance(self.diffusion_variables, DiffusionVariables_DDPM):
            indices = self.augmente_time_steps_with_jumps_ddpm(len(original_time_steps), jump_len, jump_n_sample)
            time_steps = original_time_steps.flip(0)[indices]
        elif isinstance(self.diffusion_variables, DiffusionVariables_DDIM):
            indices = self.augmente_time_steps_with_jumps_ddim(len(original_time_steps), jump_len, jump_n_sample)
            time_steps = original_time_steps.flip(0)[indices]
        else:
            raise NotImplementedError

        return time_steps, indices

    @jaxtyped
    @typechecker
    def augmente_time_steps_with_jumps_ddpm(self, t_T, jump_len, jump_n_sample) -> Int[torch.Tensor, 'augmented_steps']:
        jumps = {}
        for j in range(0, t_T - jump_len, jump_len):
            jumps[j] = jump_n_sample - 1

        t = t_T
        ts = []
        while t >= 1:
            t = t - 1
            ts.append(t)

            if jumps.get(t, 0) > 0:
                jumps[t] = jumps[t] - 1
                for _ in range(jump_len):
                    t = t + 1
                    ts.append(t)

        # ts.append(-1)  # add the -1 but will not be used

        ts = torch.tensor(ts, device=self.device).long()

        return ts

    def augmente_time_steps_with_jumps_ddim(self, t_T, jump_len, jump_n_sample) -> List[int]:
        return self.augmente_time_steps_with_jumps_ddpm(t_T, jump_len, jump_n_sample)

    @jaxtyped
    @typechecker
    def get_time_steps(self) -> Int[torch.Tensor, 'nb_diffusion_steps']:
        if isinstance(self.diffusion_variables, DiffusionVariables_DDIM):
            return self.get_time_steps_ddim()
        elif isinstance(self.diffusion_variables, DiffusionVariables_DDPM):
            return self.get_time_steps_ddpm()
        else:
            raise NotImplementedError

    @jaxtyped
    @typechecker
    def get_time_steps_ddim(self) -> Int[torch.Tensor, 'nb_diffusion_steps']:
        df: DiffusionVariables_DDIM = self.diffusion_variables  # check if it should stop at 1 or 0?
        time_steps = df.time_steps.flip(dims=[0])[df.skip_steps:]  # tau_S, ..., tau_1
        time_steps = time_steps.to(self.device)
        return time_steps

    @jaxtyped
    @typechecker
    def get_time_steps_ddpm(self) -> Int[torch.Tensor, 'nb_diffusion_steps']:
        df: DiffusionVariables_DDPM = self.diffusion_variables
        time_steps = df.n_steps_training - 1
        time_steps = torch.arange(time_steps, -1, -1, device=self.device)
        return time_steps

    @jaxtyped
    @typechecker
    def forward_diffusion_step(
        self,
        x_t: Float[torch.Tensor, 'b c_n_dom h w'],
        t: Int[torch.Tensor, 'b n_dom'],
    ) -> Float[torch.Tensor, 'b c_n_dom h w']:  # x_t+1
        """
        Ony one step forward from x_t to x_t+1 (not from x0)
        we actually don't care what is condition or not in this step
        """
        df: DiffusionVariables_DDPM = self.diffusion_variables
        noise = self.get_noise(shapes=[list(x_t.shape)])

        sqrt_alpha = extract(self.diffusion_variables.sqrt_alpha, t, self.c_per_domain)
        sqrt_beta = extract(self.diffusion_variables.sqrt_beta, t, self.c_per_domain)

        x_t_plus_1 = sqrt_alpha * x_t + sqrt_beta * noise

        x_t_plus_1 = x_t_plus_1 if not df.clamp_generation \
            else torch.clamp(x_t_plus_1, min=df.clamp_min_max[0], max=df.clamp_min_max[1])

        return x_t_plus_1

    @jaxtyped
    @typechecker
    def backward_diffusion_step(
        self,
        x_t: Float[torch.Tensor, 'b c_n_dom h w'],
        t: Int[torch.Tensor, 'b n_dom'],
        indexes: Int[torch.Tensor, 'b n_dom'],
    ) -> Tuple[
        Float[torch.Tensor, 'b c_n_dom h w'],  # x_t-1
        Float[torch.Tensor, 'b c_n_dom h w'],  # x_0_t
        Float[torch.Tensor, 'b c_n_dom h w'],  # e_t
    ]:
        if isinstance(self.diffusion_variables, DiffusionVariables_DDIM):
            return self.backward_diffusion_ddim(
                x_t=x_t,
                t=t,
                indexes=indexes,
            )
        elif isinstance(self.diffusion_variables, DiffusionVariables_DDPM):
            return self.backward_diffusion_ddpm(
                x_t=x_t,
                t=t,
                indexes=indexes,
            )
        else:
            raise NotImplementedError

    @jaxtyped
    @typechecker
    def backward_diffusion_ddpm(
        self,
        x_t: Float[torch.Tensor, 'b c_ndom h w'],
        t: Int[torch.Tensor, 'b ndom'],
        indexes: Int[torch.Tensor, 'b ndom'],
    ) -> Tuple[
        Float[torch.Tensor, 'b c_n_dom h w'],  # x_t-1
        Float[torch.Tensor, 'b c_n_dom h w'],  # x_0_t
        Float[torch.Tensor, 'b c_n_dom h w'],  # e_t
    ]:
        b, c_ndom, h, w = x_t.shape
        n_dom = t.shape[1]

        df: DiffusionVariables_DDPM = self.diffusion_variables

        z = torch.randn_like(x_t, device=self.device)
        t_c_ndom = self.from_ndom_to_cndom(t)
        z = torch.where(t_c_ndom.unsqueeze(-1).unsqueeze(-1) > 1, z, torch.zeros_like(z))

        eps_theta = self.model(x_t, t)

        eps_coef = extract(df.eps_coef, t, self.c_per_domain)
        sqrt_recip_alpha = extract(df.sqrt_recip_alpha, t, self.c_per_domain)

        mean = sqrt_recip_alpha * (x_t - eps_coef * eps_theta)
        mean = mean if not df.clamp_generation else torch.clamp(mean, min=df.clamp_min_max[0], max=df.clamp_min_max[1])
        var_sqrt = extract(df.sigma2_sqrt, t, self.c_per_domain)

        # region compute x0
        sqrt_recip_alpha_bar = extract(df.sqrt_recip_alpha_bar, t, self.c_per_domain)
        sqrt_recip_m1_alpha_bar = extract(df.sqrt_recip_m1_alpha_bar, t, self.c_per_domain)
        x0 = sqrt_recip_alpha_bar * x_t - sqrt_recip_m1_alpha_bar * eps_theta
        x0 = x0 if not df.clamp_generation else torch.clamp(x0, min=df.clamp_min_max[0], max=df.clamp_min_max[1])
        # endregion

        x_t_minus_one = mean + var_sqrt * z

        return x_t_minus_one, x0, eps_theta

    @jaxtyped
    @typechecker
    def backward_diffusion_ddim(
        self,
        x_t: Float[torch.Tensor, 'b c_n_dom h w'],  # tau_t
        t: Int[torch.Tensor, 'b n_dom'],
        indexes: Int[torch.Tensor, 'b n_dom'],
    ) -> Tuple[
        Float[torch.Tensor, 'b c_n_dom h w'],  # x_t-1
        Float[torch.Tensor, 'b c_n_dom h w'],  # x_0_t
        Float[torch.Tensor, 'b c_n_dom h w'],  # e_t
    ]:
        b, _, h, w = x_t.shape
        df: DiffusionVariables_DDIM = self.diffusion_variables

        # tau = df.from_indexes_get_tau(indexes)

        e_t = self.model(x_t, t)

        alpha_sqrt           = extract(df.ddim_alpha_sqrt          , indexes, self.c_per_domain)
        alpha_prev_sqrt      = extract(df.ddim_alpha_prev_sqrt     , indexes, self.c_per_domain)
        sigma                = extract(df.ddim_sigma               , indexes, self.c_per_domain)
        sqrt_one_minus_alpha = extract(df.ddim_sqrt_one_minus_alpha, indexes, self.c_per_domain)
        dir_xt_coef          = extract(df.dir_xt_coef              , indexes, self.c_per_domain)

        pred_x0 = (x_t - sqrt_one_minus_alpha * e_t) / alpha_sqrt  # current prediction of x0
        pred_x0 = pred_x0 if not df.clamp_generation else torch.clamp(pred_x0, min=df.clamp_min_max[0], max=df.clamp_min_max[1])
        dir_xt  = dir_xt_coef * e_t                                # direction pointing to xt

        shapes = [[b, ci, h, w] for ci in self.params_data.data_params.dimension_per_domain]
        noise = self.get_noise(shapes=shapes)  # TODO check: maybe the added noise should be 0 where t is very low
        noise = torch.where(sigma == 0, torch.zeros_like(x_t), noise).to(self.device)
        if df.repeat_noise:
            noise = noise[0:1]
        noise = noise * df.temperature

        x_tau_i_minus_1 = alpha_prev_sqrt * pred_x0 + dir_xt + sigma * noise

        return x_tau_i_minus_1, pred_x0, e_t
    # endregion

    @rank_zero_only
    def log_generation_time_steps(
        self,
        logging_params: LoggingParams,
    ):
        """
        Log on time the generation step to wandb
        """
        undersampling = logging_params.time_step_in_process
        strategy = logging_params.strategy
        quad_factor = logging_params.quad_factor

        all_time_steps = self.get_time_steps()
        initial_time_steps_length = len(all_time_steps)
        all_indexes = torch.arange(len(all_time_steps)).flip(0)  # in decreasing order
        if self.diffusion_variables.jump:
            all_time_steps, all_indexes = self.augmente_time_steps_with_jumps(all_time_steps)

        if undersampling is not None:
            initial_sampled_steps = get_undersample_indices(
                src_length=initial_time_steps_length,
                nb_samples=undersampling,
                strategy=strategy,
                quad_factor=quad_factor,
            )
            # construct the list of indice in the loop where the samples will be saved
            # this has to be done in case we are using jump, and we see multiple time the same indices
            seen_steps = dict()
            for indice_loop, index in enumerate(all_indexes):
                if index in initial_sampled_steps:
                        seen_steps[index.item()] = indice_loop
            sampled_steps = [seen_steps[index] for index in initial_sampled_steps]
        else:
            sampled_steps = range(len(all_time_steps))

        conditions_steps = []
        for indice_loop, (i, step) in enumerate(zip(all_indexes, all_time_steps)):
            # compute the indices
            next_i = all_indexes[indice_loop + 1] if indice_loop + 1 < len(all_indexes) else None
            if next_i is not None and i < next_i:  # jump operation, forward diffusion
                conditions_steps.append(None)
            else:  # normal operation, backward diffusion
                i, t = self.get_condition_time_steps(index=i.item(), step=step.item())
                conditions_steps.append(t)

        # create a math plot lib figure
        fig, ax = plt.subplots()
        ax.plot(all_time_steps)
        ax.plot(conditions_steps)
        ax.scatter(sampled_steps, [all_time_steps[i] for i in sampled_steps], color='red')
        plt.xlabel('Generation loop steps')
        plt.ylabel('Time steps')
        wandb.log({'generation diffusion steps': wandb.Image(fig)})

    @rank_zero_only
    def test_log_generation_time_steps(
        self,
        logging_params: LoggingParams,
    ):
        for condition_mode in [
            'noisy',
            'clean',
            'noisy_constant',
            'noisy_constant_fade',
            'noisy_skip',
        ]:
            if condition_mode in ['noisy', 'clean']:
                pull = [0]
            else:
                pull = [i for i in [
                    0,
                    10,
                    20,
                    30,
                    40,
                    50,
                    60,
                    70,
                    80,
                    90,
                    99,
                ]]
            for noisy_condition_progression in pull:
                self.params.approach_spe.condition_mode = condition_mode
                self.params.approach_spe.noisy_condition_progression = noisy_condition_progression
                self._test_log_generation_time_steps(logging_params)

    def _test_log_generation_time_steps(
        self,
        logging_params: LoggingParams,
    ):
        undersampling = logging_params.time_step_in_process
        strategy = logging_params.strategy
        quad_factor = logging_params.quad_factor

        all_time_steps = self.get_time_steps()
        initial_time_steps_length = len(all_time_steps)
        all_indexes = torch.arange(len(all_time_steps)).flip(0)  # in decreasing order
        if self.diffusion_variables.jump:
            all_time_steps, all_indexes = self.augmente_time_steps_with_jumps(all_time_steps)

        if undersampling is not None:
            initial_sampled_steps = get_undersample_indices(
                src_length=initial_time_steps_length,
                nb_samples=undersampling,
                strategy=strategy,
                quad_factor=quad_factor,
            )
            # construct the list of indice in the loop where the samples will be saved
            # this has to be done in case we are using jump, and we see multiple time the same indices
            seen_steps = dict()
            for indice_loop, index in enumerate(all_indexes):
                if index in initial_sampled_steps:
                    seen_steps[index.item()] = indice_loop
            sampled_steps = [seen_steps[index] for index in initial_sampled_steps]
        else:
            sampled_steps = range(len(all_time_steps))

        conditions_steps = []
        for indice_loop, (i, step) in enumerate(zip(all_indexes, all_time_steps)):
            # compute the indices
            next_i = all_indexes[indice_loop + 1] if indice_loop + 1 < len(all_indexes) else None
            if next_i is not None and i < next_i:  # jump operation, forward diffusion
                conditions_steps.append(None)
            else:  # normal operation, backward diffusion
                i, t = self.get_condition_time_steps(index=i.item(), step=step.item())
                conditions_steps.append(t)

        # create a math plot lib figure
        fig, ax = plt.subplots()
        ax.plot(all_time_steps)
        ax.plot(conditions_steps)
        ax.scatter(sampled_steps, [all_time_steps[i] for i in sampled_steps], color='red')
        plt.xlabel('Generation loop steps')
        plt.ylabel('Time steps')
        wandb.log({f'{self.params.approach_spe.condition_mode} {self.params.approach_spe.noisy_condition_progression}': wandb.Image(fig)})

    def get_stage(self) -> str:
        if self.trainer.training:
            return 'train'
        elif self.trainer.validating or self.trainer.sanity_checking:
            return 'valid'
        elif self.trainer.testing or self.trainer.predicting:
            return 'test'
        else:
            raise Exception(f'Stage not supported.')

    def get_metric_object(self):
        return {
            'train': self.train_metrics,
            'valid': self.valid_metrics,
            'test': self.test_metrics,
        }[self.get_stage()]
