import os, sys
import numpy as np
import imageio
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange
import datetime

from run_nerf_helpers import *
from MV_mae_encoder import MaskedViTEncoder
from load_blender import load_mae_data
from einops import rearrange, repeat
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
DEBUG = False

def distance(x1, x2):
    diff = torch.abs(x1 - x2)
    return torch.pow(diff, 2).sum(dim=1)


def batchify(fn, chunk):
    """Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn
    def ret(inputs, latent=None):
        # return torch.cat([fn(inputs[i:i+chunk], latent=latent) for i in range(0, inputs.shape[0], chunk)], 0)
        return torch.cat([fn(inputs[i:i+chunk], latent=latent[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret


def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64, latent=None):
    """Prepares inputs and applies network 'fn'.
    """
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)
    latent_flat = latent.reshape((-1, latent.shape[-1])) # [B*N_samples, dim]
    if viewdirs is not None:
        input_dirs = viewdirs[:,None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = torch.cat([embedded, embedded_dirs], -1)
    outputs_flat = batchify(fn, netchunk)(embedded, latent=latent_flat)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs


def batchify_rays(rays_flat, chunk=1024*32, latent=None, args=None, test_mode=None, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i+chunk], latent=latent, args=args, test_mode=test_mode, **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret


def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None, latent=None, args=None, test_mode=None,
                  **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w, device=device)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()

    sh = rays_d.shape # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()

    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, latent=latent, args=args, test_mode=test_mode, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'depth_map', 'acc_map']
        
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]


def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0, latent=None, args=None, test_mode=None):

    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor

    rgbs = []
    disps = []

    t = time.time()
    for i, c2w in enumerate(tqdm(render_poses)):
        t = time.time()
        rgb, disp, depth, acc, extras = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], latent=latent, args=args, test_mode=test_mode, **render_kwargs)

        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())

        """
        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
            print(p)
        """

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)

    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)
    
    return rgbs, disps

def create_nerf(args, basedir, expname):
    """Instantiate NeRF's MLP model.
    """
    embed_fn, input_ch = get_embedder(args.multires, device=device, i=args.i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, device=device, i=args.i_embed)
    output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]
    
    latent_embed = MaskedViTEncoder(img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.embed_dim, depth=args.vit_depth, 
                                    num_heads=args.vit_num_heads, num_view=args.num_view, device=device, time_interval=args.time_interval,
                                    decoder_depth = args.decoder_depth, decoder_num_heads = args.decoder_num_heads,
                                    decoder_output_dim = args.decoder_output_dim, batch_size = args.batch_size,
                                    vit_encoder_mlp_dim = args.vit_encoder_mlp_dim, vit_decoder_mlp_dim = args.vit_decoder_mlp_dim,
                                    ).to(device)

    latent_dim = latent_embed.decoder_output_dim

    grad_vars = list(latent_embed.parameters())



    model = NeRF(D=args.netdepth, W=args.netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, 
                 latent_dim=latent_dim,
                 ).to(device)
    grad_vars += list(model.parameters())

    model_fine = None
    if args.N_importance > 0:
        model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
                          input_ch=input_ch, output_ch=output_ch, skips=skips,
                          input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, 
                          latent_dim=latent_dim,
                          ).to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn, latent : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=args.netchunk,
                                                                latent = latent)

    # Create optimizer
    optimizer = torch.optim.AdamW(params=grad_vars, lr=args.lrate, betas=(0.9, 0.95))

    start = 0

    # Load checkpoints
    if args.ft_path is not None and args.ft_path!='None':
        ckpts = [args.ft_path]
        print('Reloading from', args.ft_path[:-4]+'_encoder.tar')
        latent_embed.load_state_dict(torch.load(args.ft_path[:-4]+'_encoder.tar'))
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f and 'encoder' not in f]

        ckpts_latent = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f and 'encoder' in f]
        if len(ckpts_latent) > 0 and not args.no_reload:
            ckpt_latent_path = ckpts_latent[-1]
            print('Reloading from', ckpt_latent_path)
            latent_embed.load_state_dict(torch.load(ckpt_latent_path))

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        # Load model
        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_importance' : args.N_importance,
        'network_fine' : model_fine,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, latent_embed


def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False, test_mode=None):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: [num_rays, num_samples along ray, 4]. Prediction from model.
        z_vals: [num_rays, num_samples along ray]. Integration time.
        rays_d: [num_rays, 3]. Direction of each ray.
    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
        disp_map: [num_rays]. Disparity map. Inverse of depth map.
        acc_map: [num_rays]. Sum of weights along each ray.
        weights: [num_rays, num_samples]. Weights assigned to each sampled color.
        depth_map: [num_rays]. Estimated distance to object.
    """
    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)

    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.tensor([1e10], device=device).expand(dists[...,:1].shape)], -1)  # [N_rays, N_samples]

    dists = dists * torch.norm(rays_d[...,None,:], dim=-1)

    rgb = torch.sigmoid(raw[...,:3])  # [N_rays, N_samples, 3]
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[...,3].shape, device=device) * raw_noise_std

        # Overwrite randomly sampled data if pytest
        if pytest:
            np.random.seed(0)
            noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
            noise = torch.Tensor(noise, device=device)

    alpha = raw2alpha(raw[...,3] + noise, dists)  # [N_rays, N_samples]
    # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=device), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]

    depth_map = torch.sum(weights * z_vals, -1)
    disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, -1)

    if white_bkgd:
        rgb_map = rgb_map + (1.-acc_map[...,None])
        
    return rgb_map, disp_map, depth_map, acc_map, weights


