from multiprocessing import allow_connection_pickling
import os
from posixpath import split
from re import I
import time
from generate_rnd import generate_rnd
from generate_rnd_nn import generate_rnd_nn
from generate_sample_nn import generate_sample_nn

import imageio
from interpolate import random_interp
import numpy as np
from ppl import calc_ppl
from ppl_uniform import calc_ppl_uniform
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler

from data import set_up_data
from sampler import Sampler
from train_helpers import set_up_hyperparams, load_vaes, load_opt, accumulate_stats, save_model, update_ema, \
    save_latents_latest, save_latents, save_snoise
from utils import get_cpu_stats_over_ranks, ZippedDataset
from torch.optim import AdamW, SGD


def training_step_imle(H, n, targets, latents, snoise, vae, ema_vae, optimizer, loss_fn):
    """
    This method performs a training step for a batch of data
    This doesn't use elbo to back propagate - instead expects proper mapping between
    the data_inputs and their respective nearest neighbors among some randomly generated samples
    it uses this correspondence to pull the nearest neighbors closer to their associated data input which
    is specifically the IMLE part. See http://www.sfu.ca/~keli/projects/imle/ for more details about how IMLE works
    """
    t0 = time.time()
    vae.zero_grad()
    targets = targets.permute(0, 3, 1, 2)
    px_z = vae(latents, snoise)
    loss = loss_fn(px_z, targets)
    # loss = ind_loss.mean()
    loss.backward()
    # grad_norm = torch.nn.utils.clip_grad_norm(vae.parameters(), H.grad_clip).item()
    grad_norm = 0.
    # if H.normalize_grad:
        # grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), H.grad_clip).item()

    loss_nan = torch.isnan(loss).sum()
    stats = dict(loss_nans=0 if loss_nan == 0 else 1, loss=loss)
    stats = get_cpu_stats_over_ranks(stats)

    skipped_updates = 1
    # only update if no rank has a nan and if the grad norm is below a specific threshold
    # if stats['loss_nans'] == 0 and (H.skip_threshold == -1 or grad_norm < H.skip_threshold):
    optimizer.step()
    skipped_updates = 0
    if ema_vae is not None:
        update_ema(vae, ema_vae, H.ema_rate)

    t1 = time.time()
    stats.update(skipped_updates=skipped_updates, iter_time=t1 - t0, grad_norm=grad_norm)
    return stats


def eval_step(data_input, target, ema_vae):
    with torch.no_grad():
        stats = ema_vae.forward(data_input, target)
    stats = get_cpu_stats_over_ranks(stats)
    return stats


def get_sample_for_visualization(data, preprocess_fn, num, dataset):
    for x in DataLoader(data, batch_size=num):
        break
    orig_image = (x[0] * 255.0).to(torch.uint8).permute(0, 2, 3, 1) if dataset == 'ffhq_1024' else x[0]
    preprocessed = preprocess_fn(x)[0]
    return orig_image, preprocessed


