#%% main.py
import run_lib
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import logging
import os
import tensorflow as tf

class flags:
    def __init__(self, workdir=None, mode=None):
        # from configs.homotopy.svhn_ddpmpp import get_config
        # from configs.homotopy.cifar10_ddpmpp import get_config
        from configs.homotopy.celeba_ddpmpp import get_config

        self.config = get_config()
        self.workdir = workdir
        self.mode = mode
        # self.eval_folder = "eval"

# save_dir = "homotopy_planar_cifar10_batchKFZs_Cov_Norm_Reg_sd0.01_sp1_d1_te1_et1e-4_Adam_lr1e-3_wd3e-5_a1e-3_b0.9_gc3_wu0_B256_b256_n1_VAEBMWRN-gelu"
save_dir = "homotopy_planar_celeba_batchKFZs_Cov_Norm_Reg_sd0.01_sp1_d1_te1_et1e-4_Adam_lr5e-5_wd3e-5_a1e-3_b0.5_gc3_wu0_B128_b128_n1_VAEBMWRN-gelu"

FLAGS = flags(workdir=save_dir, mode="eval")
tf.io.gfile.makedirs(FLAGS.workdir)
# Set logger so that it outputs to both console and file
# Make logging work for both disk and Google Cloud Storage
gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'w')
handler = logging.StreamHandler(gfile_stream)
formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
handler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel('INFO')

## %% run_lib.py
import gc
import io
import os
import time
import copy

import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
# Keep the import below for registering all model definitions
from models import resmlp, wideresnet, hybrid
import losses
import sampling
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets
import evaluation
import likelihood
import methods
from absl import flags
import torch
torch.cuda.empty_cache()
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint
import datasets_utils.celeba

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

#%% run_lib.py - train
config = FLAGS.config
workdir = FLAGS.workdir
# config.device = torch.device('cpu')

# Create directories for experimental logs
sample_dir = os.path.join(workdir, "samples")
tf.io.gfile.makedirs(sample_dir)

tb_dir = os.path.join(workdir, "tensorboard")
tf.io.gfile.makedirs(tb_dir)
writer = tensorboard.SummaryWriter(tb_dir)

# Initialize model.
net = mutils.create_model(config)
print("Number of model parameters: %.5e" %sum(p.numel() for p in net.parameters()))
ema = ExponentialMovingAverage(net.parameters(), decay=config.model.ema_rate)
optimizer = losses.get_optimizer(config, net.parameters())
print("Weight decay: %.5f" %optimizer.param_groups[0]['weight_decay'])
state = dict(optimizer=optimizer, model=net, ema=ema, step=0, sigma_max=config.training.sigma_max, t_eval=0)

# Create checkpoints directory
print(workdir)
checkpoint_dir = os.path.join(workdir, "checkpoints")
# Intermediate checkpoints to resume training after pre-emption in cloud environments
# checkpoint_meta_dir = os.path.join(workdir, "checkpoints", "checkpoint_1390000.pth")
checkpoint_meta_dir = os.path.join(workdir, "checkpoints", "checkpoint_1120000.pth")
# checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
tf.io.gfile.makedirs(checkpoint_dir)
tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir))
# Resume training when intermediate checkpoints are detected
state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
initial_step = int(state['step'])

# Build data iterators
if config.data.dataset == 'CELEBA':
  # I cannot load CelebA from tfds loader. So I write a pytorch loader instead.
  train_ds, eval_ds = datasets_utils.celeba.get_celeba(config)
else:
  train_ds, eval_ds, _ = datasets.get_dataset(config, uniform_dequantization=config.data.uniform_dequantization)

