from ast import Assert
import os
import json
from argparse import ArgumentParser
from re import split
import numpy as np
import numpy.matlib
from tqdm import tqdm
import imageio
from PIL import Image
import time
import config as config
import config2 as config2
import math
import tifffile
np.random.seed(0)

def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory | grep Free >tmp')
    memory_available = [int(x.split()[2]) for i, x in enumerate(open('tmp', 'r').readlines()) if i % 3 == 0]
    print(f"memory available = {memory_available}")
    print(f"np.argmax() = {np.argmax(memory_available)}")
    return np.argmax(memory_available)

gpu = get_freer_gpu()
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
print(f'gpu is {gpu}')

# Import jax only after setting the visible gpu
import jax
print(f"jax devices is {jax.devices()}")
import jax.numpy as jnp
from functools import partial
import plenoxel_ct
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
if __name__ != "__main__":
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.001'


flags = ArgumentParser()


flags.add_argument(
    "--data_dir", '-d',
    type=str,
    default='/home/datasets/', # for all the datasets
    help="Dataset directory e.g. ct_scans/"
)
flags.add_argument(
    "--expname",
    type=str,
    default = 'newPNGspikeTest_radius7.5_res900_tv0.0002_lr0.002_nonneg_views720',
    help="Experiment name."
)
flags.add_argument(
    "--scene",
    type=str,
    # default='jerry/', # for jerry projections
    default='spike/', #  for spike projections
    # default='ct_shepp',
    # default = 'scans/', 
    help="Name of the scene."
)
flags.add_argument(
    "--log_dir",
    type=str,
    default='jax_logs/',
    help="Directory to save outputs."
)
flags.add_argument(
    "--resolution",
    type=int,
    default=900, #900 for spike
    help="Grid size."
)
flags.add_argument(
    "--ini_rgb",
    type=float,
    default=0.0,
    help="Initial harmonics value in grid."
)
flags.add_argument(
    "--ini_sigma",
    type=float,
    default=0.1,
    help="Initial sigma value in grid."
)
flags.add_argument(
    "--radius", # affects resolution
    type=float,
    default=7.5,
    help="Grid radius. 1.3 works well on most scenes, but ship requires 1.5"
)
flags.add_argument(
    "--harmonic_degree",
    type=int,
    default=-1,
    help="Degree of spherical harmonics. Supports 0, 1, 2, 3, 4."
)
flags.add_argument(
    '--num_epochs',
    type=int,
    default=1,
    help='Epochs to train for.'
)
flags.add_argument(
    '--render_interval',
    type=int,
    default = 30,
    help='Render images during test/val step every x images.'
)
flags.add_argument(
    '--val_interval',
    type=int,
    default=2,
    help='Run test/val step every x epochs.'
)
flags.add_argument(
    '--lr_rgb',
    type=float,
    default=None,
    help='SGD step size for rgb. Default chooses automatically based on resolution.'
    )
flags.add_argument(
    '--lr_sigma',
    type=float,
    default=0.002,
    help='SGD step size for sigma. Default chooses automatically based on resolution.'
    )
flags.add_argument(
    '--physical_batch_size',
    type=int,
    default=3000,
    help='Number of rays per batch, to avoid OOM.'
    )
flags.add_argument(
    '--logical_batch_size',
    type=int,
    default=3000,
    help='Number of rays per optimization batch. Must be a multiple of physical_batch_size.'
    )