def train_loop_imle(H, data_train, data_valid, preprocess_fn, vae, ema_vae, logprint):
    subset_len = len(data_train)
    if H.subset_len != -1:
        subset_len = H.subset_len
    for data_train in DataLoader(data_train, batch_size=subset_len):
        data_train = TensorDataset(data_train[0])
        break
    print(len(data_train))

    loss_fn = torch.nn.MSELoss(reduction='mean').cuda()
    if H.lpips_loss:
        loss_fn = lambda x, y: sampler.calc_loss(x, y)


    optimizer, scheduler, cur_eval_loss, iterate, starting_epoch = load_opt(H, vae, logprint)

    early_evals = set([1] + [2 ** exp for exp in range(3, 14)])
    stats = []
    iters_since_starting = 0
    H.ema_rate = torch.as_tensor(H.ema_rate)

    n_split = H.n_split
    if n_split == -1:
        n_split = len(data_train)

    sampler = Sampler(H, n_split, preprocess_fn)

    last_updated = torch.zeros((n_split), dtype=torch.int16).cuda()
    times_updated = torch.zeros((n_split), dtype=torch.int8).cuda()
    change_thresholds = torch.empty((n_split)).cuda()
    change_thresholds[:] = H.change_threshold

    prev_selected_dists = sampler.selected_dists.clone()
    latent_lr = H.latent_lr

    epoch = -1
    for outer in range(H.num_epochs):
        for split_ind, split_x_tensor in enumerate(DataLoader(data_train, batch_size=n_split, pin_memory=True)):
            split_x_tensor = split_x_tensor[0].contiguous()
            split_x = TensorDataset(split_x_tensor)
            sampler.init_projection(split_x_tensor)
            viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn, H.num_images_visualize, H.dataset)

            print('doing for {}th data outer batch - {}'.format(split_ind, len(split_x)))


            while True:
                epoch += 1
                last_updated[:] = last_updated + 1

                sampler.selected_dists[:] = sampler.calc_dists_existing(split_x_tensor, vae, dists=sampler.selected_dists)
                dists_in_threshold = sampler.selected_dists < change_thresholds
                updated_enough = last_updated >= H.imle_staleness
                updated_too_much = last_updated >= H.imle_force_resample
                in_threshold = torch.logical_and(dists_in_threshold, updated_enough)
                all_conditions = torch.logical_or(in_threshold, updated_too_much)
                to_update = torch.nonzero(all_conditions, as_tuple=False).squeeze(1)
                print(epoch)

                if epoch == 0:
                    if os.path.isfile(str(H.restore_latent_path)):
                        latents = torch.load(H.restore_latent_path)
                        sampler.selected_latents[:] = latents[:]
                        for x in DataLoader(split_x, batch_size=H.num_images_visualize, pin_memory=True):
                            break
                        batch_slice = slice(0, x[0].size()[0])
                        latents = sampler.selected_latents[batch_slice]
                        with torch.no_grad():
                            snoise = [s[batch_slice] for s in sampler.selected_snoise]
                            generate_for_NN(sampler, x[0], latents, snoise, viz_batch_original.shape, vae,
                                f'{H.save_dir}/NN-samples_{outer}-{split_ind}-vae.png', logprint)
                        print('loaded latest latents')

                    if os.path.isfile(str(H.restore_latent_path)):
                        threshold = torch.load(H.restore_threshold_path)
                        change_thresholds[:] = threshold[:]
                        print('loaded thresholds', torch.mean(change_thresholds))
                    else:
                        to_update = sampler.entire_ds


                change_thresholds[to_update] = sampler.selected_dists[to_update].clone() * (1 - H.change_coef)

                print(f'{to_update.shape[0]} data examples need update')


                sampler.imle_sample(split_x, vae)
                sampler.imle_sample_force(split_x_tensor, vae, to_update)

                last_updated[to_update] = 0
                times_updated[to_update] = times_updated[to_update] + 1
                save_latents_latest(H, split_ind, sampler.selected_latents)
                save_latents_latest(H, split_ind, change_thresholds, name='threshold_latest')

                if to_update.shape[0] >= H.num_images_visualize:
                    latents = sampler.selected_latents[to_update[:H.num_images_visualize]]
                    with torch.no_grad():
                        generate_for_NN(sampler, split_x_tensor[to_update[:H.num_images_visualize]], latents,
                                        [s[to_update[:H.num_images_visualize]] for s in sampler.selected_snoise],
                                        viz_batch_original.shape, vae,
                                        f'{H.save_dir}/NN-samples_{epoch}-vae.png', logprint)

                

                comb_dataset = ZippedDataset(split_x, TensorDataset(sampler.selected_latents))
                data_loader = DataLoader(comb_dataset, batch_size=H.n_batch, pin_memory=True, shuffle=True)
                for cur, indices in data_loader:
                    x = cur[0]
                    latents = cur[1][0]
                    _, target = preprocess_fn(x)
                    cur_snoise = [s[indices] for s in sampler.selected_snoise]
                    stat = training_step_imle(H, target.shape[0], target, latents, cur_snoise, vae, ema_vae, optimizer, sampler.calc_loss)
                    stats.append(stat)
                    scheduler.step()
                    # sampler.selected_dists[indices] = ind_loss.detach()

                    if iterate % H.iters_per_print == 0 or iters_since_starting in early_evals:
                        logprint(model=H.desc, type='train_loss', latest=stat['loss'], lr=scheduler.get_last_lr()[0],
                                 epoch=epoch + starting_epoch, step=iterate,
                                 **accumulate_stats(stats, H.iters_per_print))

                    if iterate % H.iters_per_images == 0 or (
                            iters_since_starting in early_evals and H.dataset != 'ffhq_1024') and H.rank == 0:
                        with torch.no_grad():
                            generate_images_initial(H, sampler, viz_batch_original,
                                                    sampler.selected_latents[0: H.num_images_visualize],
                                                    [s[0: H.num_images_visualize] for s in sampler.selected_snoise],
                                                    viz_batch_original.shape, vae, ema_vae,
                                                    f'{H.save_dir}/samples-{iterate}.png', logprint)

                    iterate += 1
                    iters_since_starting += 1
                    if iterate % H.iters_per_save == 0:
                        logprint(model=H.desc, type='train_loss', epoch=epoch + starting_epoch, step=iterate,
                                 **accumulate_stats(stats, H.iters_per_print))
                        fp = os.path.join(H.save_dir, 'latest')
                        logprint(f'Saving model@ {iterate} to {fp}')
                        save_model(fp, vae, ema_vae, optimizer, H)
                        save_latents_latest(H, split_ind, sampler.selected_latents)
                        save_latents_latest(H, split_ind, change_thresholds, name='threshold_latest')

                    if iterate % H.iters_per_ckpt == 0:
                        save_model(os.path.join(H.save_dir, f'iter-{iterate}'), vae, ema_vae, optimizer, H)
                        save_latents(H, iterate, split_ind, sampler.selected_latents)
                        save_latents(H, iterate, split_ind, change_thresholds, name='threshold')
                        save_snoise(H, iterate, sampler.selected_snoise)


                # last_epoch_loss = sum(losses) / len(split_x)
                # if (last_epoch_loss < change_thresholds[split_ind] or change_thresholds[split_ind] == -1) and last_updated >= H.imle_staleness:  # let's force the update
                #     if change_thresholds[split_ind] != -1:
                #         change_thresholds[split_ind] = last_epoch_loss * (1 - H.change_coef)
                #     with torch.no_grad():
                #         generate_images_initial(H, sampler, viz_batch_original,
                #                                 sampler.selected_latents[0: H.num_images_visualize],
                #                                 viz_batch_original.shape, vae.module, ema_vae,
                #                                 f'{H.save_dir}/samples-{iterate}.png', logprint)

                #     fp = os.path.join(H.save_dir, 'done')
                #     logprint(f'Saving model@ done {iterate} to {fp}')
                #     save_latents_latest(H, split_ind, sampler.selected_latents, iterate)
                #     save_model(fp, vae, ema_vae, optimizer, H)
                #     break