train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
# Create data normalizer and its inverse
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
# Setup methods
if config.training.sde.lower() == 'vpsde':
  sde = methods.VPSDE(config=config, beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
  sde = methods.subVPSDE(config=config, beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
  sde = methods.VESDE(config=config, sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
  sampling_eps = 1e-5
elif config.training.sde.lower() == 'poisson':
  # PFGM
  sde = methods.Poisson(config=config)
  sampling_eps = config.sampling.z_min
elif config.training.sde.lower() == 'homotopy':
  # GrAPH
  sde = methods.Homotopy(config=config)
  sampling_eps = config.sampling.eps_z
else:
  raise NotImplementedError(f"Method {config.training.sde} unknown.")

# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
reduce_mean = config.training.reduce_mean
method_name = config.training.sde.lower()
# train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
#                                     reduce_mean=reduce_mean, method_name=method_name)
# eval_step_fn = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn,
#                                   reduce_mean=reduce_mean, method_name=method_name)

# Building sampling functions
if config.training.snapshot_sampling:
  sampling_shape = (25, config.data.num_channels,
                    config.data.image_size, config.data.image_size)
  sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)

num_train_steps = config.training.n_iters

# In case there are multiple hosts (e.g., TPU pods), only log to host 0
# logging.info("Starting training loop at step %d." % (initial_step,))
print(initial_step)
for step in range(initial_step, num_train_steps + 1):
  break
# Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy.
if config.data.dataset == 'CELEBA':
  try:
    batch = next(train_iter)[0].cuda()
  except StopIteration:
    train_iter = iter(train_ds)
    batch = next(train_iter)[0].cuda()
else:
  batch = torch.from_numpy(next(train_iter)['image']._numpy()).to(config.device).float()
  batch = batch.permute(0, 3, 1, 2)
batch = scaler(batch)

# #%% losses.py
# import torch
# import torch.optim as optim
# import numpy as np
# from scipy import integrate
# import math
# import time
# from models import utils as mutils
# from methods import VESDE, VPSDE
# from models import utils_poisson

# from IPython.display import clear_output
# import matplotlib.pyplot as plt

# def get_diag_jac(func, inp):
#     inp.requires_grad = True
#     inp = inp.unsqueeze(1).repeat(1, inp.shape[-1], 1)
#     out = torch.diagonal(func(inp), 0, -2, -1)
#     jac = torch.autograd.grad(out, inp, torch.ones_like(out))[0]
#     grad = torch.diagonal(jac, 0, -2, -1)
#     return grad

# scaler = datasets.get_data_scaler(config)
# inverse_scaler = datasets.get_data_inverse_scaler(config)

# # Execute one training step
# model = state['model']
# optimizer = state['optimizer']
# # optimizer.zero_grad()
# train = True
# continuous = True
# eps = 1e-5
# sample_bool = False
# ind_bool = True
# # torch.manual_seed(123456)

# # Get configs
# data_dim = sde.config.data.channels * sde.config.data.image_size * sde.config.data.image_size
# batch_size = sde.config.training.batch_size
# ensemble_size = sde.config.training.small_batch_size
# assert batch_size == ensemble_size
# num_particles = sde.config.training.num_particles
# sample_size = sde.config.training.sample_size

# # Get the mini-batch
# batch = batch.reshape(batch_size, -1)
# samples_batch = batch[:ensemble_size, :data_dim]

# # Get prior sigma
# # sigma_prior = math.sqrt(sde.config.data.image_size * sde.config.data.image_size * sde.config.data.channels)
# # sigma_prior = math.sqrt(sde.config.data.image_size) * sde.config.data.channels
# # sigma_prior = math.sqrt(sde.config.data.image_size * sde.config.data.channels)
# mean_prior = sde.config.training.mean_prior
# sigma_prior = 1
# divisor = 1

# # Change-of-variable: z = -ln(t)
# z_t = lambda t: -torch.log(t); t_z = lambda z: torch.exp(-z)

# with torch.no_grad():
#   # Get dec (conditional likelihood) sigma
#   sigma_min = 0.01
#   sigma_range = (sigma_min, sigma_min)
#   # sigma_range = (sde.config.training.sigma_min, sde.config.training.sigma_max)
#   if sigma_range[0] == sigma_range[1]:
#     sigma_dec = sigma_range[0] * torch.ones(ensemble_size, num_particles, 1).to(samples_batch.device)
#   else:
#     eps_sigma = sde.config.training.eps_sigma
#     eps_log_range = (np.log(sigma_range[0] + eps_sigma), np.log(sigma_range[1] + eps_sigma))
#     eps_log_random = eps_log_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (eps_log_range[1] - eps_log_range[0])
#     sigma_dec = torch.exp(eps_log_random) - eps_sigma
#     if sde.config.training.invert_sigma: sigma_range[1] - sigma_dec

#   # Sample t from log-uniform distribution
#   t_range = (0, sde.config.training.t_end)
#   if sample_bool:
#     t_samples = t_range[1] * torch.ones(ensemble_size, num_particles, 1).to(samples_batch.device)
#   else:
#     # if sde.config.training.eps_min == sde.config.training.eps_max:
#     #   eps_t = sde.config.training.eps_min
#     # else:
#     eps_t_min = 1e-4
#     eps_t_range = (eps_t_min, eps_t_min)
#     eps_t = eps_t_range[0] + torch.rand(1).item() * (eps_t_range[1] - eps_t_range[0])
#     print("%e" % eps_t)
#     if eps_t == math.inf:
#         t_samples = t_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (t_range[1] - t_range[0])
#     else:
#       t_log_range = (np.log(t_range[0] + eps_t), np.log(t_range[1] + eps_t))
#       t_log_random = t_log_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (t_log_range[1] - t_log_range[0])
#       t_samples = torch.exp(t_log_random) - eps_t
#       if sde.config.training.invert_t: t_samples = t_range[0] - t_samples

#   # z_range = (0, z_max)
#   # z_log_range = (np.log(z_range[0] + eps_z), np.log(z_range[1] + eps_z))
#   # log_random = z_log_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (z_log_range[1] - z_log_range[0])
#   # samples_z = torch.exp(log_random) - eps_z

#   # Compute enc (posterior) mean and var
#   var_prior = sigma_prior**2
#   var_dec = t_range[1] * sigma_dec**2
#   # mean_enc = (t_samples * var_prior + mean_prior * var_dec) / (var_dec + t_samples * var_prior)
#   mean_enc_y = (t_samples * var_prior) / (var_dec + t_samples * var_prior)
#   mean_enc_prior = var_dec / (var_dec + t_samples * var_prior)
#   var_enc = var_prior * var_dec / (var_dec + t_samples * var_prior)

#   # print(t_samples.min(), t_samples.max())
#   # print(var_enc.min(), var_enc.max())
#   # plt.plot(sigma_dec.squeeze().cpu(), 4*np.ones_like(sigma_dec.squeeze().cpu()), '.r')
#   # plt.plot(t_samples.squeeze().cpu(), 3*np.ones_like(t_samples.squeeze().cpu()), '.k')
#   # plt.plot(mean_enc_y.squeeze().cpu(), 2*np.ones_like(mean_enc_y.squeeze().cpu()), '.g')
#   # plt.plot(mean_enc_prior.squeeze().cpu(), 1*np.ones_like(mean_enc_prior.squeeze().cpu()), '.g')
#   # plt.plot((1/sigma_prior)*np.sqrt(var_enc.squeeze().cpu()), 0*np.ones_like(var_enc.squeeze().cpu()), '.b')
#   # plt.show()

#   # Perturb data samples with gaussians
#   gaussians_x = torch.randn(ensemble_size, num_particles, data_dim).to(samples_batch.device)
#   # samples_x = mean_enc * samples_batch.unsqueeze(dim=1)
#   samples_x = mean_enc_y * samples_batch.unsqueeze(dim=1) + mean_enc_prior * mean_prior
#   samples_x += gaussians_x * torch.sqrt(var_enc)
#   samples_x = samples_x.reshape(ensemble_size * num_particles, -1)
#   print(samples_x.min(), samples_x.max(), samples_x.mean())
#   t_samples = t_samples.reshape(ensemble_size * num_particles, -1)
#   if sde.config.training.augment_z: 
#     z_samples = z_t(t_samples)
#     samples_x = torch.cat([samples_x, z_samples], dim=-1)

#   Const = data_dim * var_enc + (1 - mean_enc_y).pow(2) * samples_batch.unsqueeze(dim=1).pow(2).sum(dim=-1, keepdim=True)
#   Const = Const.reshape(ensemble_size * num_particles)
#   const = var_enc + (1 - mean_enc_y).pow(2) * samples_batch.unsqueeze(dim=1).pow(2)
#   const = const.reshape(ensemble_size * num_particles, -1)

#   if sample_bool:
#     samples_s, nfe = sampling_fn(model, state, sample_size=sample_size, method='RK23', eps=1e-2, rtol=1e-2, atol=1e-2, inverse_scale=False)
#     print("step: %d, nfe: %d" % (state['step'], nfe))
#     # print(samples_x.min(), samples_x.max(), samples_s.min(), samples_s.max())
#     samples_x[:sample_size] = samples_s

# # with torch.enable_grad():
# #   # Get model function
# #   net_fn = mutils.get_predict_fn(sde, model, train=train, continuous=continuous)

# #   # Predict scalar potential
# #   samples_x.requires_grad = True
# #   samples_net = samples_x
# #   psi = net_fn(samples_net)

# #   # Normalize potential by its mean
# #   # psi -= psi.mean(dim=0, keepdim=True)

# #   # Compute (backpropagate) N-dimensional Poisson field (gradient)
# #   drift = torch.autograd.grad(psi.squeeze(dim=-1), samples_x, torch.ones_like(psi.squeeze(dim=-1)), create_graph=True)[0]

# # # Compute drift norm
# # if sde.config.training.augment_z:
# #   x_drift, z_drift = torch.split(drift, [data_dim, 1], dim=-1)
# #   norm_x = x_drift.pow(2).sum(dim=-1)
# #   z_drift = z_drift + torch.reciprocal(t_samples)
# #   norm_z = z_drift.pow(2).sum(dim=-1)
# #   Norm = norm_x + norm_z
# # else:
# #   Norm = drift.pow(2).sum(dim=-1)

# with torch.no_grad():
#   # Compute Normalized Innovation Squared (Gamma)
#   if sde.config.training.augment_z: 
#     z_batch = z_t(t_range[1] * torch.ones(batch_size, 1).to(batch.device))
#     batch = torch.cat([batch, z_batch], dim=-1)

#   if ind_bool:
#     distance = batch.unsqueeze(dim=1) - samples_x
#   else:
#     distance = batch - samples_x

#   if sde.config.training.augment_z: 
#     distance_x, distance_z = torch.split(distance, [data_dim, 1], dim=-1)
#     innovation_x = distance_x.pow(2).sum(dim=-1)
#     innovation_z = distance_z.pow(2).sum(dim=-1)
#     innovation = innovation_x + innovation_z
#   else:
#     innovation = distance.pow(2).sum(dim=-1)
#   # innovation = innovation.sqrt()
  
#   if ind_bool:
#     Gamma = innovation.mean(dim=0)
#     Gamma -= Gamma.mean(dim=0, keepdim=True)
#   else:
#     Gamma = innovation - Const.mean(dim=0)
  
#   multiplier = 1
#   # multiplier *= t_samples.squeeze().sqrt()
#   # multiplier *= math.log(sde.config.training.t_end / 1e-5)

#   # divisor = 9
#   # divisor *= sigma_prior ** 2
#   divisor *= math.sqrt(data_dim)
#   # divisor *= state['sigma_max']**2

#   Gamma = Gamma * multiplier / (divisor + eps)

# # # Compute sample correlation between potential and NIS
# # Cov = torch.sum(Gamma * psi, dim=0) / (ensemble_size * num_particles - 1)
# # Cov = Cov.sum(dim=-1)
# # # vars = torch.sum(Gamma.pow(2), dim=-1) * torch.sum(psi.detach().pow(2), dim=-1)
# # # Corr = Cov / torch.sqrt(vars + eps)

# # Reg = psi.pow(2).mean(dim=-1)
# # Loss = 0.5 * (Cov + Norm) + 0.001 * Reg
# # Nll = state['sigma_max'] * torch.ones_like(Loss)

# # if train:
# #   optimizer = state['optimizer']
# #   optimizer.zero_grad()
# #   Loss.backward()
# #   optimize_fn(optimizer, model.parameters(), step=state['step'])

# sample = inverse_scaler(samples_x[:,:data_dim]).reshape(-1,config.data.num_channels, config.data.image_size, config.data.image_size).detach()
# nrow = int(np.sqrt(sample.shape[0]))
# image_grid = make_grid(sample, nrow, padding=2)
# sample = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
# save_path = './sample_batch.png'
# with tf.io.gfile.GFile(save_path, "wb") as fout:
#   save_image(image_grid, fout)
# img = plt.imread(save_path) 
# plt.imshow(img)
# plt.show()

# # print(const.mean(dim=-1))
# # print(innovation.mean(dim=-1))
# # print(Gamma.min(), Gamma.max())
# # plt.plot(Gamma.squeeze().cpu(), 0*np.ones_like(Gamma.squeeze().cpu()), '.k')
# # plt.show()

# %% sampling.py
import functools
import torch
import numpy as np
import abc
import time
import math

from models.utils import from_flattened_numpy, to_flattened_numpy, get_predict_fn
from scipy import integrate
import methods
from models import utils as mutils
from tqdm import tqdm
from PIL import Image
from torchvision.utils import make_grid, save_image

from IPython.display import clear_output
import matplotlib.pyplot as plt
from scipy.spatial import geometric_slerp
from mask import get_masked_fn

# torch.cuda.empty_cache()
model = state['model']
ema.store(model.parameters())
ema.copy_to(model.parameters())

data_dim = sde.config.data.channels * sde.config.data.image_size * sde.config.data.image_size
if config.training.snapshot_sampling:
  sampling_shape = (150, config.data.num_channels, config.data.image_size, config.data.image_size)
shape = sampling_shape
sample_size = shape[0]
device = config.device
x = None
mask = None
grad_mask = None
# method='RK23'
method='RK45'
# method='Euler'
# method='DOP853'
rtol = 1e-3
atol = 1e-3
eps_z = 1e-10
c = 1
Drift = []
Time = []
step_size = 0
# torch.manual_seed(49)
inpainting = False
interpolation = False
nfe = 100
t_end = 1.65
c_norm = 1

if inpainting:
  batch = batch[:sample_size]
  t_sample = 2.5e-5
else:
  batch = None
  t_sample = 0

# save_path_sample = './cifar10_small.png'
# save_path_sample = './cifar10_big.png' 
# save_path_sample = './cifar10_interp_big.png'
# save_path_sample = './celeba_small.png'
save_path_sample = './celeba_big.png'
# save_path_sample = './celeba_interp_small.png'
# save_path_sample = './celeba_interp_big.png'

with torch.no_grad():
  data_dim = sde.config.data.channels * sde.config.data.image_size * sde.config.data.image_size
  # batch_size = sde.config.training.batch_size
  # ensemble_size = sde.config.training.small_batch_size
  # num_particles = sde.config.training.num_particles

  # Change-of-variable: z = -ln(t)
  z_t = lambda t: -c * math.log(t); t_z = lambda z: math.exp(-z / c)

  # Initial sample
  if x is None:
    # Geometric sequence of sigmas
    # sigma_prior = math.sqrt(sde.config.data.image_size * sde.config.data.image_size * sde.config.data.channels)
    # sigma_prior = math.sqrt(sde.config.data.image_size) * sde.config.data.channels
    # sigma_prior = math.sqrt(sde.config.data.image_size * sde.config.data.channels)
    mean_prior = sde.config.training.mean_prior
    sigma_prior = sde.config.training.sigma_prior
    sigma_dec = sde.config.training.sigma_min

    var_prior = sigma_prior**2
    var_dec = sigma_dec**2

    # Sample from prior
    mean_enc_y = (t_sample * var_prior) / (var_dec + t_sample * var_prior)
    mean_enc_prior = var_dec / (var_dec + t_sample * var_prior)
    var_enc = var_prior * var_dec / (var_dec + t_sample * var_prior)

    if inpainting:
      config.eval.mask = True
      config.eval.mask_type = "box_center"
      config.eval.mask_box_size = 32
      mask_fn = get_masked_fn(config)
      mask_init = mask_fn(config).bool() # [H, W]
      mask = mask_init.detach().clone()[None, None].expand(*batch.shape[:2], -1, -1) # [1, 1, H, W]
      # mask = ~mask
      # grad_mask = mask.reshape(len(mask), -1).to(device) 
      masked_x = batch.masked_fill(mask, 0) # [B, C, H, W]
      # x = masked_x.to(device)
      gaussian = torch.randn(sample_size, data_dim).to(device)
      x = math.sqrt(var_enc) * gaussian
      if batch is not None:
        x += mean_enc_y * masked_x.reshape(len(masked_x), -1).to(device)
        x += mean_enc_prior * mean_prior
      # x = torch.where(mask, gaussian.reshape(sampling_shape), x.reshape(sampling_shape))
      x = x.reshape(len(x), -1).to(device) 

    elif interpolation:
      gaussian = torch.randn(sample_size, data_dim)
      x_init = math.sqrt(var_enc) * gaussian.to(device)
      n_interp = 20
      t_vals = np.linspace(0, 1, n_interp)
      unit_vec = x_init.detach().cpu().numpy().astype(np.double)
      unit_mag = np.sqrt(np.sum(unit_vec**2, axis=1, keepdims=True))
      unit_vec /= unit_mag
      x = []
      for i in range(round(sample_size // n_interp)):
        interp = geometric_slerp(unit_vec[i], unit_vec[i+1], t_vals)
        x.append(torch.from_numpy(interp))
      x = torch.stack(x).to(device) * unit_mag.mean()

    else:
      gaussian = torch.randn(sample_size, data_dim)
      x =  math.sqrt(var_enc) * gaussian.to(device)
      if batch is not None:
        x += mean_enc_y * batch.reshape(len(batch), -1).to(device)
        # .mean(axis=0, keepdims=True)
        x += mean_enc_prior * mean_prior
      norm_x = torch.sqrt(torch.sum(x**2, axis=1, keepdims=True))
      x = (x / norm_x) * c_norm * norm_x.mean()

    x_prior = x.clone()

    if sde.config.training.augment_z: 
      x = torch.cat([x, z_t(eps_z) * torch.ones(sample_size, 1).to(device)], dim=-1)

  # x = x.view(shape).float()
  new_shape = (sample_size, sde.config.data.channels, sde.config.data.image_size, sde.config.data.image_size)

  # t = np.log(sde.config.sampling.z_max)
  # x = to_flattened_numpy(x)
  state['t_eval'] = 0
  
  if method == 'Euler':
##================================================================================================================##
    for t in np.linspace(0, t_end, nfe+1) ** 1:
    # for t in (1 - np.linspace(0, 1, nfe+1) ** 1)[::-1]:

      step_size =  abs(t - state['t_eval'])
      state['t_eval'] = t

      with torch.enable_grad():

        # Get model function
        net_fn = get_predict_fn(sde, model, train=False)

        # Predict scalar potential (FC)
        x.requires_grad = True
        samples_net = x
        if sde.config.training.augment_t:
          samples_net = torch.cat([samples_net, t * torch.ones(sample_size, 1).to(device).type(torch.float32)], dim=-1)
        psi = net_fn(samples_net).squeeze(dim=-1)

        # Normalize field by its mean
        # psi -= psi.mean(dim=0, keepdim=True)

        # Compute (backpropagate) N-dimensional Poisson field (gradient)
        drift = torch.autograd.grad(psi, x, torch.ones_like(psi))[0]

        if grad_mask is not None:
          drift = drift * grad_mask

      x += drift * step_size
      # x += torch.randn_like(x) * math.sqrt(2*step_size)
    
      Drift.append(drift.abs().mean().cpu().numpy())
      Time.append(t)

      time.sleep(0.01)
      clear_output(wait=False)
      print(t, psi.mean())

##================================================================================================================##
  else:
    def ode_func(l, x):

      # Prepare potential network input
      # z = math.exp(l)
      z = l
      t = t_z(z)
      # t = l

      samples_x = from_flattened_numpy(x, (sample_size, -1)).to(device).type(torch.float32)

      # step_size = t - state['t_eval']
      # if step_size >= 0.01:
      #   print('add_noise')
      #   samples_x += torch.randn_like(samples_x) * math.sqrt(step_size)
      #   state['t_eval'] = t

      if sde.config.training.augment_z: 
        # z_pred = samples_x[:,-1][:,None]
        z_samples = z * torch.ones(sample_size, 1).to(device).type(torch.float32)
        samples_x = torch.cat([samples_x[:,:-1], z_samples], dim=-1)
      samples_x.requires_grad = True

      with torch.enable_grad():
        # Get model function
        net_fn = get_predict_fn(sde, model, train=False)

        # Predict scalar potential (FC)
        samples_net = samples_x
        if sde.config.training.augment_t:
          samples_net = torch.cat([samples_net, t * torch.ones(sample_size, 1).to(device).type(torch.float32)], dim=-1)
        psi = net_fn(samples_net).squeeze(dim=-1)

        # Normalize field by its mean
        # psi -= psi.mean(dim=0, keepdim=True)

        # Compute (backpropagate) N-dimensional Poisson field (gradient)
        drift = torch.autograd.grad(psi, samples_x, torch.ones_like(psi))[0]

        # Predicted normalized Poisson field
        dt_dz = - math.exp(-z)
        if sde.config.training.augment_z:
          dx_dt, dz_dt = torch.split(drift, [data_dim, 1], dim=-1)
          dt_dz_pred = 1 / (dz_dt + 1e-5)
          drift = torch.cat([dx_dt * dt_dz, torch.ones_like(dz_dt)], dim=-1) 
          # dz_dl = z
          # drift *= dz_dl
          diff = (dt_dz - dt_dz_pred).abs().mean()
        else:
          drift = drift * dt_dz

        if grad_mask is not None:
          drift = drift * grad_mask

      time.sleep(0.1)
      clear_output(wait=False)
      print(t, state['t_eval'], step_size)
      if sde.config.training.augment_z:
        print(z, diff.item())
      print(psi.mean().item(), drift.abs().mean().item())

      Drift.append(drift.abs().mean().cpu().numpy())
      Time.append(t)

      return to_flattened_numpy(drift)

    # Black-box ODE solver for the probability flow ODE.
    # Note that we use z = exp(t) for change-of-variable to accelearte the ODE simulation
    # boundary = [0, t_end]
    boundary = [z_t(eps_z), z_t(t_end)]
    # boundary = [np.log(sde.config.training.z_max), np.log(eps)]
    solution = integrate.solve_ivp(ode_func, boundary, to_flattened_numpy(x), rtol=rtol, atol=atol, method=method)

    nfe = solution.nfev
    x = torch.tensor(solution.y[:,-1]).reshape(sample_size, -1).to(device).type(torch.float32)
    y = x
    if sde.config.training.augment_z: x, _ = torch.split(x, [data_dim, 1], dim=-1)
##================================================================================================================##
  
  # Detach augmented z dimension
  x = inverse_scaler(x.reshape(new_shape))
  n = nfe
  print(n)

sample = inverse_scaler(x_prior.view(shape))
nrow = int(np.sqrt(sample.shape[0]))
image_grid = make_grid(sample, nrow, padding=2)
sample = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
save_path = './sample_prior.png'
with tf.io.gfile.GFile(save_path, "wb") as fout:
  save_image(image_grid, fout)
img = plt.imread(save_path) 
plt.imshow(img)
plt.show()

if batch is not None:
  sample = inverse_scaler(batch.view(shape))
  nrow = int(np.sqrt(sample.shape[0]))
  image_grid = make_grid(sample, nrow, padding=2)
  sample = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
  save_path = './sample_data.png'
  with tf.io.gfile.GFile(save_path, "wb") as fout:
    save_image(image_grid, fout)
  img = plt.imread(save_path) 
  plt.imshow(img)
  plt.show()

sample = x.clone()
if inpainting:
  sample = torch.where(mask, sample, inverse_scaler(batch))
nrow = int(np.sqrt(sample.shape[0]))
image_grid = make_grid(sample, nrow, padding=2)
sample = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
# save_path = './sample_x.png'
with tf.io.gfile.GFile(save_path_sample, "wb") as fout:
  save_image(image_grid, fout)
img = plt.imread(save_path_sample) 
plt.imshow(img)
plt.show()

ema.restore(model.parameters())

plt.plot(Time, Drift, '.')
plt.show()

# %%
#%% inpainitng
from mask import get_masked_fn
from tqdm import tqdm

from IPython.display import clear_output
import matplotlib.pyplot as plt

eval_dir = save_dir
ckpt = 45

config.eval.num_samples = 2500
config.eval.batch_size = 128
config.eval.save_samples = False
config.eval.mask = True
config.eval.mask_type = "box_center"
config.eval.mask_box_size = 32

num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
# Directory to save samples. Different for each host to avoid writing conflicts
this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
tf.io.gfile.makedirs(this_sample_dir)

sampling_shape = (config.eval.batch_size,
                  config.data.num_channels,
                  config.data.image_size, config.data.image_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, eps=1e-4, rtol=1e-5, atol=1e-5)

# Build data iterators
if config.data.dataset == 'CELEBA':
  # I cannot load CelebA from tfds loader. So I write a pytorch loader instead.
  train_ds, eval_ds = datasets_utils.celeba.get_celeba(config)
else:
  train_ds, eval_ds, _ = datasets.get_dataset(config, uniform_dequantization=config.data.uniform_dequantization)

def load_batch(ds, ds_iter):
  if config.data.dataset == 'CELEBA':
    try:
      batch = next(ds_iter)[0].cuda()
      if len(batch) != config.eval.batch_size:
        raise ContinueIteration
    except StopIteration:
      ds_iter = iter(ds)
      batch = next(ds_iter)[0].cuda()
  else:
    batch = torch.from_numpy(next(ds_iter)['image']._numpy()).to(config.device).float()
    batch = batch.permute(0, 3, 1, 2)
  batch = scaler(batch)
  return batch, ds_iter

def single_sampling_iter(iter_id, x=None, mask=None, grad_mask=None):
  print("sampling iter")
  logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, iter_id))

  # ema.store(net.parameters())
  # ema.copy_to(net.parameters())
  samples, n = sampling_fn(net, state, x=x, mask=mask, grad_mask=grad_mask)
  print("nfe:", n)
  # ema.restore(net.parameters())
  samples_torch = copy.deepcopy(samples)
  samples_torch = samples_torch.view(-1, config.data.num_channels, config.data.image_size, config.data.image_size)

  samples = np.clip(samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
  samples = samples.reshape(
    (-1, config.data.image_size, config.data.image_size, config.data.num_channels))
  
  return samples, samples_torch

# NOTE: assumes CelebA (if not, then need to change dataset pre-processing)
print("USING MASK")
mask_fn = get_masked_fn(config)

# NOTE: assume mask is fixed for all images, if not move this to within the loop below
# 
# with torch.no_grad():
mask_init = mask_fn(config).bool() # [H, W]
eval_n_iters = len(eval_ds)
eval_ds_iter = iter(eval_ds)
for i in tqdm(range(eval_n_iters)):
  try:
    batch, eval_ds_iter = load_batch(eval_ds, eval_ds_iter) # batch: [B, C, H, W]
  except ContinueIteration:
    print("continue invoked")
    continue
  mask = mask_init.detach().clone()[None, None].expand(*batch.shape[:2], -1, -1) # [1, 1, H, W]
  print(f"x.shape: {batch.shape}, mask.shape: {mask.shape}")
  masked_x = batch.masked_fill(mask, 0.0) # [B, C, H, W]
  break

  _, samples = single_sampling_iter(i, x=masked_x, mask=mask, grad_mask=None)
  # break

samples = inverse_scaler(masked_x)
nrow = int(np.sqrt(samples.shape[0]))
image_grid = make_grid(samples, nrow, padding=2)
save_path = './sample_x2.png'
with tf.io.gfile.GFile(save_path, "wb") as fout:
  save_image(image_grid, fout)
img = plt.imread(save_path) 
plt.imshow(img)
plt.show()