"""
Based on: https://github.com/crowsonkb/k-diffusion
"""
import random

import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from piq import LPIPS
from torchvision.transforms import RandomCrop
from . import dist_util, logger
import torch.distributed as dist
from torchvision.transforms import Normalize

from .nn import mean_flat, append_dims, append_zero
from .random_util import get_generator
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
import blobfile as bf
import os
from torchvision.utils import make_grid, save_image
from pg_modules.projector import F_RandomProj
from pg_modules.diffaug import DiffAugment

def get_weightings(weight_schedule, snrs, sigma_data, t, s):
    if weight_schedule == "snr":
        weightings = snrs
    elif weight_schedule == "snr+1":
        weightings = snrs + 1
    elif weight_schedule == "karras":
        weightings = snrs + 1.0 / sigma_data**2
    elif weight_schedule == "truncated-snr":
        weightings = th.clamp(snrs, min=1.0)
    elif weight_schedule == "uniform":
        weightings = th.ones_like(snrs)
    elif weight_schedule == "uniform_g":
        return 1./(1. - s / t)
    elif weight_schedule == "karras_weight":
        sigma = snrs ** -0.5
        weightings = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
    elif weight_schedule == "sq-t-inverse":
        weightings = 1. / snrs ** 0.25
    else:
        raise NotImplementedError()
    return weightings