def evaluate(H, ema_vae, data_valid, preprocess_fn):
    stats_valid = []
    valid_sampler = DistributedSampler(data_valid, num_replicas=H.mpi_size, rank=H.rank)
    for x in DataLoader(data_valid, batch_size=H.n_batch, drop_last=True, pin_memory=True, sampler=valid_sampler):
        data_input, target = preprocess_fn(x)
        stats_valid.append(eval_step(data_input, target, ema_vae))
    vals = [a['elbo'] for a in stats_valid]
    finites = np.array(vals)[np.isfinite(vals)]
    stats = dict(n_batches=len(vals), filtered_elbo=np.mean(finites),
                 **{k: np.mean([a[k] for a in stats_valid]) for k in stats_valid[-1]})
    return stats


def generate_for_NN(sampler, orig, initial, snoise, shape, ema_vae, fname, logprint):
    mb = shape[0]
    initial = initial[:mb].cuda()
    nns = sampler.sample(initial, ema_vae, snoise)
    batches = [orig[:mb], nns]
    n_rows = len(batches)
    im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape(
        [n_rows * shape[1], mb * shape[2], 3])

    logprint(f'printing samples to {fname}')
    imageio.imwrite(fname, im)


def generate_images_initial(H, sampler, orig, initial, snoise, shape, vae, ema_vae, fname, logprint):
    mb = shape[0]
    initial = initial[:mb]
    batches = [orig[:mb], sampler.sample(initial, vae, snoise)]

    temp_latent_rnds = torch.randn([mb, H.latent_dim], dtype=torch.float32).cuda()
    for t in range(H.num_temperatures_visualize):
        temp_latent_rnds.normal_()
        tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp]
        batches.append(sampler.sample(temp_latent_rnds, vae, tmp_snoise))

    tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp]
    batches.append(sampler.sample(temp_latent_rnds, vae, tmp_snoise))

    tmp_snoise = [s[:mb].normal_() for s in sampler.snoise_tmp]
    batches.append(sampler.sample(temp_latent_rnds, vae, tmp_snoise))

    tmp_snoise = [s[:mb] for s in sampler.neutral_snoise]
    batches.append(sampler.sample(temp_latent_rnds, vae, tmp_snoise))

    tmp_snoise = [s[:mb] for s in sampler.neutral_snoise]
    temp_latent_rnds.normal_()
    batches.append(sampler.sample(temp_latent_rnds, vae, tmp_snoise))

    n_rows = len(batches)
    im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape(
        [n_rows * shape[1], mb * shape[2], 3])

    logprint(f'printing samples to {fname}')
    imageio.imwrite(fname, im)