flags.add_argument(
    '--jitter',
    type=float,
    default=0.0,
    help='Take samples that are jittered within each voxel, where values are computed with trilinear interpolation. Parameter controls the std dev of the jitter, as a fraction of voxel_len.'
)
flags.add_argument(
    '--uniform',
    type=float,
    default=0.5,
    help='Initialize sample locations to be uniformly spaced at this interval (as a fraction of voxel_len), rather than at voxel intersections (default if uniform=0).'
)
flags.add_argument(
    '--occupancy_penalty',
    type=float,
    default=0.0,
    help='Penalty in the loss term for occupancy; encourages a sparse grid.'
)
flags.add_argument(
    '--reload_epoch',
    type=int,
    default=None,
    help='Epoch at which to resume training from a saved model.'
)
flags.add_argument(
    '--save_interval',
    type=int,
    default=1,
    help='Save the grid checkpoints after every x epochs.'
)
flags.add_argument(
    '--prune_epochs',
    type=int,
    nargs='+',
    default=[],
    help='List of epoch numbers when pruning should be done.'
)
flags.add_argument(
    '--prune_method',
    type=str,
    default='weight',
    help='Weight or sigma: prune based on contribution to training rays, or opacity.'
)
flags.add_argument(
    '--prune_threshold',
    type=float,
    default=0.001,
    help='Threshold for pruning voxels (either by weight or by sigma).'
)
flags.add_argument(
    '--split_epochs',
    type=int,
    nargs='+',
    default=[], # Try w/ 0, 1, or 2
    help='List of epoch numbers when splitting should be done.'
)
flags.add_argument(
    '--interpolation',
    type=str,
    default='trilinear',
    help='Type of interpolation to use. Options are constant, trilinear, or tricubic.'
)
flags.add_argument(
    '--nv',
    action='store_true',
    help='Use the Neural Volumes rendering formula instead of the Max (NeRF) rendering formula.'
)
flags.add_argument(
    '--ct',
    action='store_true',
    help='Optimize sigma only, based on the gt alpha channel.'
)
flags.add_argument(
    '--nonnegative',
    action='store_true',
    help='Clip stored grid values to be nonnegative. Intended for ct.'
)
flags.add_argument(
    '--num_views',
    type=int,
    default=20,
    help='Number of CT projections to train with. Only used with Jerry-CBCT.'
)
flags.add_argument(
    '--cut_cube',
    action='store_true',
    help='cuts the cube in half and halves the radius aswell'
)

FLAGS = flags.parse_args()
data_dir = FLAGS.data_dir + FLAGS.scene
radius = FLAGS.radius
np.random.seed(0)

# This is a dataloader for the ct datasets
def get_ct_jerry(root, stage, max_projections, xoff, yoff, zoff):
    all_w2c = []
    all_gt = []

    print('LOAD DATA', root)
    
    projection_matrices = np.genfromtxt(os.path.join(root, 'proj_mat_jerry.csv'), delimiter=',')  # [719, 12]
    print(f'proj mat len is {len(projection_matrices)}')
    
    #Traslation matrix along x,y,z
    Tz = np.matlib.zeros((4,4))
    Tz[0,0]=1.0
    Tz[1,1]=1.0
    Tz[2,2]=1.0
    Tz[3,3]=1.0
    Tz[0,3]=-xoff #test
    Tz[1,3]=-yoff #test
    Tz[2,3]=-zoff #test

    # tif_proj = tifffile.imread('/home/datasets/jerry/jerry_corr_src_proj.tif')

    # reads #max_projections projection images
    for i in range(len(projection_matrices)-1): 
        index = "{:04d}".format(i)
        im_gt = imageio.imread(os.path.join('/data/datasets/jerry/jerry_corr_proj', f'New_Cor_Proj{index}.png')).astype(np.float32) / 255.0
    
        # im_gt = tif_proj[i,:,:]
        im_gt = 1 - im_gt

        w2c = np.reshape(projection_matrices[i], (3,4))
        w2c = np.matmul(w2c,Tz) 
        
        all_w2c.append(w2c)
        all_gt.append(im_gt)

    all_gt = np.asarray(all_gt)
    all_w2c = np.asarray(all_w2c)

    focal = 100 
    
    mask = np.zeros(len(all_w2c))
    print(f'max is {len(all_w2c)}')
    idx = np.random.choice(len(all_w2c), max_projections, replace = False)
    mask[idx] = 1
    mask = mask.astype(bool)

    if stage == 'train':
        all_gt = all_gt[mask]
        all_w2c = all_w2c[mask]
    elif stage == 'test':
        all_gt  = all_gt[~mask]
        all_w2c = all_w2c[~mask]

    return focal, all_w2c,all_gt

