
import os
from tqdm.auto import tqdm
from opt import config_parser



import json, random
from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
import datetime

from dataLoader import dataset_dict
import sys



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

renderer = OctreeRender_trilinear_fast


class SimpleSampler:
    def __init__(self, total, batch):
        self.total = total
        self.batch = batch
        self.curr = total
        self.ids = None

    def nextids(self):
        self.curr+=self.batch
        if self.curr + self.batch > self.total:
            self.ids = torch.LongTensor(np.random.permutation(self.total))
            self.curr = 0
        return self.ids[self.curr:self.curr+self.batch]


@torch.no_grad()
def export_mesh(args):

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.load(ckpt)

    alpha,_ = tensorf.getDenseAlpha()
    convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)


@torch.no_grad()
def render_test(args):
    # init dataset
    dataset = dataset_dict[args.dataset_name]
    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True, n_scales=args.n_scales)
    white_bg = test_dataset.white_bg
    ndc_ray = args.ndc_ray

    if not os.path.exists(args.ckpt):
        print('the ckpt path does not exists!!')
        return

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.load(ckpt)

    logfolder = os.path.dirname(args.ckpt)
    if args.render_train:
        print(f"Evaluating on train set...")
        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True, n_scales=args.n_scales)
        PSNRs_test = evaluation_multiscale(
            train_dataset, tensorf, args, renderer, f'{logfolder}/imgs_train_all/', N_vis=-1,
            N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device, compute_extra_metrics=False
        )
        log = ', '.join([
            f'PSNR_800: {np.mean(PSNRs_test[800])}',
            f'PSNR_400: {np.mean(PSNRs_test[400])}',
            f'PSNR_200: {np.mean(PSNRs_test[200])}',
            f'PSNR_100: {np.mean(PSNRs_test[100])}',
            f'PSNR_all: {np.mean(PSNRs_test["all"])}'
        ])
        print(f'========================> {args.expname} (train_set) <========================')
        print(log)

    if args.render_test:
        print(f"Evaluating on test set...")
        os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
        PSNRs_test = evaluation_multiscale(
            test_dataset, tensorf, args, renderer, f'{logfolder}/imgs_test_all/', N_vis=-1,
            N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray, device=device, compute_extra_metrics=False
        )
        log = ', '.join([
            f'PSNR_800: {np.mean(PSNRs_test[800])}',
            f'PSNR_400: {np.mean(PSNRs_test[400])}',
            f'PSNR_200: {np.mean(PSNRs_test[200])}',
            f'PSNR_100: {np.mean(PSNRs_test[100])}',
            f'PSNR_all: {np.mean(PSNRs_test["all"])}'
        ])
        print(f'========================> {args.expname} (test_set) <========================')
        print(log)

    # if args.render_path:
    #     c2ws = test_dataset.render_path
    #     os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
    #     evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
    #                             N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)

