# -*- coding: utf-8 -*-

from typing import List, Tuple

import numpy as np
import torch
import torch.nn.functional as F

from common.constants import CLASSIFICATION, PROMPT
from tabular_diffusion.multinomial_diffusion_utils import cosine_beta_schedule, extract, index_to_log_onehot, \
    log_1_min_a, log_add_exp, log_onehot_to_index


class DiffusionManager(object):

    def __init__(self,
                 num_cont: int,
                 num_classes: List[int],
                 condition_type: str,
                 timesteps: int=1000,
                 problem_type: str=CLASSIFICATION,
                 device: str='cuda'):
        """

        :param num_cont: int, Number of continuous features
        :param num_classes: List[int], list containing the size of each categorical feature
        :param condition_type: str, "only_target", "prompt", None
        :param timesteps: int, number of timesteps for the diffusion process
        :param problem_type: str, CLASSIFICATION or REGRESSION
        :param device: str
        """
        self.device = device
        self.num_cont = num_cont
        self.num_cat = len(num_classes)
        self.num_classes = num_classes
        self.condition_type = condition_type
        self.timesteps = timesteps
        self.problem_type = problem_type
        self.target_index = -1 if problem_type == CLASSIFICATION else 0
        self.target_dtype = torch.long if problem_type == CLASSIFICATION else torch.float

        alphas = cosine_beta_schedule(timesteps)

        # Multinomial
        alphas = torch.tensor(alphas.astype('float64'))
        log_alpha = np.log(alphas)
        log_cumprod_alpha = np.cumsum(log_alpha)

        log_1_min_alpha = log_1_min_a(log_alpha)
        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)

        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5

        self.lt_history = torch.zeros(timesteps)
        self.lt_count = torch.zeros(timesteps)

        # Convert to float32 and register buffers.
        self.log_alpha = log_alpha.float().to(self.device)
        self.log_cumprod_alpha = log_cumprod_alpha.float().to(self.device)
        self.log_1_min_alpha = log_1_min_alpha.float().to(self.device)
        self.log_1_min_cumprod_alpha = log_1_min_cumprod_alpha.float().to(self.device)

        # Gaussian
        self.alpha = alphas.float().to(self.device)
        self.betas = (1 - alphas).float().to(self.device)
        self.alpha_cumulative = torch.cumprod(self.alpha, dim=0).float().to(self.device)
        self.sqrt_alpha_cumulative = torch.sqrt(self.alpha_cumulative).float().to(self.device)
        self.one_by_sqrt_alpha = (1. / torch.sqrt(self.alpha)).float().to(self.device)
        self.sqrt_one_minus_alpha_cumulative = torch.sqrt(1 - self.alpha_cumulative).float().to(self.device)

    def total_loss(self,
                   denoising_model: torch.nn.Module,
                   batch: torch.tensor,
                   mask: torch.tensor = None,
                   target_weights: torch.tensor = None) -> Tuple[torch.tensor, torch.tensor]:
        """
        Compute the total loss in bits per dim.

        :param denoising_model: nn.Module to use for denoising
        :param batch: shape=[batch_size, self.num_cont + self.num_cat]
        :param mask: optional torch.tensor of shape = [batch_size, self.num_cont + self.num_cat]
        :param target_weights: torch.tensor of shape = [batch_size, ]
        :return Tuple of torch.tensor, the continuous and categorical losses
        """
        batch_size, device = batch.size(0), batch.device

        # Sample time to use in the stochastic optimization
        t, pt = self.sample_time(batch_size, device, 'uniform')

        # FORWARD DIFFUSION
        x_t = []

        # Continuous (gaussian) variable Forward diffusion "simulation"
        eps_num_t = None
        if self.num_cont > 0:
            x_num = batch[:, :self.num_cont]
            x_num_t, eps_num_t = self.q_sample_gaussian(x_num, t)
            x_t.append(x_num_t)

        # Categorical (multinomial) variable Forward diffusion "simulation"
        log_x_start = None
        log_x_t = None
        if self.num_cat > 0:
            x_start = batch[:, self.num_cont:].to(dtype=torch.long)
            log_x_start = []
            log_x_t = []
            for i, n in enumerate(self.num_classes):
                log_x_start.append(index_to_log_onehot(x_start[:, i], n))
                log_x_t.append(self.q_sample_multinomial(log_x_start=log_x_start[-1], t=t, num_classes=n))
            log_x_start = torch.cat(log_x_start, dim=1)
            log_x_t = torch.cat(log_x_t, dim=1)

            # Model prediction
            x_cat_t = []
            ini = 0
            for n in self.num_classes:
                x_cat_t.append(log_onehot_to_index(log_x_t[:, ini: ini + n]).unsqueeze(dim=1))
                ini += n
            x_t.append(torch.cat(x_cat_t, dim=1))

        x_t = torch.cat(x_t, dim=1)

        # Apply "prompt" or "mask"
        if self.condition_type != PROMPT:
            mask = torch.zeros((batch_size, self.num_cont + self.num_cat)).to(device)
        x_t_mask = (1 - mask) * x_t + mask * batch

        model_prediction = denoising_model.forward(x_t_mask, t, mask=mask)

        cont_loss = None
        if self.num_cont > 0:
            model_prediction_num = model_prediction[:, :self.num_cont]
            cont_loss = self.mean_flat((0.5 * (eps_num_t - model_prediction_num) ** 2) *
                                       (1 - mask[:, :self.num_cont]))
            if target_weights is not None:
                cont_loss = cont_loss * target_weights
            cont_loss = cont_loss.mean()

        vb_loss = None
        if self.num_cat > 0:
            model_prediction_cat = model_prediction[:, self.num_cont:]

            # Loss functions
            kl = self.compute_lt(model_prediction_cat, log_x_start, log_x_t, t, mask[:, self.num_cont:])
            kl_prior = self.kl_prior(log_x_start, mask[:, self.num_cont:])

            # Upweigh loss term of the kl
            vb_loss = (kl / pt + kl_prior)
            vb_loss = (vb_loss / (len(self.num_classes)))
            if target_weights is not None:
                vb_loss = vb_loss * target_weights
            vb_loss = vb_loss.mean()

        return cont_loss, vb_loss

    def q_sample_gaussian(self,
                          x_start: torch.Tensor,
                          timesteps: torch.Tensor) -> Tuple[torch.tensor, torch.tensor]:
        """Forward diffusion process for continuous variables
                       q(xt | x0, t) = sqrt(cum_alpha_t) * x0 + sqrt(1 - cum_alpha_t) * noise

        :param x_start: torch.tensor of shape = [batch size, self.num_cont], dtype = torch.float32
        :param timesteps: torch.tensor of shape = [batch_size, ], dtype = torch.long
        :return: xt and eps(t)
        """
        eps = torch.randn_like(x_start)
        mean = extract(self.sqrt_alpha_cumulative, t=timesteps, x_shape=x_start.shape) * x_start
        std_dev = extract(self.sqrt_one_minus_alpha_cumulative, t=timesteps, x_shape=x_start.shape)
        sample = mean + std_dev * eps
        return sample, eps

    def p_pred_gaussian(self,
                        model_predictions: torch.tensor,
                        x_t: torch.tensor,
                        t: torch.tensor) -> torch.tensor:
        """Inverse diffusion process for continuous variables
                     p(xtm1 | xt) -> 1/sqrt(alpha_t) * (xt - beta_t / sqrt(1-cum_alpha_t) * model_prediction) +
                                     + sqrt(beta_t) * noise
        where the normal noise term is zero only when t > 0

        :param model_predictions: torch.tensor of shape = [batch_size, self.num_cont]
        :param x_t: torch.tensor of shape = [batch_size, self.num_cont]
        :param t: torch.tensor of shape = [batch size, ]
        :return: torch.tensor of shape = [batch_size, self.num_cont]
        """
        z = torch.randn_like(x_t)
        nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))

        beta_t = extract(self.betas, t, x_t.shape)
        one_by_sqrt_alpha_t = extract(self.one_by_sqrt_alpha, t, x_t.shape)
        sqrt_one_minus_alpha_cumulative_t = extract(self.sqrt_one_minus_alpha_cumulative, t, x_t.shape)

        x_tm1 = (
                 one_by_sqrt_alpha_t * (x_t - (beta_t / sqrt_one_minus_alpha_cumulative_t) * model_predictions)
                 + torch.sqrt(beta_t) * nonzero_mask * z
        )
        return x_tm1

    def q_sample_multinomial(self,
                             log_x_start: torch.tensor,
                             t: torch.tensor,
                             num_classes: int) -> torch.tensor:
        """Log of x(t) (here xt for simplicity):
            1) Step 1: log q(xt | x0) --> diffusion forward process
            2) Step 2: log sample from log q(xt | x0)

        :param log_x_start: shape: [batch_size, num_classes]
        :param t: shape: [batch_size, ]
        :param num_classes: int
        :return: Samples after "t" steps of noise = [batch_size, num_classes]
        """
        log_expected_value_qxt_given_x0 = self.q_pred_multinomial(log_x_start, t, num_classes)
        log_sample = self.log_sample_categorical(log_expected_value_qxt_given_x0, num_classes)
        return log_sample

    def q_pred_multinomial(self,
                           log_x_start: torch.tensor,
                           t: torch.tensor,
                           num_classes: int) -> torch.tensor:
        """Compute log Q(xt | x0), i.e. the multinomal forward diffusion process, the logarithmic version of

                xt = alpha_t * E[x0] + (1 - alpha_t) 1 / K

        :param log_x_start: log of X0, the instance to which the noise has to be added,
                            shape=[batch_size, num_classes]
        :param t: number of timesteps, shape=[batch_size, ]
        :param num_classes: int
        :return: torch.tensor, shape=[batch_size, num_classes]
        """
        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)

        # alpha_t * E[x0] + (1 - alpha_t) 1 / K
        log_probs = log_add_exp(log_x_start + log_cumprod_alpha_t,
                                log_1_min_cumprod_alpha - np.log(num_classes))
        return log_probs

    def q_pred_one_timestep(self,
                            log_x_t: torch.tensor,
                            t: torch.tensor,
                            num_classes: int) -> torch.tensor:
        """Compute log Q(xt+1 | xt), i.e. one step multinomal forward diffusion process, the logarithmic version of

                x(t+1) = alpha_t * E[xt] + (1 - alpha_t) 1 / K

        :param log_x_t: log of xt to which the noise has to be added, shape=[batch_size, num_classes]
        :param t: torch.tensor, shape=[batch_size, ]
        :param num_classes: int
        :return: torch.tensor, shape=[batch_size, num_classes]
        """
        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)

        # alpha_t * E[xt] + (1 - alpha_t) 1 / K
        log_probs = log_add_exp(log_x_t + log_alpha_t, log_1_min_alpha_t - np.log(num_classes))

        return log_probs

    @staticmethod
    def log_sample_categorical(logits: torch.tensor, num_classes: int) -> torch.tensor:
        """Log sample

        :param logits: Logit to use for sampling, shape = [batch_size, num_classes]
        :param num_classes: int
        :return: torch.tensor, shape = [batch size, num_classes]
        """
        uniform = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
        sample = (gumbel_noise + logits).argmax(dim=1)
        log_sample = index_to_log_onehot(sample, num_classes)
        return log_sample

    def compute_lt(self,
                   model_predictions: torch.tensor,
                   log_x_start: torch.tensor,
                   log_x_t: torch.tensor,
                   t: torch.tensor,
                   mask: torch.tensor,
                   detach_mean: bool=False) -> torch.tensor:
        """Compute the (multinomial) loss function for each timestep

        :param model_predictions: nn.tensor of shape = [batch_size, sum(self.num_classes)]
        :param log_x_start: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param log_x_t: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param t: torch.tensor of shape = [batch_size, ]
        :param mask: torch.tensor of shape = [batch_size, self.num_cat]
        :param detach_mean: bool
        :return: torch.tensor of shape = [batch_size, ]
        """
        log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)
        log_model_prob = self.p_pred(model_predictions=model_predictions, log_x=log_x_t, t=t)

        if detach_mean:
            log_model_prob = log_model_prob.detach()

        kl = self.multinomial_kl(log_true_prob, log_model_prob, mask)
        kl = self.mean_flat(kl)

        decoder_nll = -self.log_categorical(log_x_start, log_model_prob, mask)
        decoder_nll = self.mean_flat(decoder_nll)

        decoder_mask = (t == torch.zeros_like(t)).float()
        loss = decoder_mask * decoder_nll + (1. - decoder_mask) * kl

        return loss

    def q_posterior(self, log_x_start: torch.tensor, log_x_t: torch.tensor, t: torch.tensor) -> torch.tensor:
        """ Compute
                q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
            where
                q(xt | xt-1, x0) = q(xt | xt-1).

        :param log_x_start: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param log_x_t: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param t: torch.tensor of shape = [batch_size, ]
        :return: torch.tensor -> shape = [batch_size, ¿sum(self.num_classes)?]
        """
        t_minus_1 = t - 1
        # Remove negative values, will not be used anyway for final decoder
        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
        log_expected_value_qxtmin1_given_x0 = []
        ini = 0
        for n in self.num_classes:
            log_expected_value_qxtmin1_given_x0.append(self.q_pred_multinomial(log_x_start[:, ini:ini + n], t_minus_1, n))
            ini += n
        log_expected_value_qxtmin1_given_x0 = torch.cat(log_expected_value_qxtmin1_given_x0, dim=1)

        num_axes = (1,) * (len(log_x_start.size()) - 1)
        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
        log_expected_value_qxtmin1_given_x0 = torch.where(t_broadcast == 0,
                                                          log_x_start,
                                                          log_expected_value_qxtmin1_given_x0)

        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
        # Not very easy to see why this is true. But it is :)
        log_expected_value_qxtplus1_given_xt = []
        ini = 0
        for n in self.num_classes:
            log_expected_value_qxtplus1_given_xt.append(self.q_pred_one_timestep(log_x_t[:, ini: ini + n], t, n))
            ini += n
        log_expected_value_qxtplus1_given_xt = torch.cat(log_expected_value_qxtplus1_given_xt, dim=1)
        unnormed_logprobs = log_expected_value_qxtmin1_given_x0 + log_expected_value_qxtplus1_given_xt

        log_expected_value_xtmin1_given_xt_given_xstart = []
        ini = 0
        for n in self.num_classes:
            log_expected_value_xtmin1_given_xt_given_xstart.append(unnormed_logprobs[:, ini: ini + n] -
                                                                   torch.logsumexp(unnormed_logprobs[:, ini: ini + n],
                                                                                   dim=1, keepdim=True))
            ini += n
        log_expected_value_xtmin1_given_xt_given_xstart = torch.cat(log_expected_value_xtmin1_given_xt_given_xstart, dim=1)
        return log_expected_value_xtmin1_given_xt_given_xstart

    def p_pred(self, model_predictions: torch.tensor, log_x: torch.tensor, t: torch.tensor) -> torch.tensor:
        """ (using log)
        p(xtminus1 | xt) ~ q(xtminus1 | xt, hat(x0))

        where

        hat(x0) = model.forward(xt, t)

        :param model_predictions: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param log_x: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param t: torch.tensor of shape = [batch_size, ]
        :return: torch.tensor of shape [batch_size, sum(self.num_classes)]
        """
        assert model_predictions.shape[0] == log_x.shape[0]
        assert model_predictions.shape[1] == sum(self.num_classes)
        log_x_recon = []
        ini = 0
        for n in self.num_classes:
            log_x_recon.append(F.log_softmax(model_predictions[:, ini: ini + n], dim=1))
            ini += n
        log_x_recon = torch.cat(log_x_recon, dim=1)
        return self.q_posterior(log_x_start=log_x_recon, log_x_t=log_x, t=t)

    def multinomial_kl(self,
                       log_prob1: torch.tensor,
                       log_prob2: torch.tensor,
                       mask: torch.tensor) -> torch.tensor:
        """Compute the multinomial Kullback-Leibler loss

        :param log_prob1: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param log_prob2: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param mask: torch.tensor of shape = [batch_size, self.num_cat]
        :return: torch.tensor of shape = [batch_size, len(self.num_classes)]
        """
        kl = []
        ini = 0
        for j, n in enumerate(self.num_classes):
            # KL terms to sum
            tmp = (log_prob1[:, ini: ini + n].exp() * (log_prob1[:, ini: ini + n] - log_prob2[:, ini: ini + n]))
            # Sum over degrees of current variable
            tmp = tmp.sum(dim=1)

            # Multiply the current variable for its mask
            tmp = tmp * (1 - mask[:, j])

            kl.append(tmp.unsqueeze(dim=1))
            ini += n
        kl = torch.cat(kl, dim=1)
        return kl

    def log_categorical(self,
                        log_x_start: torch.tensor,
                        log_prob: torch.tensor,
                        mask: torch.tensor) -> torch.tensor:
        """Categorical loss

        :param log_x_start: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param log_prob: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param mask: torch.tensor of shape = [batch_size, len(self.num_classes)]
        :return: torch.tensor of shape = [batch_size, len(self.num_classes)]
        """
        log_cat = []
        ini = 0
        for j, n in enumerate(self.num_classes):
            tmp = (log_x_start[:, ini: ini + n].exp() * log_prob[:, ini: ini + n]).sum(dim=1) * (1 - mask[:, j])
            log_cat.append(tmp.unsqueeze(dim=1))
            ini += n
        log_cat = torch.cat(log_cat, dim=1)
        return log_cat

    def kl_prior(self, log_x_start: torch.tensor, mask: torch.tensor) -> torch.tensor:
        """KL Prior

        :param log_x_start: torch.tensor of shape = [batch_size, sum(self.num_classes)]
        :param mask: torch.tensor of shape = [batch_size, len(self.num_classes)]
        :return: Torch.tensor of shape = [batch_size, ]
        """
        b = log_x_start.size(0)
        device = log_x_start.device
        ones = torch.ones(b, device=device).long()

        log_qxT_prob = []
        log_half_prob = []
        ini = 0
        for n in self.num_classes:
            log_qxT_prob.append(self.q_pred_multinomial(log_x_start[:, ini: ini + n], t=(self.timesteps - 1) * ones, num_classes=n))
            log_half_prob.append(-torch.log(n * torch.ones_like(log_qxT_prob[-1])))
            ini += n
        log_qxT_prob = torch.cat(log_qxT_prob, dim=1)
        log_half_prob = torch.cat(log_half_prob, dim=1)

        kl_prior = self.multinomial_kl(log_qxT_prob,
                                       log_half_prob,
                                       mask)
        return self.mean_flat(kl_prior)

    def sample_time(self, batch_size: int, device, method: str='uniform') -> Tuple[torch.tensor, torch.tensor]:
        """Generate t samples for training --> stochastic training

        :param batch_size: int
        :param device: device to use
        :param method: str, sample method
        :return: torch.tensor of shape = [batch_size]
        """
        if method == 'importance':
            if not (self.lt_count > 10).all():
                return self.sample_time(batch_size, device, method='uniform')

            lt_sqrt = torch.sqrt(self.lt_history + 1e-10) + 0.0001
            lt_sqrt[0] = lt_sqrt[1]  # Overwrite decoder term with L1.
            pt_all = lt_sqrt / lt_sqrt.sum()

            t = torch.multinomial(pt_all, num_samples=batch_size, replacement=True)

            pt = pt_all.gather(dim=0, index=t)

            return t, pt

        elif method == 'uniform':
            t = torch.randint(0, self.timesteps, (batch_size,), device=device).long()

            pt = torch.ones_like(t).float() / self.timesteps
            return t, pt
        else:
            raise ValueError

    def sample(self,
               denoising_model: torch.nn.Module,
               dataloader,
               disable_progress_bar: bool=False) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
        """Diffusion sampling

        :param denoising_model: nn.Module, the denoising model
        :param dataloader: dataloader to use for generating the prompt (original data + mask)
        :param disable_progress_bar: bool, if True the progress bar is disabled
        :return: Samples, torch.tensor of shape = [num_samples, self.num_cont + self.num_cat]
        """
        assert self.condition_type == "prompt"

        device = self.log_alpha.device
        torch_samples = []
        torch_original = []
        torch_mask = []
        for batch, mask in dataloader:
            batch_size = batch.shape[0]
            batch = batch.to(device)
            mask = mask.to(device)

            # Sample the initial noise to denoise
            #   - continuous from gaussian distribution
            z_norm = None
            if self.num_cont > 0:
                z_norm = torch.randn((batch_size, self.num_cont), device=device)

            #   - categorical variable from uniform distribution,
            log_z = None
            if self.num_cat > 0:
                log_z = []
                for n in self.num_classes:
                    uniform_logits = torch.zeros((batch_size, n), device=device)
                    log_z.append(self.log_sample_categorical(uniform_logits, n))
                log_z = torch.cat(log_z, dim=1)

            # Denoise
            for i in reversed(range(0, self.timesteps)):
                if not disable_progress_bar:
                    print("\r {} {}".format('Sample timestep', i), end='')
                t = torch.full((batch_size,), i, device=device, dtype=torch.long)

                # Model denoising
                torch_z = [z_norm] if self.num_cont > 0 else []
                if self.num_cat > 0:
                    x_z = []
                    ini = 0
                    for n in self.num_classes:
                        x_z.append(log_onehot_to_index(log_z[:, ini: ini + n]).unsqueeze(dim=1))
                        ini += n
                    torch_z.append(torch.cat(x_z, dim=1))
                torch_z = torch.cat(torch_z, dim=1)

                # Apply "prompt" or "mask"
                torch_z = (1 - mask) * torch_z + mask * batch

                model_prediction = denoising_model.forward(torch_z.to(torch.float32),
                                                           t,
                                                           mask=mask)

                # Update z_norm
                if self.num_cont > 0:
                    z_norm = self.p_pred_gaussian(model_predictions=model_prediction[:, :self.num_cont],
                                                  x_t=z_norm, t=t)

                # Update log_z
                if self.num_cat > 0:
                    model_log_prob = self.p_pred(model_predictions=model_prediction[:, self.num_cont:],
                                                 log_x=log_z, t=t)
                    tmp = []
                    ini = 0
                    for n in self.num_classes:
                        tmp.append(self.log_sample_categorical(model_log_prob[:, ini: ini + n], n))
                        ini += n
                    log_z = torch.cat(tmp, dim=1)
            print()
            x0 = [z_norm] if self.num_cont > 0 else []
            if self.num_cat > 0:
                x0_cat = []
                ini = 0
                for n in self.num_classes:
                    x0_cat.append(log_onehot_to_index(log_z[:, ini: ini + n]).unsqueeze(dim=1))
                    ini += n
                x0.append(torch.cat(x0_cat, dim=1))
            torch_samples.append(((1 - mask) * torch.cat(x0, dim=1) + mask * batch).to('cpu'))
            torch_original.append(batch.to('cpu'))
            torch_mask.append(mask.to('cpu'))
        return torch.cat(torch_samples, dim=0).to('cpu'), torch.cat(torch_original, dim=0).to('cpu'), torch.cat(torch_mask).to('cpu')

    @staticmethod
    def mean_flat(tensor):
        """
        Take the mean over all non-batch dimensions.
        """
        return tensor.mean(dim=list(range(1, len(tensor.shape))))
