#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import torch
from random import randint
from utils.loss_utils import l1_loss, l2_loss, ssim, kl_divergence, compute_normal_loss, env_tv_loss
from gaussian_renderer import render, network_gui, render_lighting
import sys
from scene import Scene, GaussianModel, DeformModel, DeformEnvModel
from utils.general_utils import safe_state, get_linear_noise_func, get_expon_lr_func, calculate_eccentricities, safe_normalize, reflect 
import uuid
from tqdm import tqdm
import lpips
from utils.image_utils import psnr, apply_depth_colormap
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams

try:
    from torch.utils.tensorboard import SummaryWriter

    TENSORBOARD_FOUND = True
except ImportError:
    TENSORBOARD_FOUND = False


def training(dataset, opt, pipe, testing_iterations, saving_iterations):
    tb_writer = prepare_output_and_logger(dataset)
    gaussians = GaussianModel(dataset.sh_degree, dataset.brdf_envmap_res)
    deform = DeformModel(dataset.is_blender, dataset.is_6dof)
    deform.train_setting(opt)
    deform_env = DeformEnvModel()
    deform_env.train_setting(opt)
    scene = Scene(dataset, gaussians)
    gaussians.training_setup(opt)

    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    iter_start = torch.cuda.Event(enable_timing=True)
    iter_end = torch.cuda.Event(enable_timing=True)

    depth_weight = get_expon_lr_func(opt.depth_weight_init, opt.depth_weight_final, max_steps=opt.iterations)
    viewpoint_stack = None
    ema_loss_for_log = 0.0
    ema_normal_for_log = 0.0
    ema_delta_normal_for_log = 0.0
    ema_env_tv_for_log = 0.0
    best_psnr = 0.0
    best_iteration = 0
    progress_bar = tqdm(range(opt.iterations), desc="Training progress")
    smooth_term = get_linear_noise_func(lr_init=0.1, lr_final=1e-15, lr_delay_mult=0.01, max_steps=20000)
    for iteration in range(1, opt.iterations + 1):
        if network_gui.conn == None:
            network_gui.try_connect()
        while network_gui.conn != None:
            try:
                net_image_bytes = None
                custom_cam, do_training, pipe.do_shs_python, pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive()
                if custom_cam != None:
                    net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
                    net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2,
                                                                                                               0).contiguous().cpu().numpy())
                network_gui.send(net_image_bytes, dataset.source_path)
                if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
                    break
            except Exception as e:
                network_gui.conn = None

        iter_start.record()

        # Every 1000 its we increase the levels of SH up to a maximum degree
        if iteration % 1000 == 0:
            gaussians.oneupSHdegree()

        # Pick a random Camera
        if not viewpoint_stack:
            viewpoint_stack = scene.getTrainCameras().copy()

        total_frame = len(viewpoint_stack)
        time_interval = 1 / total_frame

        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
        if dataset.load2gpu_on_the_fly:
            viewpoint_cam.load2device()
        fid = viewpoint_cam.fid

        d_reflvec = 0.0
        d_reflvec = torch.tensor(d_reflvec)
        d_reflvec = d_reflvec.to('cuda')
        if iteration < opt.warm_up:
            d_xyz, d_rotation, d_scaling = 0.0, 0.0, 0.0
        else:
            N = gaussians.get_xyz.shape[0]
            time_input = fid.unsqueeze(0).expand(N, -1)

            # ast_noise = 0 if dataset.is_blender else torch.randn(1, 1, device='cuda').expand(N, -1) * time_interval * smooth_term(iteration)
            if iteration >= opt.warm_up2 and iteration < (opt.warm_up3):
                with torch.no_grad():
                    d_xyz, d_rotation, d_scaling = deform.step(gaussians.get_xyz.detach(), time_input)
            else:
                d_xyz, d_rotation, d_scaling = deform.step(gaussians.get_xyz.detach(), time_input)
            gb_pos = gaussians.get_xyz + d_xyz # (N, 3) 
            view_pos = viewpoint_cam.camera_center.repeat(gaussians.get_opacity.shape[0], 1) # (N, 3) 
            d_viewdir_normalized = safe_normalize(view_pos - gb_pos)
            normal, deform_delta_normal = gaussians.get_normal(gaussians.get_scaling, gaussians.get_rotation, d_scaling, d_rotation, d_viewdir_normalized) # (N, 3)


        if iteration >= opt.warm_up2:
            gaussians.brdf_mlp.build_mips()
            if iteration >= (opt.warm_up2 +2000):
                reflvec = safe_normalize(reflect(d_viewdir_normalized, normal))
                d_reflvec = deform_env.step(reflvec.detach(), time_input)
            
            # if iteration < (opt.warm_up2 +3000):
            #     d_xyz = d_xyz.detach()
            #     d_rotation = d_rotation.detach()
            #     d_scaling = d_scaling.detach()

        gaussians.set_requires_grad("xyz", state=not (iteration >= opt.warm_up2 and iteration < (opt.warm_up3)))
        gaussians.set_requires_grad("opacity", state=not (iteration >= opt.warm_up2 and iteration < (opt.warm_up3)))
        gaussians.set_requires_grad("scaling", state=not (iteration >= opt.warm_up2 and iteration < (opt.warm_up3)))
        gaussians.set_requires_grad("rotation", state=not (iteration >= opt.warm_up2 and iteration < (opt.warm_up3)))

        
        # Render
        render_pkg_re = render(viewpoint_cam, gaussians, pipe, background, d_xyz, d_rotation, d_scaling, d_reflvec, iteration, opt, dataset.is_6dof)
        image, viewspace_point_tensor, visibility_filter, radii = render_pkg_re["render"], render_pkg_re[
            "viewspace_points"], render_pkg_re["visibility_filter"], render_pkg_re["radii"]
        # depth = render_pkg_re["depth"]

        # Loss
        gt_image = viewpoint_cam.original_image.cuda()
        Ll1 = l1_loss(image, gt_image)
        loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
        env_tv , normal_loss , delta_normal_loss = 0.0, 0.0, 0.0
        env_tv = torch.tensor(env_tv)
        normal_loss = torch.tensor(normal_loss)
        delta_normal_loss = torch.tensor(delta_normal_loss)
        if iteration >= opt.warm_up + 3000:
            lambda_normal = opt.lambda_normal if iteration >= (opt.warm_up + 3000) else 0.0
            normal_loss = lambda_normal * compute_normal_loss(render_pkg_re["normal"], render_pkg_re["normal_ref"])
            
            lambda_delta_normal = calculate_eccentricities(torch.abs(gaussians.get_scaling + d_scaling)).unsqueeze(-1) if iteration > (opt.warm_up + 3000) else 0.0
            lambda_delta_normal = lambda_delta_normal if iteration >= (opt.warm_up + 3000) else 0.0
            delta_normal_loss = (lambda_delta_normal * (deform_delta_normal ** 2)).mean()
            # if iteration >= opt.warm_up2:
            #     lighting = render_lighting(scene.gaussians, resolution=(scene.gaussians.brdf_envmap_res, scene.gaussians.brdf_envmap_res * 2))
            #     env_tv = opt.lambda_env * env_tv_loss(lighting)            

            loss = loss + normal_loss + delta_normal_loss
            
        loss.backward()

        iter_end.record()

        if dataset.load2gpu_on_the_fly:
            viewpoint_cam.load2device('cpu')

        with torch.no_grad():
            # Progress bar
            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
            ema_normal_for_log = 0.4 * normal_loss.item() + 0.6 * ema_normal_for_log
            ema_delta_normal_for_log = 0.4 * delta_normal_loss.item() + 0.6 * ema_delta_normal_for_log
            ema_env_tv_for_log = 0.4 * env_tv.item() + 0.6 * ema_env_tv_for_log
            if iteration % 10 == 0:
                loss_dict = {
                    "Loss": f"{ema_loss_for_log:.{5}f}",
                    "delta_normal": f"{ema_delta_normal_for_log:.{5}f}",
                    "normal": f"{ema_normal_for_log:.{5}f}",
                }
                progress_bar.set_postfix(loss_dict)
                progress_bar.update(10)
            if iteration == opt.iterations:
                progress_bar.close()
                
            if tb_writer is not None:
                tb_writer.add_scalar('train_loss_patches/delta_normal_loss', ema_delta_normal_for_log, iteration)
                tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration)
                tb_writer.add_scalar('train_loss_patches/env_tv_loss_loss', ema_env_tv_for_log, iteration)
            # Keep track of max radii in image-space for pruning
            gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter],
                                                                 radii[visibility_filter])

            # Log and save
            training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end),
                                       testing_iterations, scene, render, (pipe, background), deform, deform_env, opt,
                                       dataset.load2gpu_on_the_fly, dataset.is_6dof)


            if iteration in saving_iterations:
                print("\n[ITER {}] Saving Gaussians".format(iteration))
                scene.save(iteration)
                deform.save_weights(args.model_path, iteration)
                deform_env.save_weights(args.model_path, iteration)

            # Densification
            if not(iteration >= opt.warm_up2 and iteration < (opt.warm_up3)):
                if iteration < opt.densify_until_iter:
                    gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)

                    if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
                        size_threshold = 20 if iteration > opt.opacity_reset_interval else None
                        gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)

                    if iteration % opt.opacity_reset_interval == 0 or (
                            dataset.white_background and iteration == opt.densify_from_iter):
                        gaussians.reset_opacity()

            # Optimizer step
            if iteration < opt.iterations:
                gaussians.optimizer.step()
                gaussians.update_learning_rate(iteration)
                deform.optimizer.step()
                deform_env.optimizer.step()
                gaussians.optimizer.zero_grad(set_to_none=True)
                deform.optimizer.zero_grad()
                deform_env.optimizer.zero_grad()
                deform.update_learning_rate(iteration)
                deform_env.update_learning_rate(iteration)
            
            if iteration >= opt.warm_up2:
                gaussians.brdf_mlp.clamp_(min=0.0, max=1.0)