def get_ct_spike(root, stage, max_projections, xoff, yoff, zoff):
    all_w2c = []
    all_gt = []

    print('LOAD DATA', root)
    projection_matrices = np.genfromtxt(os.path.join(root, 'proj_mat_720frames.csv'), delimiter=',')  # [719, 12] 

    #Traslation matrix along x,y,z
    Tz = np.matlib.zeros((4,4))
    Tz[0,0]=1.0
    Tz[1,1]=1.0
    Tz[2,2]=1.0
    Tz[3,3]=1.0
    Tz[0,3]=-xoff #test
    Tz[1,3]=-yoff #test
    Tz[2,3]=-zoff #test

    # tif_proj = tifffile.imread('/home/datasets/spike/Spike_720_proj.tif')

    for i in range(len(projection_matrices)): 
        index = "{:04d}".format(i)
        im_gt = imageio.imread(os.path.join('/home/datasets/spike/NewSpike92_8_16_3_proj', f'NewSpike92_8_16_33_proj_{index}.png')).astype(np.float32) / 255.0

        # im_gt = tif_proj[i,:,:]
        im_gt = 1 - im_gt
        w2c = np.reshape(projection_matrices[i], (3,4))
        w2c = np.matmul(w2c,Tz) 
        
        all_w2c.append(w2c)
        all_gt.append(im_gt)

    all_gt = np.asarray(all_gt)
    all_w2c = np.asarray(all_w2c)

    focal = 100 
    
    mask = np.zeros(len(all_w2c))
    idx = np.random.choice(len(all_w2c), max_projections, replace = False) 
    mask[idx] = 1
    mask = mask.astype(bool)

    if stage == 'train':
        all_gt = all_gt[mask]
        all_w2c = all_w2c[mask]
    elif stage == 'test':
        all_gt  = all_gt[~mask]
        all_w2c = all_w2c[~mask]

    return focal, all_w2c,all_gt

def get_ct_shepp(root, stage, max_projections, xoff, yoff, zoff):
    all_w2c = []
    all_gt = []

    print('LOAD DATA', root)
    
    # Use the same projection matrices as Spike
    projection_matrices = np.genfromtxt(os.path.join('/home/datasets/spike/', 'proj_mat_720frames.csv'), delimiter=',')  # [719, 12] /home/fabriz/data/spike/proj_mat_720frames.csv

    #Traslation matrix along x,y,z
    Tz = np.zeros((4,4))
    Tz[0,0]=1.0
    Tz[1,1]=1.0
    Tz[2,2]=1.0
    Tz[3,3]=1.0
    Tz[0,3]=-xoff #test
    Tz[1,3]=-yoff #test
    Tz[2,3]=-zoff #test

    # tif_proj = tifffile.imread(os.path.join(root, 'synthetic0.2_projections_raw_radius5_reso128_H140_W128_dhw0.12.tif'))
    tif_proj = tifffile.imread(os.path.join(root, 'fifthdensity_1_projections_raw_radius5_reso128_H140_W128_dhw0.12.tif'))

    # reads #max_projections projection images
    for i in range(len(projection_matrices)): 
        im_gt = tif_proj[i,:,:]
        w2c = np.reshape(projection_matrices[i], (3,4))
        w2c = np.matmul(w2c,Tz) 
        
        all_w2c.append(w2c)
        all_gt.append(im_gt)

    all_gt = np.asarray(all_gt)
    all_w2c = np.asarray(all_w2c)

    focal = 100 

    return focal, all_w2c,all_gt

def get_ct_synthetic(root, stage, max_projections, xoff, yoff, zoff):
    all_w2c = []
    all_gt = []

    print('LOAD DATA', root)
    
    # Use the same projection matrices as Spike
    projection_matrices = np.genfromtxt(os.path.join('/home/datasets/spike/', 'proj_mat_720frames.csv'), delimiter=',')  # [719, 12]

    #Traslation matrix along x,y,z
    Tz = np.zeros((4,4))
    Tz[0,0]=1.0
    Tz[1,1]=1.0
    Tz[2,2]=1.0
    Tz[3,3]=1.0
    Tz[0,3]=-xoff #test
    Tz[1,3]=-yoff #test
    Tz[2,3]=-zoff #test

    tif_proj = tifffile.imread(os.path.join(root, 'semi_easy_synthetic_projections_raw_radius5_reso50_H700.tif'))

    # reads #max_projections projection images
    for i in range(len(projection_matrices)): 
        im_gt = tif_proj[i,:,:]
        w2c = np.reshape(projection_matrices[i], (3,4))
        w2c = np.matmul(w2c,Tz)
        
        all_w2c.append(w2c)
        all_gt.append(im_gt)

    all_gt = np.asarray(all_gt)
    all_w2c = np.asarray(all_w2c)

    focal = 100 

    return focal, all_w2c,all_gt

