import os
import os.path as osp
import cv2
import time
import math
import numpy
import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
from torch.utils.data import Dataset

from PIL import Image
import imageio
from imageio.v2 import imread
from io import BytesIO
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import sympy
import random
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from scipy.optimize import minimize
from sympy import *
import copy
import json
import warnings
from absl import app, flags
import torch
#from torchmin import minimize
from tensorboardX import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image
from torchvision import transforms
from tqdm import trange
from tqdm import tqdm
import logging
from model import UNet
from score.both import get_inception_and_fid_score
from libs.iddpm import UNetModel,UNetModel4Pretrained,UNetModel4Pretrained2
from adan import Adan
from shampoo import Shampoo
FLAGS = flags.FLAGS
flags.DEFINE_bool('train', False, help='train from scratch')
flags.DEFINE_bool('eval', False, help='load ckpt.pt and evaluate FID and IS')
# UNet: IDDPM
flags.DEFINE_integer('in_channel', 3, help='input channel of UNet')
flags.DEFINE_integer('out_channel', 3, help='output channel of UNet')
flags.DEFINE_integer('ch', 128, help='base channel of UNet')
flags.DEFINE_integer('num_res_blocks', 3, help='# resblock in each level')
flags.DEFINE_integer('num_heads', 4, help='Multi-Heads for attention')
flags.DEFINE_integer('dims', 2, help='1,2,3 dims')
flags.DEFINE_multi_integer('ch_mult', [1, 2, 2, 2], help='channel multiplier')
flags.DEFINE_multi_integer('attn', [32 // 16, 32 // 8], help='add attention to these levels')
flags.DEFINE_float('dropout', 0.3, help='dropout rate of resblock')
flags.DEFINE_bool('use_scale_shift_norm', True, help='load ckpt.pt and evaluate FID and IS')
flags.DEFINE_string('exp_name', 'CIFAR10', help='name of experiment')

flags.DEFINE_integer('head_out_channels', 3, help='the final layer of High order noise network')
flags.DEFINE_enum('mode', 'simple', ['simple','complex'], help='the mode for honn modeling')

# Gaussian Diffusion
flags.DEFINE_float('beta_1', 1e-4, help='start beta value')
flags.DEFINE_float('beta_T', 0.02, help='end beta value')
flags.DEFINE_integer('T', 1000, help='total diffusion training noising steps')
flags.DEFINE_enum('sample_type', 'ddpm', ['ddpm', 'analyticdpm', 'gmddpm','ddim'], help='sample type for sampling')
flags.DEFINE_enum('mean_type', 'epsilon', ['xprev', 'xstart', 'epsilon'], help='predict variable')
flags.DEFINE_enum('var_type', 'fixedlarge', ['fixedlarge', 'fixedsmall'], help='variance type')
# Training
flags.DEFINE_float('lr', 1e-4, help='target learning rate')
flags.DEFINE_float('grad_clip', 1., help="gradient norm clipping")
flags.DEFINE_integer('total_steps', 500001, help='total training steps')
flags.DEFINE_integer('img_size', 32, help='image size')
flags.DEFINE_integer('warmup', 5000, help='learning rate warmup')
flags.DEFINE_integer('batch_size', 128, help='batch size')
flags.DEFINE_integer('num_workers', 4, help='workers of Dataloader')
flags.DEFINE_integer('noise_order', 1, help="the order of noise used to training")
flags.DEFINE_float('ema_decay', 0.9999, help="ema decay rate")
flags.DEFINE_bool('parallel', False, help='multi gpu training')
flags.DEFINE_string('pretrained_dir', './logs/iDDPM_CIFAR10_EPS/models/ckpt_1_450000.pt', help='log directory')

# Logging & Sampling
flags.DEFINE_string('logdir', './logs/iDDPM_CIFAR10_EPS', help='log directory')
flags.DEFINE_integer('sample_size', 64, "sampling size of images")
flags.DEFINE_integer('sample_step', 10000, help='frequency of sampling')
flags.DEFINE_integer('sample_steps', 1000, help='Sampling steps for generation stage')
flags.DEFINE_integer('t0', 1200, help='Sampling steps for generation stage')
# Evaluation
flags.DEFINE_integer('save_step', 50000, help='frequency of saving checkpoints, 0 to disable during training')
flags.DEFINE_integer('eval_step', 0, help='frequency of evaluating model, 0 to disable during training')
flags.DEFINE_integer('num_images', 50000, help='the number of generated images for evaluation')
flags.DEFINE_bool('fid_use_torch', False, help='calculate IS and FID on gpu')
flags.DEFINE_bool('time_shift', False, help='whether the noised data is from t=1')
flags.DEFINE_bool('rescale_time', True, help='adjust the maxmimum time to input the network is 1000')
flags.DEFINE_bool('nll_training', False, help='training the model to fit the noise.pow(a)')
flags.DEFINE_enum('noise_schedule', 'linear', ['linear','cosine'], help='the mode for honn modeling')
flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', help='FID cache')
flags.DEFINE_enum('model_type', 'noise', ['noise', 'nll'], help='variance type')
# Model Dir
flags.DEFINE_string('eps1_dir', './logs/iDDPM_CIFAR10_EPS/models/ckpt_1_300000.pt', help='eps1 model log directory')
flags.DEFINE_string('eps2_dir', './logs/iDDPM_CIFAR10_EPS2/models/ckpt_2_300000.pt', help='eps2 model log directory')
flags.DEFINE_string('eps3_dir', './logs/iDDPM_CIFAR10_complex_EPS3/models/ckpt_3_300000.pt', help='eps3 model log directory')
flags.DEFINE_string('eps4_dir', './logs/iDDPM_CIFAR10_complex_EPS4/models/ckpt_4_300000.pt', help='eps4 model log directory')

device = torch.device('cuda:0')

def _rescale_timesteps_ratio(N, flag):
    if flag:
        return 1000.0 / float(N)
    return 1

def statistics2str(statistics):
    #for k,v in statistics.items():
    #    print(v)
    return str({k: f'{v:.6g}' for k, v in statistics.items()})


def report_statistics(s, t, statistics):
    logging.info(f'[(s, r): ({s:.6g}, {t:.6g})] [{statistics2str(statistics)}]')


class TemporaryGrad(object):
    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        torch.set_grad_enabled(self.prev)

class SDEditData(Dataset):

    def __init__(self, cond_idx=1, filter_idx=0):
        self.image_path = "/home/aiops/allanguo/texttotext/image/"
        self.stroke_path = "/home/aiops/allanguo/texttotext/stroke/"

        #print(os.listdir(self.path))
        self.folders_image = [osp.join(self.image_path, d) for d in os.listdir(self.image_path)]
        self.folders_storke = [osp.join(self.stroke_path, d) for d in os.listdir(self.stroke_path)]
        print(len(self.folders_image))

        #self.images  = []
        #self.strokes = []
        #for i, folder in enumerate(self.folders_image):
        #    print(folder)
        #    im_path = [osp.join(folder, im) for im in os.listdir(folder)]
        #    st_path = [osp.join(self.folders_storke[i], im) for im in os.listdir(self.folders_storke[i])]
        #    self.images.extend(im_path)
        #    self.strokes.extend(st_path)

    def __len__(self):
        return len(self.folders_image)

    def __getitem__(self, index):
        #logging.info(index)
        #path = self.images[index]
        path = self.image_path + str(index+1)+'.png'
        im = imageio.v2.imread(path)
        if len(im.shape) == 2:
            im = np.tile(im[:, :, None], (1, 1, 3))
        else:
            im = im[:, :, :3]
        image_size = 64
        im = numpy.array(Image.fromarray(im).resize((image_size,image_size)))
        #label = np.eye(1000)[self.labels[index]]

        #path = self.strokes[index]
        path = self.stroke_path + str(index+1)+'.png'
        st = imageio.v2.imread(path)
        if len(st.shape) == 2:
            st = np.tile(st[:, :, None], (1, 1, 3))
        else:
            st = st[:, :, :3]
        image_size = 64
        st_raw = numpy.array(Image.fromarray(st).resize((image_size,image_size)))   
        st = st_raw / 256
        st = st + np.random.uniform(0, 1 / 256., st.shape)
        st = st.transpose((2, 0, 1))
        #st = im.transpose((2, 0, 1))
        st = (st-0.5)/0.5
        return im,st_raw,st

def infiniteloop(dataloader):
    while True:
        for x,y ,z in iter(dataloader):
            yield x,y,z

def solve_gmm(mean,cov,ske,kur,gt,timestep,report_dict):
    device= mean.device
    x0 = torch.unsqueeze((mean),dim=0)
    x1 = torch.unsqueeze((mean-1e-3),dim=0)
    #beta2 = torch.unsqueeze(((torch.ones(size=mean.size()).to(device))*(cov/gt.mean().item())),dim=0)
    beta = torch.unsqueeze((torch.ones(size=mean.size()).to(device)*0.999),dim=0)
    #x     = torch.cat([x0,x1,beta1,beta2],axis=0)
    #x0,x1,beta = solve_analytic(mean,cov,ske)
    x     = torch.cat([x0,x1,beta],axis=0)
    cov_g = gt
    def loss_f(tensor):
        #if solve_type =='pi':
        x0, x1, beta = tensor[0,...], tensor[1,...],tensor[2,...]
        #x0, x1, beta1 = tensor[0,...], tensor[1,...],tensor[2,...]
        beta = torch.clamp(beta, 0.1, 1.2)
        #beta2 = 1
        pi = 1/3
        E0 = (pi*x0 + (1-pi)*x1 - mean).pow(2)
        E1 = (pi*(x0**2+cov_g*beta)+(1-pi)*(x1**2+cov_g*beta) - (mean**2+cov)).pow(2)
        E2 = (pi*(x0**3+3*x0*cov_g*beta)+(1-pi)*(x1**3+3*x1*cov_g*beta) - ske).pow(2)
        if kur is not None:
            E3 = (pi*(x0**4+6*x0**2*cov_g+3*(cov_g)**2)+(1-pi)*(x1**4+6*x1**2*cov_g*beta+3*(cov_g*beta)**2) - kur).pow(2)
        else:
            E3 = 0
        #return ((E0+E1+E2)).mean(),((E0+E1+E2)).max(),E0.mean(),E2.mean()
        return ((E0+E1+E2)).mean(),E2.max()
    import time
    s = time.time()

    #def warmup_lr(step):
    #    return min(step, 10) / 10
    warm_up    = 18
    iterations = 120
    lr     = max(-0.10*((4000-timestep)**2/4000**2)+0.04,0.02)
    lr     = 0.02
    min_lr = 0.01

    warm_up_with_cosine_lr = lambda iter: (iter) / warm_up if iter <= warm_up \
        else max(0.5 * ( math.cos((iter - warm_up) /(iterations - warm_up) * math.pi) + 1), 
        min_lr / lr)

    """
    warm_up_with_cosine_lr = lambda iter: iter / opt.warm_iters if iter <= opt.warm_iters \
        else max(0.5 * ( math.cos((iter - opt.warm_iters) /(opt.iters - opt.warm_iters) * math.pi) + 1), 
        opt.min_lr / opt.lr)
    """

    with TemporaryGrad():
        #optimizer_solve = torch.optim.Adam([x],lr=lr,betas=(0.9, 0.95))
        #optimizer_solve = torch.optim.RMSprop([x],lr=lr,alpha=0.9)
        #optimizer_solve = torch.optim.Adagrad([x],lr=lr,weight_decay=1e-4)
        #optimizer_solve = torch.optim.AdamW([x],lr=lr,weight_decay=1e-4)
        optimizer_solve = Adan([x],lr=lr,betas=(0.9,0.92,0.92))
        #optimizer_solve = Shampoo([x],lr=lr,momentum=0.9)

        #sched = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, lr_lambda=warmup_lr)
        #scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, warm_up_with_cosine_lr)
        #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, warm_up_with_cosine_lr)
        #for step in range(200):
        pred_0,max_pre_E_2 = loss_f(x)
        for step in range(iterations):
            x.requires_grad = True
            #pred,pre_max,E_0_pre,E_2_pre = loss_f(x)
            pred,max_E_2 = loss_f(x)
            optimizer_solve.zero_grad()
            pred.backward()
            optimizer_solve.step()
            #scheduler.step()
            #x[3,...] = torch.clip(x[3,...],0,1)
        # return mu1 mu2 sigma1 sigma2
    e = time.time()
    #logging.info(e-s)
    report_dict['mean optimize'] = pred/pred_0
    report_dict['3-max optimize'] = max_E_2/max_pre_E_2
    #logging.info("the first output is {0} and final output rmse is {1}".format(pred_0,pred))
    #logging.info("the first output max is {0} and final output max is {1}".format(max_pre_0,pre_max))
    #logging.info('mean optimize {0},max optimize {1},onemoment optimize {2},threemoment optimize {2}'.format(pred/pred_0,pre_max/max_pre_0,E_0_pre/E_0,E_2_pre/E_2))
    #logging.info(x[0,...].mean())
    #logging.info(x[1,...].mean())
    return x[0,...], x[1,...],torch.clamp(x[2,...], 0.1, 1.2),report_dict

"""
Z11:mean
Z12:mean^2+cov
Z13: Ske
x  -> Z11 + 1.5874 (2. Z11^3 - 3. Z11 Z12 + Z13)^(1/3), 
y  -> 0.5 (2. Z11 - 1.5874 (2. Z11^3 - 3. Z11 Z12 + Z13)^(1/3)), 
Z1 -> (-1. Z11^2 + Z12 - 1.25992 (2. Z11^3 - 3. Z11 Z12 + Z13)^(2/3))
"""

def extract(v, t, x_shape,ratio=None):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    #if ratio:
    #    out = torch.ones(size=(200,1)).squeeze()
    #    for ele in range(ratio):
    #        out *= torch.gather(v, index=t-ele, dim=0).float()
    out = torch.gather(v, index=t, dim=0).float()
    #print
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


class GaussianDiffusionSampler(nn.Module):
    def __init__(self, eps1_model,eps2_model,eps3_model,eps4_model, beta_1, beta_T, sample_T,total_T = 4000,img_size=32,
                 sample_type='ddpm',time_shift=True,noise_schedule='linear',rescale_time=True,model_type='noise',t0=1250):
        assert sample_type in ['ddpm', 'analyticdpm', 'gmddpm','ddim']
        super().__init__()
        self.model      = eps1_model
        self.cov_model  = eps2_model
        self.eps3_model = eps3_model
        self.eps4_model = eps4_model

        self.T = sample_T
        self.total_T = total_T

        self.model_type = model_type

        self.rescale_ratio = _rescale_timesteps_ratio(total_T, rescale_time)
        logging.info('the scale ratio for timesteps is {0}'.format(self.rescale_ratio))
        self.t0 = t0

        self.t_list = [int(max(x/self.T * self.t0,0)) for x in reversed(range(self.T+1))]
        logging.info(self.t_list)

        self.img_size  = img_size
        self.sample_type = sample_type
        self.time_shift  = time_shift
        self.noise_schedule = noise_schedule
        if noise_schedule=='linear':
            self.register_buffer(
                'betas', torch.linspace(beta_1, beta_T, self.total_T).double())
            alphas = 1. - self.betas
            alphas_bar = torch.cumprod(alphas, dim=0)
            # calculations for diffusion q(x_t | x_{t-1}) and others
        else:
            logging.info(noise_schedule)
            g = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
            betas = [0.]
            for i in range(self.total_T):
                t1 = i / self.total_T
                t2 = (i + 1) / self.total_T
                betas.append(min(1 - g(t2) / g(t1), 0.999))
            betas = torch.tensor(np.array(betas))
            self.register_buffer(
                'betas', betas[1:])
            alphas= 1-betas
            alphas_bar = torch.cumprod(alphas[1:], dim=0)
            alphas = alphas[1:]
            #logging.info(alphas_bar)
            #logging.info(alphas_bar.size())
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:self.total_T]
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'one_minus_alphas_bar', (1.- alphas_bar))
        self.register_buffer(
            'sqrt_recip_one_minus_alphas_bar', 1./torch.sqrt(1.- alphas_bar))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar))
        self.register_buffer(
            'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.register_buffer(
            'posterior_var',
            self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
        
        # below: log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer(
            'posterior_log_var_clipped',
            torch.log(
                torch.cat([self.posterior_var[1:2], self.posterior_var[1:]])))
        
        self.register_buffer(
            'posterior_mean_coef1',
            torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar))
        self.register_buffer(
            'posterior_mean_coef2',
            torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar))

        #logging.info(alphas_bar[:4])
        #logging.info(alphas_bar[-10:])

    def diffusion_to_guide(self,x_0):
        #logging.info(x_0)
        t = x_0.new_ones([x_0.shape[0], ], dtype=torch.long)  * self.t0
        noise = torch.randn_like(x_0)
        x_g = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        return x_g

    # use eps to estimate one order moment
    def predict_xpre_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        a_t = extract(self.sqrt_alphas_bar, t, x_t.shape)
        if (t-self.ratio)[0]>=0:
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        mean_x0 = (x_t - sigma_t * eps)/a_t
        self.statistics['xt_mean'] = x_t.mean().item()
        self.statistics['eps_mean'] = eps.mean().item()
        self.statistics['unclip mean_x0_mean'] = mean_x0.mean().item()
        mean_x0 = mean_x0.clamp(-1.,1.)
        self.statistics['clip mean_x0_mean'] = mean_x0.mean().item()
        mean_xs = a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t + a_s*beta_ts/(sigma_t.pow(2)) * mean_x0
        mean_xs = mean_xs.clamp(-1000.,1000.)
        self.statistics['clip mean_xs_max'] = mean_xs.max().item()
        return mean_xs,mean_x0

    # use eps and eps2 to estimate one order moment
    def predict_xpre_cov_from_eps(self, x_t, t, eps):
        if self.time_shift:
            eps2 = self.cov_model(x_t, (t+1)*self.rescale_ratio)
        else:
            eps2 = self.cov_model(x_t, t*self.rescale_ratio)
        a_t  = extract(self.sqrt_alphas_bar, t, x_t.shape)

        if (t-self.ratio)[0]>=0:
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-self.ratio, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-t, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        sigma2_small = (sigma_s**2*beta_ts)/(sigma_t**2)
        cov_x0_pred = sigma_t.pow(2)/a_t.pow(2) * (eps2-eps.pow(2)) 
        self.statistics['noise1 mean'] = eps.mean().item()
        self.statistics['noise2 mean'] = eps2.mean().item()
        self.statistics['cov_x0_coeffi'] = (sigma_t.pow(2)/a_t.pow(2)).mean().item()
        self.statistics['unclip cov_x0_mean'] = cov_x0_pred.mean().item()
        cov_x0_pred = cov_x0_pred.clamp(0., 1.)
        self.statistics['clip cov_x0_mean'] = cov_x0_pred.mean().item()
        offset = a_s.pow(2)*beta_ts.pow(2)/sigma_t.pow(4) * cov_x0_pred
        self.statistics['offset'] = offset.mean().item()
        self.statistics['offset_max'] = offset.max().item()
        self.statistics['sigma2_small'] = sigma2_small.mean().item()
        model_var  = sigma2_small + offset
        model_var  = model_var.clamp(0., 1.)
        return model_var,eps2
    
    def ddpm_cov(self, x_t, t):
        sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
        if (t-self.ratio)[0]>=0:
            # \alpha_{t|s}
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        model_var1 = (sigma_s**2*beta_ts)/(sigma_t**2)
        self.statistics['sigma2_small'] = model_var1.mean().item()
        return model_var1

    # use eps and eps2 and eps3 to estimate one order moment
    def predict_xpre_3moment_from_eps(self, x_t, t, eps, eps2, mean):
        if self.time_shift:
            eps3 = self.eps3_model(x_t,(t+1)*self.rescale_ratio)
        else:
            eps3 = self.eps3_model(x_t, t*self.rescale_ratio)
        sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
        a_t     = extract(self.sqrt_alphas_bar, t, x_t.shape)
        if (t-self.ratio)[0]>=0:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        mean_x0 = (x_t - sigma_t * eps)/a_t
        twom_x0 = 1/(a_t.pow(2))*(x_t.pow(2)+sigma_t.pow(2)*eps2-2*x_t*sigma_t*eps)
        mean_x0 = mean_x0.clamp(-1., 1.)
        twom_x0 = twom_x0.clamp(0., 1.)

        if self.model_type == 'noise':
            skew_x0 = 1/(a_t.pow(3))*(x_t.pow(3) - sigma_t.pow(3)*eps3 - 3*x_t.pow(2)*sigma_t*eps + 3*x_t*sigma_t.pow(2)*eps2)
        else:
            skew_x0 = 1/(a_t.pow(3))*(x_t.pow(3) + eps3)
        self.statistics['unclip_x0_skew'] = skew_x0.mean().item()
        skew_x0 = torch.where(torch.abs(skew_x0)<=torch.abs(mean_x0),skew_x0,mean_x0)
        skew_x0 = skew_x0.clamp(-1., 1.)
        self.statistics['clip_x0_skew'] = skew_x0.mean().item()
        sigma2_small = (sigma_s**2*beta_ts)/(sigma_t**2)

        skew_xs_part1 = (a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t).pow(3)+\
            3*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t).pow(2)*(a_s*beta_ts/sigma_t.pow(2))*mean_x0 +\
            3*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t)*(a_s*beta_ts/sigma_t.pow(2)).pow(2)*twom_x0 +\
            (a_s*beta_ts/sigma_t.pow(2)).pow(3)*skew_x0
        skew_xs_part2 = 3*sigma2_small*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t + a_s*beta_ts/(sigma_t.pow(2)) * mean_x0)
        skew_xs  = skew_xs_part1+skew_xs_part2
        #part1 = 1/(a_ts**3) * ((x_t**3) - 3*(x_t**2)*eps*(beta_ts/sigma_t)+3*(x_t)*eps2*(beta_ts**2/sigma_t**2)-(beta_ts/sigma_t)**3*eps3)
        #part2 = 3*(sigma_s**2*beta_ts)/(sigma_t**2) * (1/a_ts) * (x_t-beta_ts/sigma_t*eps)
        #third_moment = part1 + part2 
        self.statistics['clip_xs_skew'] = skew_xs.mean().item()
        return skew_xs,eps3
        
    #@torch.no_grad()
    def p_mean_variance(self, x_t, t):
        # below: only log_variance is used in the KL computations or Analytic-DPM
        # Mean parameterization
        if self.time_shift:
            eps = self.model(x_t, (t+1)*self.rescale_ratio)
        else:
            eps = self.model(x_t, t*self.rescale_ratio)

        if self.sample_type == 'ddpm':   # the model predicts epsilon
            model_mean,mean_x0 = self.predict_xpre_from_eps(x_t, t, eps=eps)
            model_log_var = {
            # for fixedlarge, we set the initial (log-)variance like so to
            # get a better decoder log likelihood
            'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2],
                                               self.betas[1:]])),
            'fixedsmall': self.posterior_log_var_clipped,
            }['fixedsmall']
            if self.ratio == 1:
                model_log_var = extract(model_log_var, t, x_t.shape)
                return model_mean, torch.exp(model_log_var)
            else:
                model_log_var = self.ddpm_cov(x_t,t)
                return model_mean,model_log_var

        elif self.sample_type == 'analyticdpm':
            assert self.cov_model is not None
            model_mean,mean_x0 = self.predict_xpre_from_eps(x_t, t, eps=eps)
            model_var,eps2 = self.predict_xpre_cov_from_eps(x_t, t, eps)
            #print(model_var.mean())
            return model_mean, model_var

        elif self.sample_type == 'ddim':
            ## DDIM only need eps network
            model_mean,mean_x0 = self.predict_xpre_from_eps(x_t, t, eps=eps)
            sigma_t = 0
            sqrt_a_s = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sqrt_one_minus_a_s = 1/extract(self.sqrt_recip_one_minus_alphas_bar, t-self.ratio, x_t.shape)
            mean = sqrt_a_s * mean_x0 + sqrt_one_minus_a_s*eps
            #model_var,eps2 = self.predict_xpre_cov_from_eps(x_t, t, eps)
            #print(model_var.mean())
            return mean

        elif self.sample_type == 'gmddpm':
            assert self.eps3_model is not None
            mean,mean_x0     = self.predict_xpre_from_eps(x_t, t, eps=eps)
            cov,eps2 = self.predict_xpre_cov_from_eps(x_t, t, eps)
            #s1 = time.time()
            skeness,eps3  = self.predict_xpre_3moment_from_eps(x_t, t, eps, eps2, mean)
            #e1 = time.time()
            #logging.info('time for network {0}'.format(e1-s1))
            sigma2_small  = self.ddpm_cov(x_t,t)
            """
            if self.eps4_model is not None:
                fmoment  = self.predict_xpre_4moment_from_eps(x_t, t, eps,eps2,eps3)
            else:
                fmoment  = None
            """
            fmoment = None
            return mean,cov,skeness,fmoment,cov
        else:
            raise NotImplementedError(self.sample_type)

    def forward(self, x_T,K=3):
        x_g = self.diffusion_to_guide(x_T)
        x_t = x_g
        for number_of_ts, time_step in enumerate(self.t_list):
            self.statistics = {}
            t = x_t.new_ones([x_t.shape[0], ], dtype=torch.long)  * time_step

            if time_step > 0:
                noise = torch.randn_like(x_t).to(x_T.device)
                self.ratio = int(time_step-self.t_list[number_of_ts+1])

            else:
                    ######################## Branch 1 ########################
                if self.time_shift:
                    eps = self.model(x_t, (t+1)*self.rescale_ratio)
                else:
                    eps = self.model(x_t, t*self.rescale_ratio)
                a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)
                sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
                beta_ts = (1-a_ts**2)
                x_0_1 = 1/a_ts*( x_t - eps * beta_ts/sigma_t)

                ######################## Branch 2 ########################
                if self.time_shift:
                    eps = self.model(x_t2, (t+1)*self.rescale_ratio)
                else:
                    eps = self.model(x_t2, t*self.rescale_ratio)
                a_ts = extract(self.sqrt_alphas_bar, t, x_t2.shape)
                sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t2.shape))
                beta_ts = (1-a_ts**2)
                x_0_2 = 1/a_ts*( x_t2 - eps * beta_ts/sigma_t)
                
                    ######################## Branch 3 ########################
                if self.time_shift:
                    eps = self.model(x_t3, (t+1)*self.rescale_ratio)
                else:
                    eps = self.model(x_t3, t*self.rescale_ratio)
                a_ts = extract(self.sqrt_alphas_bar, t, x_t3.shape)
                sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t3.shape))
                beta_ts = (1-a_ts**2)
                x_0_3 = 1/a_ts*( x_t3 - eps * beta_ts/sigma_t)
                report_statistics(torch.tensor(0.), torch.tensor(time_step), self.statistics)
                return torch.clip(x_0_1, -1, 1),torch.clip(x_0_2, -1, 1),torch.clip(x_0_3, -1, 1)

            # sample with mixture of Gaussian
            if self.sample_type == 'gmddpm':

                # Clip-Var Gaussian Sample
                if time_step-self.ratio <= 0:
                    mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t, t=t)
                    var = cov
                    clip_pixel = 0.9
                    var_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2
                    self.statistics['unclip var_mean'] = var.mean().item()
                    var = var.clamp(0., var_threshold)
                    self.statistics['clip var_mean'] = var.mean().item()
                    self.statistics['threshold for var'] = var_threshold
                    x_t = mean + var**0.5 * noise
                    report_statistics(torch.tensor(max(time_step-self.ratio,0)), torch.tensor(time_step), self.statistics)

                    mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t2, t=t)
                    noise2 = torch.randn_like(x_t2).to(x_T.device)
                    x_t2 = mean + cov**0.5 * noise2

                    mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t3, t=t)
                    noise3 = torch.randn_like(x_t3).to(x_T.device)
                    x_t3 = mean + cov**0.5 * noise3
                    continue
                
                # the first steps to divided the guided image into three bra
                if number_of_ts == 0:
                    mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t, t=t)
                    self.statistics['moment error'] =  (torch.abs(tmoment-mean.pow(3)-3*mean*cov)).mean().item()
                    pre_cov = sigma2_small
                    mean1,mean2,beta,self.statistics = solve_gmm(mean,cov,tmoment,fmoment,pre_cov,time_step,self.statistics)
                    #var1  = mean1/mean1 * (pre_cov*beta).mean()
                    #var2  = mean1/mean1 * (pre_cov*beta).mean()
                    #var3  = mean1/mean1 * (pre_cov*beta).mean()
                    var1  = pre_cov*beta
                    var2  = pre_cov*beta
                    var3  = pre_cov*beta
                    ########################
                    mean_1 = torch.zeros(size=mean1.size()).to(mean1.device)
                    mean_2 = torch.zeros(size=mean1.size()).to(mean1.device)
                    mean_3 = torch.zeros(size=mean1.size()).to(mean1.device)
                    for n_count in range(mean1.size()[0]):
                        mean_1[n_count,...]= mean1[n_count,...]
                        mean_2[n_count,...]= mean1[n_count,...]
                        mean_3[n_count,...]= mean2[n_count,...]
                else:
                    mean_1 = torch.zeros(size=mean1.size()).to(mean1.device)
                    mean_2 = torch.zeros(size=mean1.size()).to(mean1.device)
                    mean_3 = torch.zeros(size=mean1.size()).to(mean1.device)
                    ######################## Branch 1 ########################
                    mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t, t=t)
                    self.statistics['moment error'] =  (torch.abs(tmoment-mean.pow(3)-3*mean*cov)).mean().item()
                    pre_cov = sigma2_small
                    mean1,mean2,beta,self.statistics = solve_gmm(mean,cov,tmoment,fmoment,pre_cov,time_step,self.statistics)
                    #var1  = mean1/mean1 * (pre_cov*beta).mean()
                    var1  = pre_cov*beta
                    for n_count in range(mean.size()[0]):
                        if (torch.rand(size=(1,1))<torch.tensor(1/3))[0][0]:
                            mean_1[n_count,...]= mean1[n_count,...]
                        else:
                            mean_1[n_count,...]= mean2[n_count,...]
                    

                    ######################## Branch 2 ########################
                    mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t2, t=t)
                    pre_cov = sigma2_small
                    mean1,mean2,beta,self.statistics = solve_gmm(mean,cov,tmoment,fmoment,pre_cov,time_step,self.statistics)
                    #var2  = mean1/mean1 * (pre_cov*beta).mean()
                    var2  = pre_cov*beta
                    for n_count in range(mean.size()[0]):
                        if (torch.rand(size=(1,1))<torch.tensor(1/3))[0][0]:
                            mean_2[n_count,...]= mean1[n_count,...]
                        else:
                            mean_2[n_count,...]= mean2[n_count,...]

                    ######################## Branch 3 ########################
                    mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t3, t=t)
                    self.statistics['moment error'] =  (torch.abs(tmoment-mean.pow(3)-3*mean*cov)).mean().item()
                    pre_cov = sigma2_small
                    mean1,mean2,beta,self.statistics = solve_gmm(mean,cov,tmoment,fmoment,pre_cov,time_step,self.statistics)
                    #var3  = mean1/mean1 * (pre_cov*beta).mean()
                    var3  = pre_cov*beta
                    for n_count in range(mean.size()[0]):
                        if (torch.rand(size=(1,1))<torch.tensor(1/3))[0][0]:
                            mean_3[n_count,...]= mean1[n_count,...]
                        else:
                            mean_3[n_count,...]= mean2[n_count,...]

                self.statistics['Gaussian_cov'] = pre_cov.mean().item()
                #self.statistics['choosend_cov'] = var.mean().item()
                #self.statistics['choosend_cov_min'] = var.min().item()
                x_t = mean_1 + var1**0.5 * noise
                noise2 = torch.randn_like(x_t).to(x_T.device)
                noise3 = torch.randn_like(x_t).to(x_T.device)
                x_t2 = mean_2 + var2**0.5 * noise2
                x_t3 = mean_3 + var3**0.5 * noise3

            elif self.sample_type == 'ddim':
                mean = self.p_mean_variance(x_t=x_t, t=t)
                x_t  = mean
                if number_of_ts == len(self.t_list)-2:
                    x_t  = x_t
                    x_t2 = x_t
                    x_t3 = x_t

            # sample with DDPM/Imperfect Analytic-DPM (Bao et al. (2022))
            else:
                if time_step-self.ratio <= 0:
                    mean, var = self.p_mean_variance(x_t=x_t, t=t)
                    clip_pixel = 1
                    var_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2
                    self.statistics['unclip var_mean'] = var.mean().item()
                    var = var.clamp(0., var_threshold)
                    self.statistics['clip var_mean'] = var.mean().item()
                    self.statistics['threshold for var'] = var_threshold
                    x_t = mean + var**0.5 * noise

                    noise2 = torch.randn_like(x_t).to(x_T.device)
                    noise3 = torch.randn_like(x_t).to(x_T.device)
                    mean, var = self.p_mean_variance(x_t=x_t2, t=t)
                    x_t2 = mean + var**0.5 * noise2

                    mean, var = self.p_mean_variance(x_t=x_t3, t=t)
                    x_t3 = mean + var**0.5 * noise3
                    continue
                if number_of_ts == 0:
                    mean, var = self.p_mean_variance(x_t=x_t, t=t)
                    #logging.info('var={}'.format(var))
                    x_t = mean + var**0.5 * noise
                    noise2 = torch.randn_like(x_t).to(x_T.device)
                    noise3 = torch.randn_like(x_t).to(x_T.device)
                    x_t2 = mean + var**0.5 * noise2
                    x_t3 = mean + var**0.5 * noise3
                else:
                    mean, var = self.p_mean_variance(x_t=x_t, t=t)
                    #logging.info('var={}'.format(var))
                    x_t = mean + var**0.5 * noise

                    noise2 = torch.randn_like(x_t).to(x_T.device)
                    noise3 = torch.randn_like(x_t).to(x_T.device)
                    mean, var = self.p_mean_variance(x_t=x_t2, t=t)
                    x_t2 = mean + var**0.5 * noise2
                    mean, var = self.p_mean_variance(x_t=x_t3, t=t)
                    x_t3 = mean + var**0.5 * noise3
            # logging the var-related result
            report_statistics(torch.tensor(max(time_step-self.ratio,0)), torch.tensor(time_step), self.statistics)

