
import os
from tqdm.auto import tqdm
from opt import config_parser
import matplotlib.pyplot as plt


import json, random
from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
import datetime

from dataLoader import dataset_dict
import sys

from torch_efficient_distloss import eff_distloss



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

renderer = OctreeRender_trilinear_fast


class SimpleSampler:
    def __init__(self, total, batch):
        self.total = total
        self.batch = batch
        self.curr = total
        self.ids = None

    def nextids(self):
        self.curr+=self.batch
        if self.curr + self.batch > self.total:
            self.ids = torch.LongTensor(np.random.permutation(self.total))
            self.curr = 0
        return self.ids[self.curr:self.curr+self.batch]



class PatchSampler:
    def __init__(self, total, batch, W, H):
        self.total = total
        self.frame_num = total // (W*H)
        self.total_patch = (W-1) * (H-1) 
        self.total_id = self.total_patch * self.frame_num
        self.curr = self.total_id 
        self.W = W
        self.H = H
        self.frame_size = W*H
        self.batch = batch
        print(f"frame_num: {self.frame_num} frame_size: {self.frame_size}")


    def nextids(self):
        self.curr+=self.batch
        if self.curr + self.batch > self.total_id:
            self.ids = torch.LongTensor(np.random.permutation(self.total_id))
            self.curr = 0
        # Randomly select a starting pixel for the patch
        # frame_index = torch.LongTensor(np.random.randint(0, self.frame_num, size=self.batch))
        pixel = self.ids[self.curr:self.curr+self.batch]
        frame_index = pixel // self.total_patch
        pixel_x = pixel % (self.W - 1)
        pixel_y = (pixel // (self.W - 1)) % (self.H - 1)

        id_o = pixel_x + pixel_y * self.W + frame_index * self.frame_size
        id_r = (pixel_x + 1) + pixel_y * self.W + frame_index * self.frame_size
        id_d = pixel_x + (pixel_y + 1) * self.W + frame_index * self.frame_size
        id_rd = (pixel_x + 1) + (pixel_y + 1) * self.W + frame_index * self.frame_size
        # concat to 2x2 patch
    
        return torch.cat([id_o, id_r, id_d, id_rd], dim=0)



@torch.no_grad()
def export_mesh(args):

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.load(ckpt)

    alpha,_ = tensorf.getDenseAlpha()
    convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)


@torch.no_grad()
def render_test(args):
    # init dataset
    dataset = dataset_dict[args.dataset_name]
    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
    white_bg = test_dataset.white_bg
    ndc_ray = args.ndc_ray

    if not os.path.exists(args.ckpt):
        print('the ckpt path does not exists!!')
        return

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.load(ckpt)

    logfolder = os.path.dirname(args.ckpt)
    if args.render_train:
        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
        PSNRs_test,_,_ = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
        print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')

    if args.render_test:
        os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
        evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)

    if args.render_path:
        c2ws = test_dataset.render_path
        os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
        evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device, reso=args.down_sampling_ratio)


@torch.no_grad()
def warping(allrays, ray_idx, H, W, focal, depth_map, c2w_j, startposition_in_allray, patch_mask, patch_size):
    depth_map_patch = torch.cat([depth_map]*(patch_size**2))
    c2w_j_patch = torch.cat([c2w_j] * (patch_size**2))
    startposition_in_allray_patch = torch.cat([startposition_in_allray] * (patch_size**2)).to(device)
    
    
    rays_train = torch.zeros((ray_idx.shape[0], 6), device=device)
    rays_train[patch_mask] = allrays[ray_idx[patch_mask].cpu()].to(device)
    # get 3d point in world coordinate, using "rays_o + rays_d * depth"
    xyz = rays_train[:, 0:3] + depth_map_patch.unsqueeze(-1) * rays_train[:, 3:]
    xyz[:, 0:2] = -xyz[:, 0:2]
    # world coordinate to camera coordinate j
    xyz = xyz.unsqueeze(-1) - c2w_j_patch[:, :, 3:]
    xyz = torch.matmul(c2w_j_patch[:, :, 0:3].transpose(1, 2), xyz).squeeze(-1)
    # camera coordinate j to image coordinate
    u = torch.round(xyz[:, 0] / (-xyz[:, 2]) * focal + W*.5).to(torch.int)
    v = torch.round(xyz[:, 1] / (-xyz[:, 2]) * focal + H*.5).to(torch.int)
    within_mask = (u >= 0) & (u < W) & (v >= 0) & (v < H)
    
    position_in_allray = startposition_in_allray_patch.to(device) + v * W + u
    return position_in_allray, within_mask