def run_test_eval(H, ema_vae, data_test, preprocess_fn, logprint):
    print('evaluating')
    stats = evaluate(H, ema_vae, data_test, preprocess_fn)
    print('test results')
    for k in stats:
        print(k, stats[k])
    logprint(type='test_loss', **stats)


def write_images2(H, ema_vae, fname, logprint):
    mb = 8
    latents = torch.empty([mb, H.latent_dim], dtype=torch.float32).cuda()
    latents.normal_()
    viz_batch_original = sample(latents, ema_vae)
    batches = [viz_batch_original]

    for i in range(15):
        latents.normal_()
        viz_batch_original = sample(latents, ema_vae)
        batches.append(viz_batch_original)

    n_rows = len(batches)
    im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *viz_batch_original.shape[1:])).transpose(
        [0, 2, 1, 3, 4]).reshape([n_rows * viz_batch_original.shape[1], mb * viz_batch_original.shape[2], 3])
    logprint(f'printing samples to {fname}')
    imageio.imwrite(fname, im)


def backtrack(H, sampler, vae, preprocess_fn, data, logprint):
    latents = torch.randn([data.shape[0], H.latent_dim], requires_grad=True, dtype=torch.float32, device='cuda')
    snoise = [torch.randn([data.shape[0], s.shape[1], s.shape[2], s.shape[3]], requires_grad=True, dtype=torch.float32, device='cuda') for s in sampler.snoise_tmp]

    if H.restore_latent_path:
        logprint('restoring latent path')
        latents = torch.tensor(torch.load(f'{H.restore_latent_path}/latent-best.npy'), requires_grad=True, dtype=torch.float32, device='cuda')
        snoise = [torch.tensor(torch.load(f'{H.restore_latent_path}/snoise-best-{s.shape[2]}.npy'), requires_grad=True, dtype=torch.float32, device='cuda') for s in sampler.snoise_tmp]

    latent_optimizer = AdamW([latents], lr=H.latent_lr)
    if H.space == 'w':
        latent_optimizer = AdamW([latents] + snoise, lr=H.latent_lr)
    # latent_optimizer = SGD([latents] + snoise, lr=H.latent_lr)
    dists = torch.empty([data.shape[0]], dtype=torch.float32).cuda()

    sampler.calc_dists_existing(data, vae, dists=dists, latents=latents, snoise=snoise)
    print(f'initial dists: {dists.mean()}')

    best_loss = np.inf
    num_iters = 0

    while num_iters < H.reconstruct_iter_num:
        comb_dataset = ZippedDataset(data, TensorDataset(latents))
        data_loader = DataLoader(comb_dataset, batch_size=H.n_batch)
        for cur, indices in data_loader:
            x = cur
            lat = cur[1][0]
            _, target = preprocess_fn(x)
            cur_snoise = [s[indices] for s in snoise]
            training_step_imle(H, target.shape[0], target, lat, cur_snoise, vae, None, latent_optimizer, sampler.calc_loss)
            latents.grad.zero_()
            [s.grad.zero_() for s in snoise]
        num_iters += len(data)

        logprint(f'iteration: {num_iters}')
        # torch.save(latents.detach(), f'{H.save_dir}/latent-latest.npy')
        # for s in snoise:
        #     torch.save(s.detach(), f'{H.save_dir}/snoise-latest-{s.shape[2]}.npy')

        sampler.calc_dists_existing(data, vae, dists=dists, latents=latents, snoise=snoise)
        cur_mean = dists.mean()
        logprint(f'cur mean: {cur_mean}, best: {best_loss}')
        if cur_mean < best_loss:
            torch.save(latents.detach(), f'{H.save_dir}/latent-best.npy')
            for s in snoise:
                torch.save(s.detach(), f'{H.save_dir}/snoise-best-{s.shape[2]}.npy')
            logprint(f'improved: {cur_mean}')
            best_loss = cur_mean
            for i in range(data.shape[0]):
                samp = sampler.sample(latents[i:i+1], vae, [s[i:i+1] for s in snoise])
                imageio.imwrite(f'{H.save_dir}/{i}.png', samp[0])
                imageio.imwrite(f'{H.save_dir}/{i}-real.png', data[i])

        if num_iters >= H.reconstruct_iter_num:
            break