def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False,
                latent=None,
                args=None,
                test_mode=None,
                ):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      retraw: bool. If True, include model's raw, unprocessed predictions.
      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      white_bkgd: bool. If True, assume a white background.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disp_map. Output for coarse model.
      acc0: See acc_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None

    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1] # [-1,1]

    t_vals = torch.linspace(0., 1., steps=N_samples, device=device)
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape, device=device)

        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand, device=device)

        z_vals = lower + (upper - lower) * t_rand

    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]

    if latent.shape[0]==rays_o.shape[0]: 
        tiled_latent = torch.tile(latent[:, None, :], (1, N_samples, 1)) # [B, dim] -> [B, N_samples, dim]
        
    elif latent.shape[0]==1: # when test render (not batch rays input)
        # latent : [B(1), dim] or [B(1), H',W', dim]
        if len(latent.shape)==2 : # [B(1), dim]
            tiled_latent = torch.tile(latent[:, None, :], (rays_o.shape[0], N_samples, 1)) # [B, N_samples, dim]
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError
    
    # raw = run_network(pts)
    raw = network_query_fn(pts, viewdirs, network_fn, tiled_latent)
    rgb_map, disp_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest, test_mode=test_mode)

    if N_importance > 0:

        rgb_map_0, disp_map_0, depth_map_0, acc_map_0 = rgb_map, disp_map, depth_map, acc_map

        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, device=device, det=(perturb==0.), pytest=pytest)
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]

        run_fn = network_fn if network_fine is None else network_fine
        # raw = run_network(pts, fn=run_fn)
        if latent.shape[0]==rays_o.shape[0]: 
            tiled_latent = torch.tile(latent[:, None, :], (1, N_samples+N_importance, 1)) # [B, dim] -> [B, N_samples+N_importance, dim]
            
        elif latent.shape[0]==1: # when test render (not batch rays input)
            # latent : [B(1), dim] or [B(1), H',W', dim]
            if len(latent.shape)==2 : # [B(1), dim]
                tiled_latent = torch.tile(latent[:, None, :], (rays_o.shape[0], N_samples+N_importance, 1)) # [B, N_samples+N_importance, dim]
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        raw = network_query_fn(pts, viewdirs, run_fn, tiled_latent)
        rgb_map, disp_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest, 
                                                                                   test_mode=test_mode)

    ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'depth_map' : depth_map, 'acc_map' : acc_map}

    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['depth0'] = depth_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]

    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret


