import os
import re
import copy
import time
import PIL.Image
import torch
import torch.nn.functional as F
from torch_utils import distributed as dist
from torchvision.utils import make_grid, save_image
import numpy as np
from utils import distance_point_to_line, ablation_sampler, opt_onetap_sampler, opt_sampler, opt_sampler_new, edm_sampler


#----------------------------------------------------------------------------
# Collect cos stat (.npz) for drawing Figure
def monitor_cos_collect(
    outdir              = None,
    network_name        = None,
    batch_size          = 1000,
    net                 = None,
    device              = None,
    class_idx           = None,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    class_labels = None
    if net.label_dim:
        class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
    if class_idx is not None:
        class_labels[:, :] = 0
        class_labels[:, class_idx] = 1
    
    data_dir = os.path.join(outdir, 'CIFAR-10_cos')
    os.makedirs(data_dir, exist_ok=True)
    
    n_round = 50000 // batch_size
    for k in range(n_round):
        latents = torch.randn([batch_size, net.img_channels, img_res, img_res], device=device)
        traj, _, scores = ablation_sampler(net, latents, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                            t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                            s_deriv=s_deriv, solver=solver, mode='trajectory')
        traj = traj.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)                 # (bs, num_steps+1, ch, r, r)
        scores = scores.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)             # (bs, num_steps, ch, r, r)
        for i in range(batch_size):
            scores_cosTest = scores[i].reshape(scores[i].shape[0], -1)
            traj_cosTest = (traj[i, -1] - traj[i, 0:-1].reshape(-1, img_channels, img_res, img_res)).reshape(traj[i].shape[0]-1, -1)
            cos = (torch.bmm(traj_cosTest.unsqueeze(1), scores_cosTest.unsqueeze(2)).squeeze()) \
                / (torch.norm(traj_cosTest, p=2, dim=1) * torch.norm(scores_cosTest, p=2, dim=1))
            if k == 0 and i == 0:
                cos_record = cos.unsqueeze(1)
            else:
                cos_record = torch.cat((cos_record, cos.unsqueeze(1)), dim=1)
        print(k+1, '|', n_round)
                
    np.savez(os.path.join(data_dir, 'CIFAR-10_0.npz'), cos=cos_record.cpu().numpy())
    


