#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
from argparse import ArgumentParser
from os import makedirs

import cv2
import numpy as np
import torch
import torchvision
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import render
from scene import Scene
from scene.planar_model import GaussianModelPlanes
from tqdm import tqdm
from utils.general_utils import safe_state
from utils.render_utils import generate_path, create_videos


def get_plane_colormap(num_planes):
    segmentationColor = (np.arange(num_planes + 1) + 1) * 100
    colorMap = np.stack(
        [
            segmentationColor / (256 * 256),
            segmentationColor / 256 % 256,
            segmentationColor % 256,
        ],
        axis=1,
    )
    colorMap /= 255.0
    return colorMap


def get_plane_color(idx, cmap):
    if idx != -1:
        return cmap[idx + 1]
    else:
        return np.array([0, 0, 0])


def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
    render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
    gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
    depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth")
    normal_path = os.path.join(model_path, name, "ours_{}".format(iteration), "normal")

    render_override_path = os.path.join(
        model_path, name, "ours_{}".format(iteration), "renders_colored_planes"
    )

    makedirs(render_path, exist_ok=True)
    makedirs(gts_path, exist_ok=True)
    makedirs(depth_path, exist_ok=True)
    makedirs(normal_path, exist_ok=True)
    makedirs(render_override_path, exist_ok=True)

    plane_ids = gaussians.plane_ids
    num_planes = len(gaussians.planes)

    cmap = get_plane_colormap(num_planes)

    override_color = np.array([get_plane_color(id.item(), cmap) for id in plane_ids])
    override_color = torch.tensor(override_color, dtype=torch.float32, device="cuda")

    for idx, view in enumerate(tqdm(views, desc="Rendering progress")):

        result = render(view, gaussians, pipeline, background)
        rendering = result["render"]
        gt = view.original_image[0:3, :, :]

        depth = result["depth"]
        depth = depth.permute(1, 2, 0)
        depth = (depth * 1000).detach().cpu().numpy().astype(np.uint16)

        torchvision.utils.save_image(
            rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
        )
        torchvision.utils.save_image(
            gt, os.path.join(gts_path, "{0:05d}".format(idx) + ".png")
        )

        cv2.imwrite(os.path.join(depth_path, "{0:05d}".format(idx) + ".png"), depth)

        # normal = result["normal"]
        # normal = (normal + 1) / 2
        # torchvision.utils.save_image(
        #     normal, os.path.join(normal_path, "{0:05d}".format(idx) + ".png")
        # )

        result_override = render(
            view, gaussians, pipeline, background, override_color=override_color
        )
        rendering_override = result_override["render"]

        torchvision.utils.save_image(
            rendering_override,
            os.path.join(render_override_path, "{0:05d}".format(idx) + ".png"),
        )


def render_sets(
    dataset: ModelParams,
    iteration: int,
    pipeline: PipelineParams,
    skip_train: bool,
    skip_test: bool,
    skip_video: bool,
    path_type: str,
):
    with torch.no_grad():

        gaussians = GaussianModelPlanes(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
            render_set(
                dataset.model_path,
                "train",
                scene.loaded_iter,
                scene.getTrainCameras(),
                gaussians,
                pipeline,
                background,
            )

        if not skip_test:
            render_set(
                dataset.model_path,
                "test",
                scene.loaded_iter,
                scene.getTestCameras(),
                gaussians,
                pipeline,
                background,
            )

        if not skip_video:
            print("render videos ...")
            traj_dir = os.path.join(
                args.model_path, "traj", "ours_{}".format(scene.loaded_iter)
            )
            os.makedirs(traj_dir, exist_ok=True)
            n_frames = 240
            cam_traj = generate_path(
                scene.getTrainCameras(),
                n_frames=n_frames,
                path_type=path_type,
            )
            render_set(
                dataset.model_path,
                "traj",
                scene.loaded_iter,
                cam_traj,
                gaussians,
                pipeline,
                background,
            )

            create_videos(
                base_dir=traj_dir,
                input_dir=traj_dir,
                out_name="render_traj",
                num_frames=n_frames,
            )


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--skip_video", action="store_true")
    parser.add_argument(
        "--path_type",
        default="interpolated",
        choices=["interpolated", "ellipse"],
        type=str,
    )

    parser.add_argument("--quiet", action="store_true")
    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    render_sets(
        model.extract(args),
        args.iteration,
        pipeline.extract(args),
        args.skip_train,
        args.skip_test,
        args.skip_video,
        args.path_type,
    )