# This function takesn in the given root and uses the appropriate
#   data loader to get the focal, c2w, and gt
def get_data(root, stage):
    max_projections = 50
    # to align the volume and the detector it is possibile to use a traslation matrix T to
    # premultiply the projection matrices P'=P*T
    
    # For Jerry
    # xoff =  0.0
    # yoff =  0.0
    # zoff =  -2.4

    # For Spike
    xoff =  0.0
    yoff =  -1.3 
    zoff =  -5.2 

    if root == '/home/datasets/jerry/':
        focal, all_c2w, all_gt = get_ct_jerry(root, stage, max_projections, xoff, yoff, zoff)  
        return focal, all_c2w, all_gt
    
    elif root == '/home/datasets/spike/':
        focal, all_c2w, all_gt = get_ct_spike(root, stage, max_projections, xoff, yoff, zoff)  
        return focal, all_c2w, all_gt
    elif 'ct_synthetic' in root:
        focal, all_c2w, all_gt = get_ct_synthetic(root, stage, max_projections, xoff, yoff, zoff)
        return focal, all_c2w, all_gt
    elif 'ct_shepp' in root:
        focal, all_c2w, all_gt = get_ct_shepp(root, stage, max_projections, xoff, yoff, zoff)
        return focal, all_c2w, all_gt

    all_c2w = []
    all_gt = []

    data_path = os.path.join(root, stage)
    data_json = os.path.join(root, 'transforms_' + stage + '.json')
    print('LOAD DATA', data_path)
    j = json.load(open(data_json, 'r'))

    for frame in tqdm(j['frames']):
        fpath = os.path.join(data_path, os.path.basename(frame['file_path']) + '.png')
        c2w = frame['transform_matrix']
        im_gt = imageio.imread(fpath).astype(np.float32) / 255.0
        im_gt = jnp.concatenate([im_gt[..., 3:], jnp.zeros((im_gt.shape[0], im_gt.shape[1], 2))], -1) # If we want to train with alpha
        all_c2w.append(c2w)
        all_gt.append(im_gt)
    focal = 0.5 * all_gt[0].shape[1] / np.tan(0.5 * j['camera_angle_x'])
    all_gt = np.asarray(all_gt)
    all_c2w = np.asarray(all_c2w)
    return focal, all_c2w, all_gt

# This uses calls get_data to get a training set and a test set
#   for a focal, c2w, and gt.
# If the focal is not equal to the test_focal then an AssertionError is raised. 
# The height and width are set to the shape of index 0 of the training gt from
#   the begining to index 2
# Finally the length of the trianing and test c2w's are obtained
if __name__ == "__main__":
    print(f'the data directory is {data_dir}')
    focal, train_c2w, train_gt = get_data(data_dir, "train")
    test_focal, test_c2w, test_gt = get_data(data_dir, "test")
    assert focal == test_focal
    H, W = train_gt[0].shape[:2]
    dW = 0.024
    dH = 0.024
    n_train_imgs = len(train_c2w)
    n_test_imgs = len(test_c2w)
    if 'shepp' in FLAGS.expname:
        H = 140
        W = 128
        dH = 0.12
        dW = 0.12

# Sets the new log_dirs to be the exsiting log_dir plus the experiment name
#   makes the neccessary directories for it if they don't exist.
log_dir = FLAGS.log_dir + FLAGS.expname
os.makedirs(log_dir, exist_ok=True)


automatic_lr = False
if FLAGS.lr_rgb is None or FLAGS.lr_sigma is None:
    automatic_lr = True
if FLAGS.lr_rgb is None:
    FLAGS.lr_rgb = 150 * (FLAGS.resolution ** 1.75)
if FLAGS.lr_sigma is None:
    FLAGS.lr_sigma = 51.5 * (FLAGS.resolution ** 2.37)


if FLAGS.reload_epoch is not None:
    reload_dir = os.path.join(log_dir, f'epoch_{FLAGS.reload_epoch}')
    print(f'Reloading the grid from {reload_dir}')
    data_dict = plenoxel_ct.load_grid(dirname=reload_dir, sh_dim = (FLAGS.harmonic_degree + 1)**2)
else:
    print(f'Initializing the grid')
    data_dict = plenoxel_ct.initialize_grid(resolution=FLAGS.resolution, ini_rgb=FLAGS.ini_rgb, ini_sigma=FLAGS.ini_sigma, harmonic_degree=FLAGS.harmonic_degree)