#----------------------------------------------------------------------------
# generate all trajs mentioned in the main text
def monitor_all_trajs(
    outdir              = None,
    network_name        = None,
    net                 = None,
    images_all          = None,
    images_sel          = None,
    test_images         = None,
    test_labels         = None,
    device              = None,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    if network_name.split('_')[-1] == 'uncond':
        test_labels = None
    
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    # Time step discretization.
    T = copy.deepcopy(t_steps)
    T = T.reshape(-1, 1, 1, 1)
    
    scale_temp = s(t_steps)
    try:
        scale_temp = scale_temp.reshape(num_steps, 1, 1, 1)
    except:
        pass
    
    # add noise to test image
    xs = test_images.repeat(T.shape[0], 1, 1, 1)
    x_noise = scale_temp * xs + (sigma(T) * scale_temp) * torch.randn([1, img_channels, img_res, img_res], device=device)
    
    # set the last of noisy image as new start point
    latents = (x_noise[0].reshape(1, img_channels, img_res, img_res)) / (sigma(T[0]) * s(T[0]))
    traj, x_hat = ablation_sampler(net, latents, test_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                  t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                  s_deriv=s_deriv, solver=solver, mode='traj_both')
    
    # denoise images using optimal denoiser at different noise levels
    x_opt = opt_onetap_sampler(t_steps, images_all, traj, img_channels, img_res, sigma=sigma, s=s)
    x_opt_Euler = opt_sampler(latents, t_steps, images_all, img_channels, img_res, sigma=sigma, sigma_deriv=sigma_deriv, \
                              sigma_inv=sigma_inv, s=s, s_deriv=s_deriv, solver=solver, mode='trajectory')
    x_opt_Euler_opt = opt_onetap_sampler(t_steps, images_all, x_opt_Euler, img_channels, img_res, sigma=sigma, s=s)

    x_opt_Euler_hat = net(x_opt_Euler[0:-1] / scale_temp, sigma(t_steps), test_labels)

    print('Dist btween 2 and 5\n', torch.norm(traj - x_opt_Euler, p=2, dim=(1, 2, 3)))
    print('Dist btween 3 and 7\n', torch.norm(x_opt - x_opt_Euler_opt, p=2, dim=(1, 2, 3)))
    print('Dist btween 6 and 7\n', torch.norm(x_opt_Euler_hat - x_opt_Euler_opt, p=2, dim=(1, 2, 3)))

    out = torch.flip(x_noise, dims=[0])
    out = torch.cat((out, torch.flip(traj[0:-1], dims=[0])), dim=0)
    out = torch.cat((out, torch.flip(x_hat, dims=[0])), dim=0)
    out = torch.cat((out, torch.flip(x_opt, dims=[0])), dim=0)
    out = torch.cat((out, torch.flip(x_opt_Euler[0:-1], dims=[0])), dim=0)
    out = torch.cat((out, torch.flip(x_opt_Euler_hat, dims=[0])), dim=0)
    out = torch.cat((out, torch.flip(x_opt_Euler_opt, dims=[0])), dim=0)
    
    # draw image grid
    images = torch.clamp(out / 2 + 0.5, 0, 1)
    os.makedirs(outdir, exist_ok=True)
    nrow = T.shape[0]
    image_grid = make_grid(images, nrow, padding=0)
    save_image(image_grid, os.path.join(outdir, "monitor_denoiser_diff_imgs_{}.png".format(network_name)), dpi=500)


#----------------------------------------------------------------------------
# collect norm npz stats for different trajs
def monitor_denoiser_std_collect(
    outdir              = None,
    network_name        = None,
    batch_size          = 20,
    net                 = None,
    images_all          = None,
    images_all_cond     = None,
    device              = None,
    class_idx           = None,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    
    std_dir = os.path.join(outdir, 'CIFAR-10_denoiser_std')
    os.makedirs(std_dir, exist_ok=True)
    num_npz = len(os.listdir(std_dir))
    
    if num_npz == 0:
        
        images_all_bs = images_all.unsqueeze(1).repeat(1, batch_size, 1, 1, 1)
        
        num_steps = t_steps.shape[0]
        img_res = net.img_resolution
        img_channels = net.img_channels

        t_temp = t_steps.repeat(batch_size)
        scale_temp = s(t_temp)
        try:
            scale_temp = scale_temp.reshape(t_temp.shape[0], 1, 1, 1)
        except:
            pass

        n_round = 50000 // batch_size
        for k in range(n_round):
            class_labels = None
            if net.label_dim:
                class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
            if class_idx is not None:
                class_labels[:, :] = 0
                class_labels[:, class_idx] = 1

            if class_labels:
                class_temp = []
                for i in range(batch_size):
                    tmp = class_labels[i].repeat(num_steps, 1)
                    class_temp.append(tmp)
                class_temp = torch.cat(class_temp, dim=0)
            else:
                class_temp = None
            
            # generate (Euler/denoised/optimal) trajectory
            latents = torch.randn([batch_size, img_channels, img_res, img_res], device=device)
            traj, traj_denoised = ablation_sampler(net, latents, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                        t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                        s_deriv=s_deriv, solver=solver, mode='traj_both')                                   # (bs*(num_steps+1), ch, r, r)
            traj = traj.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)                             # (bs, num_steps+1, ch, r, r)
            traj_denoised = traj_denoised.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)           # (bs, num_steps, ch, r, r)
            
            traj_temp = traj[:, 0:-1].reshape(-1, img_channels, img_res, img_res)                                           # (bs*(num_steps), ch, r, r)
            
            traj_opt = opt_onetap_sampler(t_temp, images_all, traj_temp, img_channels, img_res, sigma=sigma, s=s)           # (bs*(num_steps), ch, r, r)
            traj_opt = traj_opt.reshape(batch_size, -1, img_channels, img_res, img_res)                                     # (bs, num_steps, ch, r, r)
            
            traj_opt_Euler = opt_sampler_new(latents, t_steps, images_all_bs, img_channels, img_res, randn_like=torch.randn_like, \
                                        sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                        s_deriv=s_deriv, solver=solver, mode='trajectory')                                  # (bs*(num_steps+1), ch, r, r)
            traj_opt_Euler = traj_opt_Euler.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)         # (bs, num_steps+1, ch, r, r)
            
            traj_temp = traj_opt_Euler[:, 0:-1].reshape(-1, img_channels, img_res, img_res)                                 # (bs*(num_steps), ch, r, r)
            traj_opt_denoised = net(traj_temp / scale_temp, sigma(t_temp), class_temp).reshape(batch_size, -1, img_channels, img_res, img_res)
            
            traj_opt_Euler_opt = opt_onetap_sampler(t_temp, images_all, traj_temp, img_channels, img_res, sigma=sigma, s=s) # (bs, num_steps, ch, r, r)
            traj_opt_Euler_opt = traj_opt_Euler_opt.reshape(batch_size, -1, img_channels, img_res, img_res)                 # (bs, num_steps, ch, r, r)
            
            # calculate projection distance
            diff_25 = torch.norm(traj - traj_opt_Euler, p=2, dim=(2, 3, 4)).detach().cpu().numpy()
            diff_34 = torch.norm(traj_denoised - traj_opt, p=2, dim=(2, 3, 4)).detach().cpu().numpy()
            diff_67 = torch.norm(traj_opt_denoised - traj_opt_Euler_opt, p=2, dim=(2, 3, 4)).detach().cpu().numpy()
            
            out_dir = os.path.join(std_dir, 'CIFAR-10_{}.npz'.format(k))
            np.savez(out_dir, diff_25=diff_25, diff_34=diff_34, diff_67=diff_67)
            
            print(k + 1, '|', n_round)