def reconstruct(H, sampler, vae, preprocess_fn, images, latents, snoise, name, logprint):
    latent_optimizer = AdamW([latents], lr=H.latent_lr)
    generate_for_NN(sampler, images, latents.detach(), snoise, images.shape, vae,
                    f'{H.save_dir}/{name}-initial.png', logprint)
    for i in range(H.latent_epoch):
        for iter in range(H.reconstruct_iter_num):
            _, target = preprocess_fn([images])
            stat = training_step_imle(H, target.shape[0], target, latents, snoise, vae, None, latent_optimizer, sampler.calc_loss)

            latents.grad.zero_()
            if iter % 50 == 0:
                print('loss is: ', stat['loss'])
                generate_for_NN(sampler, images, latents.detach(), snoise, images.shape, vae,
                                f'{H.save_dir}/{name}-{iter}.png', logprint)

                torch.save(latents.detach(), '{}/reconstruct-latest.npy'.format(H.save_dir))

def save_one(orig, shape, fname, logprint):
    mb = 1
    batches = [orig]
    n_rows = len(batches)
    im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *shape[1:])).transpose([0, 2, 1, 3, 4]).reshape(
        [n_rows * shape[1], mb * shape[2], 3])

    logprint(f'printing samples to {fname}')
    imageio.imwrite(fname, im)