device = torch.device('cuda:0')

def Sample_parallel(net_sampler):
    save_file = './sample/sdedit/'+str(FLAGS.sample_type)+str(FLAGS.sample_steps)+str(FLAGS.t0)+'/'
    images_real,image_stroke = [],[]
    images_gen1,images_gen2,images_gen3  = [],[],[]
    dataset = SDEditData()
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=FLAGS.batch_size, shuffle=False,
        num_workers=4, drop_last=True)
    datalooper = infiniteloop(dataloader) 

    try:
        os.mkdir(save_file)
    except:
        pass

    for i in trange(0, FLAGS.num_images, FLAGS.batch_size):
        batch_size = min(FLAGS.batch_size, FLAGS.num_images - i)
        x0_1,st_1,y0_1 = next(datalooper)
        x0_1 = x0_1[:batch_size,...]
        st_1 = st_1[:batch_size,...]
        x_T  = y0_1[:batch_size,...].float()

        image_real = x0_1 / 256
        image_real = np.array(image_real.cpu()) + np.random.uniform(0, 1 / 256., image_real.shape)
        #logging.info(image_real.shape)
        image_real = image_real.transpose((0,3, 1, 2))
        images_real.append(torch.tensor(image_real))

        image_stroke.append((x_T.cpu()+1)/2)
        batch_images_1,batch_images_2,batch_images_3= net_sampler(x_T.to(device))
        batch_images1 = (batch_images_1.cpu()+1)/2
        batch_images2 = (batch_images_2.cpu()+1)/2
        batch_images3 = (batch_images_3.cpu()+1)/2
        images_gen1.append(batch_images1)
        images_gen2.append(batch_images2)
        images_gen3.append(batch_images3)
        for kkk in range(batch_images1.size()[0]):
            save_image(batch_images1[kkk,...], save_file+str(i+kkk)+'_1.png')
            save_image(batch_images2[kkk,...], save_file+str(i+kkk)+'_2.png')
            save_image(batch_images3[kkk,...], save_file+str(i+kkk)+'_3.png')
        grid = make_grid(batch_images1[100:164,...])
        path = os.path.join(
            save_file,'sample.png')
        save_image(grid, path)

        grid = (make_grid(x_T[100:164,...]+1))/2
        path = os.path.join(
                save_file,'stroke.png')
        save_image(grid, path)

    images1_f = torch.cat(images_gen1, dim=0)
    images2_f = torch.cat(images_gen2, dim=0)
    images3_f = torch.cat(images_gen3, dim=0)
    images = torch.cat([images1_f,images2_f,images3_f], dim=0).numpy()
    print(images.shape)
    (IS, IS_std), FID = get_inception_and_fid_score(
        images, FLAGS.fid_cache, num_images=3*FLAGS.num_images,
        use_torch=FLAGS.fid_use_torch, verbose=True)
    #print(IS)
    print(FID)
    MSE = torch.nn.MSELoss()
    storke_f = torch.cat(image_stroke, dim=0).numpy()
    images_f = torch.cat(images_real, dim=0).numpy()


    batch_norm1 = torch.unsqueeze(((torch.tensor(storke_f)-torch.tensor(images1_f)).pow(2)).reshape(torch.tensor(storke_f).size()[0],-1).mean(dim=1),dim=0)
    batch_norm2 = torch.unsqueeze(((torch.tensor(storke_f)-torch.tensor(images2_f)).pow(2)).reshape(torch.tensor(storke_f).size()[0],-1).mean(dim=1),dim=0)
    batch_norm3 = torch.unsqueeze(((torch.tensor(storke_f)-torch.tensor(images3_f)).pow(2)).reshape(torch.tensor(storke_f).size()[0],-1).mean(dim=1),dim=0)
    batch_norm = torch.cat([batch_norm1,batch_norm2,batch_norm3],dim=0)
    print(batch_norm.min(dim=0).values.mean())

    batch_norm1 = torch.unsqueeze(((torch.tensor(images_f)-torch.tensor(images1_f)).pow(2)).reshape(torch.tensor(storke_f).size()[0],-1).mean(dim=1),dim=0)
    batch_norm2 = torch.unsqueeze(((torch.tensor(images_f)-torch.tensor(images2_f)).pow(2)).reshape(torch.tensor(storke_f).size()[0],-1).mean(dim=1),dim=0)
    batch_norm3 = torch.unsqueeze(((torch.tensor(images_f)-torch.tensor(images3_f)).pow(2)).reshape(torch.tensor(storke_f).size()[0],-1).mean(dim=1),dim=0)
    batch_norm = torch.cat([batch_norm1,batch_norm2,batch_norm3],dim=0)
    print(batch_norm.min(dim=0).values.mean())
    #batch_norm = ((x-y).pow(2)).reshape(x.size()[0],-1).mean(dim=1)
    #mse1 = MSE(torch.tensor(images1_f), torch.tensor(storke_f))
    #mse2 = MSE(torch.tensor(images2_f), torch.tensor(storke_f))
    #mse3 = MSE(torch.tensor(images3_f), torch.tensor(storke_f))
    #mse_list = [mse1, mse2, mse3]
    #mse_list.sort()
    #print(mse_list[0])