class KarrasDenoiser:
    def __init__(
        self,
        args,
        schedule_sampler,
        diffusion_schedule_sampler,
        feature_networks=None,
    ):
        self.args = args
        self.schedule_sampler = schedule_sampler
        self.diffusion_schedule_sampler = diffusion_schedule_sampler
        if args.loss_norm == "lpips":
            self.lpips_loss = LPIPS(replace_pooling=True, reduction="none")
        #elif args.loss_norm == 'StyleGAN-XL':
        self.feature_networks = feature_networks

        '''elif args.loss_norm == "StyleGAN-XL":
            backbones = ['deit_base_distilled_patch16_224', 'tf_efficientnet_lite0']
            feature_networks = []
            backbone_kwargs = {'im_res': self.args.image_size}
            for i, bb_name in enumerate(backbones):
                feat = F_RandomProj(bb_name, **backbone_kwargs)
                feature_networks.append([bb_name, feat])
            self.feature_networks = nn.ModuleDict(feature_networks)
            self.feature_networks = self.feature_networks.train(False).to(dist_util.dev())
            self.feature_networks.requires_grad_(False)'''
        self.num_timesteps = args.start_scales
        self.dist = nn.MSELoss(reduction='none')

    def get_snr(self, sigmas):
        return sigmas**-2

    def get_sigmas(self, sigmas):
        return sigmas

    def get_c_in(self, sigma):
        return 1 / (sigma**2 + self.args.sigma_data**2) ** 0.5

    def get_scalings(self, sigma):
        c_skip = self.args.sigma_data**2 / (sigma**2 + self.args.sigma_data**2)
        c_out = sigma * self.args.sigma_data / (sigma**2 + self.args.sigma_data**2) ** 0.5
        return c_skip, c_out

    def get_scalings_t(self, t, s):
        c_skip = th.zeros_like(t)
        c_out = ((t ** 2 + self.args.sigma_data ** 2) / (s ** 2 + self.args.sigma_data ** 2)) ** 0.5
        return c_skip, c_out

    def get_scalings_for_generalized_boundary_condition(self, t, s):
        if self.args.parametrization.lower() == 'euler':
            c_skip = s / t
        elif self.args.parametrization.lower() == 'variance':
            c_skip = (((s - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2) / ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2)).sqrt()
        elif self.args.parametrization.lower() == 'euler_variance_mixed':
            c_skip = s / (t + 1.) + \
                     (((s - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2) /
                      ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2)).sqrt() / (t + 1.)
        c_out = (1. - s / t)
        return c_skip, c_out

    def get_scalings_for_boundary_condition(self, sigma):
        c_skip = self.args.sigma_data**2 / (
            (sigma - self.args.sigma_min) ** 2 + self.args.sigma_data**2
        )
        c_out = (
            (sigma - self.args.sigma_min)
            * self.args.sigma_data
            / (sigma**2 + self.args.sigma_data**2) ** 0.5
        )
        return c_skip, c_out

    def calculate_adaptive_weight(self, loss1, loss2, last_layer=None):
        loss1_grad = th.autograd.grad(loss1, last_layer, retain_graph=True)[0]
        loss2_grad = th.autograd.grad(loss2, last_layer, retain_graph=True)[0]
        d_weight = th.norm(loss1_grad) / (th.norm(loss2_grad) + 1e-4)
        #print("consistency gradient: ", th.norm(loss1_grad))
        #print("denoising gradient: ", th.norm(loss2_grad))
        #print("weight: ", d_weight)
        d_weight = th.clamp(d_weight, 0.0, 1e4).detach()
        return d_weight

    def adopt_weight(self, weight, global_step, threshold=0, value=0.):
        if global_step < threshold:
            weight = value
        return weight

    def hinge_d_loss(self, logits_real, logits_fake, value=1.):
        loss_real = th.mean(F.relu(value - logits_real))
        loss_fake = th.mean(F.relu(value + logits_fake))
        d_loss = 0.5 * (loss_real + loss_fake)
        return d_loss

    def vanilla_d_loss(self, logits_real, logits_fake):
        d_loss = 0.5 * (
                th.mean(th.nn.functional.softplus(-logits_real)) +
                th.mean(th.nn.functional.softplus(logits_fake)))
        return d_loss

    def rescaling_t(self, t):
        rescaled_t = 1000 * 0.25 * th.log(t + 1e-44)
        return rescaled_t

    def get_t(self, ind):
        if self.args.time_continuous:
            t = self.args.sigma_max ** (1 / self.args.rho) + ind * (
                    self.args.sigma_min ** (1 / self.args.rho) - self.args.sigma_max ** (1 / self.args.rho)
            )
            t = t ** self.args.rho
        else:
            t = self.args.sigma_max ** (1 / self.args.rho) + ind / (self.args.start_scales - 1) * (
                    self.args.sigma_min ** (1 / self.args.rho) - self.args.sigma_max ** (1 / self.args.rho)
            )
            t = t ** self.args.rho
        return t

    def pre_process(self, data, mode=False):
        encoded = DiagonalGaussianDistribution(data)
        if mode:
            data = encoded.mode()
        else:
            data = encoded.sample()
        data *= 0.24578019976615906  # * 0.45
        return data

    def get_num_heun_step(self, step):
        if self.args.num_heun_step_random:
            #if step % self.args.g_learning_period == 0:
            if self.args.time_continuous:
                num_heun_step = np.random.rand() * self.args.num_heun_step / self.args.start_scales
            else:
                if self.args.heun_step_strategy == 'uniform':
                    num_heun_step = np.random.randint(1,1+self.args.num_heun_step)
                elif self.args.heun_step_strategy == 'weighted':
                    p = np.array([i ** self.args.heun_step_multiplier for i in range(1,1+self.args.num_heun_step)])
                    p = p / sum(p)
                    num_heun_step = np.random.choice([i+1 for i in range(len(p))], size=1, p=p)[0]
            #else:
            #    if self.args.time_continuous:
            #        num_heun_step = np.random.rand() / self.args.d_learning_period +\
            #                        (self.args.d_learning_period - 1) / self.args.d_learning_period
            #    else:
                    #num_heun_step = np.random.randint((self.args.d_learning_period - 1) * self.args.start_scales //
                    #                                  self.args.d_learning_period, 1+self.args.num_heun_step)
            #        num_heun_step = self.args.num_heun_step
        else:
            if self.args.time_continuous:
                num_heun_step = self.args.num_heun_step / self.args.start_scales
            else:
                num_heun_step = self.args.num_heun_step
        return num_heun_step

    @th.no_grad()
    def heun_solver(self, x, ind, teacher_model, dims, num_step=1, **model_kwargs):
        for k in range(num_step):
            t = self.get_t(ind + k)
            denoiser = self.denoise_fn(teacher_model, x, t, None, ctm=False, teacher=True, **model_kwargs)
            d = (x - denoiser) / append_dims(t, dims)
            t2 = self.get_t(ind + k + 1)
            x_phi_ODE_1st = x + d * append_dims(t2 - t, dims)
            denoiser_2 = self.denoise_fn(teacher_model, x_phi_ODE_1st, t2, None, ctm=False, teacher=True, **model_kwargs)
            next_d = (x_phi_ODE_1st - denoiser_2) / append_dims(t2, dims)
            x_phi_ODE_2nd = x + (d + next_d) * append_dims((t2 - t) / 2, dims)
            x = x_phi_ODE_2nd
        return x

    def get_gan_estimate(self, estimate, step, x_t, t, t_dt, s, model, target_model, ctm, type=None, auxiliary_type=None, **model_kwargs):
        if self.args.gan_estimate_type == 'consistency':
            estimate = self.denoise_fn(model, x_t, t, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm, **model_kwargs)
        elif self.args.gan_estimate_type == 'enable_grad':
            if self.args.auxiliary_type == 'enable_grad':
                estimate = estimate
            else:
                estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm,
                                             auxiliary_type='enable_grad', **model_kwargs)
        elif self.args.gan_estimate_type == 'only_high_freq':
            estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm,
                                         type='stop_grad', auxiliary_type='enable_grad', **model_kwargs)
        elif self.args.gan_estimate_type == 'same':
            estimate = estimate
        return estimate

    def get_estimate(self, step, x_t, t, t_dt, s, model, target_model, ctm, type=None, auxiliary_type=None, **model_kwargs):
        distiller = self.denoise_fn(model, x_t, t, s=s, ctm=ctm, **model_kwargs)
        if self.args.training_mode == 'ctm':
            distiller = self.denoise_fn(target_model, distiller, s, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm, **model_kwargs)
        return distiller

    @th.no_grad()
    def get_target(self, step, x_t_dt, t_dt, s, model, target_model, ctm, **model_kwargs):
        with th.no_grad():
            distiller_target = self.denoise_fn(model, x_t_dt, t_dt, s=s, ctm=ctm, **model_kwargs)
            if self.args.training_mode == 'ctm':
                distiller_target = self.denoise_fn(target_model, distiller_target, s, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm, **model_kwargs)
            return distiller_target.detach()

    def denoise_fn(self, model, x, t, s, ctm=False, teacher=False, **model_kwargs):
        return self.denoise(model, x, t, s, ctm, teacher, **model_kwargs)[1]

    def denoise(self, model, x_t, t, s=None, ctm=False, teacher=False, **model_kwargs):
        rescaled_t = self.rescaling_t(t)
        if s != None:
            rescaled_s = self.rescaling_t(s)
        else:
            rescaled_s = None
        c_in = append_dims(self.get_c_in(t), x_t.ndim)
        model_output = model(c_in * x_t, rescaled_t, s=rescaled_s, teacher=teacher, **model_kwargs)
        if ctm:
            if self.args.inner_parametrization == 'edm':
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings(t)
                ]
                model_output = c_out * model_output + c_skip * x_t
            elif self.args.inner_parametrization == 'scale':
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings_t(t, s)
                ]
                #print("c_skip, c_out: ", c_skip.reshape(-1), c_out.reshape(-1))
                model_output = c_out * model_output + c_skip * x_t
            elif self.args.inner_parametrization == 'no':
                model_output = model_output
            if teacher:
                if self.args.parametrization.lower() == 'euler':
                    denoised = model_output
                elif self.args.parametrization.lower() == 'variance':
                    denoised = model_output + append_dims((self.args.sigma_min ** 2 + self.args.sigma_data ** 2
                                                          - self.args.sigma_min * t) / \
                               ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2), x_t.ndim) * x_t
                elif self.args.parametrization.lower() == 'euler_variance_mixed':
                    denoised = model_output + x_t - append_dims(t / (t + 1.) * (1. + (t - self.args.sigma_min) /
                                                                          ((t - self.args.sigma_min) ** 2 + self.args.sigma_data ** 2)), x_t.ndim) * x_t
                return model_output, denoised
            else:
                assert s != None
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings_for_generalized_boundary_condition(t, s, )
                ]
                denoised = c_out * model_output + c_skip * x_t
        else:
            if teacher:
                c_skip, c_out = [
                    append_dims(x, x_t.ndim) for x in self.get_scalings(t)
                ]
            else:
                c_skip, c_out = [
                    append_dims(x, x_t.ndim)
                    for x in self.get_scalings_for_boundary_condition(t)
                ]
            denoised = c_out * model_output + c_skip * x_t

        if self.args.data_name in ['church'] and teacher:
            denoised = model_output
        return model_output, denoised

    def get_consistency_loss(self, estimate, target, t_m, pretrained_classifier, classifier_vpsde, weights, step, loss_norm=''):
        if loss_norm == '':
            loss_norm = self.args.loss_norm
        estimate_out = estimate
        target_out = target
        if loss_norm.lower() == 'unet':
            #print("consistency loss UNet")
            mean_vp_tau, tau = classifier_vpsde.transform_unnormalized_wve_to_normalized_vp(t_m)
            rescaled_estimate = mean_vp_tau[:, None, None, None] * estimate
            estimated_features = pretrained_classifier(rescaled_estimate, timesteps=tau, feature=True,
                                                            out_res=self.args.out_res)
            rescaled_target = mean_vp_tau[:, None, None, None] * target
            #print("feature dim: ", estimated_features.shape, self.args.out_res)
            with th.no_grad():
                target_features = pretrained_classifier(rescaled_target, timesteps=tau, feature=True,
                                                             out_res=self.args.out_res)
            if self.args.discriminator_input == 'latent':
                estimate_out = rescaled_estimate
                target_out = rescaled_target
            else:
                estimate_out = estimated_features
                target_out = target_features
            if self.args.feature_aggregated:
                if pretrained_classifier.pool == 'attention':
                    numerator = 100.
                else:
                    numerator = 5.
                consistency_loss = 0.
                for k in range(len(estimated_features)):
                    consistency_loss += weights * mean_flat(
                        (estimated_features[k] / numerator - target_features[k] / numerator) ** 2)
            else:
                numerator = 1.
                if self.args.loss_type_ == 'l1':
                    a = '0.9-'
                    b = '0.9-'
                    consistency_loss = mean_flat(th.abs(
                        (estimated_features - target_features) / th.where(target_features.abs() > 1e-3,
                                                                          target_features,
                                                                          1e-3).detach()), percent=a if np.random.choice(2) else b)
                elif self.args.loss_type_ == 'l2':
                    consistency_loss = weights * mean_flat(
                        (estimated_features / numerator - target_features / numerator) ** 2)
            consistency_loss1 = mean_flat((rescaled_estimate - rescaled_target) ** 2)  # ,
            consistency_loss = consistency_loss1 + consistency_loss
        elif loss_norm == 'lpips':
            if estimate.shape[-2] < 256:
                estimate = F.interpolate(estimate, size=224, mode="bilinear")
                target = F.interpolate(
                    target, size=224, mode="bilinear"
                )
            consistency_loss = (self.lpips_loss(
                (estimate + 1) / 2.0,
                (target + 1) / 2.0, ) * weights)
            #if step < 300.:
            #    consistency_loss = consistency_loss + 0.01 * mean_flat((estimate - target) ** 2)
        elif loss_norm == 'l2':
            # print("consistency loss L2")
            mean_vp_tau, tau = classifier_vpsde.transform_unnormalized_wve_to_normalized_vp(t_m)
            rescaled_estimate = mean_vp_tau[:, None, None, None] * estimate
            rescaled_target = mean_vp_tau[:, None, None, None] * target
            estimate_out = rescaled_estimate
            target_out = rescaled_target
            consistency_loss = mean_flat((rescaled_estimate - rescaled_target) ** 2)
        elif loss_norm == "StyleGAN-XL":
            distances, estimate_features, target_features = [], [], []
            estimate_features, target_features = self.get_xl_feature(estimate, target, step=step)
            cnt = 0
            for _, _ in self.feature_networks.items():
                for fe in list(estimate_features[cnt].keys()):
                    norm_factor = th.sqrt(th.sum(estimate_features[cnt][fe] ** 2, dim=1, keepdim=True))
                    est_feat = estimate_features[cnt][fe] / (norm_factor + 1e-10)
                    norm_factor = th.sqrt(th.sum(target_features[cnt][fe] ** 2, dim=1, keepdim=True))
                    tar_feat = target_features[cnt][fe] / (norm_factor + 1e-10)
                    distances.append(self.dist(est_feat, tar_feat))
                cnt += 1
            #for d in distances:
            #    print(d.shape, d.mean(dim=[2, 3]).sum(dim=1).mean())
            consistency_loss = th.cat([d.mean(dim=[2, 3]) for d in distances], dim=1).sum(dim=1)
            return consistency_loss, estimate_features, target_features
        return consistency_loss, estimate_out, target_out

    def get_feature(self, input, feat, brightness, saturation, contrast, translation_x, translation_y,
                                   offset_x, offset_y, name, step):
        # augment input
        input_aug_ = input
        if self.args.augment:
            input_aug_ = DiffAugment(input[:brightness.shape[0]], brightness, saturation, contrast, translation_x, translation_y,
                                       offset_x, offset_y, policy='color,translation,cutout')
            input_aug_ = th.cat((input_aug_, input[brightness.shape[0]:]))
        # transform to [0,1]
        input_aug = input_aug_.add(1).div(2)
        # apply F-specific normalization
        input_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(input_aug)
        # upsample if smaller, downsample if larger + VIT
        #print("input: ", input.shape)
        if input.shape[-2] < 256:
            input_n = F.interpolate(input_n, 224, mode='bilinear', align_corners=False)
            if step % self.args.save_period == 0 and step >= 0:
                input_aug_ = F.interpolate(input_aug_, 224, mode='bilinear', align_corners=False)
                self.save(input_aug_, logger.get_dir(), f'{name}_{step}_augment')
        # forward pass
        input_features = feat(input_n)
        return input_features

    def get_denoising_loss(self, model, x_start, model_kwargs, consistency_loss,
                           step, init_step):
        sigmas, denoising_weights = self.diffusion_schedule_sampler.sample(x_start.shape[0], dist_util.dev())
        #print("diffusion sigmas: ", sigmas)
        noise = th.randn_like(x_start)
        dims = x_start.ndim
        x_t = x_start + noise * append_dims(sigmas, dims)
        model_estimate = self.denoise_fn(model, x_t, sigmas, s=sigmas, ctm=True, teacher=True, **model_kwargs)
        snrs = self.get_snr(sigmas)
        denoising_weights = append_dims(get_weightings(self.args.diffusion_weight_schedule, snrs, self.args.sigma_data, None, None), dims)
        denoising_loss = mean_flat(denoising_weights * (model_estimate - x_start) ** 2)
        if self.args.apply_adaptive_weight:
            if self.args.data_name in ['church']:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), denoising_loss.mean(),
                                                               last_layer=model.module.output_blocks[14][0].out_layers[
                                                                   3].weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), denoising_loss.mean(),
                                                               last_layer=model.module.output_blocks[11][0].out_layers[
                                                                   3].weight)
            else:
                try:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), denoising_loss.mean(),
                                                                    last_layer=model.module.model.dec[
                                                                        '32x32_aux_conv'].weight)
                except:
                    balance_weight = self.calculate_adaptive_weight(consistency_loss.mean(), denoising_loss.mean(),
                                                                    last_layer=
                                                                    model.module.output_blocks[11][0].out_layers[
                                                                        3].weight)
        else:
            balance_weight = 1.
        if self.args.large_log:
            print("denoising weight: ", balance_weight)
        balance_weight = self.adopt_weight(balance_weight, step, threshold=init_step, value=1.)
        denoising_loss = denoising_loss * balance_weight
        return denoising_loss

    def get_discriminator_feature(self, img, discriminator):
        x = None
        features = []
        for res in discriminator.block_resolutions:
            block = getattr(discriminator, f'b{res}')
            x, img = block(x, img)
            if self.args.channelwise_normalization:
                norm_factor = th.sqrt(th.sum(x ** 2, dim=1, keepdim=True))
            else:
                norm_factor = th.sqrt(th.sum(x ** 2, dim=[2,3], keepdim=True))
            features.append(x / (norm_factor + 1e-10))
        return features

    def compute_distance(self, x_features, y_features):
        return [self.dist(x, y) for x, y in zip(x_features, y_features)]

    def get_xl_feature(self, estimate, target=None, discriminator=None, step=-1):
        logits_fake, logits_real = [], []
        estimate_features, target_features = [], []
        for bb_name, feat in self.feature_networks.items():
            # apply augmentation (x in [-1, 1])
            brightness = (th.rand(estimate.size(0) // 2, 1, 1, 1, dtype=estimate.dtype,
                                  device=estimate.device) - 0.5)
            # brightness = 0.
            saturation = (th.rand(estimate.size(0) // 2, 1, 1, 1, dtype=estimate.dtype,
                                  device=estimate.device) * 2)
            # saturation = 0.
            contrast = (th.rand(estimate.size(0) // 2, 1, 1, 1, dtype=estimate.dtype,
                                device=estimate.device) + 0.5)
            # contrast = 0.
            shift_x, shift_y = int(estimate.size(2) * self.args.shift_ratio + 0.5), int(
                estimate.size(3) * self.args.shift_ratio + 0.5)
            translation_x = th.randint(-shift_x, shift_x + 1, size=[estimate.size(0) // 2, 1, 1],
                                       device=estimate.device)
            translation_y = th.randint(-shift_y, shift_y + 1, size=[estimate.size(0) // 2, 1, 1],
                                       device=estimate.device)
            cutout_size = int(estimate.size(2) * self.args.cutout_ratio + 0.5), int(
                estimate.size(3) * self.args.cutout_ratio + 0.5)
            offset_x = th.randint(0, estimate.size(2) + (1 - cutout_size[0] % 2),
                                  size=[estimate.size(0) // 2, 1, 1], device=estimate.device)
            offset_y = th.randint(0, estimate.size(3) + (1 - cutout_size[1] % 2),
                                  size=[estimate.size(0) // 2, 1, 1], device=estimate.device)

            estimate_feature = self.get_feature(estimate, feat, brightness, saturation, contrast,
                                                translation_x, translation_y, offset_x, offset_y, 'estimate',
                                                step)
            estimate_features.append(estimate_feature)
            if discriminator:
                logits_fake += discriminator.module[bb_name](estimate_feature, None)

            if target != None:
                with th.no_grad():
                    target_feature = self.get_feature(target, feat, brightness, saturation, contrast,
                                                      translation_x, translation_y, offset_x, offset_y, 'target',
                                                      step)
                    target_features.append(target_feature)
                logits_real += discriminator.module[bb_name](target_feature, None)
        if discriminator:
            if target == None:
                return logits_fake
            else:
                return logits_fake, logits_real
        else:
            if target == None:
                return estimate_features
            else:
                return estimate_features, target_features



    def get_discriminator_loss(self, t, t_dt, model, estimate=None, target=None, consistency_loss=None,
                               learn_generator=True, free=False, pretrained_classifier=None, discriminator=None,
                               step=0, init_step=0, noise=None, estimate_features=None, target_features=None,):
        #print("loss_norm: ", self.args.loss_norm, learn_generator, self.args.d_architecture, self.args.discriminator_fix)
        if self.args.loss_norm == 'lpips':
            if learn_generator:
                if self.args.d_architecture == 'stylegan-ada':
                    #logits_fake = discriminator(estimate, c=th.ones_like(t) * 0.)
                    logits_fake = self.get_discriminator_feature(estimate, discriminator)
                elif self.args.d_architecture == 'stylegan-xl':
                    logits_fake = self.get_xl_feature(estimate, discriminator=discriminator)
                else:
                    logits_fake = discriminator(estimate, sigmoid=False)
                if target != None and self.args.discriminator_fix:
                    logits_real = self.get_discriminator_feature(target, discriminator)
                if self.args.discriminator_fix:
                    distances = self.compute_distance(logits_real, logits_fake)
                    g_loss = th.cat([d.mean(dim=[2, 3]) for d in distances], dim=1).sum(dim=1)
                else:
                    if self.args.d_architecture == 'stylegan-xl':
                        g_loss = sum([(-l).mean() for l in logits_fake]) / len(logits_fake)
                        #print("g_loss: ", g_loss.mean().item())
                    else:
                        g_loss = mean_flat(-logits_fake)
                if self.args.large_log:
                    print("g_loss: ", g_loss.mean().item())
                CTM_loss = consistency_loss.mean()
                if self.args.d_apply_adaptive_weight:
                    if self.args.data_name in ['church']:
                        try:
                            d_weight = self.calculate_adaptive_weight(CTM_loss, g_loss.mean(),
                                                                      last_layer=
                                                                      model.module.output_blocks[14][0].out_layers[
                                                                          3].weight)
                        except:
                            d_weight = self.calculate_adaptive_weight(CTM_loss, g_loss.mean(),
                                                                      last_layer=
                                                                      model.module.output_blocks[11][0].out_layers[
                                                                          3].weight)
                    else:
                        try:
                            d_weight = self.calculate_adaptive_weight(CTM_loss.mean(),
                                                                      g_loss.mean(),
                                                                      last_layer=
                                                                      model.module.output_blocks[11][0].out_layers[
                                                                          3].weight)
                        except:
                            d_weight = self.calculate_adaptive_weight(CTM_loss.mean(),
                                                                      g_loss.mean(),
                                                                      last_layer=model.module.model.dec[
                                                                          '32x32_aux_conv'].weight)
                    d_weight = th.clip(d_weight, 0.01, 10.)
                else:
                    d_weight = 1.
                print("discriminator weight: ", d_weight)
                discriminator_loss = self.adopt_weight(d_weight, step,
                                                       threshold=init_step + self.args.discriminator_start_itr) * g_loss
            else:
                if self.args.d_architecture == 'stylegan-xl':
                    logits_fake, logits_real = self.get_xl_feature(estimate.detach(), target=target.detach(), discriminator=discriminator)
                    loss_Dgen = sum([(F.relu(th.ones_like(l) + l)).mean() for l in logits_fake]) / len(logits_fake)
                    loss_Dreal = sum([(F.relu(th.ones_like(l) - l)).mean() for l in logits_real]) / len(logits_real)
                    print("real, fake: ", loss_Dreal.item(), loss_Dgen.item())
                    discriminator_loss = loss_Dreal + loss_Dgen
                    return discriminator_loss
                if ((self.args.g_learning_period - 1) * (step - init_step) // self.args.g_learning_period) % self.args.lazy_reg == 0:
                    target.requires_grad = True
                    logits_real = discriminator(target, sigmoid=False)
                else:
                    logits_real = discriminator(target.detach(), sigmoid=False)
                logits_fake = discriminator(estimate.detach(), sigmoid=False)
                if self.args.large_log:
                    print("logits_real: ", logits_real.reshape(-1).mean())
                    print("logits_fake: ", logits_fake.reshape(-1).mean())
                disc_factor = self.adopt_weight(1.0, step, threshold=0)
                discriminator_loss = disc_factor * self.hinge_d_loss(logits_real, logits_fake, value=self.args.hinge_value)
                if self.args.large_log:
                    print("discriminator loss: ", discriminator_loss.mean())
                if ((self.args.g_learning_period - 1) * (step - init_step) // self.args.g_learning_period) % self.args.lazy_reg == 0:
                    grad_real = th.autograd.grad(outputs=logits_real.sum(), inputs=target, create_graph=True)[0]
                    grad_real = grad_real.reshape(target.shape[0],-1)
                    grad_penalty = (grad_real.norm(2, dim=1) ** 2).reshape(-1)
                    grad_penalty = self.args.r1_gamma / 2 * grad_penalty
                    #print(discriminator_loss.shape, grad_penalty.shape)
                    print("disc, gp: ", discriminator_loss.mean(), grad_penalty.mean())
                    discriminator_loss += grad_penalty.mean()
            return discriminator_loss


        elif self.args.loss_norm == 'StyleGAN-XL':
            if learn_generator:
                logits_fake = []
                cnt = 0
                for bb_name, feat in self.feature_networks.items():
                    logits_fake += discriminator[bb_name](estimate_features[cnt], None)
                    cnt += 1
                g_loss = sum([(-l).mean() for l in logits_fake]) / len(logits_fake)
                if self.args.large_log:
                    print("g_loss: ", g_loss.mean().item())
                CTM_loss = consistency_loss.mean()
                if self.args.d_apply_adaptive_weight:
                    if self.args.data_name in ['church']:
                        try:
                            d_weight = self.calculate_adaptive_weight(CTM_loss, g_loss.mean(),
                                                                      last_layer=
                                                                      model.module.output_blocks[14][0].out_layers[
                                                                          3].weight)
                        except:
                            d_weight = self.calculate_adaptive_weight(CTM_loss, g_loss.mean(),
                                                                      last_layer=
                                                                      model.module.output_blocks[11][0].out_layers[
                                                                          3].weight)
                    else:
                        try:
                            d_weight = self.calculate_adaptive_weight(CTM_loss.mean(),
                                                                      g_loss.mean(),
                                                                      last_layer=
                                                                      model.module.output_blocks[11][0].out_layers[
                                                                          3].weight)
                        except:
                            d_weight = self.calculate_adaptive_weight(CTM_loss.mean(),
                                                                      g_loss.mean(),
                                                                      last_layer=model.module.model.dec[
                                                                          '32x32_aux_conv'].weight)
                    d_weight = th.clip(d_weight, 0.01, 10.)
                else:
                    d_weight = 1.
                print("discriminator weight: ", d_weight)
                discriminator_loss = self.adopt_weight(d_weight, step,
                                                       threshold=init_step + self.args.discriminator_start_itr) * g_loss
            else:

                loss_Dgen = sum([(F.relu(th.ones_like(l) + l)).mean() for l in logits_fake]) / len(logits_fake)
                loss_Dreal = sum([(F.relu(th.ones_like(l) - l)).mean() for l in logits_real]) / len(logits_real)
                discriminator_loss = loss_Dreal + loss_Dgen
                return discriminator_loss
        else:
            mean_vp_tau, tau = self.classifier_vpsde.transform_unnormalized_wve_to_normalized_vp(t_m)
            if learn_generator:
                fake = mean_vp_tau[:, None, None, None] * estimate
                if self.args.discriminator_input == 'feature':
                    fake = pretrained_classifier(fake, timesteps=tau, feature=True, out_res=self.args.d_out_res)
                logits_fake = discriminator(fake, tau, sigmoid=False)

                if self.args.d_architecture in ['unet', 'ddgan']:
                    g_loss = mean_flat(-logits_fake)
                elif self.args.d_architecture in ['ddgan_']:
                    g_loss = mean_flat(th.nn.functional.softplus(-logits_fake))

                CTM_loss = consistency_loss.mean()
                if self.args.d_apply_adaptive_weight:
                    if self.args.data_name in ['church']:
                        try:
                            d_weight = self.calculate_adaptive_weight(CTM_loss, g_loss.mean(),
                                                                 last_layer=model.module.output_blocks[14][0].out_layers[
                                                                     3].weight)
                        except:
                            d_weight = self.calculate_adaptive_weight(CTM_loss, g_loss.mean(),
                                                                 last_layer=model.module.output_blocks[11][0].out_layers[
                                                                     3].weight)
                    else:
                        try:
                            d_weight = self.calculate_adaptive_weight(CTM_loss.mean(),
                                                                            g_loss.mean(),
                                                                            last_layer=model.module.output_blocks[11][0].out_layers[3].weight)
                        except:
                            d_weight = self.calculate_adaptive_weight(CTM_loss.mean(),
                                                                            g_loss.mean(),
                                                                            last_layer=model.module.model.dec['32x32_aux_conv'].weight)
                    d_weight = th.clip(d_weight, 1., 10.)
                else:
                    d_weight = 1.
                print("d_weight: ", d_weight)
                discriminator_loss = self.adopt_weight(d_weight, step, threshold=init_step + self.args.discriminator_start_itr) * g_loss
            else:
                if self.args.discriminator_free_target or free:
                    target = mean_vp_tau[:, None, None, None] * (x_start + noise * append_dims(t_m, x_start.ndim))
                else:
                    target = mean_vp_tau[:, None, None, None] * d_input2

                fake = mean_vp_tau[:, None, None, None] * d_input
                print("discriminator input: ", self.args.discriminator_input)
                if self.args.discriminator_input == 'feature':
                    target = pretrained_classifier(target, timesteps=tau, feature=True, out_res=self.args.d_out_res)
                    fake = pretrained_classifier(fake, timesteps=tau, feature=True, out_res=self.args.d_out_res)
                if ((self.args.g_learning_period - 1) * (step - init_step) // self.args.g_learning_period) % self.args.lazy_reg == 0:
                    target.requires_grad = True
                    logits_real = discriminator(target, tau, sigmoid=False)
                else:
                    logits_real = discriminator(target.detach(), tau, sigmoid=False)
                logits_fake = discriminator(fake.detach(), tau, sigmoid=False)
                print("logits_real: ", logits_real.reshape(-1).mean())
                print("logits_fake: ", logits_fake.reshape(-1).mean())
                disc_factor = self.adopt_weight(1.0, step, threshold=0)
                if self.args.d_architecture in ['unet', 'ddgan']:
                    discriminator_loss = disc_factor * self.hinge_d_loss(logits_real, logits_fake, value=self.args.hinge_value)
                elif self.args.d_architecture in ['ddgan_']:
                    threshold = th.mean(th.nn.functional.softplus(logits_fake))
                    print("threshold: ", threshold)
                    if threshold > 0.1:
                        discriminator_loss = disc_factor * self.vanilla_d_loss(logits_real, logits_fake)
                    else:
                        discriminator_loss = th.zeros(logits_fake.shape[0], device=logits_fake.device)
                        discriminator_loss.requires_grad = True
                        return discriminator_loss

                if ((self.args.g_learning_period - 1) * (step - init_step) // self.args.g_learning_period) % self.args.lazy_reg == 0:
                    grad_real = th.autograd.grad(outputs=logits_real.sum(), inputs=target, create_graph=True)[0]
                    grad_real = grad_real.reshape(target.shape[0], -1)
                    grad_penalty = (grad_real.norm(2, dim=1) ** 2).reshape(-1)
                    grad_penalty = self.args.r1_gamma / 2 * grad_penalty
                    if self.args.large_log:
                        print("disc, gp: ", discriminator_loss.mean(), grad_penalty.mean())
                    discriminator_loss += grad_penalty.mean()
        return discriminator_loss

    def save(self, x, save_dir, name, npz=False):
        nrow = int(np.sqrt(x.shape[0]))
        image_grid = make_grid((x + 1.) / 2., nrow, padding=2)
        with bf.BlobFile(os.path.join(save_dir, f"{name}.png"), "wb") as fout:
            save_image(image_grid, fout)
        if npz:
            sample = ((x + 1) * 127.5).clamp(0, 255).to(th.uint8)
            sample = sample.permute(0, 2, 3, 1)
            sample = sample.contiguous()
            sample = sample.cpu().detach()
            os.makedirs(os.path.join(save_dir, 'targets'), exist_ok=True)
            np.savez(os.path.join(save_dir, f"targets/{name}.npz"), sample.numpy())



    def ctm_losses(
        self,
        step,
        model,
        x_start,
        model_kwargs=None,
        target_model=None,
        teacher_model=None,
        pretrained_classifier=None,
        classifier_vpsde=None,
        noise=None,
        discriminator=None,
        init_step=0,
        ctm=True,
    ):
        if self.args.data_name in ['church']:
            x_start = self.pre_process(x_start)

        if model_kwargs is None:
            model_kwargs = {}
        if noise is None:
            noise = th.randn_like(x_start)
        dims = x_start.ndim
        s = None
        terms = {}

        if self.args.consistency_weight > 0.:
            if step % self.args.g_learning_period != 0 and self.args.discriminator_training and self.args.discriminator_fix:
                d_loss = th.zeros(x_start.shape[0], device=x_start.device)
                d_loss.requires_grad = True
                terms['d_loss'] = d_loss
                return terms
            #if dist.get_rank() == 0:
            cm = th.rand(1) < self.args.cm_ratio
            if step % self.args.g_learning_period != 0 and self.args.discriminator_training:
                cm = True
            if cm:
                num_heun_step = [self.args.start_scales - 1]
            else:
                num_heun_step = [self.get_num_heun_step(step)]
            dist.broadcast_object_list(num_heun_step, 0)
            num_heun_step = num_heun_step[0]
            if self.args.large_log:
                print("cm, num heun step: ", cm, num_heun_step)
            indices, _ = self.schedule_sampler.sample_t(x_start.shape[0], x_start.device, num_heun_step,
                                                        self.args.time_continuous)
            t = self.get_t(indices)
            t_dt = self.get_t(indices + num_heun_step)
            if ctm:
                new_indices = self.schedule_sampler.sample_s(self.args, x_start.shape[0], x_start.device, indices,
                                                             num_heun_step, self.args.time_continuous,
                                                             N=self.args.start_scales)
                s = self.get_t(new_indices)
            x_t = x_start + noise * append_dims(t, dims)
            # print(th.rand(1))
            dropout_state = th.get_rng_state()
            th.set_rng_state(dropout_state)

            if step % self.args.g_learning_period == 0 or not self.args.discriminator_training:

                assert (discriminator == None) == (self.args.g_learning_period == 1)
                if cm:
                    estimate = self.denoise_fn(model, x_t, t, s=th.ones_like(s) * self.args.sigma_min, ctm=ctm,
                                               **model_kwargs)
                else:
                    estimate = self.get_estimate(step, x_t, t, t_dt, s, model, target_model, ctm=ctm, **model_kwargs)

                if teacher_model:
                    x_t_dt = self.heun_solver(x_t, indices, teacher_model, dims, num_step=num_heun_step,
                                              **model_kwargs).detach()
                else:
                    with th.no_grad():
                        x_t_dt = self.denoise_fn(target_model, x_t, t, s=t_dt, ctm=ctm, **model_kwargs)
                if cm:
                    target = x_t_dt
                else:
                    target = self.get_target(step, x_t_dt, t_dt, s, model, target_model, ctm=ctm, **model_kwargs)


                if step % self.args.save_period == 0:
                    #r = np.random.randint(100000000)
                    self.save(estimate, logger.get_dir(), f'estimate_{step}')  # _{r}')
                    self.save(target, logger.get_dir(), f'target_{step}')  # _{r}')
                    self.save(x_t, logger.get_dir(), f'non_denoised_{step}')  # _{r}')
                    self.save(x_t_dt, logger.get_dir(), f'denoised_{step}')  # _{r}')

                snrs = self.get_snr(t)
                weights = get_weightings(self.args.weight_schedule, snrs, self.args.sigma_data, t, s)

                terms["consistency_loss"], estimate_features, target_features = self.get_consistency_loss(estimate, target, s, pretrained_classifier,
                                                                            classifier_vpsde, weights, step - init_step,)
                if self.args.large_log:
                    if s != None:
                        print(f"{step}-th step, t, t-dt, s, weight, loss: ", t[0].item(), t_dt[0].item(), s[0].item(), weights[0].item(), terms["consistency_loss"][0].item())
                    else:
                        print(f"{step}-th step, t, t-dt, weight, loss: ", t[0].item(), t_dt[0].item(), terms["consistency_loss"][0].item(), weights[0].item())
                if self.args.diffusion_training:
                    terms['denoising_loss'] = self.get_denoising_loss(model, x_start, model_kwargs,
                                                                      terms["consistency_loss"],
                                                                      step, init_step)

                if self.args.discriminator_training and step - init_step >= self.args.discriminator_start_itr:
                    estimate = self.get_gan_estimate(estimate, step, x_t, t, t_dt, s, model, target_model, ctm=ctm, **model_kwargs)
                    #gan_estimate = self.lpips_loss.get_features((estimate + 1) / 2.0)[-2]
                    #estimate = th.cat((estimate, self.denoise_fn(target_model, x_t, t, s=th.ones_like(t) * self.args.sigma_min, ctm=ctm, **model_kwargs)))
                    terms['d_loss'] = self.get_discriminator_loss(t, t_dt, model, estimate=estimate,
                                                                  target=target,
                                                                  consistency_loss=terms["consistency_loss"],
                                                                  discriminator=discriminator,
                                                                  step=step, init_step=init_step, noise=noise,
                                                                  estimate_features=estimate_features,
                                                                  target_features=estimate_features,)
                    print("d_loss: ", terms['d_loss'].shape, terms['denoising_loss'].shape)
            else:
                #estimate = th.cat((estimate, self.denoise_fn(target_model, x_t, t, s=th.ones_like(t) * self.args.sigma_min,
                #                                   ctm=ctm, **model_kwargs)))
                if self.args.discriminator_free_target:
                    target = x_start
                else:
                    if teacher_model:
                        x_t_dt = self.heun_solver(x_t, indices, teacher_model, dims, num_step=num_heun_step,
                                                  **model_kwargs).detach()
                    else:
                        with th.no_grad():
                            x_t_dt = self.denoise_fn(target_model, x_t, t, s=t_dt, ctm=ctm, **model_kwargs)
                    if cm:
                        target = x_t_dt
                    else:
                        target = self.get_target(step, x_t_dt, t_dt, s, model, target_model, ctm=ctm, **model_kwargs)
                estimate = self.get_gan_estimate(None, step, x_t, t, t_dt, s, model, target_model, ctm=ctm, **model_kwargs)
                terms['d_loss'] = self.get_discriminator_loss(t, t_dt, model,
                                                              estimate=estimate,
                                                              target=target,
                                                              learn_generator=False,
                                                              discriminator=discriminator,
                                                              step=step, init_step=init_step, noise=noise, free=False)
        else:
            terms = {'denoising_loss': self.get_denoising_loss(model, x_start, model_kwargs, None, step, init_step)}

        return terms

def karras_sample(
    diffusion,
    model,
    shape,
    steps,
    clip_denoised=True,
    progress=False,
    callback=None,
    model_kwargs=None,
    device=None,
    sigma_min=0.002,
    sigma_max=80,  # higher for highres?
    rho=7.0,
    sampler="heun",
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
    generator=None,
    ts=None,
    x_T=None,
    ctm=False,
    teacher=False,
    clip_output=True,
    train=False,
):
    if generator is None:
        generator = get_generator("dummy")

    if sampler in ["progdist", 'euler', 'exact']:
        sigmas = get_sigmas_karras(steps + 1, sigma_min, sigma_max, rho, device=device)
    else:
        sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)

    if x_T == None:
        x_T = generator.randn(*shape, device=device) * sigma_max

    sample_fn = {
        "heun": sample_heun,
        "dpm": sample_dpm,
        "ancestral": sample_euler_ancestral,
        "onestep": sample_onestep,
        "exact": sample_exact,
        "progdist": sample_progdist,
        "euler": sample_euler,
        "multistep": stochastic_iterative_sampler,
    }[sampler]

    if sampler in ["heun", "dpm"]:
        sampler_args = dict(
            s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise
        )
    elif sampler in ["multistep", "exact"]:
        sampler_args = dict(
            ts=ts, t_min=sigma_min, t_max=sigma_max, rho=rho, steps=steps
        )
    else:
        sampler_args = {}
    if sampler in ['heun']:
        sampler_args['teacher'] = False if train else teacher
        sampler_args['ctm'] = ctm
    def denoiser(x_t, t, s=th.ones(x_T.shape[0], device=device)):
        denoised = diffusion.denoise_fn(model, x_t, t, s, ctm, teacher, **model_kwargs)
        if clip_denoised:
            denoised = denoised.clamp(-1, 1)
        return denoised

    x_0 = sample_fn(
        denoiser,
        x_T,
        sigmas,
        generator,
        progress=progress,
        callback=callback,
        **sampler_args,
    )
    if clip_output:
        print("clip output")
        return x_0.clamp(-1, 1)
    return x_0


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = th.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_ancestral_step(sigma_from, sigma_to):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    sigma_up = (
        sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
    ) ** 0.5
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
    return sigma_down, sigma_up


@th.no_grad()
def sample_euler_ancestral(model, x, sigmas, generator, progress=False, callback=None):
    """Ancestral sampling with Euler method steps."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        denoised = model(x, sigmas[i] * s_in)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "sigma_hat": sigmas[i],
                    "denoised": denoised,
                }
            )
        d = to_d(x, sigmas[i], denoised)
        # Euler method
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        x = x + generator.randn_like(x) * sigma_up
    return x


@th.no_grad()
def sample_midpoint_ancestral(model, x, ts, generator, progress=False, callback=None):
    """Ancestral sampling with midpoint method steps."""
    s_in = x.new_ones([x.shape[0]])
    step_size = 1 / len(ts)
    if progress:
        from tqdm.auto import tqdm

        ts = tqdm(ts)

    for tn in ts:
        dn = model(x, tn * s_in)
        dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in)
        x = x + step_size * dn_2
        if callback is not None:
            callback({"x": x, "tn": tn, "dn": dn, "dn_2": dn_2})
    return x


@th.no_grad()
def sample_exact(
    denoiser,
    x,
    sigmas,
    generator,
    progress=False,
    callback=None,
    ts=[],
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    if ts != [] and ts != None:
        sigmas = []
        t_max_rho = t_max ** (1 / rho)
        t_min_rho = t_min ** (1 / rho)
        s_in = x.new_ones([x.shape[0]])

        for i in range(len(ts)):
            sigmas.append((t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho)
        sigmas = th.tensor(sigmas)
        sigmas = append_zero(sigmas).to(x.device)
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)


    for i in indices[:-1]:
        sigma = sigmas[i]
        print(sigma, sigmas[i+1])
        if sigmas[i+1] != 0:
            denoised = denoiser(x, sigma * s_in, s=sigmas[i + 1] * s_in)
            x = denoised
        else:
            denoised = denoiser(x, sigma * s_in, s=sigma * s_in)
            d = to_d(x, sigma, denoised)
            dt = sigmas[i + 1] - sigma
            x = x + d * dt
        #else:
        #    denoised = denoiser(x, sigma * s_in)

        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "denoised": denoised,
                }
            )
        #x = denoised
    return x


@th.no_grad()
def sample_heun(
    denoiser,
    x,
    sigmas,
    generator,
    progress=False,
    callback=None,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
    teacher=False,
    ctm=False,
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        print("sigmas: ", sigmas[i])
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
            if s_tmin <= sigmas[i] <= s_tmax
            else 0.0
        )
        eps = generator.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
        if ctm:
            denoised = denoiser(x, sigma_hat * s_in, s=sigma_hat * s_in)
        else:
            if teacher:
                denoised = denoiser(x, sigma_hat * s_in, s=None)
            else:
                denoised = denoiser(x, sigma_hat * s_in, s=sigma_hat * s_in)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "sigma_hat": sigma_hat,
                    "denoised": denoised,
                }
            )
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            if ctm:
                denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, s=sigmas[i + 1] * s_in)
            else:
                if teacher:
                    denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, s=None)
                else:
                    denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, s=sigmas[i + 1] * s_in)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    return x


@th.no_grad()
def sample_euler(
    denoiser,
    x,
    sigmas,
    generator,
    progress=False,
    callback=None,
):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        sigma = sigmas[i]
        denoised = denoiser(x, sigma * s_in)
        d = to_d(x, sigma, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "denoised": denoised,
                }
            )
        dt = sigmas[i + 1] - sigma
        x = x + d * dt
    return x


@th.no_grad()
def sample_dpm(
    denoiser,
    x,
    sigmas,
    generator,
    progress=False,
    callback=None,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
):
    """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
    s_in = x.new_ones([x.shape[0]])
    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
            if s_tmin <= sigmas[i] <= s_tmax
            else 0.0
        )
        eps = generator.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
        denoised = denoiser(x, sigma_hat * s_in)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigmas[i],
                    "sigma_hat": sigma_hat,
                    "denoised": denoised,
                }
            )
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigma_hat
        dt_2 = sigmas[i + 1] - sigma_hat
        x_2 = x + d * dt_1
        denoised_2 = denoiser(x_2, sigma_mid * s_in)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
    return x


@th.no_grad()
def sample_onestep(
    distiller,
    x,
    sigmas,
    generator=None,
    progress=False,
    callback=None,
):
    """Single-step generation from a distilled model."""
    s_in = x.new_ones([x.shape[0]])
    return distiller(x, sigmas[0] * s_in, None)


@th.no_grad()
def stochastic_iterative_sampler(
    distiller,
    x,
    sigmas,
    generator,
    ts,
    progress=False,
    callback=None,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    t_max_rho = t_max ** (1 / rho)
    t_min_rho = t_min ** (1 / rho)
    s_in = x.new_ones([x.shape[0]])

    for i in range(len(ts) - 1):
        t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        x0 = distiller(x, t * s_in, None)
        next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = np.clip(next_t, t_min, t_max)
        x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)

    return x


@th.no_grad()
def sample_progdist(
    denoiser,
    x,
    sigmas,
    generator=None,
    progress=False,
    callback=None,
):
    s_in = x.new_ones([x.shape[0]])
    sigmas = sigmas[:-1]  # skip the zero sigma

    indices = range(len(sigmas) - 1)
    if progress:
        from tqdm.auto import tqdm

        indices = tqdm(indices)

    for i in indices:
        sigma = sigmas[i]
        denoised = denoiser(x, sigma * s_in)
        d = to_d(x, sigma, denoised)
        if callback is not None:
            callback(
                {
                    "x": x,
                    "i": i,
                    "sigma": sigma,
                    "denoised": denoised,
                }
            )
        dt = sigmas[i + 1] - sigma
        x = x + d * dt

    return x


@th.no_grad()
def iterative_colorization(
    distiller,
    images,
    x,
    ts,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
    generator=None,
):
    def obtain_orthogonal_matrix():
        vector = np.asarray([0.2989, 0.5870, 0.1140])
        vector = vector / np.linalg.norm(vector)
        matrix = np.eye(3)
        matrix[:, 0] = vector
        matrix = np.linalg.qr(matrix)[0]
        if np.sum(matrix[:, 0]) < 0:
            matrix = -matrix
        return matrix

    Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)
    mask = th.zeros(*x.shape[1:], device=dist_util.dev())
    mask[0, ...] = 1.0

    def replacement(x0, x1):
        x0 = th.einsum("bchw,cd->bdhw", x0, Q)
        x1 = th.einsum("bchw,cd->bdhw", x1, Q)

        x_mix = x0 * mask + x1 * (1.0 - mask)
        x_mix = th.einsum("bdhw,cd->bchw", x_mix, Q)
        return x_mix

    t_max_rho = t_max ** (1 / rho)
    t_min_rho = t_min ** (1 / rho)
    s_in = x.new_ones([x.shape[0]])
    images = replacement(images, th.zeros_like(images))

    for i in range(len(ts) - 1):
        t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        x0 = distiller(x, t * s_in)
        x0 = th.clamp(x0, -1.0, 1.0)
        x0 = replacement(images, x0)
        next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = np.clip(next_t, t_min, t_max)
        x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)

    return x, images


@th.no_grad()
def iterative_inpainting(
    distiller,
    images,
    x,
    ts,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
    generator=None,
):
    from PIL import Image, ImageDraw, ImageFont

    image_size = x.shape[-1]

    # create a blank image with a white background
    img = Image.new("RGB", (image_size, image_size), color="white")

    # get a drawing context for the image
    draw = ImageDraw.Draw(img)

    # load a font
    font = ImageFont.truetype("arial.ttf", 250)

    # draw the letter "C" in black
    draw.text((50, 0), "S", font=font, fill=(0, 0, 0))

    # convert the image to a numpy array
    img_np = np.array(img)
    img_np = img_np.transpose(2, 0, 1)
    img_th = th.from_numpy(img_np).to(dist_util.dev())

    mask = th.zeros(*x.shape, device=dist_util.dev())
    mask = mask.reshape(-1, 7, 3, image_size, image_size)

    mask[::2, :, img_th > 0.5] = 1.0
    mask[1::2, :, img_th < 0.5] = 1.0
    mask = mask.reshape(-1, 3, image_size, image_size)

    def replacement(x0, x1):
        x_mix = x0 * mask + x1 * (1 - mask)
        return x_mix

    t_max_rho = t_max ** (1 / rho)
    t_min_rho = t_min ** (1 / rho)
    s_in = x.new_ones([x.shape[0]])
    images = replacement(images, -th.ones_like(images))

    for i in range(len(ts) - 1):
        t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        x0 = distiller(x, t * s_in)
        x0 = th.clamp(x0, -1.0, 1.0)
        x0 = replacement(images, x0)
        next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = np.clip(next_t, t_min, t_max)
        x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)

    return x, images


@th.no_grad()
def iterative_superres(
    distiller,
    images,
    x,
    ts,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
    generator=None,
):
    patch_size = 8

    def obtain_orthogonal_matrix():
        vector = np.asarray([1] * patch_size**2)
        vector = vector / np.linalg.norm(vector)
        matrix = np.eye(patch_size**2)
        matrix[:, 0] = vector
        matrix = np.linalg.qr(matrix)[0]
        if np.sum(matrix[:, 0]) < 0:
            matrix = -matrix
        return matrix

    Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32)

    image_size = x.shape[-1]

    def replacement(x0, x1):
        x0_flatten = (
            x0.reshape(-1, 3, image_size, image_size)
            .reshape(
                -1,
                3,
                image_size // patch_size,
                patch_size,
                image_size // patch_size,
                patch_size,
            )
            .permute(0, 1, 2, 4, 3, 5)
            .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
        )
        x1_flatten = (
            x1.reshape(-1, 3, image_size, image_size)
            .reshape(
                -1,
                3,
                image_size // patch_size,
                patch_size,
                image_size // patch_size,
                patch_size,
            )
            .permute(0, 1, 2, 4, 3, 5)
            .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
        )
        x0 = th.einsum("bcnd,de->bcne", x0_flatten, Q)
        x1 = th.einsum("bcnd,de->bcne", x1_flatten, Q)
        x_mix = x0.new_zeros(x0.shape)
        x_mix[..., 0] = x0[..., 0]
        x_mix[..., 1:] = x1[..., 1:]
        x_mix = th.einsum("bcne,de->bcnd", x_mix, Q)
        x_mix = (
            x_mix.reshape(
                -1,
                3,
                image_size // patch_size,
                image_size // patch_size,
                patch_size,
                patch_size,
            )
            .permute(0, 1, 2, 4, 3, 5)
            .reshape(-1, 3, image_size, image_size)
        )
        return x_mix

    def average_image_patches(x):
        x_flatten = (
            x.reshape(-1, 3, image_size, image_size)
            .reshape(
                -1,
                3,
                image_size // patch_size,
                patch_size,
                image_size // patch_size,
                patch_size,
            )
            .permute(0, 1, 2, 4, 3, 5)
            .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
        )
        x_flatten[..., :] = x_flatten.mean(dim=-1, keepdim=True)
        return (
            x_flatten.reshape(
                -1,
                3,
                image_size // patch_size,
                image_size // patch_size,
                patch_size,
                patch_size,
            )
            .permute(0, 1, 2, 4, 3, 5)
            .reshape(-1, 3, image_size, image_size)
        )

    t_max_rho = t_max ** (1 / rho)
    t_min_rho = t_min ** (1 / rho)
    s_in = x.new_ones([x.shape[0]])
    images = average_image_patches(images)

    for i in range(len(ts) - 1):
        t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        x0 = distiller(x, t * s_in)
        x0 = th.clamp(x0, -1.0, 1.0)
        x0 = replacement(images, x0)
        next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = np.clip(next_t, t_min, t_max)
        x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2)

    return x, images
