import os
import random
import sys

from argparse import Namespace

import lovely_tensors as lt
import torch

from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from tqdm import trange

import pdb

sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../")))

from arguments import ModelParams, OptimizationParams, PipelineParams
from gaussian_splatting.gm_fluid import GaussianModel
from helpers.helper_gaussian import get_model
from helpers.helper_parser import get_parser, write_args_to_file
from helpers.helper_pipe import get_render_pipe
from helpers.helper_train import prepare_output_and_logger
from scene import Scene
from utils.image_utils import psnr
from utils.loss_utils import distance_loss, l1_loss, l2_loss, ssim

def debug_view(cam, gaussians):
    xyz = gaussians._visual_xyz
    xyz_h = torch.cat([xyz, torch.ones(len(xyz),1,device=xyz.device)],1)
    xyz_cam = (cam.world_view_transform @ xyz_h.T).T
    z  = xyz_cam[:,2]
    ndc = xyz_cam[:,:3] / xyz_cam[:,3:4]
    in_ndc = (ndc[:,0].abs()<1)&(ndc[:,1].abs()<1)&(z<0)
    print("u range:", ndc[:,0].min().item(), ndc[:,0].max().item())
    print("v range:", ndc[:,1].min().item(), ndc[:,1].max().item())
    print(f"z<0:{(z<0).float().mean():.2f}  in_frustum:{in_ndc.float().mean():.2f}")
    print("scale", gaussians.get_visual_scaling.min().item(),  gaussians.get_visual_scaling.max().item())
    print("alpha", gaussians.get_visual_opacity.mean().item())

    