'''
    x0_1,st_1,y0_1 = next(datalooper)
    x_T = y0_1.float()
    print(x0_1.shape)
    print(x_T)

    x0_1 = x0_1[0,...].cpu().numpy()
    x0_n = x0_1.astype(np.uint8)
    x0 = Image.fromarray(x0_n).resize((256,256))

    st_1 = st_1[0,...].cpu().numpy()
    st_n = st_1.astype(np.uint8)
    y0 = Image.fromarray(st_n).resize((256,256))
    #print(x0_n.shape)
    #with torch.no_grad():
    #for i in trange(0, FLAGS.num_images, FLAGS.batch_size):
    #batch_size = min(FLAGS.batch_size, FLAGS.num_images - i)
    #x_T = torch.randn((batch_size, 3, FLAGS.img_size, FLAGS.img_size))

    batch_images_1,batch_images_2,batch_images_3= net_sampler(x_T.to(device))

    batch_images1 = (batch_images_1.cpu()+1)/2
    batch_images2 = (batch_images_2.cpu()+1)/2
    batch_images3 = (batch_images_3.cpu()+1)/2


    #save_image(batch_images1[0,...], save_file+str(1)+'.png')
    #save_image(batch_images2[0,...], save_file+str(2)+'.png')
    #save_image(batch_images3[0,...], save_file+str(3)+'.png')
    print(batch_images1[0,...].permute(1, 2, 0).size())
    batch_images1 = Image.fromarray(np.uint8(256*(batch_images1[0,...].permute(1, 2, 0)))).resize((256,256))
    batch_images2 = Image.fromarray(np.uint8(256*(batch_images2[0,...].permute(1, 2, 0)))).resize((256,256))
    batch_images3 = Image.fromarray(np.uint8(256*(batch_images3[0,...].permute(1, 2, 0)))).resize((256,256))
    batch_images1.save(save_file+str(1)+'.png')
    batch_images2.save(save_file+str(2)+'.png')
    batch_images3.save(save_file+str(3)+'.png')
    x0.save(save_file+'real.png')
    y0.save(save_file+'stroke.png')

        #images.append((batch_images + 1) / 2)
        #for kkk in range(batch_images.size()[0]):
        #    single_image = (batch_images[kkk,...]+1)/2
        #    try:
        #        save_image(single_image, save_file+str(i+kkk)+'.png')
        #    except:
        #        os.mkdir(save_file)
        #        save_image(single_image, save_file+str(i+kkk)+'.png')
        #grid = (make_grid(batch_images[:64,...]) + 1) / 2
        #path = os.path.join(
        #    save_file,'sample.png')
        #save_image(grid, path)
    #images = torch.cat(images, dim=0).numpy()
    #print(images.shape)
    #(IS, IS_std), FID = get_inception_and_fid_score(
    #    images, FLAGS.fid_cache, num_images=FLAGS.num_images,
    #    use_torch=FLAGS.fid_use_torch, verbose=True)
    #print(IS)
    #print(FID)
'''