def reconstruction(args):

    # init dataset
    dataset = dataset_dict[args.dataset_name]
    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False, n_scales=args.n_scales)
    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True, n_scales=args.n_scales)
    white_bg = train_dataset.white_bg
    near_far = train_dataset.near_far
    ndc_ray = args.ndc_ray

    # init resolution
    upsamp_list = args.upsamp_list
    update_AlphaMask_list = args.update_AlphaMask_list
    n_lamb_sigma = args.n_lamb_sigma
    n_lamb_sh = args.n_lamb_sh

    
    if args.add_timestamp:
        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
    else:
        logfolder = f'{args.basedir}/{args.expname}'
    

    # init log file
    os.makedirs(logfolder, exist_ok=True)
    os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
    os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
    os.makedirs(f'{logfolder}/rgba', exist_ok=True)
    summary_writer = SummaryWriter(logfolder)



    # init parameters
    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
    aabb = train_dataset.scene_bbox.to(device)
    reso_cur = N_to_reso(args.N_voxel_init, aabb)
    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))


    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt, map_location=device)
        kwargs = ckpt['kwargs']
        kwargs.update({'device':device})
        tensorf = eval(args.model_name)(**kwargs)
        tensorf.load(ckpt)
    else:
        tensorf = eval(args.model_name)(
            aabb, reso_cur, device, density_n_comp=n_lamb_sigma,
            appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,
            shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift,
            distance_scale=args.distance_scale, pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe,
            featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct,
            kernel_size=args.kernel_size, n_scales=args.n_scales, learnable_kernel=args.learnable_kernel,
            kernel_init=args.kernel_init, apply_kernel=args.begin_kernel == 0, apply_basis=args.apply_basis, 
            learnable_basis=args.learnable_basis, basis_init=args.basis_init, kernel_nonlinear=args.kernel_nonlinear
        )


    grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
    if args.lr_decay_iters > 0:
        lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
    else:
        args.lr_decay_iters = args.n_iters
        lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)

    print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
    
    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))


    #linear in logrithmic space
    N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]


    torch.cuda.empty_cache()
    PSNRs = []
    PSNRs_test = {res: 0 for res in [800, 400, 200, 100, "all"]}

    all_rays = train_dataset.all_rays
    all_rgbs = train_dataset.all_rgbs
    all_scales = train_dataset.all_scales
    all_lossmults = train_dataset.all_lossmults

    if not args.ndc_ray:
        all_rays, all_rgbs, all_scales = tensorf.filtering_rays(all_rays, all_rgbs, all_scales, bbox_only=True)
    
    trainingSampler = SimpleSampler(all_rays.shape[0], args.batch_size)

    Ortho_reg_weight = args.Ortho_weight
    print("initial Ortho_reg_weight", Ortho_reg_weight)

    L1_reg_weight = args.L1_weight_inital
    print("initial L1_reg_weight", L1_reg_weight)
    TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
    tvreg = TVLoss()
    print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}")


    pbar = tqdm(range(args.n_iters), position=0, leave=True, miniters=args.progress_refresh_rate, file=sys.stdout)
    for iteration in pbar:


        ray_idx = trainingSampler.nextids()
        rays_train = all_rays[ray_idx]
        rgbs_train = all_rgbs[ray_idx].to(device)
        scales = all_scales[ray_idx]
        lossmults = all_lossmults[ray_idx].to(device)

        rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(
            rays_train, scales, tensorf, chunk=args.batch_size, N_samples=nSamples,
            white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True
        )
        
        if args.lossmult:
            loss = torch.mean(lossmults * (rgb_map - rgbs_train)**2)
        else:
            loss = torch.mean((rgb_map - rgbs_train)**2)

        # loss
        total_loss = loss
        if Ortho_reg_weight > 0:
            loss_reg = tensorf.vector_comp_diffs()
            total_loss += Ortho_reg_weight*loss_reg
            summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
        if L1_reg_weight > 0:
            loss_reg_L1 = tensorf.density_L1()
            total_loss += L1_reg_weight*loss_reg_L1
            summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)

        if TV_weight_density>0:
            TV_weight_density *= lr_factor
            loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
        if TV_weight_app>0:
            TV_weight_app *= lr_factor
            loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        loss = loss.detach().item()
        
        PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
        summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
        summary_writer.add_scalar('train/mse', loss, global_step=iteration)


        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * lr_factor

        # Print the current values of the losses.
        if iteration % args.progress_refresh_rate == 0:
            pbar.set_description(
                f'Iteration {iteration:05d}:'
                + f' mse={loss:.6f}'
                + f' train_psnr={float(np.mean(PSNRs)):.2f}'
                + f' test_psnr_800={float(np.mean(PSNRs_test[800])):.2f}'
                + f' test_psnr_400={float(np.mean(PSNRs_test[400])):.2f}'
                + f' test_psnr_200={float(np.mean(PSNRs_test[200])):.2f}'
                + f' test_psnr_100={float(np.mean(PSNRs_test[100])):.2f}'
                + f' test_psnr={float(np.mean(PSNRs_test["all"])):.2f}'
            )
            PSNRs = []


        if iteration % args.vis_every == args.vis_every - 1 and args.N_vis != 0:
            PSNRs_test = evaluation_multiscale(
                test_dataset, tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis, prtx=f'{iteration:06d}_',
                N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False
            )
            summary_writer.add_scalar('test/psnr_800', np.mean(PSNRs_test[800]), global_step=iteration)
            summary_writer.add_scalar('test/psnr_400', np.mean(PSNRs_test[400]), global_step=iteration)
            summary_writer.add_scalar('test/psnr_200', np.mean(PSNRs_test[200]), global_step=iteration)
            summary_writer.add_scalar('test/psnr_100', np.mean(PSNRs_test[100]), global_step=iteration)
            summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test["all"]), global_step=iteration)
        

        if iteration in update_AlphaMask_list:

            if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
                reso_mask = reso_cur
            new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))

            if iteration == update_AlphaMask_list[0]:
                tensorf.shrink(new_aabb)
                # tensorVM.alphaMask = None
                L1_reg_weight = args.L1_weight_rest
                print("continuing L1_reg_weight", L1_reg_weight)


            if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
                # filter rays outside the bbox
                all_rays, all_rgbs, all_scales = tensorf.filtering_rays(all_rays, all_rgbs, all_scales)
                trainingSampler = SimpleSampler(all_rgbs.shape[0], args.batch_size)


        if iteration in upsamp_list:
            n_voxels = N_voxel_list.pop(0)
            reso_cur = N_to_reso(n_voxels, tensorf.aabb)
            nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
            tensorf.upsample_volume_grid(reso_cur)

            if args.lr_upsample_reset:
                print("reset lr to initial")
                lr_scale = 1 #0.1 ** (iteration / args.n_iters)
            else:
                lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
            grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
            optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))

        
        if iteration == args.begin_kernel:
            tensorf.apply_kernel = True

        if iteration == args.end_lossmult:
            args.lossmult = False
        

    tensorf.save(f'{logfolder}/{args.expname}.th')


    if args.render_train:
        print(f"Evaluating on train set...")
        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True, n_scales=args.n_scales)
        PSNRs_test = evaluation_multiscale(
            train_dataset, tensorf, args, renderer, f'{logfolder}/imgs_train_all/', N_vis=-1,
            N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device, compute_extra_metrics=False
        )
        log = ', '.join([
            f'PSNR_800: {np.mean(PSNRs_test[800])}',
            f'PSNR_400: {np.mean(PSNRs_test[400])}',
            f'PSNR_200: {np.mean(PSNRs_test[200])}',
            f'PSNR_100: {np.mean(PSNRs_test[100])}',
            f'PSNR_all: {np.mean(PSNRs_test["all"])}'
        ])
        print(f'========================> {args.expname} (train_set) <========================')
        print(log)

    if args.render_test:
        print(f"Evaluating on test set...")
        os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
        PSNRs_test = evaluation_multiscale(
            test_dataset, tensorf, args, renderer, f'{logfolder}/imgs_test_all/', N_vis=-1,
            N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray, device=device, compute_extra_metrics=False
        )
        log = ', '.join([
            f'PSNR_800: {np.mean(PSNRs_test[800])}',
            f'PSNR_400: {np.mean(PSNRs_test[400])}',
            f'PSNR_200: {np.mean(PSNRs_test[200])}',
            f'PSNR_100: {np.mean(PSNRs_test[100])}',
            f'PSNR_all: {np.mean(PSNRs_test["all"])}'
        ])
        print(f'========================> {args.expname} (test_set) <========================')
        print(log)

    # if args.render_path:
    #     c2ws = test_dataset.render_path
    #     # c2ws = test_dataset.poses
    #     print('========>',c2ws.shape)
    #     os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
    #     evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
    #                             N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)


if __name__ == '__main__':

    torch.set_default_dtype(torch.float32)
    torch.manual_seed(20211202)
    np.random.seed(20211202)

    args = config_parser()
    print(args)

    if  args.export_mesh:
        export_mesh(args)

    if args.render_only and (args.render_test or args.render_path):
        render_test(args)
    else:
        reconstruction(args)