def train(args: Namespace, model_args: ModelParams, optim_args: OptimizationParams, pipe_args: PipelineParams):

    write_args_to_file(args, model_args, optim_args, pipe_args, "training")

    tb_writer = prepare_output_and_logger(model_args)
    render_func, GRsetting, GRzer = get_render_pipe(pipe_args.rd_pipe)

    print(f"Model: {model_args.model}")
    Gaussian = get_model(model_args.model)

    gaussians: GaussianModel = Gaussian()

    scene = Scene(model_args, gaussians, loader=model_args.loader)

    num_channel = 3  # this is the render channel

    bg_color = 1 if model_args.white_background else 0
    bg_color = [bg_color] * num_channel
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    train_camera_list = scene.get_train_cameras().copy()
    train_cam_dict = {}
    unique_timestamps = sorted(list(set([cam.timestamp for cam in train_camera_list])))

    for i, timestamp in enumerate(unique_timestamps):
        train_cam_dict[i] = [cam for cam in train_camera_list if cam.timestamp == timestamp]

    print("Start training")
    print(f"Num of unique training timestamps: {len(unique_timestamps)}")

    test_camera_list = scene.get_test_cameras().copy()
    test_cam_dict = {}
    unique_timestamps = sorted(list(set([cam.timestamp for cam in test_camera_list])))

    for i, timestamp in enumerate(unique_timestamps):
        test_cam_dict[i] = [cam for cam in test_camera_list if cam.timestamp == timestamp]

    gaussians.setup_constants(optim_args)

    checkpoint_path = os.path.join(scene.model_path, "checkpoint")
    quantities_path = os.path.join(scene.model_path, "quantities")
    quantities_sim_path = os.path.join(scene.model_path, "quantities_sim")
    quantities_optim_path = os.path.join(scene.model_path, "quantities_optim")

    ##########################################################################################
    ####################### First frame, optimize visual particles' xyz ######################
    ##########################################################################################
    # gaussians.create_particles_visual()
    gaussians.create_particles_visual_from_pcd()
    gaussians.prepare_visual_particles_for_rendering()

    cur_time_index = 0

    cur_viewpoint_set = train_cam_dict[cur_time_index]
    cur_test_viewpoint_set = test_cam_dict[cur_time_index]
    current_time_iterations = optim_args.iterations_per_time_first

    gaussians.save_particles_optimization_first(quantities_optim_path, cur_time_index, 0)

    # Use 1 based index for saving and testing
    testing_iterations = [current_time_iterations]

    gaussians.training_setup_first_visual(optim_args)

    desc_str = f"Optimizing first frame visual"
    postfix = {"Visual": gaussians._visual_xyz.shape[0]}
    for itr in trange(1, current_time_iterations + 1, desc=desc_str, postfix=postfix, leave=True):
        gaussians.total_iterations += 1

        gaussians.update_learning_rate_first_visual(itr)

        gaussians.zero_gradient_cache_first_visual()

        cam_index = random.sample(cur_viewpoint_set, optim_args.batch)

        for i in range(optim_args.batch):
            viewpoint_cam = cam_index[i]
            render_pkg = render_func(
                viewpoint_cam,
                gaussians,
                pipe_args,
                background,
                GRsetting=GRsetting,
                GRzer=GRzer,
                pos_type="visual",
            )
            image = render_pkg["render"]

            gt_image = viewpoint_cam.original_image.float().cuda()

            view_name = viewpoint_cam.image_name

            l1_value = l1_loss(image, gt_image)
            ssim_value = 1.0 - ssim(image, gt_image)
            dist_value = distance_loss(gaussians.get_visual_xyz, optim_args.distance_threshold_visual)
            weight_loss = 0.0
            weight_loss += (1.0 - optim_args.lambda_dssim) * l1_value
            weight_loss += optim_args.lambda_dssim * ssim_value
            weight_loss += optim_args.lambda_first_distance * dist_value
            loss = weight_loss

            t_idx = cur_time_index
            loss_prefix_str = f"train_loss_frame_{t_idx:03d}"
            tb_writer.add_scalar(f"{loss_prefix_str}/l1_{view_name}", l1_value.item(), itr)
            tb_writer.add_scalar(f"{loss_prefix_str}/ssim_{view_name}", ssim_value.item(), itr)
            tb_writer.add_scalar(f"{loss_prefix_str}/dist_{view_name}", dist_value.item(), itr)
            tb_writer.add_scalar(f"{loss_prefix_str}/total_{view_name}", loss.item(), itr)

            loss.backward()
            gaussians.cache_gradient_first_visual()
            gaussians.optimizer.zero_grad()

        gaussians.set_batch_gradient_first_visual(optim_args.batch)

        gaussians.optimizer.step()
        gaussians.optimizer.zero_grad()

        if itr % 10 == 0:
            gaussians.save_particles_optimization_first(quantities_optim_path, cur_time_index, itr)

        if itr in testing_iterations:
            training_report(
                cur_time_index,
                # cur_viewpoint_set,
                cur_test_viewpoint_set,
                tb_writer,
                gaussians.total_iterations,
                scene,
                render_func,
                pipe_args,
                background,
                GRsetting,
                GRzer,
                scale=False,
                save_gt=itr == testing_iterations[0],
            )

    ####################################################################################################
    ####################### First frame, initialize hidden particles' xyz, velocity ####################
    ####################################################################################################

    gaussians.detach_visual_and_scale()

    gaussians.create_particles_hidden_from_pcd()

    num_hidden = gaussians._xyz.shape[0]
    num_visual = gaussians._visual_xyz.shape[0]
    tb_writer.add_scalar("num_particles/hidden", num_hidden, gaussians.total_sim_iterations)
    tb_writer.add_scalar("num_particles/visual", num_visual, gaussians.total_sim_iterations)
    tb_writer.add_scalar("num_particles/total", num_hidden + num_visual, gaussians.total_sim_iterations)

    gaussians.save_particles_simulation(quantities_sim_path, gaussians.total_sim_iterations)
    gaussians.total_sim_iterations += 1

    # In stable iterations, we don't update the visual particles
    for stable_iter in trange(optim_args.stable_iterations, desc="Stabilizing first frame", leave=True):
        gaussians.remove_invalid_particles()
        gaussians.guess_hidden_particles(stable=True)
        gaussians.save_particles_simulation_guess(quantities_sim_path, gaussians.total_sim_iterations)
        for _ in range(optim_args.solver_iterations):
            gaussians.update_solver_counts()
        for _ in range(optim_args.solver_iterations):
            ret_values = gaussians.project_gas_constraints()
            for ret_k, ret_v in ret_values.items():
                tb_writer.add_scalar(f"sim_stable/{ret_k}", ret_v, gaussians.total_tb_log_iterations)
            if "elapsed_time" in ret_values:
                tb_writer.add_scalar("elapsed_time", ret_values["elapsed_time"], gaussians.total_tb_log_iterations)
            gaussians.total_tb_log_iterations += 1

        gaussians.confirm_guess_hidden_particles()

        num_hidden = gaussians._xyz.shape[0]
        num_visual = gaussians._visual_xyz.shape[0]
        tb_writer.add_scalar("num_particles/hidden", num_hidden, gaussians.total_sim_iterations)
        tb_writer.add_scalar("num_particles/visual", num_visual, gaussians.total_sim_iterations)
        tb_writer.add_scalar("num_particles/total", num_hidden + num_visual, gaussians.total_sim_iterations)

        gaussians.save_particles_simulation(quantities_sim_path, gaussians.total_sim_iterations)
        gaussians.total_sim_iterations += 1

    gaussians.save_particles_frame(quantities_path, 0)
    gaussians.save_all(checkpoint_path, 0)

    ####################################################################################################
    ####################### Current frame, simulation and fitting ######################################
    ####################################################################################################
    ## the first hidden stage is skipped, as the theory is not clear
    ## currently, we have the fitted visual_xyz, and initialized hidden xyz

    gaussians.prepare_emitter_points_from_pcd()

    simulation_ratio = optim_args.simulation_ratio
    assert simulation_ratio == 1
    wind_since = optim_args.wind_since
    use_wind = False

    desc_str = "Simulating and optimizing current frame"
    for cur_time_index in trange(1, len(train_cam_dict), desc=desc_str, leave=True):
        if cur_time_index == 20:
            gaussians.emit_ratio_hidden = 0.1
            gaussians.emit_ratio_visual = 0.1
        # if cur_time_index == 70:
        #     gaussians.emit_ratio_hidden = 0.04
        #     gaussians.emit_ratio_visual = 0.02
        
        gaussians.emit_new_particles()

        # if cur_time_index <= 80:
        #     gaussians.remove_invalid_particles()
            
        # if cur_time_index % 10 ==1:
        num_hidden = gaussians._xyz.shape[0]
        num_visual = gaussians._visual_xyz.shape[0]
        desc_str = f" {cur_time_index} num_particles/hidden, {num_hidden}, visual, {num_visual}, total, {num_hidden + num_visual}"
            # print("num_particles/hidden", num_hidden)
            # print("num_particles/visual", num_visual)
            # print("num_particles/total", num_hidden + num_visual)

        use_wind = wind_since >= 0 and cur_time_index >= wind_since

        estimate_xyz = gaussians.guess_hidden_particles(use_wind=use_wind)

        gaussians.save_particles_simulation_guess(quantities_sim_path, gaussians.total_sim_iterations)

        for _ in range(optim_args.solver_iterations):
            gaussians.update_solver_counts()
        for _ in range(optim_args.solver_iterations):
            ret_values = gaussians.project_gas_constraints()
            for k, v in ret_values.items():
                tb_writer.add_scalar(f"sim_frame_{cur_time_index:03d}/{k}", v, gaussians.total_tb_log_iterations)
            if "elapsed_time" in ret_values:
                tb_writer.add_scalar("elapsed_time", ret_values["elapsed_time"], gaussians.total_tb_log_iterations)
            gaussians.total_tb_log_iterations += 1

        # setup the visual particles for current frame
        gaussians.training_setup_current(optim_args)
        gaussians.prepare_visual_particles_for_rendering()

        cur_viewpoint_set = train_cam_dict[cur_time_index]
        num_cams = len(cur_viewpoint_set)
        if optim_args.local_inter:
            added_viewpoint_set = []
            for  i in range(len(optim_args.inter_iters)):
                jump_gap = (i+1)*optim_args.inter_gap * num_cams
                if cur_time_index + jump_gap < len(train_cam_dict):
                    added_viewpoint_set.append(train_cam_dict[cur_time_index + jump_gap])
                else:
                    added_viewpoint_set.append(train_cam_dict[cur_time_index - jump_gap])

        iters_min = optim_args.iterations_per_time_current
        iters_max = optim_args.iterations_per_time_current_max
        current_time_iterations = iters_min + (iters_max - iters_min) * cur_time_index / len(train_cam_dict)
        current_time_iterations = int(current_time_iterations)

        if optim_args.sparse_views_from_time_index > 0 and cur_time_index >= optim_args.sparse_views_from_time_index:
            # sparse views
            sparse_viewpoint_set = []
            for viewpoint in cur_viewpoint_set:
                if viewpoint.image_name in optim_args.sparse_views:
                    sparse_viewpoint_set.append(viewpoint)
            cur_viewpoint_set = sparse_viewpoint_set
            current_time_iterations = optim_args.iterations_per_time_current_sparse

        # testing_iterations = [1, current_time_iterations // 2, current_time_iterations]
        testing_iterations = [current_time_iterations]

        # Here we save the visual particles for the current frame before optimization
        gaussians.save_particles_optimization(quantities_optim_path, gaussians.get_visual_xyz, cur_time_index, 0)

        # desc_str = f"Optimizing frame {cur_time_index}"
        
        postfix = {"Hidden": gaussians._xyz.shape[0], "Visual": gaussians._visual_xyz.shape[0]}

        for itr in (pbar := trange(1, current_time_iterations + 1, desc=desc_str, postfix=postfix, leave=False)):
            gaussians.total_iterations += 1

            gaussians.update_learning_rate_current(itr)

            gaussians.zero_gradient_cache_current()
            
            if optim_args.local_inter:
                if itr in optim_args.inter_iters and len(added_viewpoint_set):
                    candidate_view = added_viewpoint_set.pop(0)[0]
                    render_pkg = render_func(
                        candidate_view,
                        gaussians,
                        pipe_args,
                        background,
                        GRsetting=GRsetting,
                        GRzer=GRzer,
                        pos_type="guess_visual_nn",
                        scale=True,
                    )
                    image = render_pkg["render"]
                    candidate_view.original_image = image.detach().clone()
                    cur_viewpoint_set.append(candidate_view)
                    # pdb.set_trace()
            
            cam_index = random.sample(cur_viewpoint_set, optim_args.batch)
            for i in range(optim_args.batch):
                viewpoint_cam = cam_index[i]
                render_pkg = render_func(
                    viewpoint_cam,
                    gaussians,
                    pipe_args,
                    background,
                    GRsetting=GRsetting,
                    GRzer=GRzer,
                    pos_type="guess_visual_nn",
                    scale=True,
                )
                image = render_pkg["render"]
                visual_xyz = render_pkg["render_xyz"]

                gt_image = viewpoint_cam.original_image.float().cuda()
                view_name = viewpoint_cam.image_name

                l1_value = l1_loss(image, gt_image)
                ssim_value = 1.0 - ssim(image, gt_image)

                dist_value = distance_loss(visual_xyz, optim_args.distance_threshold_visual)

                fake_estimated_xyz = gaussians._estimate_xyz_nn * gaussians.scale_factor
                exyz_loss_value = l2_loss(fake_estimated_xyz, estimate_xyz)

                gas_cs_p_ratio = gaussians.get_gas_constraints_from_exyz_nn()
                gt_value = torch.ones_like(gas_cs_p_ratio)
                gas_cs_loss_value = l2_loss(gas_cs_p_ratio, gt_value)

                next_gas_cs_p_ratio = gaussians.get_gas_constraints_from_vel_nn_guess()
                next_gt_value = torch.ones_like(next_gas_cs_p_ratio)
                next_gas_cs_loss_value = l2_loss(next_gas_cs_p_ratio, next_gt_value)

                weight_loss = 0.0
                weight_loss = weight_loss + (1.0 - optim_args.lambda_dssim) * l1_value * optim_args.lambda_image
                weight_loss = weight_loss + optim_args.lambda_dssim * ssim_value * optim_args.lambda_image

                weight_loss = weight_loss + optim_args.lambda_current_distance * dist_value

                weight_loss = weight_loss + optim_args.lambda_exyz * exyz_loss_value

                weight_loss = weight_loss + optim_args.lambda_gas_constraints * gas_cs_loss_value
                weight_loss = weight_loss + optim_args.lambda_next_gas_constraints * next_gas_cs_loss_value

                loss = weight_loss

                t_idx = cur_time_index
                loss_prefix_str = f"train_loss_frame_{t_idx:03d}"
                tb_writer.add_scalar(f"{loss_prefix_str}/l1_{view_name}", l1_value.item(), itr)
                tb_writer.add_scalar(f"{loss_prefix_str}/ssim_{view_name}", ssim_value.item(), itr)

                tb_writer.add_scalar(f"{loss_prefix_str}/dist_{view_name}", dist_value.item(), itr)

                tb_writer.add_scalar(f"{loss_prefix_str}/exyz_{view_name}", exyz_loss_value.item(), itr)

                tb_writer.add_scalar(f"{loss_prefix_str}/gas_cs_{view_name}", gas_cs_loss_value.item(), itr)
                tb_writer.add_scalar(f"{loss_prefix_str}/next_gas_cs_{view_name}", next_gas_cs_loss_value.item(), itr)

                tb_writer.add_scalar(f"{loss_prefix_str}/total_{view_name}", loss.item(), itr)

                optim_prefix_str = f"optim_frame_{t_idx:03d}"
                tb_writer.add_scalar(f"{optim_prefix_str}/gas_cs_p_ratio", gas_cs_p_ratio.mean().item(), itr)
                tb_writer.add_scalar(f"{optim_prefix_str}/next_gas_cs_p_ratio", next_gas_cs_p_ratio.mean().item(), itr)
                loss.backward(retain_graph=True)

                gaussians.cache_gradient_current()
                gaussians.optimizer.zero_grad()
                if hasattr(gaussians, "optimizer_alpha") and gaussians.optimizer_alpha is not None:
                    gaussians.optimizer_alpha.zero_grad()

            gaussians.set_batch_gradient_current(optim_args.batch)
            if hasattr(gaussians, "optimizer_alpha") and gaussians.optimizer_alpha is not None:
                gaussians.optimizer_alpha.step()
                gaussians.optimizer_alpha.zero_grad()
            gaussians.optimizer.step()
            gaussians.optimizer.zero_grad()

            if itr % 10 == 0:
                gaussians.save_particles_optimization(quantities_optim_path, visual_xyz, cur_time_index, itr)

            if itr in testing_iterations:
                with torch.no_grad():
                    training_report(
                        cur_time_index,
                        cur_viewpoint_set,
                        tb_writer,
                        itr,
                        scene,
                        render_func,
                        pipe_args,
                        background,
                        GRsetting,
                        GRzer,
                        pos_type="guess_visual_nn",
                        save_gt=itr == testing_iterations[0],
                        scale=True,
                    )
        # if cur_time_index == 1 or cur_time_index == 13:
        #     pdb.set_trace()
        # print('Testing ...')
        # for test_time_index in trange(1, len(test_cam_dict), desc=desc_str, leave=True):
        #     # cur_viewpoint_set = train_cam_dict[cur_time_index]
        #     cur_test_viewpoint_set = test_cam_dict[test_time_index]
            
        #     with torch.no_grad():
        #         training_report(
        #             test_time_index, 
        #             # cur_viewpoint_set,
        #             cur_test_viewpoint_set,
        #             tb_writer,
        #             itr,
        #             scene,
        #             render_func,
        #             pipe_args,
        #             background,
        #             GRsetting,
        #             GRzer,
        #             pos_type="guess_visual_nn",
        #             save_gt=itr == testing_iterations[0],
        #             scale=True,
        #         )

        gaussians.confirm_guess_hidden_particles_from_nn()
        gaussians.update_visual_xyz_from_nn()
        gaussians.confirm_guess_hidden_particles_wo_velocity()

        num_hidden = gaussians._xyz.shape[0]
        num_visual = gaussians._visual_xyz.shape[0]
        tb_writer.add_scalar("num_particles/hidden", num_hidden, gaussians.total_sim_iterations)
        tb_writer.add_scalar("num_particles/visual", num_visual, gaussians.total_sim_iterations)
        tb_writer.add_scalar("num_particles/total", num_hidden + num_visual, gaussians.total_sim_iterations)

        gaussians.save_particles_simulation(quantities_sim_path, gaussians.total_sim_iterations)
        gaussians.save_particles_frame(quantities_path, cur_time_index)
        gaussians.save_all(checkpoint_path, cur_time_index)
        gaussians.total_sim_iterations += 1

    ####################################################################################################
    ####################### Future prediction frame, simulation ######################################
    ####################################################################################################

    gaussians._estimate_xyz_nn_grad = None
    gaussians._estimate_xyz_nn = None
    gaussians._velocity_nn_grad = None
    gaussians._velocity_nn = None
    gaussians.optimizer = None

    cur_time_index += 1
    future_pred_frames = optim_args.future_pred_frames
    if future_pred_frames <= 0:
        print("No future prediction frames")
        return

    decay_frames_future_p0 = optim_args.decay_frames_future_p0
    p0_recon = gaussians.p0
    p0_future = optim_args.p0_future

    for future_time_index in (pbar := trange(future_pred_frames, desc=f"Predicting future frames", leave=True)):
        future_frame_index = cur_time_index + future_time_index
        gaussians.p0 = p0_future + (p0_recon - p0_future) * (1 - min(1, future_time_index / decay_frames_future_p0))
        gaussians.remove_invalid_particles()
        gaussians.emit_new_particles(future_time_index)
        gaussians.guess_hidden_particles()
        gaussians.save_particles_simulation_guess(quantities_sim_path, gaussians.total_sim_iterations)
        for _ in range(optim_args.solver_iterations_future):
            gaussians.update_solver_counts()
        for i in range(optim_args.solver_iterations_future):
            ret_values = gaussians.project_gas_constraints()
            for k, v in ret_values.items():
                tb_writer.add_scalar(f"fut_frame_{future_frame_index:03d}/{k}", v, gaussians.total_tb_log_iterations)
            if "elapsed_time" in ret_values:
                tb_writer.add_scalar("elapsed_time", ret_values["elapsed_time"], gaussians.total_tb_log_iterations)
            gaussians.total_tb_log_iterations += 1

        gaussians.confirm_guess_hidden_particles()
        gaussians.update_visual_particles()

        tb_writer.add_scalar("p0", gaussians.p0, gaussians.total_sim_iterations)
        num_hidden = gaussians._xyz.shape[0]
        num_visual = gaussians._visual_xyz.shape[0]
        tb_writer.add_scalar("num_particles/hidden", num_hidden, gaussians.total_sim_iterations)
        tb_writer.add_scalar("num_particles/visual", num_visual, gaussians.total_sim_iterations)
        tb_writer.add_scalar("num_particles/total", num_hidden + num_visual, gaussians.total_sim_iterations)

        gaussians.prepare_visual_particles_for_rendering()

        # we just use the first camera for future prediction
        # we append _0000 to filename to make it compatible with frames to video
        time_index = 0
        viewpoint_set = train_cam_dict[time_index]
        for viewpoint_cam in viewpoint_set:
            render_pkg = render_func(
                viewpoint_cam,
                gaussians,
                pipe_args,
                background,
                GRsetting=GRsetting,
                GRzer=GRzer,
                pos_type="visual",
                scale=True,
            )

            image = render_pkg["render"]
            save_image(
                image,
                os.path.join(
                    scene.model_path,
                    "training_render",
                    f"render_train_{future_frame_index:03d}_{viewpoint_cam.image_name}.png",
                ),
            )
        test_viewpoint_set = test_cam_dict[time_index]
        for viewpoint_cam in test_viewpoint_set:
            render_pkg = render_func(
                viewpoint_cam,
                gaussians,
                pipe_args,
                background,
                GRsetting=GRsetting,
                GRzer=GRzer,
                pos_type="visual",
                scale=True,
            )

            image = render_pkg["render"]
            save_image(
                image,
                os.path.join(
                    scene.model_path,
                    "training_render",
                    f"render_test_{future_frame_index:03d}_{viewpoint_cam.image_name}.png",
                ),
            )

        gaussians.save_particles_simulation(quantities_sim_path, gaussians.total_sim_iterations)
        gaussians.save_particles_frame(quantities_path, future_frame_index)
        gaussians.save_all(checkpoint_path, future_frame_index)
        gaussians.total_sim_iterations += 1

        post_fix = {}
        post_fix["Hidden"] = gaussians._xyz.shape[0]
        post_fix["Visual"] = gaussians._visual_xyz.shape[0]
        post_fix["Total"] = post_fix["Hidden"] + post_fix["Visual"]
        pbar.set_postfix(post_fix)