#----------------------------------------------------------------------------
def monitor_interpolation(
    outdir              = None,
    network_name        = None,
    batch_size          = 1,
    inter_size          = 9,
    net                 = None,
    images_all          = None,
    device              = None,
    class_idx           = None,
    sigma_max           = 80,
    sigma_min           = 0.002,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    
    
    class_labels = None
    if net.label_dim:
        class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
    if class_idx is not None:
        class_labels[:, :] = 0
        class_labels[:, class_idx] = 1
        
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    torch.manual_seed(128)      # selected left white car
    latents_start = torch.randn([1, img_channels, img_res, img_res], device=device)[-1].unsqueeze(0)
    torch.manual_seed(2059)     # selected right white car
    latents_end = torch.randn([1, img_channels, img_res, img_res], device=device)[-1].unsqueeze(0)

    # get intermediate points for three interpolation strategies
    alphas = torch.linspace(0, 1, inter_size + 2).to(device)
    rhos = alphas / (alphas**2 + (1 - alphas)**2)**(0.5)
    phi = torch.acos(torch.sum(latents_start * latents_end, dim=(1, 2, 3)) / (torch.norm(latents_start, p=2, dim=(1, 2, 3)) * torch.norm(latents_end, p=2, dim=(1, 2, 3))))
    rho1 = torch.sin((1 - alphas.unsqueeze(1)) * phi.unsqueeze(0)) / torch.sin(phi.unsqueeze(0))
    rho2 = torch.sin(alphas.unsqueeze(1) * phi.unsqueeze(0)) / torch.sin(phi.unsqueeze(0))
    
    alphas = alphas.unsqueeze(1).unsqueeze(2).unsqueeze(3)
    rhos = rhos.unsqueeze(1).unsqueeze(2).unsqueeze(3)
    rho1 = rho1.unsqueeze(2).unsqueeze(3)
    rho2 = rho2.unsqueeze(2).unsqueeze(3)

    linear_inters = (1 - alphas) * latents_start + alphas * latents_end
    n_linear_inters = (1 - rhos**2)**(0.5) * latents_start + rhos * latents_end
    slerp_inters = rho1 * latents_start + rho2 * latents_end

    # generate interpolation trajectories
    trajs_linear, _, _ = ablation_sampler(net, linear_inters, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                          t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                          s_deriv=s_deriv, solver=solver, mode='trajectory')
    trajs_n_linear, _, _ = ablation_sampler(net, n_linear_inters, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                            t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                            s_deriv=s_deriv, solver=solver, mode='trajectory')
    trajs_slerp, _, _ = ablation_sampler(net, slerp_inters, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                         t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                         s_deriv=s_deriv, solver=solver, mode='trajectory')
                    
    nrow = inter_size+2
    os.makedirs(outdir, exist_ok=True)
    
    # concatenate and save results
    temp = torch.cat((trajs_linear.reshape(-1, nrow, 3, 32, 32)[-1], trajs_n_linear.reshape(-1, nrow, 3, 32, 32)[-1], \
                      trajs_slerp.reshape(-1, nrow, 3, 32, 32)[-1]), dim=0)
    images = torch.clamp(temp / 2 + 0.5, 0, 1)
    image_grid = make_grid(images, nrow, padding=0)
    save_image(image_grid, os.path.join(outdir, "monitor_interpolation_{}.png".format(network_name)))
    
    
    # calculate knn (k=10) for all end points
    interp_collect = [trajs_linear.reshape(-1, nrow, 3, 32, 32)[-1], trajs_n_linear.reshape(-1, nrow, 3, 32, 32)[-1], \
                        trajs_slerp.reshape(-1, nrow, 3, 32, 32)[-1]]
    modes = ['linear', 'n_linear', 'slerp']
    for i in range(len(modes)):
        imgs = interp_collect[i]
        neighbors = None
        for j in range(inter_size+2):
            distance = torch.norm(images_all - imgs[j].unsqueeze(0), p=2, dim=(1, 2, 3))
            sorted_dist, sorted_indices = torch.sort(distance, descending=False, dim=-1)
            if j == 0:
                neighbors = torch.cat((imgs[j].unsqueeze(0), images_all[sorted_indices[0:10]]), dim=0)
            else:
                neighbors = torch.cat((neighbors, imgs[j].unsqueeze(0), images_all[sorted_indices[0:10]]), dim=0)
        images = torch.clamp(neighbors / 2 + 0.5, 0, 1)
        image_grid = make_grid(images, nrow, padding=0)
        save_image(image_grid, os.path.join(outdir, "monitor_interpolation_knn_{}_{}.png".format(network_name, modes[i])))


#----------------------------------------------------------------------------
# generate all interpolated sampling trajs for FID evaluation
def monitor_interpolation_generate(
    outdir              = None,
    network_name        = None,
    batch_size          = 125,
    inter_size          = 9,
    net                 = None,
    images_all          = None,
    device              = None,
    class_idx           = None,
    sigma_max           = 80,
    sigma_min           = 0.002,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
    strategy            = 'slerp',
):
    
    assert strategy in ['linear', 'n_linear', 'slerp']
    
    def parse_int_list(s):
        if isinstance(s, list): return s
        ranges = []
        range_re = re.compile(r'^(\d+)-(\d+)$')
        for p in s.split(','):
            m = range_re.match(p)
            if m:
                ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
            else:
                ranges.append(int(p))
        return ranges
    
    seeds = parse_int_list('0-49999')
    num_batches = 50000 // batch_size
    all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
    
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    n_round = num_batches // dist.get_world_size()
    for i in range(n_round):
        time_round = time.time()
        
        class_labels = None
        if net.label_dim:
            class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
        if class_idx is not None:
            class_labels[:, :] = 0
            class_labels[:, class_idx] = 1
        if class_labels:
            class_labels = class_labels.repeat(inter_size+2, 1)
        
        latents_start = torch.randn([batch_size, img_channels, img_res, img_res], device=device)
        latents_end = torch.randn([batch_size, img_channels, img_res, img_res], device=device)
    
        alphas = torch.linspace(0, 1, inter_size + 2).to(device)
        if strategy == 'linear':
            alphas = alphas.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            linear_inters = (1 - alphas) * latents_start + alphas * latents_end
            linear_inters = linear_inters.reshape(-1, img_channels, img_res, img_res)
            images = edm_sampler(net, linear_inters, class_labels, randn_like=torch.randn_like, num_steps=num_steps)
        elif strategy == 'n_linear':
            rhos = alphas / (alphas**2 + (1 - alphas)**2)**(0.5)
            rhos = rhos.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            n_linear_inters = (1 - rhos**2)**(0.5) * latents_start + rhos * latents_end
            n_linear_inters = n_linear_inters.reshape(-1, img_channels, img_res, img_res)
            images = edm_sampler(net, n_linear_inters, class_labels, randn_like=torch.randn_like, num_steps=num_steps)
        else:
            phi = torch.acos(torch.sum(latents_start * latents_end, dim=(1, 2, 3)) / (torch.norm(latents_start, p=2, dim=(1, 2, 3)) * torch.norm(latents_end, p=2, dim=(1, 2, 3))))
            rho1 = torch.sin((1 - alphas.unsqueeze(1)) * phi.unsqueeze(0)) / torch.sin(phi.unsqueeze(0))
            rho2 = torch.sin(alphas.unsqueeze(1) * phi.unsqueeze(0)) / torch.sin(phi.unsqueeze(0))
            slerp_inters = rho1.unsqueeze(2).unsqueeze(3).unsqueeze(4) * latents_start + rho2.unsqueeze(2).unsqueeze(3).unsqueeze(4) * latents_end
            slerp_inters = slerp_inters.reshape(-1, img_channels, img_res, img_res)
            images = edm_sampler(net, slerp_inters, class_labels, randn_like=torch.randn_like, num_steps=num_steps)
    
        # Save images.
        images = images.reshape(inter_size+2, batch_size, img_channels, img_res, img_res)
        subdirs = True
        batch_seeds = rank_batches[i]
        for k in range(inter_size+2):
            images_np = (images[k] * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
            for seed, image_np in zip(batch_seeds, images_np):
                image_dir = os.path.join(outdir, 'traj_samples', 'interp_{}_{}'.format(strategy, k), f'{seed-seed%1000:06d}') if subdirs else outdir
                os.makedirs(image_dir, exist_ok=True)
                image_path = os.path.join(image_dir, f'{seed:06d}.png')
                if image_np.shape[2] == 1:
                    PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
                else:
                    PIL.Image.fromarray(image_np, 'RGB').save(image_path)

        dist.print0(i+1, '|', n_round, '|', time.time()-time_round)


#----------------------------------------------------------------------------
# collect forward/backward norm npz stat for drawing
def monitor_normTraj_collect(
    forward_steps       = 500,
    outdir              = None,
    net                 = None,
    images_all          = None,
    device              = None,
    class_idx           = None,
    sigma_max           = 80,
    sigma_min           = 0.002,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    
    dataset_name = 'CIFAR-10'
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    
    # forward
    forward_dir = os.path.join(outdir, 'forward_norms')
    os.makedirs(forward_dir, exist_ok=True)
    num_npz = len(os.listdir(forward_dir))
    
    if num_npz == 0:
        norm_raw = torch.norm(images_all, p=2, dim=(1, 2, 3))
        np.savez(os.path.join(forward_dir, '{}_{}.npz'.format(dataset_name, 0)), l2norms=norm_raw.cpu().numpy())
        for i in range(forward_steps):
            t = torch.tensor(1 * (i + 1) / forward_steps)
            norm_forward = torch.norm(images_all * s(t) + s(t) * sigma(t) * torch.randn_like(images_all, device=device), p=2, dim=(1, 2, 3))
            np.savez(os.path.join(forward_dir, '{}_{}.npz'.format(dataset_name, i+1)), l2norms=norm_forward.cpu().numpy())
        num_npz = len(os.listdir(forward_dir))
    
    
    # backward and backward_denoiser
    backward_dir1 = os.path.join(outdir, '{}_backward_norms').format(dataset_name)
    backward_dir2 = os.path.join(outdir, '{}_backward_denoiser_norms').format(dataset_name)
    os.makedirs(backward_dir1, exist_ok=True)
    os.makedirs(backward_dir2, exist_ok=True)
    num_npz = len(os.listdir(backward_dir2))
    K = 100
    batch_size = 50000 // K
    
    if num_npz == 0:
        for i in range(K):
            class_labels = None
            if net.label_dim:
                class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
            if class_idx is not None:
                class_labels[:, :] = 0
                class_labels[:, class_idx] = 1
            
            latents = torch.randn(size=(batch_size, img_channels, img_res, img_res), device=device)
            norm_backward1, norm_backward2 = ablation_sampler(net, latents, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                            t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                            s_deriv=s_deriv, solver=solver, mode='norm_denoiser')
            np.savez(os.path.join(backward_dir1, '{}_{}.npz'.format(dataset_name, i)), l2norms=norm_backward1.cpu().numpy())
            np.savez(os.path.join(backward_dir2, '{}_{}.npz'.format(dataset_name, i)), l2norms=norm_backward2.cpu().numpy())
            print(i+1, '|', K)
            


#----------------------------------------------------------------------------
# collect deviation npz stat for drawing
def monitor_deviation_collect(
    outdir              = None,
    network_name        = None,
    batch_size          = 500,
    net                 = None,
    images_all          = None,
    device              = None,
    class_idx           = None,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    std_dir = os.path.join(outdir, 'CIFAR-10_deviation')
    os.makedirs(std_dir, exist_ok=True)
    num_npz = len(os.listdir(std_dir))
    
    n_round = 50000 // batch_size
    if num_npz == 0:
        for k in range(n_round):

            class_labels = None
            if net.label_dim:
                class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
            if class_idx is not None:
                class_labels[:, :] = 0
                class_labels[:, class_idx] = 1

            # generate (Euler/denoised/optimal) trajectory
            latents = torch.randn([batch_size, img_channels, img_res, img_res], device=device)
            traj, traj_denoised = ablation_sampler(net, latents, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                        t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                        s_deriv=s_deriv, solver=solver, mode='traj_both')                                # (bs*(num_steps+1), ch, r, r)
            traj = traj.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)                             # (bs, num_steps+1, ch, r, r)
            traj_denoised = traj_denoised.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)                             # (bs, num_steps+1, ch, r, r)
            
            # calculate projection distance
            norm_traj = distance_point_to_line(traj, img_channels, img_res, bs=batch_size, mode='norm').cpu().numpy()                                              # (num_steps-1,)
            norm_traj_denoiser = distance_point_to_line(traj_denoised, img_channels, img_res, bs=batch_size, mode='norm').cpu().numpy()                     # (num_steps-2,)

            out_dir = os.path.join(std_dir, 'CIFAR-10_{}.npz'.format(k))
            np.savez(out_dir, norm_traj=norm_traj, norm_traj_denoiser=norm_traj_denoiser)
            
            print(k + 1, '|', n_round)


#----------------------------------------------------------------------------
# collect npz stat for drawing the distance between the current point to the final point
def monitor_trajDistance_collect(
    outdir              = None,
    network_name        = None,
    batch_size          = 100,
    net                 = None,
    images_all          = None,
    device              = None,
    class_idx           = None,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    traj_dist_dir = os.path.join(outdir, 'CIFAR-10_dist_traj')
    os.makedirs(traj_dist_dir, exist_ok=True)
    num_npz = len(os.listdir(traj_dist_dir))
    
    # if there is no npz stat, then generate
    if num_npz == 0:
        n_round = 50000 // batch_size
        for i in range(n_round):
            class_labels = None
            if net.label_dim:
                class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
            if class_idx is not None:
                class_labels[:, :] = 0
                class_labels[:, class_idx] = 1

            # generate sampling/denoising trajectory
            latents = torch.randn([batch_size, img_channels, img_res, img_res], device=device)
            traj, traj_denoised = ablation_sampler(net, latents, class_labels, randn_like=torch.randn_like, num_steps=num_steps, \
                                        t_steps=t_steps, sigma=sigma, sigma_deriv=sigma_deriv, sigma_inv=sigma_inv, s=s, \
                                        s_deriv=s_deriv, solver=solver, mode='traj_both')
            traj = traj.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)
            traj_denoised = traj_denoised.reshape(-1, batch_size, img_channels, img_res, img_res).transpose(0, 1)
            
            # calculate the distance
            distance_traj = torch.norm(traj[:, 0:-1] - traj[:, -1].unsqueeze(1), p=2, dim=(2, 3, 4))
            distance_traj_denoised = torch.norm(traj_denoised - traj[:, -1].unsqueeze(1), p=2, dim=(2, 3, 4))

            # save npz stat
            np.savez(os.path.join(traj_dist_dir, 'CIFAR-10_{}.npz'.format(i)), \
                     traj_dist=distance_traj.cpu().numpy(), denoised_dist=distance_traj_denoised.cpu().numpy())
            print(i+1, '|', n_round)


#----------------------------------------------------------------------------
# generate the whole sampling/denoising trajs for FID evaluation
def monitor_traj_generate(
    outdir              = None,
    network_name        = None,
    batch_size          = 200,
    net                 = None,
    device              = None,
    class_idx           = None,
    sigma_max           = 80,
    sigma_min           = 0.002,
    t_steps             = None,
    sigma               = None,
    sigma_deriv         = None,
    sigma_inv           = None,
    s                   = None,
    s_deriv             = None,
    solver              = None,
):
    
    def parse_int_list(s):
        if isinstance(s, list): return s
        ranges = []
        range_re = re.compile(r'^(\d+)-(\d+)$')
        for p in s.split(','):
            m = range_re.match(p)
            if m:
                ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
            else:
                ranges.append(int(p))
        return ranges

    
    seeds = parse_int_list('0-49999')
    num_batches = 50000 // batch_size
    all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
    
    num_steps = t_steps.shape[0]
    img_res = net.img_resolution
    img_channels = net.img_channels
    
    n_round = num_batches // dist.get_world_size()
    for i in range(n_round):
        time_round = time.time()
        
        class_labels = None
        if net.label_dim:
            class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
        if class_idx is not None:
            class_labels[:, :] = 0
            class_labels[:, class_idx] = 1
        
        latents = torch.randn([batch_size, img_channels, img_res, img_res], device=device)
    
        traj, traj_de = edm_sampler(net, latents, class_labels, randn_like=torch.randn_like, num_steps=num_steps, mode='traj_both')
        
        # Save images.
        traj = traj.reshape(-1, batch_size, img_channels, img_res, img_res)
        traj_de = traj_de.reshape(-1, batch_size, img_channels, img_res, img_res)
        
        subdirs = True
        batch_seeds = rank_batches[i]
        for k in range(num_steps):
            images_traj = (traj[k] * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
            for seed, image_np in zip(batch_seeds, images_traj):
                image_dir = os.path.join(outdir, 'traj_gen', 'traj_{}'.format(18-k), f'{seed-seed%1000:06d}') if subdirs else outdir
                os.makedirs(image_dir, exist_ok=True)
                image_path = os.path.join(image_dir, f'{seed:06d}.png')
                PIL.Image.fromarray(image_np, 'RGB').save(image_path)
                    
            images_traj_de = (traj_de[k] * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
            for seed, image_np in zip(batch_seeds, images_traj_de):
                image_dir = os.path.join(outdir, 'traj_gen', 'traj_de_{}'.format(18-k), f'{seed-seed%1000:06d}') if subdirs else outdir
                os.makedirs(image_dir, exist_ok=True)
                image_path = os.path.join(image_dir, f'{seed:06d}.png')
                PIL.Image.fromarray(image_np, 'RGB').save(image_path)
        
        # more one loop for the last traj
        images_traj = (traj_de[-1] * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
        for seed, image_np in zip(batch_seeds, images_traj):
            image_dir = os.path.join(outdir, 'traj_gen', 'traj_{}'.format(0), f'{seed-seed%1000:06d}') if subdirs else outdir
            os.makedirs(image_dir, exist_ok=True)
            image_path = os.path.join(image_dir, f'{seed:06d}.png')
            PIL.Image.fromarray(image_np, 'RGB').save(image_path)
                    
        dist.print0(i+1, '|', n_round, '|', time.time()-time_round)