@ torch.no_grad()
def cal_reprojection_error(rgb, projected_rgb, mask,  patch_size):    
    repro_error = torch.ones(mask.shape[0]).to(mask.device)
    repro_error[mask] = torch.mean((rgb - projected_rgb)**2, -1)
    repro_error = torch.mean(repro_error.view(int(repro_error.shape[0] / (patch_size**2)), patch_size**2), dim=-1) # get the mean reprojection error of each patch
    
    return repro_error


# calculate reprojection error with rgb of frame i and rgb warped to frame j
@torch.no_grad()
def patchify(ray_idx, H, W, patch_size, total_frame_len):
    patch_offset = patch_size // 2
    t_ref = (ray_idx // (H * W)).to(device).unsqueeze(-1).repeat(1, patch_size**2)  # frame num
    v_ref = ((ray_idx % (H * W)) // W).to(device).unsqueeze(-1).repeat(1, patch_size**2) + torch.tensor([i - patch_offset for i in range(patch_size)], device=device).repeat(patch_size)
    u_ref = ((ray_idx % (H * W)) % W).to(device).unsqueeze(-1).repeat(1, patch_size**2) + torch.tensor([j - patch_offset for j in range(patch_size)], device=device).repeat_interleave(patch_size)
    patch_ray_idx = t_ref * (H * W) + v_ref * W + u_ref
    patch_ray_idx = patch_ray_idx.view(-1)  # (batch_size * k * k)
    patch_mask = ((u_ref >= 0) & (u_ref < W) & (v_ref >= 0) & (v_ref < H)) & (t_ref >= 0) & (t_ref < total_frame_len)
    patch_mask = patch_mask.view(-1)  # mask for pixels out of (H, W)

    return patch_ray_idx, patch_mask


def reconstruction(args):

    # init dataset
    dataset = dataset_dict[args.dataset_name]
    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False, frame_num=args.train_frame_num)
    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True, frame_num=args.test_frame_num)
    novel_dataset = dataset(args.datadir, split='novel', downsample=args.downsample_train, is_stack=False, frame_num=args.train_frame_num)
    white_bg = train_dataset.white_bg
    near_far = train_dataset.near_far
    ndc_ray = args.ndc_ray

    # init resolution
    upsamp_list = args.upsamp_list
    update_AlphaMask_list = args.update_AlphaMask_list
    n_lamb_sigma = args.n_lamb_sigma
    n_lamb_sh = args.n_lamb_sh

    
    if args.add_timestamp:
        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
    else:
        logfolder = f'{args.basedir}/{args.expname}'
    

    # init log file
    os.makedirs(logfolder, exist_ok=True)
    os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
    os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
    os.makedirs(f'{logfolder}/rgba', exist_ok=True)
    summary_writer = SummaryWriter(logfolder)



    # init parameters
    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
    print("Downsampling ratio:", args.down_sampling_ratio)
    aabb = train_dataset.scene_bbox.to(device)
    reso_cur = N_to_reso(args.N_voxel_init, aabb)
    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
    nSamples_MR = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio*args.down_sampling_ratio[0]))
    nSamples_LR = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio*args.down_sampling_ratio[1]))
    

    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt, map_location=device)
        kwargs = ckpt['kwargs']
        kwargs.update({'device':device})
        tensorf = eval(args.model_name)(**kwargs)
        tensorf.load(ckpt)
    else:
        tensorf = eval(args.model_name)(aabb, reso_cur, device,
                    density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,down_sampling_ratio = args.down_sampling_ratio,
                    shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,
                    pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio,fea2denseAct=args.fea2denseAct, rayMarch_weight_thres=1e-3)
        

    grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
    if args.lr_decay_iters > 0:
        lr_factor = args.lr_decay_target_ratio_for_weight**(1/args.lr_decay_iters)
    else:
        args.lr_decay_iters = args.n_iters
        lr_factor = args.lr_decay_target_ratio_for_weight**(1/args.n_iters)

    print("lr decay", args.lr_decay_target_ratio_for_weight, args.lr_decay_iters)
    
    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))


    #linear in logrithmic space
    N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]


    torch.cuda.empty_cache()
    PSNRs,PSNRs_test,SSIMs_test,LPIPSs_test = [],[0],[0],[0]

    allrays, allrgbs, alldepths, alldepthweights = train_dataset.all_rays, train_dataset.all_rgbs, train_dataset.all_depths, train_dataset.all_depth_weights
    # TODO: parameter for warping
    all_ids, allnearest_ids, all_poses = train_dataset.all_ids, train_dataset.all_nearest_ids, train_dataset.poses #get frame_id, nearest_frame_id, and poses_of_each_frame
    W, H = train_dataset.img_wh
    f = train_dataset.focal[0]
    frameid2_startpoints_in_allray = torch.tensor(train_dataset.frameid2_startpoints_in_allray) # get start position in "allray" for each view

    allrays_novel, all_ids_novel, allnearest_ids_novel, all_poses_novel = novel_dataset.all_rays, novel_dataset.all_ids, novel_dataset.all_nearest_ids, novel_dataset.render_path


    Ortho_reg_weight = args.Ortho_weight
    print("initial Ortho_reg_weight", Ortho_reg_weight)

    L1_reg_weight = args.L1_weight_inital
    print("initial L1_reg_weight", L1_reg_weight)

    TV_weight_density, TV_weight_app, TV_weight_color_density = args.TV_weight_density, args.TV_weight_app, args.TV_weight_color_density
    tvreg = TVLoss()
    colortvreg = colorTVLoss()
    print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app} color_density: {TV_weight_color_density}")

    sparse_depth_weight = args.Sparse_Depth_weight
    print(f"initial sparse_depth_weight: {sparse_depth_weight}")

    self_depth_weight = args.Self_Depth_weight
    print(f"initial self_depth_weight: {self_depth_weight}")

    dist_loss_weight = args.Dist_weight
    print(f"initial dist_loss_weight: {dist_loss_weight}")

    depth_smooth_weight = args.Depth_Smooth_weight
    print(f"initial depth_smooth_weight: {depth_smooth_weight}")

    mr_color_weight = args.MR_color_weight
    print(f"initial mr_color_weight: {mr_color_weight}")

    lr_color_weight = args.LR_color_weight
    print(f"initial lr_color_weight: {lr_color_weight}")

    warping_patch_size = args.warping_patch_size
    print(f"warping_patch_size: {warping_patch_size}")
    
    occ_loss_weight = args.Occ_loss_weight
    reg_rate = args.reg_rate
    wb_prior = args.wb_prior
    wb_rate = args.wb_rate
    
    self_depth_start_iter = args.self_depth_start_iter
    reprojection_error_thr = args.reprojection_error_thr

    


    train_frame_len = len(args.train_frame_num)
    novel_frame_len = 60

    self_depth_ratio_high_list = []
    self_depth_ratio_mid_list = []
    self_depth_ratio_low_list = []
    repo_error_high_list = []
    repo_error_mid_list = []
    repo_error_low_list = []

    if depth_smooth_weight > 0:
        trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
        novelSampler = PatchSampler(allrays_novel.shape[0], args.novel_batch_size, W, H)
        train_items = args.batch_size
        batch_size = args.batch_size  + args.novel_batch_size * 4
    else:
        trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
        novelSampler = SimpleSampler(allrays_novel.shape[0], args.novel_batch_size)
        train_items = args.batch_size
        batch_size = args.batch_size + args.novel_batch_size


    pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
    for iteration in pbar:

        ray_idx = trainingSampler.nextids()
        # set ray_idx first 10 items be random selected depth id
        if iteration < 1000:
            depth_mask = alldepths > 0
            depth_mask = depth_mask.view(-1)
            depth_mask_idx = torch.nonzero(depth_mask).view(-1)
            depth_mask_idx = depth_mask_idx[torch.randperm(depth_mask_idx.shape[0])[:10]]
            ray_idx[:10] = depth_mask_idx
        rays_train, rgb_train, depth_train = allrays[ray_idx].to(device), allrgbs[ray_idx].to(device), alldepths[ray_idx].to(device)
        
        if args.novel_batch_size > 0:
            ray_idx_novel = novelSampler.nextids()
            rays_novel = allrays_novel[ray_idx_novel].to(device)
        
            rays_train_all = torch.cat([rays_train, rays_novel], dim=0)
        else:
            rays_train_all = rays_train

        rgb_map, depth_map, weight, m, sigma, rgb_ray = renderer(rays_train_all, tensorf, chunk=batch_size, N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True, reso=1)
        rgb_map_MR, depth_map_MR, weight_MR, m_MR, sigma_MR, rgb_ray_MR = renderer(rays_train_all, tensorf, chunk=batch_size, N_samples=nSamples_MR, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True, reso=args.down_sampling_ratio[0])
        rgb_map_LR, depth_map_LR, weight_LR, m_LR, sigma_LR, rgb_ray_LR = renderer(rays_train_all, tensorf, chunk=batch_size, N_samples=nSamples_LR, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True, reso=args.down_sampling_ratio[1])
        

        if self_depth_weight > 0 and iteration > self_depth_start_iter:
            ids_train, nearest_ids_train = all_ids[ray_idx], allnearest_ids[ray_idx] # get correpsonding pose id for each ray
            c2w_train, nearest_c2w_train = all_poses[ids_train].to(device), all_poses[nearest_ids_train].to(device) # get correpsonding c2w matrix for each ray
            
            
            patch_ray_idx_i, patch_mask_i = patchify(ray_idx, H, W, warping_patch_size, train_frame_len)
            
            patch_ray_idx_j, patch_mask_j = warping(allrays, patch_ray_idx_i, H, W, f, depth_map[:train_items], nearest_c2w_train, frameid2_startpoints_in_allray[nearest_ids_train], patch_mask_i, warping_patch_size)
            patch_ray_idx_j_MR, patch_mask_j_MR = warping(allrays, patch_ray_idx_i, H, W, f, depth_map_MR[:train_items], nearest_c2w_train, frameid2_startpoints_in_allray[nearest_ids_train], patch_mask_i, warping_patch_size)
            patch_ray_idx_j_LR, patch_mask_j_LR = warping(allrays, patch_ray_idx_i, H, W, f, depth_map_LR[:train_items], nearest_c2w_train, frameid2_startpoints_in_allray[nearest_ids_train], patch_mask_i, warping_patch_size)

            with torch.no_grad():
                mask = patch_mask_i & patch_mask_j
                rgb = allrgbs[patch_ray_idx_i[mask].cpu()].to(device)
                projected_rgb  = allrgbs[patch_ray_idx_j[mask].cpu()].to(device)

                mask_MR = patch_mask_i & patch_mask_j_MR
                rgb_MR = allrgbs[patch_ray_idx_i[mask_MR].cpu()].to(device)
                projected_rgb_MR  = allrgbs[patch_ray_idx_j_MR[mask_MR].cpu()].to(device)

                mask_LR = patch_mask_i & patch_mask_j_LR
                rgb_LR = allrgbs[patch_ray_idx_i[mask_LR].cpu()].to(device)
                projected_rgb_LR  = allrgbs[patch_ray_idx_j_LR[mask_LR].cpu()].to(device)

            reprojection_error = cal_reprojection_error(rgb, projected_rgb, mask, warping_patch_size)
            reprojection_error_MR = cal_reprojection_error(rgb_MR, projected_rgb_MR, mask_MR, warping_patch_size)
            reprojection_error_LR = cal_reprojection_error(rgb_LR, projected_rgb_LR, mask_LR, warping_patch_size)

            # novel view (patch set 1 and use HR render color)
            if args.novel_batch_size > 0:
                ids_novel, nearest_ids_novel = all_ids_novel[ray_idx_novel], allnearest_ids_novel[ray_idx_novel]
                c2w_novel, nearest_c2w_novel = all_poses_novel[ids_novel].to(device), all_poses[nearest_ids_novel].to(device)
                
                patch_ray_idx_i_novel, patch_mask_i_novel = patchify(ray_idx_novel, H, W, 1, novel_frame_len)
                
                patch_ray_idx_j_novel, patch_mask_j_novel = warping(allrays_novel, patch_ray_idx_i_novel, H, W, f, depth_map[train_items:], nearest_c2w_novel, frameid2_startpoints_in_allray[nearest_ids_novel], patch_mask_i_novel, 1)
                patch_ray_idx_j_novel_MR, patch_mask_j_novel_MR = warping(allrays_novel, patch_ray_idx_i_novel, H, W, f, depth_map_MR[train_items:], nearest_c2w_novel, frameid2_startpoints_in_allray[nearest_ids_novel], patch_mask_i_novel, 1)
                patch_ray_idx_j_novel_LR, patch_mask_j_novel_LR = warping(allrays_novel, patch_ray_idx_i_novel, H, W, f, depth_map_LR[train_items:], nearest_c2w_novel, frameid2_startpoints_in_allray[nearest_ids_novel], patch_mask_i_novel, 1)
                
                with torch.no_grad():
                    mask_novel = patch_mask_i_novel & patch_mask_j_novel
                    rgb_novel = rgb_map[train_items:][mask_novel]
                    projected_rgb_novel  = allrgbs[patch_ray_idx_j_novel[mask_novel].cpu()].to(device)

                    mask_MR_novel = patch_mask_i_novel & patch_mask_j_novel_MR
                    rgb_MR_novel = rgb_map_MR[train_items:][mask_MR_novel]
                    projected_rgb_MR_novel  = allrgbs[patch_ray_idx_j_novel_MR[mask_MR_novel].cpu()].to(device)

                    mask_LR_novel = patch_mask_i_novel & patch_mask_j_novel_LR
                    rgb_LR_novel = rgb_map[train_items:][mask_LR_novel]
                    projected_rgb_LR_novel  = allrgbs[patch_ray_idx_j_novel_LR[mask_LR_novel].cpu()].to(device)
                
                reprojection_error_novel = cal_reprojection_error(rgb_novel, projected_rgb_novel, mask_novel, 1)
                reprojection_error_novel_MR = cal_reprojection_error(rgb_MR_novel, projected_rgb_MR_novel, mask_MR_novel, 1)
                reprojection_error_novel_LR = cal_reprojection_error(rgb_LR_novel, projected_rgb_LR_novel, mask_LR_novel, 1)

                reprojection_error = torch.cat([reprojection_error, reprojection_error_novel])
                reprojection_error_MR = torch.cat([reprojection_error_MR, reprojection_error_novel_MR])
                reprojection_error_LR = torch.cat([reprojection_error_LR, reprojection_error_novel_LR])

        HR_error = (rgb_map[:train_items] - rgb_train)  ** 2
        MR_error = (rgb_map_MR[:train_items] - rgb_train)  ** 2
        LR_error = (rgb_map_LR[:train_items] - rgb_train)  ** 2

        loss = torch.mean(HR_error)
        loss_MR = torch.mean(MR_error)
        loss_LR = torch.mean(LR_error)
        
        # loss
        total_loss = loss * 1.0 + loss_MR * mr_color_weight + loss_LR * lr_color_weight
        mr_color_weight *= lr_factor
        lr_color_weight *= lr_factor
        summary_writer.add_scalar('train/loss2', loss_MR.detach().item(), global_step=iteration)
        summary_writer.add_scalar('train/loss3', loss_LR.detach().item(), global_step=iteration)

        if Ortho_reg_weight > 0:
            loss_reg = tensorf.vector_comp_diffs()
            total_loss += Ortho_reg_weight*loss_reg
            summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)

        if L1_reg_weight > 0:
            loss_reg_L1 = tensorf.density_L1()
            total_loss += L1_reg_weight*loss_reg_L1
            summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)

        if TV_weight_density>0:
            TV_weight_density *= lr_factor
            loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
            loss_tv = loss_tv + tensorf.TV_loss_density_MR(tvreg) * TV_weight_density
            loss_tv = loss_tv + tensorf.TV_loss_density_LR(tvreg) * TV_weight_density
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)

        if TV_weight_app>0:
            TV_weight_app *= lr_factor
            loss_tv = tensorf.TV_loss_app(tvreg) * TV_weight_app
            loss_tv = loss_tv + tensorf.TV_loss_app_MR(tvreg) * TV_weight_app
            loss_tv = loss_tv + tensorf.TV_loss_app_LR(tvreg) * TV_weight_app
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
        
        if TV_weight_color_density>0:
            TV_weight_color_density *= lr_factor
            loss_tv = tensorf.TV_loss_color_aware_density(colortvreg) * TV_weight_color_density
            loss_tv = loss_tv + tensorf.TV_loss_color_aware_density_MR(colortvreg) * TV_weight_color_density
            loss_tv = loss_tv + tensorf.TV_loss_color_aware_density_LR(colortvreg) * TV_weight_color_density
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_color_tv', loss_tv.detach().item(), global_step=iteration)
        
        if sparse_depth_weight>0:
            sparse_depth_weight *= lr_factor
            depth_mask = depth_train > 0
            loss_depth = 0
            loss_depth = cal_disparity_loss(depth_map[:train_items][depth_mask], depth_train[depth_mask])*sparse_depth_weight
            loss_depth = loss_depth + cal_disparity_loss(depth_map_MR[:train_items][depth_mask], depth_train[depth_mask])*sparse_depth_weight
            loss_depth = loss_depth + cal_disparity_loss(depth_map_LR[:train_items][depth_mask], depth_train[depth_mask])*sparse_depth_weight

            total_loss = total_loss + loss_depth
            summary_writer.add_scalar('train/reg_depth', loss_depth.detach().item(), global_step=iteration)

        if dist_loss_weight>0:
            dist_loss_weight /= lr_factor
            dist_loss = eff_distloss(weight, m.detach(), 1/nSamples) * dist_loss_weight
            dist_loss = dist_loss + eff_distloss(weight_MR, m_MR.detach(), 1/nSamples_MR) * dist_loss_weight
            dist_loss = dist_loss + eff_distloss(weight_LR, m_LR.detach(), 1/nSamples_LR) * dist_loss_weight
            total_loss = total_loss + dist_loss 
            summary_writer.add_scalar('train/dist_loss', dist_loss.detach().item(), global_step=iteration)

        if self_depth_weight>0 and iteration > self_depth_start_iter:
            with torch.no_grad():
                _, min_idx = torch.min(torch.stack([reprojection_error, reprojection_error_MR, reprojection_error_LR]), dim=0)
                HR_mask = (min_idx == 0) & (reprojection_error > 0) & (reprojection_error < reprojection_error_thr)
                MR_mask = (min_idx == 1) & (reprojection_error_MR > 0) & (reprojection_error_MR < reprojection_error_thr)
                LR_mask = (min_idx == 2) & (reprojection_error_LR > 0) & (reprojection_error_LR < reprojection_error_thr)

                if  iteration % 100 == 0:
                    total_warped = torch.sum(HR_mask).float() + torch.sum(MR_mask).float() + torch.sum(LR_mask).float()
                    self_depth_ratio_high = torch.sum(HR_mask).float()/total_warped
                    self_depth_ratio_mid = torch.sum(MR_mask).float() / total_warped
                    self_depth_ratio_low = torch.sum(LR_mask).float() / total_warped
                    self_depth_ratio_high_list.append(self_depth_ratio_high.cpu())
                    self_depth_ratio_mid_list.append(self_depth_ratio_mid.cpu())
                    self_depth_ratio_low_list.append(self_depth_ratio_low.cpu())
                    repo_error_high_list.append(torch.mean(reprojection_error).cpu())
                    repo_error_mid_list.append(torch.mean(reprojection_error_MR).cpu())
                    repo_error_low_list.append(torch.mean(reprojection_error_LR).cpu())

            self_depth_loss = cal_depth_loss(depth_map[MR_mask], depth_map_MR[MR_mask].detach())*self_depth_weight
            self_depth_loss = self_depth_loss + cal_depth_loss(depth_map[LR_mask], depth_map_LR[LR_mask].detach())*self_depth_weight
            self_depth_loss = self_depth_loss + cal_depth_loss(depth_map_MR[HR_mask], depth_map[HR_mask].detach())*self_depth_weight
            self_depth_loss = self_depth_loss + cal_depth_loss(depth_map_MR[LR_mask], depth_map_LR[LR_mask].detach())*self_depth_weight
            self_depth_loss = self_depth_loss + cal_depth_loss(depth_map_LR[HR_mask], depth_map[HR_mask].detach())*self_depth_weight
            self_depth_loss = self_depth_loss + cal_depth_loss(depth_map_LR[MR_mask], depth_map_MR[MR_mask].detach())*self_depth_weight

            total_loss = total_loss + self_depth_loss
            summary_writer.add_scalar('train/self_depth_loss', self_depth_loss.detach().item(), global_step=iteration)

        if depth_smooth_weight>0:
            # depth_smooth_weight *= lr_factor
            depth_map_patch = depth_map[train_items:].view(2, 2, -1)
            depth_map_MR_patch = depth_map_MR[train_items:].view(2, 2, -1)
            depth_map_LR_patch = depth_map_LR[train_items:].view(2, 2, -1)
            depth_smooth_loss = DSLoss(depth_map_patch) * depth_smooth_weight
            depth_smooth_loss = depth_smooth_loss + DSLoss(depth_map_MR_patch) * depth_smooth_weight
            depth_smooth_loss = depth_smooth_loss + DSLoss(depth_map_LR_patch) * depth_smooth_weight
            total_loss = total_loss + depth_smooth_loss 
            summary_writer.add_scalar('train/depth_smooth_loss', depth_smooth_loss.detach().item(), global_step=iteration)
        
        if occ_loss_weight>0:
            reg_rate *= lr_factor
            wb_rate *= lr_factor
            occ_loss_weight *= lr_factor
            rgb_mask = torch.zeros_like(sigma)
            if args.wb_prior:
                rgb_ray_mean = rgb_ray.mean(-1)
                white_mask = torch.where(rgb_ray_mean > 0.99, 1, 0)
                black_mask = torch.where(rgb_ray_mean < 0.01, 1, 0)
                rgb_mask = (white_mask + black_mask)
                rgb_mask[:, int(nSamples * wb_rate):] = 0 # white or black background range

            rgb_mask[:, :int(nSamples * reg_rate)] = 1 # Penalize the points in reg_range close to the camera
            occ_loss = torch.mean(sigma * rgb_mask, -1)
            occ_loss = torch.mean(occ_loss)
            # occ_loss = occ_loss + torch.sum(sigma_LR * rgb_mask)
            total_loss = total_loss + occ_loss * occ_loss_weight
            summary_writer.add_scalar('train/occ_loss', occ_loss.detach().item(), global_step=iteration)
        
        # if depth_smooth_weight>0:
        #     depth_map_patch = depth_map[:train_items].view(2, 2, -1)
        #     depth_map_LR_patch = depth_map_LR[:train_items].view(2, 2, -1)
        #     rgb_map_patch = rgb_train.view(2, 2, -1, 3)
        #     rgb_map_LR_patch = rgb_train.view(2, 2, -1, 3)
        #     depth_smooth_loss = ColorDSLoss(rgb_map_patch.detach(), depth_map_patch) * depth_smooth_weight
        #     depth_smooth_loss = depth_smooth_loss + ColorDSLoss(rgb_map_LR_patch.detach(), depth_map_LR_patch) * depth_smooth_weight
        #     total_loss = total_loss + depth_smooth_loss

        # if iteration > 1000:
        #     novel_mse = torch.mean((rgb_map[train_items:][mask_novel] - projected_rgb_novel) ** 2)
        #     novel_mse = novel_mse + torch.mean((rgb_map_LR[train_items:][mask_LR_novel] - projected_rgb_LR_novel) ** 2)
        #     total_loss = total_loss + novel_mse * 0.1
        #     summary_writer.add_scalar('train/novel_mse', novel_mse.detach().item(), global_step=iteration)

        # if occ_rate>0:
        #     # occ_rate *= lr_factor
        #     # occ_loss_weight *= lr_factor
        #     rgb_mask = torch.zeros_like(sigma)
        #     rgb_mask[:, :int(nSamples*occ_rate)] = 1
        #     occ_loss = torch.sum(sigma * rgb_mask)
        #     occ_loss = occ_loss + torch.sum(sigma_LR * rgb_mask)
        #     total_loss = total_loss + occ_loss * occ_loss_weight
        #     summary_writer.add_scalar('train/occ_loss', occ_loss.detach().item(), global_step=iteration)


        # if True:
        #     warp_loss = torch.sum((within_mask < 0) ^ (within_mask_LR < 0)) * 0.01
        #     warp_loss = warp_loss + torch.sum((within_mask_novel < 0) ^ (within_mask_novel_LR < 0)) * 0.01
        #     total_loss = total_loss + warp_loss
        #     summary_writer.add_scalar('train/warp_loss', warp_loss.detach().item(), global_step=iteration)
            

        # if self_depth_weight>0:
        #     with torch.no_grad():
        #         HR_mask = (loss > loss2) & (reprojection_error[:train_items] == 0)
        #         LR_mask = (loss < loss2) & (reprojection_error_LR[:train_items] == 0)
        #     loss_simple_depth = torch.mean((depth_map[:train_items][HR_mask] - depth_map_LR[:train_items][HR_mask].detach())**2)*self_depth_weight
        #     loss_simple_depth = loss_simple_depth + torch.mean((depth_map[:train_items][LR_mask].detach() - depth_map_LR[:train_items][LR_mask])**2)*self_depth_weight
        #     total_loss = total_loss + loss_simple_depth
        #     summary_writer.add_scalar('train/unwarped_depth', loss_simple_depth.detach().item(), global_step=iteration)


        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        loss = loss.detach().item()
        
        PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
        summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
        summary_writer.add_scalar('train/mse', loss, global_step=iteration)
        

        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * lr_factor

        # Print the current values of the losses.
        if iteration % args.progress_refresh_rate == 0:
            pbar.set_description(
                f'Iteration {iteration:05d}:'
                + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
                + f' t_psnr = {float(np.mean(PSNRs_test)):.2f}'
                + f' t_ssim = {float(np.mean(SSIMs_test)):.2f}'
                + f' t_lpips = {float(np.mean(LPIPSs_test)):.2f}'
            )
            PSNRs = []


        if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
            PSNRs_test, SSIMs_test, LPIPSs_test,masked_PSNRs_test,masked_SSIMs_test,masked_LPIPSs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
                                    prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray)
            summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
            summary_writer.add_scalar('test/ssim', np.mean(SSIMs_test), global_step=iteration)
            summary_writer.add_scalar('test/lpips', np.mean(LPIPSs_test), global_step=iteration)
            summary_writer.add_scalar('test/masked_psnr', np.mean(masked_PSNRs_test), global_step=iteration)
            summary_writer.add_scalar('test/masked_ssim', np.mean(masked_SSIMs_test), global_step=iteration)
            summary_writer.add_scalar('test/masked_lpips', np.mean(masked_LPIPSs_test), global_step=iteration)



        if iteration in update_AlphaMask_list:
            if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
                reso_mask = reso_cur
            new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
            if iteration == update_AlphaMask_list[0]:
                tensorf.shrink(new_aabb)
                # tensorVM.alphaMask = None
                L1_reg_weight = args.L1_weight_rest
                print("continuing L1_reg_weight", L1_reg_weight)




        if iteration in upsamp_list:
            n_voxels = N_voxel_list.pop(0)
            reso_cur = N_to_reso(n_voxels, tensorf.aabb)
            nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
            tensorf.upsample_volume_grid(reso_cur)
            tensorf.downsample_volume_grid(tensorf.down_sampling_ratio)
            nSamples_MR = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio*args.down_sampling_ratio[0]))
            nSamples_LR = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio*args.down_sampling_ratio[1]))

            if args.lr_upsample_reset:
                print("reset lr to initial")
                lr_scale = 1 #0.1 ** (iteration / args.n_iters)
            else:
                lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
            grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
            optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
        

    tensorf.save(f'{logfolder}/{args.expname}.th')


    if args.render_train:
        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True, frame_num=args.train_frame_num)
        PSNRs_test,SSIMs_test,LPIPSs_test,_,_,_ = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
        print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
        print(f'======> {args.expname} train all ssim: {np.mean(SSIMs_test)} <========================')
        print(f'======> {args.expname} train all lpips: {np.mean(LPIPSs_test)} <========================')

    if args.render_test:
        os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
        PSNRs_test,SSIMs_test,LPIPSs_test,masked_PSNRs_test,masked_SSIMs_test,masked_LPIPSs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
        summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
        print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
        print(f'======> {args.expname} test all ssim: {np.mean(SSIMs_test)} <========================')
        print(f'======> {args.expname} test all lpips: {np.mean(LPIPSs_test)} <========================')

    if args.render_path:
        c2ws = test_dataset.render_path
        print('========>',c2ws.shape)
        os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
        evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device, reso=args.down_sampling_ratio)

    # subplot self_depth_ratio and reprojection_error
    if self_depth_weight > 0:
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.plot(self_depth_ratio_high_list, label='high')
        plt.plot(self_depth_ratio_mid_list, label='mid')
        plt.plot(self_depth_ratio_low_list, label='low')
        plt.legend()
        plt.xlabel('iteration(*100)')
        plt.ylabel('self_depth_ratio')
        plt.subplot(1, 2, 2)
        # smooth the reprojection_error
        repo_error_high_list = np.array(repo_error_high_list)
        repo_error_mid_list = np.array(repo_error_mid_list)
        repo_error_low_list = np.array(repo_error_low_list)
        repo_error_high_list = (repo_error_high_list[1:] + repo_error_high_list[:-1]) / 2
        repo_error_mid_list = (repo_error_mid_list[1:] + repo_error_mid_list[:-1]) / 2
        repo_error_low_list = (repo_error_low_list[1:] + repo_error_low_list[:-1]) / 2
        plt.plot(repo_error_high_list, label='high')
        plt.plot(repo_error_mid_list, label='mid')
        plt.plot(repo_error_low_list, label='low')
        plt.legend()
        plt.xlabel('iteration(*100)')
        plt.ylabel('reprojection_error')
        plt.savefig(f'{logfolder}/self_depth_ratio.png')
        plt.close()
    

if __name__ == '__main__':

    torch.set_default_dtype(torch.float32)
    torch.manual_seed(20211202)
    np.random.seed(20211202)

    args = config_parser()
    print(args)

    if  args.export_mesh:
        export_mesh(args)

    if args.render_only and (args.render_test or args.render_path):
        render_test(args)
    else:
        reconstruction(args)

