"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm
from functools import partial

from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like

#changes
#import closed-form discriminator codes here
from discriminator.grad_disc_torch import discriminator_model_RBF, find_rbf_centres_weights

from PIL import Image
import PIL.Image

import pywt

import tensorflow as tf
import torch.nn as nn

import matplotlib.pyplot as plt 
import scienceplots
plt.style.use('science')

'''
#import trainable discriminator codes here
from trainable_discriminator import grad_disc
'''     

class DDIMSampler(object):

    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()

        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

        self.decode = 1 #this is to study the evolution of any variable, like pred_x0, or direction, x_prev

        self.decode_var = 0 #this is to calculate variance either on the pixel space or latent space. 1 -> variance in the pixel space, 0 -> variance in the latent space. 

        self.decoder = model.decode_first_stage #decoder is a part of the model and not an object of the class DDIMSampler. 

        self.loss = nn.MSELoss()
        self.pred_loss = []

        #lists to store l2 norm plots 
        self.x0_pred = []
        self.xt_dir = []
        self.disc_grad = []
        self.e_t = [] #stores the unmodified noise estimation

        #lists to store the mean variance of a batch across latent space pixels
        self.mean_var = []

        #lists to store the mean variance of a batch across image pixels
        self.decode = 0
        self.mean_var_dec = []
        self.decoder = model.decode_first_stage

        self.e_count = 0
        self.d_count = 0

        #lists to store noise variances
        #self.noise_variance = []
        #self.time_shifts = []

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)


    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('one_minus_alphas_cumprod', to_torch(1. - alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    ## trainable Discriminator Guidance. Can be used for ablation and put in the Appendix!
    def grad_disc_trainable(self, discriminator, x, t, y = None, **kwargs):

        with torch.enable_grad():

            x_in = x.detach().requires_grad_(True)

            #make it a 1D tensor of size same as batch_size of x_in
            t = torch.ones(x_in.shape[0], device = x_in.device) * t 

            pr = discriminator(x_in, t, sigmoid = True).view(-1)

            pr = torch.clip(pr, min=1e-5, max=1 - 1e-5)
            log_density_ratio = torch.log(pr) - torch.log(1 - pr)

            #print("log_density_ratio mean is: ", log_density_ratio.mean())

            dg = torch.autograd.grad(log_density_ratio.sum(), x_in)[0]

            return dg

    #this code has been taken from "diffusers".
    def save_img_grid(self, image, path, gridh = 3, gridw = 3,img_resolution = 256, img_channels = 3):
        image = (image * 127.5 + 128).clip(0, 255).to(torch.uint8)
        image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
        image = image.reshape(gridh * img_resolution, gridw * img_resolution, img_channels)
        image = image.cpu().numpy()
        PIL.Image.fromarray(image, 'RGB').save(path)
        

    #this is the function to obtain the Tc value. 
    '''
    Function definition

    num - number of samples
    ds - dataset iterable
    '''

    ##########################################################################################################
    #this part of the code is for estimating the cutoff value, Tc and other statistics as mentioned in the bias24 paper. 
    #randomly sample 'num' images from the dataset - these will be x0

    def sample_dataset_samples(self, num, ds):

        device = self.model.betas.device #this is the schedule. 

        #sample from real dataset
        images_real = next(ds)

        #scale the dataset to [-1,1]
        scaler = lambda x: 2. * x - 1.
        images_real = scaler(images_real).to(device)

        #encode the samples from real dataset to the latent space
        with torch.no_grad():
            images_real = self.model.encode_first_stage(images_real).to(device)

        return images_real, device

    def extract(self, a, t, x_shape):

        batch_size = t.shape[0]
        out = a.gather(-1, t)

        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

    # forward diffusion (using the nice property :)
    def q_sample(self, x_start, t, alphas_cumprod, sqrt_one_minus_alphas_cumprod, alphas_prev_cumprod, noise = None):

        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        sqrt_alphas_cumprod_t = self.extract(sqrt_alphas_cumprod, t, x_start.shape)

        sqrt_one_minus_alphas_cumprod_t = self.extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)

        sqrt_alphas_prev_cumprod = torch.sqrt(alphas_prev_cumprod)
        sqrt_alphas_prev_cumprod_t = self.extract(sqrt_alphas_prev_cumprod, t, x_start.shape)

        return (sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise), sqrt_alphas_cumprod_t, sqrt_one_minus_alphas_cumprod_t, sqrt_alphas_prev_cumprod_t

    #predict_x0 using the network
    def predict_x0(self, x, t, a_t, sqrt_one_minus_at, sqrt_a_prev, c = None):
        
        with torch.no_grad():
            e_t = self.model.apply_model(x, t, c)

        # current prediction for x_0
        pred_x0 = sqrt_a_prev * (x - sqrt_one_minus_at * e_t) / a_t

        return pred_x0

    #calculate sample variance
    def get_sample_var(self, noisy_real_images, t):

        self.sample_var[t] = torch.var(noisy_real_images, dim = [0,1,2,3]).tolist()
        self.sample_var_wavelets[t] = np.mean(self.get_vec_torch(noisy_real_images))

    #apply forward process and generate the ground truth samples x[1,T]
    def forward_process(self, num, ds, batch_size):
        
        betas = self.model.betas 
        alphas_ddpm = self.model.alphas_cumprod
        alphas_prev_ddpm = self.model.alphas_cumprod_prev
        one_minus_alphas_ddpm = (1. - alphas_ddpm)
        sqrt_one_minus_alphas_ddpm = self.model.sqrt_one_minus_alphas_cumprod

        real_images, device = self.sample_dataset_samples(num, ds)

        for t in tqdm(range(betas.shape[0])):

            ts = torch.full((batch_size,), t, device=device, dtype=torch.long)

            noisy_real_images, sqrt_alphas_cumprod_t, sqrt_one_minus_alphas_cumprod_t, sqrt_alphas_prev_cumprod_t = self.q_sample(real_images, ts, alphas_ddpm, sqrt_one_minus_alphas_ddpm, alphas_prev_ddpm)

            self.get_sample_var(noisy_real_images, t)

            '''
            #when taking the MSE
            predicted_x0 = self.predict_x0(noisy_real_images, ts, sqrt_alphas_cumprod_t, sqrt_one_minus_alphas_cumprod_t, sqrt_alphas_prev_cumprod_t)

            #####################################################
            #take the MSE 
            with torch.no_grad():
                error = self.loss(real_images, predicted_x0)
                self.pred_loss[t] += (error.cpu().numpy())
            ######################################################
            '''

    #TODO. for multiple epochs.
    def obtain_cutoff_value(self, num, ds, batch_size):

        epochs = num // batch_size

        '''
        #when doing for multiple epochs
        self.pred_loss = [0] * self.model.betas.shape[0]
        self.sample_var_combined = [[] for i in range(epochs)]
        '''

        for i in range(num // batch_size):

            '''
            #when doing for multiple epochs
            self.sample_var = [[] for i in range(self.model.betas.shape[0])] #statistical - pixel variance
            self.sample_var_wavelets = [[] for i in range(self.model.betas.shape[0])] #noise variance using wavelets
            '''

            self.sample_var = [0 for i in range(self.model.betas.shape[0])]
            self.sample_var_wavelets = [0 for i in range(self.model.betas.shape[0])]

            self.forward_process(num, ds, batch_size)

            '''
            self.sample_var_combined[i] = self.sample_var
            '''
            
        #self.pred_loss = [(i // epochs) for i in self.pred_loss]

        #self.plot_mse()
        self.plot_variance(ds)
    ##########################################################################################################

    #from the plot, observe the time step from where the mse is negligible. This point can be approximated to be the cut off
    def plot_mse(self):

        t = [i for i in range(1000)]
        t_n = np.array(t)
        
        with plt.style.context('science'):

            plt.plot(t_n, self.pred_loss)
            plt.xlabel("Timesteps (t)")
            plt.ylabel("MSE")
            plt.title("Network prediction error")
            plt.legend()
            plt.savefig("network_error.png")          
    
    def plot_variance(self, ds):

        np.save("var_wavelets.npy", self.sample_var_wavelets)
        np.save("var.npy", self.sample_var)
        np.save("var_schedule.npy", torch.square(self.model.sqrt_one_minus_alphas_cumprod).cpu().numpy())

        exit(0)
        t = [i for i in range(1000)]
        t_n = np.array(t)

        plt.plot(t_n, self.sample_var_wavelets, label = "wavelet-based noise estimate" )
        plt.plot(t_n, self.sample_var, label = "pixel-level noise estimate")
        plt.plot(t_n, torch.square(self.model.sqrt_one_minus_alphas_cumprod).cpu().numpy(), label = "pre-defined noise schedule")

        plt.xlabel("Timesteps")
        plt.ylabel("Variance")
        plt.title("Variance estimates")
        plt.legend()
        plt.savefig("variance_estimates_celeba256.pdf")


    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None,
               dg = 0,
               ds_iter = None,
               dataset = None, 
               seed1 = 0,
               seed2 = 0,
               trainable = 0,
               discriminator = None,
               tss = 0,
               cut_off_value = 0,
               window_size = 0,

               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               **kwargs
               ):

        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    dg = dg,
                                                    ds_iter = ds_iter,
                                                    dataset = dataset,
                                                    seed1 = seed1,
                                                    seed2 = seed2,
                                                    trainable = trainable,
                                                    discriminator = discriminator,
                                                    tss = tss,
                                                    cut_off_value = cut_off_value,
                                                    window_size = window_size
                                                    )

        return samples, intermediates

    #this is to just check the langevin-type update on gradient of discriminator. 
    def only_dg(self, x, x_next, ds_iter, discriminator_rbf):

        for i in tqdm(range(500)):

            dg = 100

            #sample from real dataset
            images_real = next(ds_iter)
            #scale the dataset to [-1,1]
            scaler = lambda x: 2. * x - 1.
            images_real = scaler(images_real).to(x.device)

            #encode the samples from real dataset to the latent space
            with torch.no_grad():
                images_real = self.model.encode_first_stage(images_real).to(x.device)

            #obtain centres and weights
            C_d, C_g, D_d, D_g = find_rbf_centres_weights(images_real, x_next, images_real.shape[0])

            '''
            D_d = D_d.to(C_d.device)
            D_g = D_g.to(C_g.device)
            '''
            #set centres and weights
            discriminator_rbf.set_cw([C_d, D_d,C_g, D_g])
                    
            #obtain discriminator gradients
            real_grad, fake_grad = discriminator_rbf(x) 

            grad_disc = (fake_grad - real_grad)

            denoised = (dg)*torch.reshape(grad_disc, x.shape)

            #printing some statistics
            print("denoised - after accumulation: ", torch.min(denoised))
            print("denoised - after accumulation: ", torch.max(denoised))

            x_next = x

            #printing some statistics
            print("x - before adding: ", torch.min(x))
            print("x - before adding: ", torch.max(x))
            
            x += denoised #langevin type update

            #printing some statistics
            print("x - after adding: ", torch.min(x))
            print("x - after adding: ", torch.max(x))

            x_prev_sample = self.decoder(x)
            x_grid = x_prev_sample[:9]
            #save, if required.

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, dg = 0, ds_iter = None, dataset = None, seed1 = 0, seed2 = 1, trainable = 0, discriminator = None, Ts = None, tss = 0, cut_off_value = 0, window_size = 0):

        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            
            generator1 = torch.manual_seed(seed1)
            generator2 = torch.manual_seed(seed2)

            img = torch.randn(shape, generator = generator1).to(device)
            img_next = torch.randn(shape, generator = generator2).to(device) #defaults to second seed only!
            denoised = torch.zeros(shape).to(device)#initialization is not necessary. 
            #denoised = torch.randn(shape, generator = generator2).to(device)

        else:
            img = x_T

        discriminator_rbf = discriminator_model_RBF(b, tuple(img.shape[1:]), device)

        x = img
        x_next = img_next
        
        '''
        #to only study gradient of discriminator
        self.only_dg(x, x_next, ds_iter, discriminator_rbf)
        exit(0)
        '''

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")\

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        Ts = None #initialization for time-shift

        for i, step in enumerate(iterator):
            
            #if Ts has gone below Tc but T is still above Tc, continue until T becomes lesser than Tc. DDIM can continue from that step onwards. 
            '''
            #you can use different properties like, min, mean
            if (Ts and min(Ts) <= 600 and step > 600):
                continue
            '''

            '''
            #you can use different properties like, min, mean
            if (Ts and (sum(Ts)//len(Ts)) <= cut_off_value and step > cut_off_value):
                continue
            '''

            #as different images in the batch have different variances. This is completely heuristic. In some or the other way it will lead to redundant denoising for some time-steps. 
            if(tss and Ts and (step > cut_off_value)):
                Ts = [i if i > cut_off_value else step for i in Ts]

            
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            '''
            #the first time-step will not be shifted. Its the same as how DDIM/DDPM begins!
            if not i:
                self.time_shifts.append(step)
            '''

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            outs = self.p_sample_ddim(img, img_next, denoised, discriminator_rbf, cond, ts, i = i, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning, dg = dg, ds_iter = ds_iter, trainable = trainable, discriminator = discriminator, Ts = Ts, tss = tss, cut_off_value = cut_off_value, window_size = window_size)

            img, pred_x0, img_next, Ts = outs

            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)


        """
        #save the lists to store l2 norm plots 

        if dataset == 'celeba':
            ds_dir = 'celeba256'
        elif dataset == 'ffhq':
            ds_dir = 'ffhq256'
        elif dataset == 'lsun_church':
            ds_dir = 'lsun_church256'
        
        '''
        #save the files as required
        np.save()
        np.save()
        np.save()
        '''

        '''
        #ignore this parameter for the time-being. 
        np.save()
        '''

        '''
        np.save()    
        np.save()
        np.save()
        '''

        np.save()
        """
        
        '''
        #prints the modified time-steps being taken.
        #print only if taking the mean variance. 
        print("The time-steps taken are: ", self.time_shifts)
        '''
        print("e_theta NFEs: ", self.e_count)
        print("d evaluations: ", self.d_count)

        return img, intermediates

    '''
    add the code for discriminator here. 
    '''

    #this code block cannot be changed.
    def get_var(self, x):
        coeffs = pywt.dwt(x, 'db4', axis =0)
        # Estimate noise variance from detail coefficients
        detail_coeffs = coeffs[1]
        #sigma_n = float(tfp.stats.percentile(tf.math.abs(detail_coeffs),q = 50)) / 0.6745
        sigma_n = np.median(np.abs(detail_coeffs)) / 0.6745
        return sigma_n ** 2 #returns the variance

    def get_vec(self, x_noisy):

        x_noisy = x_noisy.permute(0,2,3,1)

        b,r,c,ch = x_noisy.shape #these are the image dimensions

        #r = c = 6. Usually keep this even.
        o = tf.ones([1,c])
        o = tf.cast(o, tf.float32)

        z_1 = tf.zeros([1,c])
        z_1 = tf.cast(z_1, tf.float32)
        
        #print(o,z_1)

        A = tf.concat((o,z_1),axis = 0)
        B = tf.concat((z_1,o),axis = 0)

        Abig = tf.tile(A,(r//2,1))
        Bbig = tf.tile(B,(r//2,1))   

        #print(Abig)
        #print(Bbig)

        z = [] #this is 128 x 1

        for x in x_noisy:

            #x = tf.image.rgb_to_grayscale(x) 
            zz = []
            x = x.cpu().numpy()

            for y in range(ch):

                sess = tf.compat.v1.Session()
                with tf.device('/GPU:0'):
            
                    xx = tf.cast(x[:,:,y],tf.float32)
                    #xx = tf.reshape(xx, shape = (32,32))
                    xx = tf.reshape(xx, shape = (64,64))
                    
                    x_flip = tf.experimental.numpy.fliplr(xx)

                    odd = tf.multiply(x_flip,Bbig) #odd rows should be flipped

                    even = tf.multiply(xx,Abig)     #even rows no changes

                    Correct = odd + even
                    
                    #Correct = tf.reshape(Correct,(65536,1))
                    Correct = tf.reshape(Correct,(4096,1))
                    Correct = np.array(Correct)

                    zz.append(self.get_var(Correct)) #1x3

                sess.close() 
                tf.compat.v1.reset_default_graph() 
            
            z.append(np.mean(zz)) #stores variance of all images in the batch. Mean of all three RGB channels

        return z

    def get_vec_torch(self, x_noisy):

        device = x_noisy.device

        x_noisy = x_noisy.permute(0,2,3,1)

        b,r,c,ch = x_noisy.shape #these are the image dimensions

        #r = c = 6. Usually keep this even.
        o = torch.ones([1,c])
        o = o.to(torch.float32).to(device)

        z_1 = torch.zeros([1,c])
        z_1 = z_1.to(torch.float32).to(device)
    
        #print(o,z_1)

        A = torch.cat((o,z_1),axis = 0).to(device)
        B = torch.cat((z_1,o),axis = 0).to(device)

        Abig = torch.tile(A,(r//2,1)).to(device)
        Bbig = torch.tile(B,(r//2,1)).to(device)   

        #print(Abig)
        #print(Bbig)

        z = [] #this is 128 x 1

        for x in x_noisy:

            #x = tf.image.rgb_to_grayscale(x) 
            zz = []

            for y in range(ch):
            
                xx = x[:,:,y].to(torch.float32).to(device)
                #xx = tf.reshape(xx, shape = (32,32))
                xx = torch.reshape(xx, shape = (r,c))
                
                x_flip = torch.flip(xx, dims = [1]).to(device)

                odd = torch.mul(x_flip,Bbig).to(device) #odd rows should be flipped

                even = torch.mul(xx,Abig).to(device)     #even rows no changes

                Correct = (odd + even).to(device)
                
                #Correct = tf.reshape(Correct,(65536,1))
                Correct = torch.reshape(Correct,(r*c,1))
                Correct = Correct.cpu().numpy()

                zz.append(self.get_var(Correct)) #1x3
        
            z.append(np.mean(zz)) #stores variance of all images in the batch. Mean of all three RGB channels

        torch.cuda.empty_cache()
        return z

    #code from Bias24 paper
    def apply_time_shift(self, img_list, t_next, alpha_list, cut_off_value, window_size):

        x_pre = img_list #[-1]
        n = x_pre.shape[0] #batch_size

        '''
        #if using data variance
        var = torch.var(x_pre.view(x_pre.size()[0],-1),dim=-1)
        '''

        with torch.no_grad():
            var = torch.as_tensor(self.get_vec_torch(img_list)).to(x_pre.device)

        var.reshape(-1,1)

        if t_next - window_size > 0 and t_next+window_size+1<len(alpha_list):
            time_list = alpha_list[t_next-window_size:t_next+window_size+1]
        elif t_next-window_size <= 0:
            time_list = alpha_list[0:t_next+window_size+1]
        elif t_next+window_size+1 >= len(alpha_list):
            time_list = alpha_list[t_next-window_size:]

        time_list = time_list.tolist()

        time_list = torch.tensor([time_list]*var.size()[0])

        var = var.unsqueeze(1).expand_as(time_list)
        
        dist = (var - time_list.to(x_pre.device)) ** 2

        next_t = torch.argmin(dist,dim=1)#next timestep

        if t_next - window_size > 0:
            n_t = next_t + t_next-window_size

        else:
            n_t = next_t 
        
        if t_next > cut_off_value:
            next_t = torch.as_tensor(n_t).to(x_pre.device)
        else:
            next_t = (torch.ones(n) * (t_next)).to(x_pre.device)

        torch.cuda.empty_cache()

        #each image in the batch will have a different variance to jump to
        return next_t

    @torch.no_grad()
    def p_sample_ddim(self, x, x_next, denoised, discriminator_rbf, c, t, i, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, dg = 0, ds_iter = None, trainable = 0, discriminator = None, Ts = None, tss = 0, cut_off_value = 0, window_size = 0):
        b, *_, device = *x.shape, x.device

        #obtains the current time step
        t_cur = t[0].item()

        """
        t: this is as per DDPM steps. Starts from 999. 

        t should be a tensor of size same as the same batch size.

        i: this keeps a track of the number of forward iterations. Starts from 0. 

        index: this is as per DDIM steps. Starts from 499. 

        For eg: 999 in DDPM is equivalent to 499 in DDIM steps
        """

        if (tss and Ts and (t_cur > cut_off_value)): #500 value for Tc is just an estimate. Ts should not be none.

            print("*****{present}**********Shifting to the range**********: {minimum} and {maximum}".format(minimum = min(Ts), maximum = max(Ts), present = t[0].item()))

            alphas_ddpm = self.model.alphas_cumprod
            alphas_prev_ddpm = self.model.alphas_cumprod_prev
            one_minus_alphas_ddpm = self.one_minus_alphas_cumprod
            sqrt_one_minus_alphas_ddpm = self.model.sqrt_one_minus_alphas_cumprod
            sigmas_ddpm = self.ddim_sigmas_for_original_num_steps #since it is DDIM, this should always be zero


            #self.time_shifts.append(Ts) #append only if taking the mean variance

            #TODO. ensure that t and Ts are of the same format. 
            #obtain Ts as just an index arg min.

            #t = torch.full((b,), Ts, device=device, dtype=torch.long) #only if Ts is a single value
            #this is the index of the time-step, Ts. extract from DDPM parameters 
            index_ts = Ts

            '''
            #select parameters corresponding to the currently considered timestep - DDPM
            a_t_ddpm = torch.full((b,1,1,1), alphas_ddpm[index_ts], device = device)
            one_minus_a_t_ddpm = torch.full((b,1,1,1), one_minus_alphas_ddpm[index_ts], device = device)
            a_prev_ddpm = torch.full((b,1,1,1), alphas_prev_ddpm[index_ts], device=device)
            sigma_t_ddpm = torch.full((b,1,1,1), sigmas_ddpm[index_ts], device=device)
            sqrt_one_minus_at_ddpm = torch.full((b,1,1,1), sqrt_one_minus_alphas[index_ts], device=device)
            '''

            a_t_ddpm = alphas_ddpm[index_ts].unsqueeze(1).unsqueeze(1).unsqueeze(1)
            one_minus_a_t_ddpm = one_minus_alphas_ddpm[index_ts].unsqueeze(1).unsqueeze(1).unsqueeze(1)
            a_prev_ddpm = alphas_prev_ddpm[index_ts].unsqueeze(1).unsqueeze(1).unsqueeze(1)
            sigma_t_ddpm = sigmas_ddpm[index_ts].unsqueeze(1).unsqueeze(1).unsqueeze(1)
            sqrt_one_minus_at_ddpm = sqrt_one_minus_alphas_ddpm[index_ts].unsqueeze(1).unsqueeze(1).unsqueeze(1)

            #ensure that all the DDIM parameters now equal DDPM parameters.
            a_t = a_t_ddpm
            a_prev = a_prev_ddpm
            sigma_t = sigma_t_ddpm
            sqrt_one_minus_at = sqrt_one_minus_at_ddpm

            Ts_tensor = torch.as_tensor(Ts).to(device)

            t = Ts_tensor

        
        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            e_t = self.model.apply_model(x, t, c)

            self.e_count += 1

        else:
            x_in = torch.cat([x] * 2)
            t_in = torch.cat([t] * 2)
            c_in = torch.cat([unconditional_conditioning, c])
            e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

        if score_corrector is not None:
            assert self.model.parameterization == "eps"
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
        
        #############################################################################################################################

        # select parameters corresponding to the currently considered timestep. DDIM parameters only when Ts is none or T is below Tc
        # if Ts is none then (not Ts) is true
        if ((not tss) or (not Ts) or (t_cur < cut_off_value)):

            #DDIM parameters
            alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
            alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
            sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
            sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
            #TODO. self.model.ddim_sigmas_for_original_num_steps. This is not defined! Fixed. Don't use self.model but just do self.! 

            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
            
        #############################################################################################################################

        """
        #modified epsilon here. - closed-form discriminator. This can be done as an ablation experiment and reported in Appendix!. Algebraically same as the current update, all depends on the weighting factor. 
        if(dg > 0 and trainable == 0):

            #sample from real dataset
            images_real = next(ds_iter)
            #scale the dataset to [-1,1]
            scaler = lambda x: 2. * x - 1.
            images_real = scaler(images_real).to(device)

            #encode the samples from real dataset to the latent space
            with torch.no_grad():
                images_real = self.model.encode_first_stage(images_real).to(device)

            #obtain centres and weights
            C_d, C_g, D_d, D_g = find_rbf_centres_weights(images_real, x_next, images_real.shape[0])

            #set centres and weights
            discriminator_rbf.set_cw([C_d, D_d,C_g, D_g])
                
            #obtain discriminator gradients
            real_grad, fake_grad = discriminator_rbf(x) 

            grad_disc = (fake_grad - real_grad)
            denoised = torch.reshape(grad_disc, x.shape)
            denoised = - 10*(sqrt_one_minus_at * denoised)

            #gradient of discriminator - l2 norm
            C = denoised.view(denoised.shape[0], -1)
            self.grad_disc.append(torch.mean(torch.norm(C, dim = 1)).cpu().numpy())
            
            D = e_t.view(e_t.shape[0], -1)
            self.e_t.append(torch.mean(torch.norm(D, dim = 1)).cpu().numpy())

            e_t = e_t + denoised
        """
        '''
        D = e_t.view(e_t.shape[0], -1)
        self.e_t.append(torch.mean(torch.norm(D, dim = 1)).cpu().numpy())
        '''
        
        # current prediction for x_0
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

        '''
        # current prediction for x_0 - l2 norm
        A = pred_x0.view(pred_x0.shape[0], -1)
        self.x0_pred.append(torch.mean(torch.norm(A, dim = 1)).cpu().numpy())
        '''

        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)

        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t

        '''
        # direction pointing to x_t - l2 norm
        B = dir_xt.view(dir_xt.shape[0], -1)
        self.xt_dir.append(torch.mean(torch.norm(B, dim = 1)).cpu().numpy())
        '''

        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)

        x_prev = (a_prev.sqrt() * pred_x0 + dir_xt + noise)

        #add the gradient of discriminator part here as an additional term in the DDIM update equation. closed-form discriminator
        if((dg > 0) and (trainable == 0) and (t_cur > cut_off_value)):

            #sample from real dataset
            images_real = next(ds_iter)
            #scale the dataset to [-1,1]
            scaler = lambda x: 2. * x - 1.
            images_real = scaler(images_real).to(device)

            #encode the samples from real dataset to the latent space
            with torch.no_grad():
                images_real = self.model.encode_first_stage(images_real).to(device)

            #obtain centres and weights
            C_d, C_g, D_d, D_g = find_rbf_centres_weights(images_real, x_next, images_real.shape[0])

            D_d = D_d.to(C_d.device)
            D_g = D_g.to(C_g.device)

            #set centres and weights
            discriminator_rbf.set_cw([C_d, D_d,C_g, D_g])
                
            #obtain discriminator gradients
            #real_grad, fake_grad = discriminator_rbf(x)
            real_grad, fake_grad = discriminator_rbf(x) #changes

            grad_disc = (fake_grad - real_grad)

            #earlier: no accumulation of grad_disc
            #different weighting factors can be done as an ablation.
            denoised = (dg/((i+1)**2))*torch.reshape(grad_disc, x.shape) #exponentially annealing it
            #denoised = (dg/(i+1))*torch.reshape(grad_disc, x.shape) #linearly annealing it

            '''
            #these changes can be ignored
            ##############################################################
            #changes
            #accumulation of grad_disc. Start with unit weight.
            grad_disc = (dg)*torch.reshape(grad_disc, x.shape)

            #langevin-type discriminator update goes here.
            denoised += grad_disc
            ##############################################################
            '''
            '''
            #gradient of discriminator - l2 norm
            C = denoised.view(denoised.shape[0], -1)
            self.disc_grad.append(torch.mean(torch.norm(C, dim = 1)).cpu().numpy())
            '''

            # score + grad. disc.
            x_prev += denoised
            f_dg = dg
            #torch.cuda.empty_cache()

            self.d_count += 1


        #add the 
        #add the gradient of discriminator part here as an additional term in the DDIM update equation. trainable discriminator can be done as an ablation experiment. 
        elif (dg > 0 and trainable == 1 and (not(discriminator) == False) and (t_cur > cut_off_value)):

            denoised = self.grad_disc_trainable(discriminator, x, t)
            x_prev += (dg/(i+1)) * denoised

            '''
            #gradient of discriminator - l2 norm
            C = denoised.view(denoised.shape[0], -1)
            self.disc_grad.append(torch.mean(torch.norm(C, dim = 1)).cpu().numpy())
            '''

        '''
        #TODO. in the image space.
        if self.decode:

            x_prev_dec = self.decoder(x_prev)

            x_prev_dec_copy = x_prev_dec
            x_prev_dec_var = torch.var(x_prev_dec_copy, dim = [1,2,3])

            self.mean_var_dec.append(torch.mean(x_prev_dec_var, dim = 0).cpu().numpy())

        x_prev_copy = x_prev
        x_prev_var = torch.var(x_prev_copy, dim = [1,2,3])

        self.mean_var.append(torch.mean(x_prev_var, dim=0).cpu().numpy())
        '''

        '''
        self.noise_variance.append(get_vec(x_prev))
        '''

        #only if TSS is true, implement the time-shift sampler. This part can be optimized
        if (tss and (t_cur > cut_off_value)):

            #DDPM parameters
            #create variables to store for DDPM configurations.
            alphas_ddpm = self.model.alphas_cumprod
            alphas_prev_ddpm = self.model.alphas_cumprod_prev
            one_minus_alphas_ddpm = self.one_minus_alphas_cumprod
            sqrt_one_minus_alphas_ddpm = self.model.sqrt_one_minus_alphas_cumprod
            sigmas_ddpm = self.ddim_sigmas_for_original_num_steps #since it is DDIM, this should always be zero

            #TODO. perform time-shift here. Update Ts
            #TODO - ensure the dimensions are the same as given in their github code.

            with torch.no_grad():
                
                img_list = x_prev

                #time consuming.
                if self.decode_var:
                    img_list = self.decoder(x_prev)

                Ts = self.apply_time_shift(img_list, t_cur, one_minus_alphas_ddpm, cut_off_value, window_size).tolist()

            '''
            x_prev_copy = x_prev

            #this was my part of the code. Don't use this. Use the module from original code associated with the paper. 
            x_prev_var = torch.var(x_prev_copy, dim = [1,2,3])

            mean_var = (torch.mean(x_prev_var, dim=0))
            dist = ((mean_var - one_minus_alphas_ddpm.to(device)) ** 2)

            Ts = torch.argmin(dist, dim=0).item()
                
            self.time_shifts.append(Ts)
            print("\nGoing to Ts: ", Ts)
            '''


        #make self.decode = 1 only when the sample evolution has to be studied. 
        self.decode = 0
        #study the evolution of predicted x0. Decode the predicted_x0 or x_prev
        if self.decode and (i%5 == 0):
            
            '''
            #saving x_pred - prediction - x0
            x_pred_sample = self.decoder(pred_x0)
            x_grid = x_pred_sample[:9]
            self.save_img_grid(x_grid, "") #save as needed
            '''

            '''
            #saving x_prev - updated sample
            x_prev_sample = self.decoder(x_prev)
            x_grid = x_prev_sample[:9]
            self.save_img_grid(x_grid, "") #save as needed
            '''
            
            #saving denoised - disc. only updated sample
            denoised_sample = self.decoder(denoised)
            x_grid = denoised_sample[:9]
            self.save_img_grid(x_grid, "") #save as needed

        return x_prev, pred_x0, x, Ts
