import math
import copy
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count

import numpy as np

import torch
from torch import nn, einsum
from torch.cuda.amp import autocast
import torch.nn.functional as F

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from tqdm.auto import tqdm
from ema_pytorch import EMA

from gps_encoder import Sphere2VecLocationEncoder, SphericalHarmonicsLocationEncoder, RFFLocationEncoder, SphericalHarmonicsDiracLocationEncoder
from gps_decoder import Sphere2VecLocationDecoder, SphericalHarmonicsDiracLocationDecoder
from image_encoder import ImageEncoder, ImageEmbeddingEncoder

from utils import great_circle_distance_loss

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

DEFAULT_Y0 = 0.886226925452758

ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

def divisible_by(numer, denom):
    return (numer % denom) == 0

def identity(t, *args, **kwargs):
    return t

def cycle(dl):
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def convert_image_to_fn(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# normalization functions

def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

# small helper modules

def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original ddpm paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float32)

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float32) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    # return torch.clip(betas, 0, 0.999)
    return betas

def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
    """
    sigmoid schedule
    proposed in https://arxiv.org/abs/2212.11972 - Figure 8
    better for images > 64x64, when used during training
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float32) / timesteps
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

def zero_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
    """
    sigmoid schedule
    proposed in https://arxiv.org/abs/2212.11972 - Figure 8
    better for images > 64x64, when used during training
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float32) / timesteps
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model,
        *,
        dim_inputs,
        location_encoding_type,
        dim_encoding,
        dim_hidden,
        dim_condition,
        data_size,
        num_train_grid_points,
        num_sample_grid_points,
        train_grid_filepath,
        sample_grid_filepath,
        device,
        train_mode,
        timesteps = 1000,
        sample_save_interval=10,
        sampling_timesteps = None,
        objective = 'pred_v',
        beta_schedule = 'cosine',
        noise_amplifier=1.0,
        schedule_fn_kwargs = dict(),
        ddim_sampling_eta = 0.,
        auto_normalize = True,
        offset_noise_strength = 0.,  # https://www.crosslabs.org/blog/diffusion-with-offset-noise
        min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
        min_snr_gamma = 5
    ):
        super().__init__()

        self.model = model
        self.dim_inputs = dim_inputs
        self.dim_encoding = dim_encoding
        self.dim_hidden = dim_hidden
        self.location_encoding_type = location_encoding_type
        self.self_condition = None

        self.objective = objective

        assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'

        if beta_schedule == 'linear':
            beta_schedule_fn = linear_beta_schedule
        elif beta_schedule == 'cosine':
            beta_schedule_fn = cosine_beta_schedule
        elif beta_schedule == 'sigmoid':
            beta_schedule_fn = sigmoid_beta_schedule
        else:
            raise ValueError(f'unknown beta schedule {beta_schedule}')

        self.noise_amplifier = noise_amplifier
        betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

        timesteps, = betas.shape
        self.sample_save_interval = sample_save_interval
        self.num_timesteps = int(timesteps)

        # sampling related parameters
        self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training

        assert self.sampling_timesteps <= timesteps
        self.is_ddim_sampling = self.sampling_timesteps < timesteps
        self.ddim_sampling_eta = ddim_sampling_eta

        if location_encoding_type == "sphere2vec":
            self.location_encoder = Sphere2VecLocationEncoder(dim_encoding, dim_hidden, device).to(device)
        elif location_encoding_type == "spherical_harmonics":
            self.location_encoder = SphericalHarmonicsLocationEncoder(dim_encoding, dim_hidden, device).to(device)
        elif location_encoding_type == "spherical_harmonics_coefficient":
            self.location_encoder = SphericalHarmonicsCoefficientLocationEncoder(dim_encoding, device).to(device)
        elif location_encoding_type == "spherical_harmonics_dirac":
            self.location_encoder = SphericalHarmonicsDiracLocationEncoder(data_size, dim_encoding, device).to(device)
        elif location_encoding_type == "rff":
            self.location_encoder = RFFLocationEncoder(dim_hidden, file_dir="weights/location_encoder_weights.pth").to(device)
        else:
            assert False, "Unknown location encoder!"

        self.train_grid_filepath, self.sample_grid_filepath = train_grid_filepath, sample_grid_filepath
        self.num_train_grid_points, self.num_sample_grid_points = num_train_grid_points, num_sample_grid_points

        self.location_decoder = SphericalHarmonicsDiracLocationDecoder(dim_hidden, device,
                                                                       num_train_grid_points=self.num_train_grid_points,
                                                                       train_gallery_filepath=self.train_grid_filepath,
                                                                       num_sample_grid_points=self.num_sample_grid_points,
                                                                       sample_gallery_filepath=self.sample_grid_filepath).to(device)

        if train_mode == "pretrained":
            self.image_encoder = ImageEmbeddingEncoder(dim_condition, normalize=False)
        else:
            self.image_encoder = ImageEncoder(dim_condition, normalize=False)

        self.kl_loss = nn.KLDivLoss(reduction='none', log_target=True)
        self.cosine_loss = nn.CosineEmbeddingLoss(margin=0.01, reduction='none')
        # self.mse_loss = nn.MSELoss(reduction='none')
        self.mse_loss = nn.L1Loss(reduction='none')

        # helper function to register buffer from float32 to float32

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others

        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

        register_buffer('posterior_variance', posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain

        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

        # offset noise strength - in blogpost, they claimed 0.1 was ideal

        self.offset_noise_strength = offset_noise_strength

        # derive loss weight
        # snr - signal noise ratio

        snr = alphas_cumprod / (1 - alphas_cumprod)

        # https://arxiv.org/abs/2303.09556

        maybe_clipped_snr = snr.clone()
        if min_snr_loss_weight:
            maybe_clipped_snr.clamp_(max = min_snr_gamma)

        if objective == 'pred_noise':
            register_buffer('loss_weight', maybe_clipped_snr / snr)
        elif objective == 'pred_x0':
            register_buffer('loss_weight', maybe_clipped_snr)
        elif objective == 'pred_v':
            register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))

        # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False

        self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity

    @property
    def device(self):
        return self.betas.device

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
            (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        )

    def predict_v(self, x_start, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )

    def predict_start_from_v(self, x_t, t, v):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def model_predictions(self, x, t, condition_embs, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
        model_output = self.model(x, t, condition_embs)
        maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity

        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, pred_noise)
            x_start = maybe_clip(x_start)

            if clip_x_start and rederive_pred_noise:
                pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == 'pred_x0':
            x_start = model_output
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == 'pred_v':
            v = model_output
            x_start = self.predict_start_from_v(x, t, v)
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        return ModelPrediction(pred_noise, x_start)

    def condition_mean(self, cond_fn, mean, variance, x, t):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.
        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        # this fixes a bug in the official OpenAI implementation:
        # https://github.com/openai/guided-diffusion/issues/51 (see point 1)
        # use the predicted mean for the previous timestep to compute gradient
        gradient = cond_fn(mean, t)
        new_mean = (
                mean.float() + variance * gradient.float()
        )

        return new_mean

    def p_mean_variance(self, x, t, condition_embs, x_self_cond = None, clip_denoised = True):
        preds = self.model_predictions(x, t, condition_embs, x_self_cond)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.inference_mode()
    def p_sample(self, x, t, condition_embs, x_self_cond = None):
        b, *_, device = *x.shape, self.device
        batched_times = torch.full((b,), t, device = device, dtype = torch.long)
        model_mean, variance, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, condition_embs=condition_embs, x_self_cond = x_self_cond, clip_denoised = True)

        # if exists(cond_fn) and exists(guidance_kwargs):
        #     model_mean = self.condition_mean(cond_fn, model_mean, variance, x, batched_times, guidance_kwargs)

        noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_img, x_start

    @torch.inference_mode()
    def p_sample_loop(self, shape, condition_embs, return_all_timesteps=False):
        # print("p_sample_loop")
        batch, device = shape[0], self.device

        # for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):

        img = torch.randn(shape, device=device)
        img_out = self.location_encoder.normalizing_std * img + self.location_encoder.normalizing_mean
        imgs = [img_out]

        x_start = None
        for t in reversed(range(0, self.num_timesteps)):
            # print("Step {} in p_sample_loop".format(t))
            self_cond = x_start if self.self_condition else None
            img, x_start = self.p_sample(img, t, condition_embs, self_cond)

            if (t + 1) % self.sample_save_interval == 0:
                img_out = self.location_encoder.normalizing_std * img + self.location_encoder.normalizing_mean
                img_out[:, 0] = DEFAULT_Y0
                imgs.append(img_out)

        n_steps = len(imgs)
        # imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
        # ret = img * (self.noise_amplifier * self.location_encoder.normalizer)
        # ret = img * self.noise_amplifier

        # return ret, self.location_decoder(ret) # self.location_decoder(ret.reshape((batch * n_steps, -1))).reshape((batch, n_steps, 2))
        return ret, self.location_decoder(ret.reshape((-1, self.dim_hidden))).reshape((batch, n_steps, 2))
    # self.location_decoder.kl_decode(ret.reshape((-1, self.dim_hidden))).reshape((batch, n_steps, 2))

    @torch.inference_mode()
    def ddim_sample(self, shape, return_all_timesteps = False):
        # print("ddim_sample")
        batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

        times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = torch.randn(shape, device = device)
        imgs = [img]

        x_start = None

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            # print(time, time_next)
            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
            self_cond = x_start if self.self_condition else None
            pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
            # print("Predicted x_start: ", x_start)

            if time_next < 0:
                img = x_start
                imgs.append(img)
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(img)

            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise

            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

        return ret

    @torch.inference_mode()
    def sample(self, conditional_embs, return_all_timesteps = False):
        conditional_embs = self.image_encoder(conditional_embs)
        dim = self.dim_hidden
        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
        return sample_fn((conditional_embs.size(0), dim), conditional_embs,  return_all_timesteps=return_all_timesteps)

    @torch.inference_mode()
    def interpolate(self, x1, x2, t = None, lam = 0.5):
        b, *_, device = *x1.shape, x1.device
        t = default(t, self.num_timesteps - 1)

        assert x1.shape == x2.shape

        t_batched = torch.full((b,), t, device = device)
        xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))

        img = (1 - lam) * xt1 + lam * xt2

        x_start = None

        for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
            self_cond = x_start if self.self_condition else None
            img, x_start = self.p_sample(img, i, self_cond)

        return img

    @autocast(enabled = False)
    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    #### TO-DO: add offset noise ####
    def p_losses(self, inputs, idx, t, conditional_embs, preload=True, noise = None, offset_noise_strength = 0.01):

        x_input = self.location_encoder(inputs, idx, preload)
        # print(x_input)

        # x_start /= self.noise_amplifier

        ### Normalizing x_start to [-1,1] ###
        # x_start /= (self.noise_amplifier * self.location_encoder.normalizer)
        x_start = (x_input - self.location_encoder.normalizing_mean) / self.location_encoder.normalizing_std
        # print(x_start)

        conditional_embs = self.image_encoder(conditional_embs)


        noise = default(noise, lambda: torch.randn_like(x_start))


        # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise

        offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

        if offset_noise_strength > 0.:
            offset_noise = torch.randn(x_start.shape[:2], device=self.device)
            noise += offset_noise_strength * offset_noise



        # noise sample

        x = self.q_sample(x_start=x_start, t=t, noise=noise)



        # if doing self-conditioning, 50% of the time, predict x_start from current set of times
        # and condition with unet with that
        # this technique will slow down training by 25%, but seems to lower FID significantly

        x_self_cond = None
        if self.self_condition and random() < 0.5:
            with torch.inference_mode():
                x_self_cond = self.model_predictions(x, t).pred_x_start
                x_self_cond.detach_()

        # predict and take gradient step

        # print(x.dtype, t.dtype, conditional_embs.dtype)

        model_out = self.model(x, t, conditional_embs)
        # print(model_out)

        # model_out *= self.noise_amplifier

        # model_out *= (self.noise_amplifier * self.location_encoder.normalizer)
        # model_out /= (model_out[:,0].unsqueeze(1) / NORMALIZER)
        model_out = self.location_encoder.normalizing_std * model_out + self.location_encoder.normalizing_mean
        model_out[:,0] = DEFAULT_Y0
        # print(model_out)

        if self.objective == 'pred_noise':
            target = noise
        elif self.objective == 'pred_x0':
            target = x_input
        elif self.objective == 'pred_v':
            v = self.predict_v(x_start, t, noise)
            target = v
        else:
            raise ValueError(f'unknown objective {self.objective}')

        anchors = self.location_encoder._get_anchors(idx, 2048)
        # anchors = None

        gps_outputs = self.location_decoder(model_out)
        # kl_decoded_gps_outputs = self.location_decoder.kl_decode(model_out)
        # print("KL decoded GPS outputs: ", great_circle_distance_loss(kl_decoded_gps_outputs, inputs))

        log_target_probs = self.location_decoder.evaluate_spherical_log_probability(target, anchors)
        log_model_probs = self.location_decoder.evaluate_spherical_log_probability(model_out, anchors)

        # forward_kl_loss = torch.sum(self.kl_loss(log_model_probs, log_target_probs), dim=1)
        reverse_kl_loss = torch.sum(self.kl_loss(log_target_probs, log_model_probs), dim=1)

        cosine_loss = self.cosine_loss(model_out, target, torch.ones(model_out.size(0)).to(self.device))
        mse_loss = self.mse_loss(model_out, target)

        loss1 = reverse_kl_loss
        # loss1 = forward_kl_loss
        loss1_weight = loss1 * extract(self.loss_weight, t, loss1.shape)

        loss2 = great_circle_distance_loss(gps_outputs, inputs)
        loss2_weight = loss2 * extract(self.loss_weight, t, loss2.shape)

        loss3 = cosine_loss
        loss3_weight = loss3 * extract(self.loss_weight, t, loss3.shape)

        loss4 = mse_loss
        loss4_weight = loss4 * extract(self.loss_weight, t, loss4.shape)

        return loss1, loss2, loss3, loss4, loss1_weight, loss2_weight, loss3_weight, loss4_weight, log_model_probs, log_target_probs # model_out.detach().cpu().numpy(), target.detach().cpu().numpy()

    def forward(self, inputs, idx, conditional_emb, preload, last_step_only=False, *args, **kwargs):
        b, d, device, dim, = *inputs.shape, inputs.device, self.dim_inputs
        assert d == dim, f'the dimension of inputs must be {dim}'
        if last_step_only:
            t = (self.num_timesteps - 1) * torch.ones((b,), device=device).long()
        else:
            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

        return self.p_losses(inputs, idx, t, conditional_emb, preload, *args, **kwargs)

    def sanity_check(self, inputs, idx, conditional_embs, preload=True, noise = None, offset_noise_strength = 0.01):
        b, d, device, dim, = *inputs.shape, inputs.device, self.dim_inputs
        t = (self.num_timesteps -1) * torch.ones((b,), device=device).long()


        x_start = self.location_encoder(inputs, idx, preload)
        # x_start /= (self.noise_amplifier * self.location_encoder.normalizer)

        clean_x = x_start.detach().cpu().numpy()

        conditional_embs = self.image_encoder(conditional_embs)

        noise = default(noise, lambda: torch.randn_like(x_start))

        # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise

        offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

        if offset_noise_strength > 0.:
            offset_noise = torch.randn(x_start.shape[:2], device=self.device)
            noise += offset_noise_strength * offset_noise


        # noise sample

        x = self.q_sample(x_start = x_start, t = t, noise = noise)
        noised_x = x.detach().cpu().numpy()

        # if doing self-conditioning, 50% of the time, predict x_start from current set of times
        # and condition with unet with that
        # this technique will slow down training by 25%, but seems to lower FID significantly

        x_self_cond = None
        if self.self_condition and random() < 0.5:
            with torch.inference_mode():
                x_self_cond = self.model_predictions(x, t).pred_x_start
                x_self_cond.detach_()

        # predict and take gradient step

        # print(x.dtype, t.dtype, conditional_embs.dtype)

        model_out = self.model(x, t, conditional_embs)
        # model_out *= (self.noise_amplifier * self.location_encoder.normalizer)

        gps_outputs = self.location_decoder(model_out)

        return clean_x, noise.detach().cpu().numpy(), noised_x, model_out.detach().cpu().numpy(), gps_outputs.detach().cpu().numpy()