# low-pass filter the ground truth image so the effective resolution matches twice that of the grid
def lowpass(gt, resolution):
    if gt.ndim > 3:
        print(f'lowpass called on image with more than 3 dimensions; did you mean to use multi_lowpass?')
    H = gt.shape[0]
    W = gt.shape[1]
    im = Image.fromarray((np.squeeze(np.asarray(gt))*255).astype(np.uint8))
    im = im.resize(size=(resolution*2, resolution*2))
    im = im.resize(size=(H, W))
    return np.asarray(im) / 255.0


# low-pass filter a stack of images where the first dimension indexes over the images
# Takes a high resolution picture, blurs it, gets the low sine waves and creates a lower resolution picture from it
def multi_lowpass(gt, resolution):
    if gt.ndim <= 3:
        print(f'multi_lowpass called on image with 3 or fewer dimensions; did you mean to use lowpass instead?')
    H = gt.shape[-3]
    W = gt.shape[-2]
    clean_gt = np.copy(gt)
    for i in range(len(gt)):
        im = Image.fromarray(np.squeeze(gt[i,...] * 255).astype(np.uint8))
        im = im.resize(size=(resolution*2, resolution*2))
        im = im.resize(size=(W, H))
        im = np.asarray(im) / 255.0
        clean_gt[i,...] = im
    return clean_gt

def compute_tv(t):
    x_tv = jnp.abs(t[1:, :, :] - t[:-1, :, :]).mean()
    y_tv = jnp.abs(t[:, 1:, :] - t[:, :-1, :]).mean()
    z_tv = jnp.abs(t[:, :, 1:] - t[:, :, :-1]).mean()
    return x_tv + y_tv + z_tv

def get_loss(data_dict, c2w, gt, H, W, focal, resolution, radius, harmonic_degree, jitter, uniform, key, sh_dim, occupancy_penalty, interpolation, nv):
    rays = plenoxel_ct.get_rays(H, W, focal, c2w)
    rgb, disp, acc, weights, voxel_ids = plenoxel_ct.render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv)
    mse = jnp.mean((rgb - lowpass(gt, resolution))**2)
    indices, data = data_dict
    loss = mse + occupancy_penalty * jnp.mean(jax.nn.relu(data_dict[-1]))
    return loss

@partial(jax.jit, static_argnums=(3,4,5,6,7,9,11,12))
def get_loss_rays(data_dict, rays, gt, resolution, radius, harmonic_degree, jitter, uniform, key, sh_dim, occupancy_penalty, interpolation, nv):
    rgb, disp, acc, weights, voxel_ids = plenoxel_ct.render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv)
    mse = jnp.mean((acc - gt[...,0])**2)# Optimize the alpha channel only
 
    loss = mse + occupancy_penalty * jnp.mean(jax.nn.relu(data_dict[-1]))

    tv = compute_tv(data_dict[-1])
    return loss + 0.0003 * tv

def get_rays_np(H, W, dH, dW, w2c):
    # get M matrix
    M = w2c[:,0:3]
    # get  p4
    p4 = w2c[:,-1]

    # compute uo,vo,sdd
    uo = (M[0,:]*M[2,:]).sum()
    vo = (M[1,:]*M[2,:]).sum()
    aU = math.sqrt((M[0,:]*M[0,:]).sum() - uo*uo)
    aV = math.sqrt((M[1,:]*M[1,:]).sum() - vo*vo)
    sdd = 0.5*(aU+aV)

    #source position in the World Reference system
    M_inv = np.linalg.inv(M)
    srcPos = -np.matmul(M_inv,p4)
    
    shiftVo = (vo - 0.5 * H * dH)
    u, v = jnp.meshgrid(jnp.linspace(0, W-1, W) + 0.5, jnp.linspace(0, H-1, H) + 0.5)
    u = u * dW # u
    v = v * dH # v
    dirs   = jnp.stack([u, v, jnp.ones_like(u)], -1) 
    rays_d = jnp.sum(dirs[..., jnp.newaxis, :] * M_inv, -1) 
    rays_o = jnp.broadcast_to(srcPos,rays_d.shape)        
    return rays_o, rays_d