@torch.no_grad
def training_report(
    cur_time_index: int,
    # cur_viewpoint_set: list,
    cur_test_viewpoint_set: list,
    tb_writer: SummaryWriter,
    cur_iteration: int,
    scene: Scene,
    render_func: callable,
    pipe_args: PipelineParams,
    background: torch.Tensor,
    GRsetting: dict,
    GRzer: dict,
    pos_type: str = "visual",
    save_gt=True,
    scale=False,
):

    validation_configs = (
        {"name": "test", "viewpoint_set": cur_test_viewpoint_set},
        # {"name": "train", "viewpoint_set": cur_viewpoint_set},
    )

    for config in validation_configs:
        l1_test, l1_test_real = 0.0, 0.0
        psnr_test, psnr_test_real = 0.0, 0.0
        for idx, viewpoint in enumerate(config["viewpoint_set"]):
            rendered = render_func(
                viewpoint,
                scene.gaussians,
                pipe_args,
                background,
                override_color=None,
                GRsetting=GRsetting,
                GRzer=GRzer,
                pos_type=pos_type,
                scale=scale,
            )
            cam_name = viewpoint.image_path.split("/")[-2].split("cam")[-1]
            image = torch.clamp(rendered["render"], 0.0, 1.0)
            gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
            real_gt_image = torch.clamp(viewpoint.original_image_real.to("cuda"), 0.0, 1.0)

            save_image(
                image,
                os.path.join(
                    scene.model_path,
                    "training_render",
                    f"render_cam{cam_name}_{viewpoint.image_name}",
                ),
            )
            if save_gt:
                save_image(
                    gt_image,
                    os.path.join(
                        scene.model_path,
                        "training_render",
                        f"gt_cam{cam_name}_{viewpoint.image_name}",
                    ),
                )
                save_image(
                    real_gt_image,
                    os.path.join(
                        scene.model_path,
                        "training_render",
                        f"real_cam{cam_name}_{viewpoint.image_name}",
                    ),
                )
            na = viewpoint.image_name.split('.')[-2]
            tb_writer.add_images(
                f"frame_view_{na}_cam{cam_name}/render",
                image[None],
                global_step=cur_iteration,
            )
            if save_gt:
                tb_writer.add_images(
                    f"frame_view_{na}_cam{cam_name}/ground_truth",
                    gt_image[None],
                    global_step=cur_iteration,
                )

            l1_test += l1_loss(image, gt_image).mean()
            psnr_test += psnr(image, gt_image).mean()
            l1_test_real = l1_loss(image, real_gt_image).mean()
            psnr_test_real = psnr(image, real_gt_image).mean()

        l1_test /= len(config["viewpoint_set"])
        psnr_test /= len(config["viewpoint_set"])
        l1_test_real /= len(config["viewpoint_set"])
        psnr_test_real /= len(config["viewpoint_set"])

        # print(f"[ITER {cur_iteration} Evaluation {config['name']}] L1: {l1_test}, PSNR: {psnr_test}")

        tb_writer.add_scalar(f"eval_{config['name']}/frame_{cur_time_index:03d} - l1", l1_test, cur_iteration)
        tb_writer.add_scalar(f"eval_{config['name']}/frame_{cur_time_index:03d} - psnr", psnr_test, cur_iteration)
        tb_writer.add_scalar(
            f"eval_{config['name']}/frame_{cur_time_index:03d} - l1_real", l1_test_real, cur_iteration
        )
        tb_writer.add_scalar(
            f"eval_{config['name']}/frame_{cur_time_index:03d} - psnr_real", psnr_test_real, cur_iteration
        )


if __name__ == "__main__":
    lt.monkey_patch()
    args, mp_extract, op_extract, pp_extract = get_parser()
    train(args, mp_extract, op_extract, pp_extract)

    # All done
    print("Training complete.")