def config_parser():

    import configargparse
    parser = configargparse.ArgumentParser()
    parser.add_argument('--config', is_config_file=True, 
                        help='config file path')
    parser.add_argument("--expname", type=str, 
                        help='experiment name')
    parser.add_argument("--basedir", type=str, default='./logs/', 
                        help='where to store ckpts and logs')
    parser.add_argument("--datadir", type=str, default='./data/llff/fern', 
                        help='input data directory')

    # training options
    parser.add_argument("--time_interval", type=int, default=4,
                        help='time interval of the multi view input images')
    parser.add_argument("--img_size", type=int, default=128,
                        help='input image size')
    parser.add_argument("--patch_size", type=int, default=16,
                        help='size of each patch')
    parser.add_argument("--embed_dim", type=int, default=256,
                        help='embedding dimension of each patch')
    parser.add_argument("--mask_ratio", type=float, default=0.9, 
                        help='mask_ratio')
    parser.add_argument("--vit_depth", type=int, default=8,
                        help='vit argument')
    parser.add_argument("--vit_num_heads", type=int, default=4,
                        help='vit argument')
    parser.add_argument("--num_view", type=int, default=3,
                        help='number of view')
    parser.add_argument("--batch_size",   type=int, default=16)
    parser.add_argument("--decoder_depth",   type=int, default=2)
    parser.add_argument("--decoder_num_heads",   type=int, default=2)
    parser.add_argument("--decoder_output_dim",   type=int, default=256)
    parser.add_argument("--log_freq",   type=int, default=10)
    parser.add_argument("--vit_encoder_mlp_dim", type=int, default=2048)
    parser.add_argument("--vit_decoder_mlp_dim", type=int, default=2048)
    parser.add_argument("--enc_contrastive_margin", type=float, default=2.0)
    parser.add_argument("--netdepth", type=int, default=8, 
                        help='layers in network')
    parser.add_argument("--netwidth", type=int, default=256, 
                        help='channels per layer')
    parser.add_argument("--netdepth_fine", type=int, default=8, 
                        help='layers in fine network')
    parser.add_argument("--netwidth_fine", type=int, default=256, 
                        help='channels per layer in fine network')
    parser.add_argument("--N_rand", type=int, default=32*32*4, 
                        help='batch size (number of random rays per gradient step)')
    parser.add_argument("--lrate", type=float, default=5e-4, 
                        help='learning rate')
    parser.add_argument("--lrate_decay", type=int, default=250, 
                        help='exponential learning rate decay (in 1000 steps)')
    parser.add_argument("--chunk", type=int, default=1024*32, 
                        help='number of rays processed in parallel, decrease if running out of memory')
    parser.add_argument("--netchunk", type=int, default=1024*64, 
                        help='number of pts sent through network in parallel, decrease if running out of memory')
    parser.add_argument("--no_reload", action='store_true', 
                        help='do not reload weights from saved ckpt')
    parser.add_argument("--ft_path", type=str, default=None, 
                        help='specific weights npy file to reload for coarse network')

    # rendering options
    parser.add_argument("--N_samples", type=int, default=64, 
                        help='number of coarse samples per ray')
    parser.add_argument("--N_importance", type=int, default=0,
                        help='number of additional fine samples per ray')
    parser.add_argument("--perturb", type=float, default=1.,
                        help='set to 0. for no jitter, 1. for jitter')
    parser.add_argument("--use_viewdirs", action='store_true', 
                        help='use full 5D input instead of 3D')
    parser.add_argument("--i_embed", type=int, default=0, 
                        help='set 0 for default positional encoding, -1 for none')
    parser.add_argument("--multires", type=int, default=10, 
                        help='log2 of max freq for positional encoding (3D location)')
    parser.add_argument("--multires_views", type=int, default=4, 
                        help='log2 of max freq for positional encoding (2D direction)')
    parser.add_argument("--raw_noise_std", type=float, default=0., 
                        help='std dev of noise added to regularize sigma_a output, 1e0 recommended')

    parser.add_argument("--render_test", action='store_true', 
                        help='render the test set instead of render_poses path')
    parser.add_argument("--render_factor", type=int, default=0, 
                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')

    # training options
    parser.add_argument("--precrop_iters", type=int, default=0,
                        help='number of steps to train on central crops')
    parser.add_argument("--precrop_frac", type=float,
                        default=.5, help='fraction of img taken for central crops') 

    # dataset options
    parser.add_argument("--dataset_type", type=str, default='llff', 
                        help='options: llff / blender / deepvoxels')
    parser.add_argument("--testskip", type=int, default=8, 
                        help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')

    ## deepvoxels flags
    parser.add_argument("--shape", type=str, default='greek', 
                        help='options : armchair / cube / greek / vase')

    ## blender flags
    parser.add_argument("--white_bkgd", action='store_true', 
                        help='set to render synthetic data on a white bkgd (always use for dvoxels)')
    parser.add_argument("--half_res", action='store_true', 
                        help='load blender synthetic data at 400x400 instead of 800x800')

    ## llff flags
    parser.add_argument("--factor", type=int, default=8, 
                        help='downsample factor for LLFF images')
    parser.add_argument("--no_ndc", action='store_true', 
                        help='do not use normalized device coordinates (set for non-forward facing scenes)')
    parser.add_argument("--lindisp", action='store_true', 
                        help='sampling linearly in disparity rather than depth')
    parser.add_argument("--spherify", action='store_true', 
                        help='set for spherical 360 scenes')
    parser.add_argument("--llffhold", type=int, default=8, 
                        help='will take every 1/N images as LLFF test set, paper uses 8')

    # logging/saving options
    parser.add_argument("--i_print",   type=int, default=100, 
                        help='frequency of console printout and metric loggin')
    parser.add_argument("--i_img",     type=int, default=500, 
                        help='frequency of tensorboard image logging')
    parser.add_argument("--i_weights", type=int, default=10000, 
                        help='frequency of weight ckpt saving')
    parser.add_argument("--i_testset", type=int, default=50000, 
                        help='frequency of testset saving')
    parser.add_argument("--i_video",   type=int, default=50000, 
                        help='frequency of render_poses video saving')
    parser.add_argument("--episode_num",   type=int, default=100, 
                        help='epi num for mae data')
    parser.add_argument("--N_iters",   type=int, default=300001, 
                        help='training iter')
    return parser


def train():

    parser = config_parser()
    args = parser.parse_args()
    if args.dataset_type in ['hammer']:
        args.datadir = open('./configs/hammer_dataset_path.txt', 'r').read().strip()
    elif args.dataset_type in ['drawer']:
        args.datadir = open('./configs/drawer_dataset_path.txt', 'r').read().strip()
    elif args.dataset_type in ['window']:
        args.datadir = open('./configs/window_dataset_path.txt', 'r').read().strip()
    elif args.dataset_type in ['push']:
        args.datadir = open('./configs/push_dataset_path.txt', 'r').read().strip()
    elif args.dataset_type in ['peg']:
        args.datadir = open('./configs/peg_dataset_path.txt', 'r').read().strip()
    elif args.dataset_type in ['stick']:
        args.datadir = open('./configs/stick_dataset_path.txt', 'r').read().strip()
    else:
        raise NotImplementedError

    # Load data
    K = None
    if args.dataset_type in ['hammer', 'drawer', 'window', 'push', 'peg', 'stick']:
        images, poses, render_poses, hwf, i_split, semantics, depths = load_mae_data(args.datadir, args.half_res, args.testskip, args.episode_num, args.num_view, dataset_type=args.dataset_type)
        print('Loaded hammer data', images.shape, render_poses.shape, hwf, args.datadir)
        i_train = i_split

        i_test = [0]

        if args.dataset_type == 'hammer':
            near = 0.02258
            far = 3.
        elif args.dataset_type == 'push':
            near = 0.02258
            far = 3.
        elif args.dataset_type == 'window':
            near = 0.02258
            far = 3.
        elif args.dataset_type == 'stick':
            near = 0.02258
            far = 3.
        elif args.dataset_type == 'peg':
            near = 0.02258
            far = 3.
        elif args.dataset_type == 'drawer':
            near = 0.02343
            far = 3.
        else:
            raise NotImplementedError

        if args.white_bkgd and images.shape[-1] == 4:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]
            
        '''
        original images, poses, render_poses shape
        V, NumEpi*Length, H, W, C = images.shape
        V, NumEpi*Length, 4, 4 = poses.shape
        num_render_poses(40), 4, 4  = render_poses.shape
        V, NumEpi*Length, H, W = depths.shape
        '''
        V, NL, H, W, C = images.shape
        
        print(f'episode num : {args.episode_num}, time interval : {args.time_interval}, mask ratio : {args.mask_ratio}, lrdecay : {args.lrate_decay}')
        print("Currently, assume all episode trajectory length is same as 120!")
        episode_length = int(NL/args.episode_num)
        images = images.reshape(V, args.episode_num, episode_length, H, W, C) # [V, episode_num, length, H, W, C],
        images = torch.permute(images, (1,2,3,0,4,5)) # [episode_num, length, H, V, W, C]
        semantics = semantics.reshape(V, args.episode_num, episode_length, H, W) # [V, NL, H, W] -> [V, episode_num, length, H, W]
        semantics = torch.permute(semantics, (1,2,3,0,4)) # [episode_num, length, H, V, W]
    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if K is None:
        K = np.array([
            [focal, 0, 0.5*W],
            [0, focal, 0.5*H],
            [0, 0, 1]
        ])

    if args.render_test:
        render_poses = np.array(poses[:, i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    
    timestamp = datetime.datetime.now().strftime("%Y.%m.%d/%H%M%S")
    expname = args.dataset_type+'/'+timestamp +'_'+ args.expname

    logdir = os.path.join(basedir, expname)
    tb_writer = SummaryWriter(logdir)

    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, latent_embed = create_nerf(args, basedir, expname)
    global_step = start

    bds_dict = {
        'near' : near,
        'far' : far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    time_interval = args.time_interval    
    
    print(f'N rand : {N_rand} images shape : {images.shape} pose shape : {poses.shape}') # semantics shape : {semantics.shape}') # pose : [V, NL, 4,4]
    
    # For random ray batching
    print('get rays')
    rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,0,:3,:4].numpy()], 0) # [V, ro+rd(2), H, W, 3]

    i_batch = 0
    
    N_iters = args.N_iters 
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
            
    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()
        logging_dict = {}

        batch_images, batch_semantics, batch_negatives = \
            get_batch_images(images, args.batch_size, args.time_interval, 
                                        semantics=semantics, device=device)
        B, T, H, V, W, C = batch_images.shape
        batch_images_for_ViT, reshaped_batch_images, batch_semantics_for_ViT, reshaped_batch_semantics \
            = get_batch_images_for_mae_encoder(batch_images, batch_semantics)

        # tile along the batch
        # tiled_rays = np.tile(rays[None, :], (args.batch_size, 1,1,1,1,1)) # [V, ro+rd(2), H, W, 3] -> [B, V, ro+rd(2), H, W, 3]
        tiled_rays = torch.tile(torch.from_numpy(rays[None, :]), (args.batch_size, 1, 1, 1, 1, 1)).to(device)
        tiled_rays = tiled_rays.reshape(-1, 2, H, W, 3) # [BV, ro+rd(2), H, W, 3]

        m = args.mask_ratio

        batch_images_for_ViT = torch.split(batch_images_for_ViT, W, dim = 3) # [B, T, H, W, C] * V
        num_ref_view = V // 3
        assert num_ref_view == (args.num_view // 3) # if you get error in this line, please modify self.decoder_img_tokens in MV_mae_encoder.py file
        view_index = np.random.choice(V, size=1+num_ref_view, replace=False)
        ref_view_index = view_index[1:].tolist()
        view_index = int(view_index[0])
        ref_batch_images_for_ViT = torch.cat([batch_images_for_ViT[index] for index in ref_view_index], dim=3) # [B, T, H, ref_V*W, C]
        
        batch_images_for_ViT = batch_images_for_ViT[view_index] # [B, T, H, W, C]
           
        if batch_negatives is not None:
            negative_view_index = view_index
            negative_ref_view_index = ref_view_index
            negative_ref_batch_images_for_ViT = torch.cat([batch_negatives[:,:,:,index] for index in negative_ref_view_index], dim=3) # [B, T, H, ref_V*W, C]
            negative_batch_images_for_ViT = batch_negatives[:,:,:,negative_view_index].to(device)   # [B, T, H, W, C]
                
        # Forward ViT encoder
        # input view
        latent, mask, ids_restore = latent_embed.SinCro_image_encoder(batch_images_for_ViT, m, time_interval, is_ref=False)
        # reference views; masking ratio = 0
        assert ref_batch_images_for_ViT.shape[3] != args.img_size
        ref_batch_images_for_ViT = torch.split(ref_batch_images_for_ViT, args.img_size, dim=3)  # [B, T, H, W, C] * V
        ref_time_interval = args.time_interval
        
        ref_batch_images_for_ViT = torch.cat(ref_batch_images_for_ViT, dim=0) # [VB, T, H, W, C]
        ref, _, _ = latent_embed.SinCro_image_encoder(ref_batch_images_for_ViT, 0, ref_time_interval, is_ref=True)
        ref = rearrange(ref[:,1:,:], 'b (ref_T hw) d -> b ref_T hw d', ref_T=ref_time_interval)[:,-1]   # [V*B, H'W', embed]
        ref = rearrange(ref, '(v b) hw d -> b (v hw) d', b=latent.shape[0]) # [B, VH'W', embed]
        
        # Forward ViT decoder
        # ref shape: [B, (ref_T=1)VH'W', embed]
        latent, mask, ids_restore = latent_embed.SinCro_state_encoder(latent, ref, mask, ids_restore)
        
        anchor_latent = latent_embed.input_feature   # [B, feat_dim]
        positive_latent = latent_embed.ref_feature  # [B, feat_dim]
            
        with torch.no_grad():
            negative_latent, negative_mask, negative_ids_restore = latent_embed.SinCro_image_encoder(negative_batch_images_for_ViT,
                                                                                                m, time_interval, is_ref=False)
            # reference views; masking ratio = 0
            assert negative_ref_batch_images_for_ViT.shape[3] != args.img_size
            negative_ref_batch_images_for_ViT = torch.split(negative_ref_batch_images_for_ViT, args.img_size, dim=3)  # [B, T, H, W, C] * V
            negative_ref_batch_images_for_ViT = torch.cat(negative_ref_batch_images_for_ViT, dim=0)   # [VB, T, H, W, C]
            negative_ref_T = negative_ref_batch_images_for_ViT.shape[1]
            negative_ref, _, _ = latent_embed.SinCro_image_encoder(negative_ref_batch_images_for_ViT, 0, ref_time_interval, is_ref=True)
            
            # Forward ViT decoder
            # [VB, ref_TH'W', embed] when B, T=1, V=(1/3)*total_V  -->  [B, ref_TVH'W', embed]
            negative_ref = rearrange(negative_ref[:,1:,:], 'b (ref_T hw) d -> b ref_T hw d', ref_T=negative_ref_T)[:,-1]   # [V_ref*B, H'W', embed]
            negative_ref = rearrange(negative_ref, '(v b) hw d -> b (v hw) d', b=negative_latent.shape[0]) # [B, VH'W', embed]
            negative_latent, _, _ = latent_embed.SinCro_state_encoder(negative_latent, negative_ref, negative_mask, negative_ids_restore)
            negative_latent = latent_embed.input_feature    # [B, feat_dim]
        
        latent = latent.reshape((B, T, 1, -1))  # [B, T, V=1, feat_dim]
        latent = repeat(latent, 'b t v d -> b t (v mv) d', mv=V)    # [B, T, V, feat_dim]
        latent = latent.permute(1,0,2,3) # [T,B,V,dim]
        latent = latent.reshape(T, B*V, -1) # [T,BV,dim]
        assert latent.shape[-1]==render_kwargs_train['network_fn'].latent_dim
        
        loss = 0 

        # Forward NeRF decoder
        for t in range(args.time_interval):
            if t < (args.time_interval-1):
                continue
            images_at_t = reshaped_batch_images[t][:, None] # [BV, 1, H, W, C]
            rays_rgb = torch.cat([tiled_rays, images_at_t], 1)
            remain_view_index = np.delete(np.arange(V), ref_view_index)
            rays_rgb = rays_rgb.reshape(B,V,3,H,W,C)[:,remain_view_index] # [B, V-ref_V, ro+rd+rgb, H, W, 3]
            rays_rgb = rays_rgb.reshape(-1,3,H,W,C) # [B*(V-ref_V), ro+rd+rgb, H, W, 3]
            
            if reshaped_batch_semantics is not None:
                semantics_at_t = reshaped_batch_semantics[t] # [BV, H, W]
                semantics_at_t = semantics_at_t.reshape(B,V,H,W)[:,remain_view_index]   # [B, V-ref_V, H, W]
                semantics_at_t = semantics_at_t.reshape(-1,H,W) # [B*(V-ref_V), H, W]

            if i < args.precrop_iters:
                dH = int(H//2 * args.precrop_frac)
                dW = int(W//2 * args.precrop_frac)
                if i==start:
                    print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
                rays_rgb = rays_rgb[:, :, H//2-dH : H//2+dH, W//2-dW : W//2+dW] # [BV, ro+rd+rgb, reduced_H, reduced_W, 3]
                reduced_H = rays_rgb.shape[2]
                reduced_W = rays_rgb.shape[3]
                if reshaped_batch_semantics is not None:
                    semantics_at_t = semantics_at_t[:, H//2-dH : H//2+dH, W//2-dW : W//2+dW] # [BV, reduced_H, reduced_W]

            rays_rgb = torch.permute(rays_rgb, [0,2,3,1,4]) # [BV, H, W, ro+rd+rgb, 3]
            rays_rgb = rays_rgb.reshape(-1,3,3).float() # [BV*H*W, ro+rd+rgb, 3]
            
            if reshaped_batch_semantics is not None:
                semantics_at_t = semantics_at_t.reshape(-1) # [BV*H*W]

            if args.dataset_type in ['stick']:
                try:
                    object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0], N_rand//2, replace=False)
                except:
                    object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0]
            elif args.dataset_type in ['hammer']:
                try:
                    object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0], N_rand//3, replace=False)
                    object_rays_indices = np.concatenate((object_rays_indices, 
                                        np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==5)[:,0], N_rand//6, replace=False)))
                except:
                    try:
                        object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0], N_rand//3, replace=False)
                        object_rays_indices = np.concatenate((object_rays_indices,
                                                            np.argwhere(semantics_at_t.cpu().numpy()==5)[:,0]))
                    except:
                        try:
                            object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0]
                            object_rays_indices = np.concatenate((object_rays_indices, 
                                            np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==5)[:,0], N_rand//6, replace=False))) 
                        except:
                            object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0]
                            object_rays_indices = np.concatenate((object_rays_indices,
                                                                np.argwhere(semantics_at_t.cpu().numpy()==5)[:,0]))
            elif args.dataset_type in ['push']:
                try:
                    object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0], N_rand//2, replace=False)
                except:
                    object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0]
            elif args.dataset_type in ['window', 'drawer']:
                try:
                    object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0], N_rand//3, replace=False)
                    object_rays_indices = np.concatenate((object_rays_indices, 
                                        np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0], N_rand//6, replace=False)))
                except:
                    try:
                        object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0], N_rand//3, replace=False)
                        object_rays_indices = np.concatenate((object_rays_indices,
                                                            np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0]))
                    except:
                        try:
                            object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0]
                            object_rays_indices = np.concatenate((object_rays_indices, 
                                            np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0], N_rand//6, replace=False)))
                        except:
                                object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0]
                                object_rays_indices = np.concatenate((object_rays_indices,
                                                            np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0]))
            elif args.dataset_type in ['peg']:
                try:
                    object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0], N_rand//6, replace=False)
                    object_rays_indices = np.concatenate((object_rays_indices, 
                                        np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0], N_rand//3, replace=False)))
                except:
                    try:
                        object_rays_indices = np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0], N_rand//6, replace=False)
                        object_rays_indices = np.concatenate((object_rays_indices,
                                                            np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0]))
                    except:
                        try:
                            object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0]
                            object_rays_indices = np.concatenate((object_rays_indices, 
                                            np.random.choice(np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0], N_rand//3, replace=False)))                          
                        except:
                            object_rays_indices = np.argwhere(semantics_at_t.cpu().numpy()==3)[:,0]
                            object_rays_indices = np.concatenate((object_rays_indices,
                                                            np.argwhere(semantics_at_t.cpu().numpy()==4)[:,0]))
            else:
                raise NotImplementedError
            random_shuffle_indices = np.random.randint(rays_rgb.shape[0], size=N_rand-len(object_rays_indices))
            random_shuffle_indices = np.concatenate((object_rays_indices, random_shuffle_indices))
            rays_rgb = rays_rgb[random_shuffle_indices]
            batch = rays_rgb
                
            batch = torch.transpose(batch, 0, 1) # [ro+rd+rgb, B', 3]
            batch_rays, target_s = batch[:2], batch[2] # [ro+rd, B', dim], [rgb, B', C]

            if i < args.precrop_iters:
                tile_H, tile_W = reduced_H, reduced_W
            else:
                tile_H, tile_W = H, W
            latent_at_t = torch.tile(latent[t][:, None, :], (1, tile_H*tile_W, 1)) # [BV,dim] -> [BV, H*W,dim]
            latent_at_t = latent_at_t.reshape(B, V, tile_H*tile_W, 
                                                render_kwargs_train['network_fn'].latent_dim)[:, remain_view_index]   # [BV, H*W,dim] -> [B, V, H*W, dim] -> [B, V', H*W, dim]
            latent_at_t = latent_at_t.reshape(-1, render_kwargs_train['network_fn'].latent_dim)[random_shuffle_indices]  # [B, V', H*W, dim] -> [BV'HW, dim] -> [B' (from BV'HW), dim]

            #####  Core optimization loop  #####
            rgb, disp, depth, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                                    verbose=i < 10, retraw=True, latent=latent_at_t, args=args,
                                                    **render_kwargs_train)     
                            
            img_loss = img2mse(rgb, target_s)
            
            trans = extras['raw'][...,-1]
        
            psnr = mse2psnr(img_loss, device)
            img_loss0 = 0 
            if 'rgb0' in extras:
                img_loss0 = img2mse(extras['rgb0'], target_s)
                psnr0 = mse2psnr(img_loss0, device)
            
            d_positive = distance(anchor_latent, positive_latent)
            d_negative = distance(anchor_latent, negative_latent)
            contrastive_loss = torch.clamp(args.enc_contrastive_margin + d_positive - d_negative, min=0.0).mean()
            loss += 0.0004*contrastive_loss
            logging_dict.update({'contrastive_loss': contrastive_loss.item()})
                
            loss += img_loss + img_loss0
        
        loss = loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        logging_dict.update({'loss' : loss.item(),
                            'psnr' : psnr.item(),
                            'img_loss' : img_loss.item(),
                            })
        if 'rgb0' in extras:
            logging_dict.update({'psnr0' : psnr0.item(),
                                'img_loss0' : img_loss0.item(),
                                })
                
        ############################################################################################################
        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time()-time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####

        # Rest is logging
        if i%args.i_weights==0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save({
                'global_step': global_step,
                'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
                'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, path)

            path = os.path.join(basedir, expname, '{:06d}_encoder.tar'.format(i))
            torch.save(latent_embed.state_dict(), path)
            print('Saved checkpoints at', path)

        if i%args.i_testset==0 and i > 0:
            with torch.no_grad():
                testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
                os.makedirs(testsavedir, exist_ok=True)
                time_index_list = [0, 40, 80, 110]
                test_time_interval = args.time_interval
                episode_index = 0
                num_ref_view = V // 3
                view_index = np.random.choice(V, size=1+num_ref_view, replace=False)
                ref_view_index = view_index[1:].tolist()
                view_index = int(view_index[0])
                for _, time_index in enumerate(time_index_list):
                    test_images, test_semantics, test_negatives = \
                        get_batch_images(images, 1, test_time_interval,
                                                    episode_index = episode_index, time_index = time_index, semantics=semantics, device=device)
                    B, T, H, V, W, C = test_images.shape
                    test_images_for_ViT, reshaped_test_images, test_semantics_for_ViT, reshaped_test_semantics \
                        = get_batch_images_for_mae_encoder(test_images, test_semantics)
                    test_images_for_ViT = test_images_for_ViT.float().to(device)
                    
                    m = 0.0

                    test_images_for_ViT = torch.split(test_images_for_ViT, W, dim = 3) # [B, T, H, W, C] * V
                        
                    test_recon_view_index = np.arange(V)
                    ref_test_images_for_ViT = torch.cat([test_images_for_ViT[index] for index in ref_view_index], dim=3) # [B, T, H, ref_V*W, C]
                    test_images_for_ViT = test_images_for_ViT[view_index] # [B, T, H, W, C]
                    # Forward ViT encoder
                    # input view
                    latent, mask, ids_restore = latent_embed.SinCro_image_encoder(test_images_for_ViT, m, test_time_interval, is_ref=False)
                    # reference views; masking ratio = 0
                    assert ref_test_images_for_ViT.shape[3] != args.img_size
                    ref_test_images_for_ViT = torch.split(ref_test_images_for_ViT, args.img_size, dim=3)  # [B, T, H, W, C] * V
                    
                    ref_time_interval = args.time_interval
                    
                    ref_test_images_for_ViT = torch.cat(ref_test_images_for_ViT, dim=0) # [VB, T, H, W, C]
                    ref, _, _ = latent_embed.SinCro_image_encoder(ref_test_images_for_ViT, 0, ref_time_interval, is_ref=True)
                    ref = rearrange(ref[:,1:,:], 'b (ref_T hw) d -> b ref_T hw d', ref_T=ref_time_interval)[:,-1]   # [V*B, H'W', embed]
                    ref = rearrange(ref, '(v b) hw d -> b (v hw) d', b=latent.shape[0]) # [B, VH'W', embed]
                                                                                        
                    # Forward ViT decoder
                    # [B, ref_TVH'W', embed]
                    test_latent, mask, ids_restore = latent_embed.SinCro_state_encoder(latent, ref, mask, ids_restore)
                    
                    reshaped_test_latent = test_latent.reshape((B, T, 1, -1))    #[B(1), T, V=1, dim]
                    reshaped_test_latent = repeat(reshaped_test_latent, 'b t v d -> b t (v mv) d', mv=V)     #[B(1), T, V, dim]
                    
                    # all poses used in training [V, 4, 4]
                    test_poses = poses[:, 0].clone().detach().to(device)
                    
                    # time, viewpoint-wise latent
                    for t in range(T):
                        if t < (T-1):
                            continue                                        
                        for v in test_recon_view_index:
                            test_rgb, test_disp = render_path(test_poses[v:v+1], hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, latent=reshaped_test_latent[:,t,v], args=args, test_mode=True)
                            # [1,H,W,C]
                            test_rgb8 = to8b(test_rgb[0])
                            test_filename = os.path.join(testsavedir, f'recon_v{v}_by_latent_v{view_index}_at_seq{t}_in_interval{test_time_interval}_epi{episode_index}_time{time_index}.png')
                            imageio.imwrite(test_filename, test_rgb8) 
                        
                print('Saved test set')

        if i%args.i_print==0:
            tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
        
        if i % args.log_freq==0:
            for k,v in logging_dict.items():
                tb_writer.add_scalar(k,v, i)

        global_step += 1


def get_batch_images(images, batch_size, time_interval, episode_index = None, 
                               time_index = None, semantics = None, device=None):
    episode_num, length, H, V, W, C = images.shape
     
    if episode_index is not None:
        # This is only for test, i.e., do not sample batch_negatives
        assert time_index is not None
        if episode_num == 1 and episode_index > 1:
            # this is a visualize code case
            episode_index = 0
        batch_images = images[episode_index, time_index:time_index+time_interval][None, :].to(device) # [1,T,H,V,W,C]
        if semantics is not None:
            batch_semantics = semantics[episode_index, time_index:time_index+time_interval][None, :].to(device) # [1,T,H,V,W]
        else:
            batch_semantics = None
        batch_negatives = None
    else:
        # sample batch for training
        episode_indices = np.random.randint(episode_num, size=batch_size)
        time_indices = np.random.randint(length-time_interval+1, size=batch_size) 

        # [B, T, H, V, W, C]
        batch_images = torch.stack([images[epi_index][time_index:time_index+time_interval].to(device) \
                                    for epi_index, time_index in zip(episode_indices, time_indices)])
            
        if semantics is not None:
            # [B, T, H, V, W]
            batch_semantics = torch.stack([semantics[epi_index][time_index:time_index+time_interval].to(device) \
                                    for epi_index, time_index in zip(episode_indices, time_indices)])
        else:
            batch_semantics = None
        negative_time_indices = np.random.randint(length-time_interval+1, size=batch_size) 
        # [B, T, H, V, W, C]
        batch_negatives = torch.stack([images[epi_index][negative_time_index:negative_time_index+time_interval].to(device) \
                                    for epi_index, negative_time_index in zip(episode_indices, negative_time_indices)])
            
    return batch_images, batch_semantics, batch_negatives

def get_batch_images_for_mae_encoder(batch_images, batch_semantics=None):
    B, T, H, V, W, C = batch_images.shape

    # [T, B, V, H, W, C]
    reshaped_batch_images = torch.permute(batch_images, (1,0,3,2,4,5))
    # [T, BV, H, W, C]
    reshaped_batch_images = torch.reshape(reshaped_batch_images, (T, -1, H,W,C))

    batch_images_for_ViT = batch_images.reshape((B, T, H, V*W, C)) # B, Time, H, VW, C = x.shape
    
    if batch_semantics is not None:
        # [T, B, V, H, W]
        reshaped_batch_semantics = torch.permute(batch_semantics, (1,0,3,2,4))
        # [T, BV, H, W]
        reshaped_batch_semantics = torch.reshape(reshaped_batch_semantics, (T, -1, H,W))

        batch_semantics_for_ViT = batch_semantics.reshape((B, T, H, V*W)) # B, Time, H, VW = x.shape
    else:
        reshaped_batch_semantics = None
        batch_semantics_for_ViT = None
        
    return batch_images_for_ViT, reshaped_batch_images, batch_semantics_for_ViT, reshaped_batch_semantics


if __name__=='__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    train()