def render_pose_rays(data_dict, c2w, H, W, focal, resolution, radius, harmonic_degree, jitter, uniform, key, sh_dim, batch_size, interpolation, nv):
    rays_o, rays_d = get_rays_np(H, W, dH, dW, c2w)
    rays_o = np.reshape(rays_o, [-1,3])
    rays_d = np.reshape(rays_d, [-1,3])
    rgbs = []
    disps = []

    for i in range(int(np.ceil(H*W/batch_size))):
        start = i*batch_size
        stop = min(H*W, (i+1)*batch_size)
        if jitter > 0:
            rgbi, dispi, acci, weightsi, voxel_idsi = jax.lax.stop_gradient(plenoxel_ct.render_rays(data_dict, (rays_o[start:stop], rays_d[start:stop]), resolution, key[start:stop], radius, harmonic_degree, jitter, uniform, interpolation, nv))
        else:
            rgbi, dispi, acci, weightsi, voxel_idsi = jax.lax.stop_gradient(plenoxel_ct.render_rays(data_dict, (rays_o[start:stop], rays_d[start:stop]), resolution, None, radius, harmonic_degree, jitter, uniform, interpolation, nv))

        rgbi = jnp.concatenate([acci[:,jnp.newaxis], jnp.zeros_like(acci)[:,jnp.newaxis], jnp.zeros_like(acci)[:,jnp.newaxis]], axis=-1)
        rgbs.append(rgbi)
        disps.append(dispi)
    
    rgb = jnp.reshape(jnp.concatenate(rgbs, axis=0), (H, W, 3))
    disp = jnp.reshape(jnp.concatenate(disps, axis=0), (H, W))
    return rgb, disp, None, None


def run_test_step(i, data_dict, test_c2w, test_gt, H, W, focal, FLAGS, key, name_appendage=''):
    print('Evaluating')
    sh_dim = (FLAGS.harmonic_degree + 1)**2
    tpsnr = 0.0
    pb = tqdm(total=len(test_c2w))
    for j, (c2w, gt) in tqdm(enumerate(zip(test_c2w, test_gt))):
        rgb, disp, _, _ = render_pose_rays(data_dict, c2w, H, W, focal, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, key, sh_dim, FLAGS.physical_batch_size, FLAGS.interpolation, FLAGS.nv)
        rgb = jnp.concatenate((rgb[...,0,jnp.newaxis], rgb[...,0,jnp.newaxis], rgb[...,0,jnp.newaxis]), axis=-1)
        gt = jnp.concatenate((gt[...,jnp.newaxis], gt[...,jnp.newaxis], gt[...,jnp.newaxis]), axis=-1)

        mse = jnp.mean((rgb - gt)**2)
        psnr = -10.0 * np.log(mse) / np.log(10.0)
        tpsnr += psnr

        if FLAGS.render_interval > 0 and j % FLAGS.render_interval == 0:
            vis = jnp.concatenate((gt, rgb), axis = 1)

            imageio.imwrite(f"{log_dir}/{j:04}_{i:04}{name_appendage}.png", (vis*255).astype(np.uint8))

            tp = tpsnr
            tp /= len(test_c2w)
            pb.set_postfix_str(f"psnr = {tp}", refresh = False)
            pb.update(1)
        del rgb, disp
    tpsnr /= len(test_c2w)
    return tpsnr


def update_grid(old_grid, lr, grid_grad):
    if FLAGS.nonnegative:
        return jnp.clip(old_grid.at[...].add( -1 * lr * grid_grad), a_min=0)
    else:
        return old_grid.at[...].add(-1 * lr * grid_grad)



def update_grids(old_grid, lrs, grid_grad):
    old_grid[-1] = update_grid(old_grid[-1], lrs[-1], grid_grad[-1])# Only updates the sigma grid for CT
    return old_grid

if FLAGS.physical_batch_size is not None:
    if False:
        print(f'reloading saved rays')
        rays_rgb = np.load('rays.npy')
    else:
        print(f'precomputing all the training rays')
        # Precompute all the training rays and shuffle them
        t0 = time.time()

        rays = np.stack([get_rays_np(H, W, dH, dW, p) for p in train_c2w[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]  # 42 seconds
        t1 = time.time()
        print(f'stack took {t1 - t0} seconds')

        train_gt = np.concatenate([train_gt[...,None], train_gt[...,None], train_gt[...,None]], -1)
        rays_rgb = np.concatenate([rays, multi_lowpass(train_gt[:,None], FLAGS.resolution).astype(np.float32)], 1)  # [N, ro+rd+rgb, H, W, 3]  # 19 seconds
        t2 = time.time()
        print(f'concatenate took {t2 - t1} seconds')

        rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3] 
        t3 = time.time()
        print(f'transpose took {t3 - t2} seconds')

        rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]  # 12 seconds
        t4 = time.time()
        print(f'reshape took {t4 - t3} seconds')

        rays_rgb = rays_rgb.take(np.random.permutation(rays_rgb.shape[0]), axis=0)  # 34 seconds
        print(f'permutation took {time.time() - t4} seconds')