def main(H=None):
    # torch.autograd.set_detect_anomaly(True)
    H_cur, logprint = set_up_hyperparams()
    if not H:
        H = H_cur
    H, data_train, data_valid_or_test, preprocess_fn = set_up_data(H)
    vae, ema_vae = load_vaes(H, logprint)
    if H.mode == 'eval':
        with torch.no_grad():
            # Generating
            sampler = Sampler(H, len(data_train), preprocess_fn)
            n_samp = H.n_batch
            temp_latent_rnds = torch.randn([n_samp, H.latent_dim], dtype=torch.float32).cuda()
            for i in range(0, H.num_images_to_generate // n_samp):
                if (i % 100 == 0):
                    print(i * n_samp)
                temp_latent_rnds.normal_()
                tmp_snoise = [s[:n_samp].normal_() for s in sampler.snoise_tmp]
                samp = sampler.sample(temp_latent_rnds, vae, tmp_snoise)
                for j in range(n_samp):
                    imageio.imwrite(f'{H.save_dir}/{i * n_samp + j}.png', samp[j])


        # write_images2(H, ema_vae, f'{H.save_dir}/samples-1.png', logprint)
        # subset_len = len(data_train)
        # if H.subset_len != -1:
        #     subset_len = H.subset_len
        # for data_train in DataLoader(data_train, batch_size=subset_len):
        #     data_train = TensorDataset(data_train[0])
        #     break

    elif H.mode == 'reconstruct':

        n_split = H.n_split
        if n_split == -1:
            n_split = len(data_train)
        ind = 0
        for split_ind, split_x_tensor in enumerate(DataLoader(data_train, batch_size=H.n_split, pin_memory=True)):
            if (ind == 14):
                break
            split_x = TensorDataset(split_x_tensor[0])
            ind += 1
            
        for param in vae.parameters():
            param.requires_grad = False
        viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn,
                                                                H.num_images_visualize, H.dataset)
        if os.path.isfile(str(H.restore_latent_path)):
            latents = torch.tensor(torch.load(H.restore_latent_path), requires_grad=True)
        else:
            latents = torch.randn([viz_batch_original.shape[0], H.latent_dim], requires_grad=True)
        sampler = Sampler(H, n_split, preprocess_fn)
        reconstruct(H, sampler, vae, preprocess_fn, viz_batch_original, latents, 'reconstruct', logprint)

    elif H.mode == 'backtrack':
        for param in vae.parameters():
            param.requires_grad = False
        for split_x in DataLoader(data_train, batch_size=H.subset_len):
            split_x = split_x[0]
            pass
        print(f'split shape is {split_x.shape}')
        sampler = Sampler(H, H.subset_len, preprocess_fn)
        backtrack(H, sampler, vae, preprocess_fn, split_x, logprint)


    elif H.mode == 'train':
        # viz_batch_original, _ = get_sample_for_visualization(data_train, preprocess_fn,
        #                                                      H.num_images_visualize, H.dataset)
        #
        # for ind, split_x_tensor in enumerate(DataLoader(data_train, batch_size=1, pin_memory=True)):
        #     save_one(split_x_tensor[0], viz_batch_original.shape, f'{H.save_dir}/{ind}.png', logprint)
        train_loop_imle(H, data_train, data_valid_or_test, preprocess_fn, vae, ema_vae, logprint)

    elif H.mode == 'ppl':
        sampler = Sampler(H, H.subset_len, preprocess_fn)
        calc_ppl(H, vae, sampler)

    elif H.mode == 'ppl_uniform':
        sampler = Sampler(H, H.subset_len, preprocess_fn)
        calc_ppl_uniform(H, vae, sampler)
    
    elif H.mode == 'interpolate':
        with torch.no_grad():
            for split_x in DataLoader(data_train, batch_size=H.subset_len):
                split_x = split_x[0]
            viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn,
                                                                    H.num_images_visualize, H.dataset)
            sampler = Sampler(H, H.subset_len, preprocess_fn)
            for i in range(H.num_images_to_generate):
                random_interp(H, sampler, (0, 256, 256, 3), vae, f'{H.save_dir}/interp-{i}.png', logprint)
    elif H.mode == 'generate_rnd':
        with torch.no_grad():
            for split_x in DataLoader(data_train, batch_size=H.subset_len):
                split_x = split_x[0]
            viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn,
                                                                    H.num_images_visualize, H.dataset)
            sampler = Sampler(H, H.subset_len, preprocess_fn)
            generate_rnd(H, sampler, (0, 256, 256, 3), vae, f'{H.save_dir}/rnd.png', logprint)

    elif H.mode == 'generate_rnd_nn':
        with torch.no_grad():
            for split_x in DataLoader(data_train, batch_size=len(data_train)):
                split_x = split_x[0]
            viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn,
                                                                    H.num_images_visualize, H.dataset)
            sampler = Sampler(H, H.subset_len, preprocess_fn)
            generate_rnd_nn(H, split_x,  sampler, (0, 256, 256, 3), vae, f'{H.save_dir}', logprint, preprocess_fn)

    elif H.mode == 'generate_sample_nn':
        with torch.no_grad():
            for split_x in DataLoader(data_train, batch_size=len(data_train)):
                split_x = split_x[0]
            viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn,
                                                                    H.num_images_visualize, H.dataset)
            sampler = Sampler(H, H.subset_len, preprocess_fn)
            generate_sample_nn(H, split_x,  sampler, (0, 256, 256, 3), vae, f'{H.save_dir}/rnd2.png', logprint, preprocess_fn)

    elif H.mode == 'backtrack_interpolate':
        with torch.no_grad():
            for split_x in DataLoader(data_train, batch_size=H.subset_len):
                split_x = split_x[0]
            viz_batch_original, _ = get_sample_for_visualization(split_x, preprocess_fn,
                                                                    H.num_images_visualize, H.dataset)
            sampler = Sampler(H, H.subset_len, preprocess_fn)
            latents = torch.tensor(torch.load(f'{H.restore_latent_path}/latent-best.npy'), requires_grad=True, dtype=torch.float32, device='cuda')
            print(latents.shape)
            snoise = [torch.tensor(torch.load(f'{H.restore_latent_path}/snoise-best-{s.shape[2]}.npy'), requires_grad=True, dtype=torch.float32, device='cuda') for s in sampler.snoise_tmp]
            for i in range(latents.shape[0] - 1):
            # for i in range(latents.shape[0]):
            #     samp = sampler.sample(latents[i:i+1], vae, [s[i:i+1] for s in snoise])
            #     imageio.imwrite(f'test/{i}.png', samp[0])
            #     imageio.imwrite(f'test/{i}-real.png', split_x[i])
                lat0 = latents[i:i+1]
                lat1 = latents[i+1:i+2]
                # sn1 = [s[i:i+1] for s in snoise]
                # sn2 = [s[i+1:i+2] for s in snoise]
                sn1 = None
                sn2 = None
                random_interp(H, sampler, (0, 256, 256, 3), vae, f'test/interp-{i}.png', logprint, lat0, lat1, sn1, sn2)


if __name__ == "__main__":
    main()