def prepare_output_and_logger(args):
    if not args.model_path:
        if os.getenv('OAR_JOB_ID'):
            unique_str = os.getenv('OAR_JOB_ID')
        else:
            unique_str = str(uuid.uuid4())
        args.model_path = os.path.join("./output/", unique_str[0:10])

    # Set up output folder
    print("Output folder: {}".format(args.model_path))
    os.makedirs(args.model_path, exist_ok=True)
    with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
        cfg_log_f.write(str(Namespace(**vars(args))))

    # Create Tensorboard writer
    tb_writer = None
    if TENSORBOARD_FOUND:
        tb_writer = SummaryWriter(args.model_path)
    else:
        print("Tensorboard not available: not logging progress")
    return tb_writer


def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene: Scene, renderFunc,
                    renderArgs, deform, deform_env, opt, load2gpu_on_the_fly, is_6dof=False):
    if tb_writer:
        tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
        tb_writer.add_scalar('iter_time', elapsed, iteration)

    test_psnr = 0.0
    # Report test and samples of training set
    if iteration in testing_iterations:
        torch.cuda.empty_cache()
        validation_configs = ({'name': 'test', 'cameras': scene.getTestCameras()},
                              {'name': 'train',
                               'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in
                                           range(5, 30, 5)]})

        for config in validation_configs:
            if config['cameras'] and len(config['cameras']) > 0:
                l1_test = []
                psnrs = []
                ssims = []
                lpipss = []
                for idx, viewpoint in enumerate(config['cameras']):
                    if load2gpu_on_the_fly:
                        viewpoint.load2device()
                    fid = viewpoint.fid
                    xyz = scene.gaussians.get_xyz
                    time_input = fid.unsqueeze(0).expand(xyz.shape[0], -1)
                    d_xyz, d_rotation, d_scaling = deform.step(xyz.detach(), time_input)
                    gb_pos = scene.gaussians.get_xyz + d_xyz # (N, 3) 
                    view_pos = viewpoint.camera_center.repeat(scene.gaussians.get_opacity.shape[0], 1) # (N, 3) 
                    d_viewdir_normalized = safe_normalize(view_pos - gb_pos)
                    normal, deform_delta_normal = scene.gaussians.get_normal(scene.gaussians.get_scaling, scene.gaussians.get_rotation, d_scaling, d_rotation, d_viewdir_normalized) # (N, 3)
                    reflvec = safe_normalize(reflect(d_viewdir_normalized, normal))
                    d_reflvec = deform_env.step(reflvec.detach(), time_input)
                    render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs, d_xyz, d_rotation, d_scaling, d_reflvec, iteration, opt)    
                    image = torch.clamp(render_pkg["render"], 0.0, 1.0)
                    gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)


                    if load2gpu_on_the_fly:
                        viewpoint.load2device('cpu')
                    if tb_writer and (idx < 5):
                        tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name),
                                             image[None], global_step=iteration)
                        if iteration == testing_iterations[0]:
                            tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name),
                                                 gt_image[None], global_step=iteration)
                        for k in render_pkg.keys():
                            if render_pkg[k].dim()<3 or k=="render":
                                continue
                            if k == "depth":
                                image_k = apply_depth_colormap(-render_pkg[k][0][...,None])
                                image_k = image_k.permute(2,0,1)
                            elif k == "alpha":
                                image_k = apply_depth_colormap(render_pkg[k][0][...,None], min=0., max=1.)
                                image_k = image_k.permute(2,0,1)
                            else:
                                if "normal" in k:
                                    render_pkg[k] = 0.5 + (0.5*render_pkg[k]) # (-1, 1) -> (0, 1)
                                image_k = torch.clamp(render_pkg[k], 0.0, 1.0)
                            tb_writer.add_images(config['name'] + "_view_{}/{}".format(viewpoint.image_name, k), image_k[None], global_step=iteration)
                            if iteration >= opt.warm_up2:
                                lighting = render_lighting(scene.gaussians, resolution=(scene.gaussians.brdf_envmap_res, scene.gaussians.brdf_envmap_res * 2))
                                tb_writer.add_images(config['name'] + "/lighting", lighting[None], global_step=iteration)
                            
                    l1_test.append(l1_loss(image, gt_image))
                    psnrs.append(psnr(image.unsqueeze(0), gt_image.unsqueeze(0)))
                    ssims.append(ssim(image.unsqueeze(0), gt_image.unsqueeze(0)))
                    lpipss.append(lpips_fn(image.unsqueeze(0), gt_image.unsqueeze(0)).detach())

                print("\n[ITER {}] Evaluating {}: Lpips {} PSNR {}".format(iteration, config['name'], torch.tensor(lpipss).mean(), torch.tensor(psnrs).mean()))
                if tb_writer:
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', torch.tensor(l1_test).mean(), iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', torch.tensor(psnrs).mean(), iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ssim', torch.tensor(ssims).mean(), iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - lpips', torch.tensor(lpipss).mean(), iteration)

        if tb_writer:
            tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
            tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
        torch.cuda.empty_cache()



if __name__ == "__main__":
    # Set up command line argument parser
    torch.manual_seed(3333)
    parser = ArgumentParser(description="Training script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    lpips_fn = lpips.LPIPS(net='vgg').to("cuda")
    parser.add_argument('--ip', type=str, default="127.0.0.1")
    parser.add_argument('--port', type=int, default=6009)
    parser.add_argument('--detect_anomaly', action='store_true', default=False)
    parser.add_argument("--test_iterations", nargs="+", type=int,
                        default=list(range(5000, 40001, 1000)))
    parser.add_argument("--save_iterations", nargs="+", type=int, default=[6_000, 10_000, 20_000, 30_000, 40000])
    parser.add_argument("--quiet", action="store_true")
    args = parser.parse_args(sys.argv[1:])
    args.save_iterations.append(args.iterations)

    print("Optimizing " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    # Start GUI server, configure and run training
    # network_gui.init(args.ip, args.port)
    torch.autograd.set_detect_anomaly(args.detect_anomaly)
    training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)

    # All done
    print("\nTraining complete.")