if FLAGS.jitter == 0:
    render_keys = None
    keys = None
    split_keys = None
    split_keys_partial = None

@jax.jit
def rmsprop_update(avg_g, data_grad):
    avg_g[-1] = 0.9 * avg_g[-1] + 0.1*data_grad[-1]**2
    return avg_g


def main():
    global rays_rgb, keys, render_keys, data_dict, FLAGS, radius, train_c2w, train_gt, test_c2w, test_gt, automatic_lr
    start_epoch = 0
    sh_dim = (FLAGS.harmonic_degree + 1)**2
    if FLAGS.reload_epoch is not None:
        start_epoch = FLAGS.reload_epoch + 1       
    if np.isin(FLAGS.reload_epoch, FLAGS.prune_epochs):
        data_dict = plenoxel_ct.prune_grid(data_dict, method=FLAGS.prune_method, threshold=FLAGS.prune_threshold, train_c2w=train_c2w, H=H, W=W, focal=focal, batch_size=FLAGS.physical_batch_size, resolution=FLAGS.resolution, key=render_keys, radius=FLAGS.radius, harmonic_degree=FLAGS.harmonic_degree, jitter=FLAGS.jitter, uniform=FLAGS.uniform, interpolation=FLAGS.interpolation)
    if np.isin(FLAGS.reload_epoch, FLAGS.split_epochs):
        print(FLAGS.resolution)
        data_dict = plenoxel_ct.split_grid(data_dict)
        FLAGS.resolution = FLAGS.resolution * 2
        if automatic_lr:
            FLAGS.lr_rgb = 150 * (FLAGS.resolution ** 1.75)
            FLAGS.lr_sigma = 51.5 * (FLAGS.resolution ** 2.37)

    avg_g = [0 for g_i in data_dict]
    
    for i in range(start_epoch, FLAGS.num_epochs):
        # Shuffle data before each epoch
        if FLAGS.physical_batch_size is None:
            temp = list(zip(train_c2w, train_gt))
            np.random.shuffle(temp)
            train_c2w, train_gt = zip(*temp)
        else:
            assert FLAGS.logical_batch_size % FLAGS.physical_batch_size == 0
            # Shuffle rays over all training images
            rays_rgb = rays_rgb.take(np.random.permutation(rays_rgb.shape[0]), axis=0)

        pb = tqdm(total=len(test_c2w), desc = f"Epoch {i}")

        if FLAGS.physical_batch_size is None:
            occupancy_penalty = FLAGS.occupancy_penalty / len(train_c2w)
            for j, (c2w, gt) in tqdm(enumerate(zip(train_c2w, train_gt)), total=len(train_c2w)):
                if FLAGS.jitter > 0:
                    splitkeys = split_keys(keys)
                    keys = splitkeys[...,0,:]
                    subkeys = splitkeys[...,1,:]
                else:
                    subkeys = None
                mse, data_grad = jax.value_and_grad(lambda grid: get_loss(grid, c2w, gt, H, W, focal, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data_dict) 
        else:
            occupancy_penalty = FLAGS.occupancy_penalty / (len(rays_rgb) // FLAGS.logical_batch_size)

            for k in tqdm(range(len(rays_rgb) // FLAGS.logical_batch_size)):
                logical_grad = None
                for j in range(FLAGS.logical_batch_size // FLAGS.physical_batch_size):
                    if FLAGS.jitter > 0:
                        splitkeys = split_keys_partial(keys)
                        keys = splitkeys[...,0,:]
                        subkeys = splitkeys[...,1,:]
                    else:
                        subkeys = None
                    effective_j = k*(FLAGS.logical_batch_size // FLAGS.physical_batch_size) + j
                    batch = rays_rgb[effective_j*FLAGS.physical_batch_size:(effective_j+1)*FLAGS.physical_batch_size] # [B, 2+1, 3*?]
                    batch_rays, target_s = (batch[:,0,:], batch[:,1,:]), batch[:,2,:]
                    mse, data_grad = jax.value_and_grad(lambda grid: get_loss_rays(grid, batch_rays, target_s, FLAGS.resolution, radius, FLAGS.harmonic_degree, FLAGS.jitter, FLAGS.uniform, subkeys, sh_dim, occupancy_penalty, FLAGS.interpolation, FLAGS.nv))(data_dict) 
                    pb.set_postfix_str(f"psnr = {-10*jnp.log10(mse)}, grad norm = {jnp.linalg.norm(data_grad[0])}", refresh = False)
                    pb.update(1)

                    # import pdb; pdb.set_trace()

                    if FLAGS.logical_batch_size > FLAGS.physical_batch_size:
                        if logical_grad is None:
                            logical_grad = data_grad
                        else:
                            logical_grad = [a + b for a, b in zip(logical_grad, data_grad)]
                        del data_grad
                    del mse, batch, batch_rays, target_s, subkeys, effective_j
                lrs = [FLAGS.lr_rgb / (FLAGS.logical_batch_size // FLAGS.physical_batch_size)]*sh_dim + [FLAGS.lr_sigma / (FLAGS.logical_batch_size // FLAGS.physical_batch_size)]
                lrs  = [lr * np.cos(k /((len(rays_rgb) // FLAGS.logical_batch_size)+ 10) * (np.pi/2)) for lr in lrs]
                
                avg_g = rmsprop_update(avg_g, data_grad)

                if FLAGS.logical_batch_size > FLAGS.physical_batch_size:
                    data_dict = update_grids(data_dict, lrs, logical_grad)
                    del logical_grad
                else:
                    data_dict[-1] = update_grid(data_dict[-1], lrs[-1], data_grad[-1]/ (jnp.sqrt(avg_g[-1]) + 1e-10))  
                    del data_grad, logical_grad

        if np.isin(i, FLAGS.prune_epochs):
            data_dict = plenoxel_ct.prune_grid(data_dict, method=FLAGS.prune_method, threshold=FLAGS.prune_threshold, train_c2w=train_c2w, H=H, W=W, focal=focal, batch_size=FLAGS.physical_batch_size, resolution=FLAGS.resolution, key=render_keys, radius=FLAGS.radius, harmonic_degree=FLAGS.harmonic_degree, jitter=FLAGS.jitter, uniform=FLAGS.uniform, interpolation=FLAGS.interpolation)
        
        if np.isin(i, FLAGS.split_epochs):
            print(f'at epoch {i}, about to split. res = {data_dict[0].shape}, flags.res = {FLAGS.resolution} ')
            data_dict = plenoxel_ct.split_grid(data_dict)
            FLAGS.lr_rgb = FLAGS.lr_rgb * 3
            FLAGS.lr_sigma = FLAGS.lr_sigma * 3
            FLAGS.resolution = FLAGS.resolution * 2
            print(f'at epoch {i}, finished split. res = {data_dict[0].shape}, flags.res = {FLAGS.resolution} ')
            if True:
                # Recompute all the training rays at the new resolution and shuffle them
                rays = np.stack([get_rays_np(H, W, dH, dW, p) for p in train_c2w[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
                rays_rgb = np.concatenate([rays, multi_lowpass(train_gt[:,None], FLAGS.resolution).astype(np.float32)], 1) # [N, ro+rd+rgb, H, W,   3]
                rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
                rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
                rays_rgb = rays_rgb.take(np.random.permutation(rays_rgb.shape[0]), axis=0)
            avg_g = plenoxel_ct.split_grid(avg_g)
            # import pdb; pdb.set_trace()

        pb.close()

        if i % FLAGS.save_interval == FLAGS.save_interval - 1 or i == FLAGS.num_epochs - 1:
            print(f'Saving checkpoint at epoch {i}')
            plenoxel_ct.save_grid(data_dict, os.path.join(log_dir, f'epoch_{i}'))

        if i % FLAGS.val_interval == FLAGS.val_interval - 1 or i == FLAGS.num_epochs - 1:
            validation_psnr = run_test_step(i + 1, data_dict, test_c2w, test_gt, H, W, focal, FLAGS, render_keys)
            print(f'at epoch {i}, test psnr is {validation_psnr}')

        if start_epoch == FLAGS.num_epochs:
            validation_psnr = run_test_step(start_epoch + 1, data_dict, test_c2w, test_gt, H, W, focal, FLAGS, render_keys)
            print(f'at epoch {start_epoch}, test psnr is {validation_psnr}')
    
if __name__ == "__main__":
    main()