def eval():
    eps1_model = UNetModel4Pretrained2(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
        channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
        head_out_channels=FLAGS.head_out_channels,mode='complex')
    ckpt1 = torch.load('/home/aiops/allanguo/cifar/logs/imagenet64_ema_eps_eps2_pretrained_complex_350000.ckpt.pth')
    """
    else:
        eps1_model = UNetModel(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
        channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,)
        #ckpt1 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
        ckpt1 = torch.load('./logs/iDDPM_CIFAR10_cos_EPS1/models/ckpt_1_1200000.pt')['ema_model']
    """
    eps1_model.load_state_dict(ckpt1)
    eps1_model.eval()    


    # Sampling for Extended Analytic DPM
    if FLAGS.sample_type == 'analyticdpm' or FLAGS.sample_type == 'gmddpm':
        print('Sample IS not using DDPM')
        eps2_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='complex')
        #if FLAGS.time_shift:
        #    ckpt2 = torch.load('/home/aiops/allanguo/cifar/logs/imagenet64_ema_eps_eps2_pretrained_complex_350000.ckpt.pth')
        #else:
        #    ckpt2 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
        ckpt2 = torch.load('/home/aiops/allanguo/cifar/logs/imagenet64_ema_eps_eps2_pretrained_complex_350000.ckpt.pth')
        eps2_model.load_state_dict(ckpt2)
        eps2_model.eval()

        if FLAGS.sample_type == 'gmddpm':
            eps3_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
                    channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
                    head_out_channels=FLAGS.head_out_channels,mode='complex')
            if FLAGS.model_type == 'noise':
                ckpt3_path = './logs/iDDPM_Imagenet_EPS3/models/ckpt_3_1700000.pt'
            else:
                ckpt3_path = './logs/iDDPM_Imagenet_EPS3_nll/models/ckpt_3_1400000.pt'
            ckpt3 = torch.load(ckpt3_path)
            eps3_model.load_state_dict(ckpt3['ema_model'])
            logging.info(ckpt3_path)
            eps3_model.eval()

            #eps4_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            #        channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            #        head_out_channels=FLAGS.head_out_channels,mode='complex')
            #ckpt4 = torch.load('./logs/iDDPM_CIFAR10_EPS4_1/models/ckpt_4_350000.pt')
            #eps4_model.load_state_dict(ckpt4['ema_model'])
            eps4_model= None
        else:
            eps3_model = None
            eps4_model = None
    else:
        eps2_model = None
        eps3_model = None
        eps4_model = None
    #print(eps2_model)
    print(FLAGS.time_shift)
    print(FLAGS.attn)
    print(FLAGS.ch_mult)

    net_sampler = GaussianDiffusionSampler(
        eps1_model,eps2_model,eps3_model,eps4_model,
        FLAGS.beta_1, FLAGS.beta_T, FLAGS.sample_steps,FLAGS.T, FLAGS.img_size,
        FLAGS.sample_type,FLAGS.time_shift,FLAGS.noise_schedule,FLAGS.rescale_time,FLAGS.model_type,FLAGS.t0).to(device)
        
    if FLAGS.parallel:
        net_sampler = torch.nn.DataParallel(net_sampler)
    with torch.no_grad():
        Sample_parallel(net_sampler)

def main(argv):
    warnings.simplefilter(action='ignore', category=FutureWarning)
    eval()

app.run(main)