#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import os.path as osp
import sys
import torch
from tqdm import tqdm
import torchvision
from time import time
import imageio as iio
import numpy as np
import concurrent.futures
import yaml
import json
import matplotlib.cm as cm
from argparse import ArgumentParser, Namespace

from arguments import ModelParams, PipelineParams, OptimizationParams, get_combined_args
from scene import Scene, GaussianModel
from gaussian_renderer import render, query
from utils.general_utils import safe_state
from utils.image_utils import metric_vol, metric_proj


def testing(
    dataset: ModelParams,
    opt: OptimizationParams,
    pipeline: PipelineParams,
    iteration: int,
    skip_proj_train: bool,
    skip_proj_test: bool,
    skip_vol: bool,
):
    with torch.no_grad():
        with open(osp.join(dataset.source_path, "meta_data.json"), "r") as handle:
            meta_data = json.load(handle)
        scanner_cfg = meta_data["scanner"]
        bbox = torch.tensor(meta_data["bbox"])
        sVoxel = torch.tensor(scanner_cfg["sVoxel"])
        nVoxel = torch.tensor(scanner_cfg["nVoxel"])
        dVoxel = sVoxel / nVoxel
        scale_min_bound = opt.scale_min_bound * float(dVoxel.min())
        max_scale = opt.max_scale * float(sVoxel.min()) if opt.max_scale else None
        scale_max_bound = opt.scale_max_bound * float(dVoxel.min())
        densify_scale_threshold = (
            opt.densify_scale_threshold * float(sVoxel.min())
            if opt.densify_scale_threshold
            else None
        )
        gaussians = GaussianModel([scale_min_bound, scale_max_bound])
        scene = Scene(
            dataset,
            gaussians,
            load_iteration=iteration,
            shuffle=False,
        )
        save_path = osp.join(
            dataset.model_path,
            "eval",
            "iter_{}".format(scene.loaded_iter),
        )

        # Evaluate projection train
        if not skip_proj_train:
            evaluate_projection(
                save_path,
                "proj_train",
                scene.getTrainCameras(),
                gaussians,
                pipeline,
                meta_data,
            )
        # Evaluate projection test
        if not skip_proj_test:
            evaluate_projection(
                save_path,
                "proj_test",
                scene.getTestCameras(),
                gaussians,
                pipeline,
                meta_data,
            )
        # Evaluate volume reconstruction
        if not skip_vol:
            evaluate_volume(
                save_path,
                "volume",
                meta_data,
                gaussians,
                pipeline,
                scene.vol_gt,
            )


def evaluate_volume(
    save_path,
    name,
    meta_data,
    gaussians: GaussianModel,
    pipeline: PipelineParams,
    vol_gt,
):
    """Evaluate volume reconstruction."""
    save_path = osp.join(save_path, name)
    slice_save_path = osp.join(save_path, "slices")
    os.makedirs(slice_save_path, exist_ok=True)
    scanner_cfg = meta_data["scanner"]

    start_time = time()
    query_pkg = query(
        gaussians,
        scanner_cfg["offOrigin"],
        scanner_cfg["nVoxel"],
        scanner_cfg["sVoxel"],
        pipeline,
    )
    duration = time() - start_time
    vol_pred = query_pkg["vol"].clip(0, 1)

    psnr_3d, _ = metric_vol(vol_gt, vol_pred, "psnr")
    ssim_3d, ssim_3d_axis = metric_vol(vol_gt, vol_pred, "ssim")

    multithread_write(
        [vol_gt[..., i][None] for i in range(vol_gt.shape[2])],
        slice_save_path,
        "_gt",
    )
    multithread_write(
        [vol_pred[..., i][None] for i in range(vol_pred.shape[2])],
        slice_save_path,
        "_render",
    )
    eval_dict = {
        "psnr_3d": psnr_3d,
        "ssim_3d": ssim_3d,
        "ssim_3d_x": ssim_3d_axis[0],
        "ssim_3d_y": ssim_3d_axis[1],
        "ssim_3d_z": ssim_3d_axis[2],
        "duration (sec)": duration,
    }

    with open(osp.join(save_path, "eval.yml"), "w") as f:
        yaml.dump(eval_dict, f, default_flow_style=False, sort_keys=False)

    np.save(osp.join(save_path, "vol_gt.npy"), vol_gt.cpu().numpy())
    np.save(osp.join(save_path, "vol_pred.npy"), vol_pred.cpu().numpy())
    print("{} query complete! psnr_3d: {}, ssim_3d: {}".format(name, psnr_3d, ssim_3d))


def evaluate_projection(save_path, name, views, gaussians, pipeline, meta_data):
    """Evaluate projection rendering."""
    save_path = osp.join(save_path, name)
    proj_save_path = osp.join(save_path, "projections")

    # If already rendered, skip.
    if osp.exists(osp.join(save_path, "eval.yml")):
        print("{} in {} already rendered. Skip.".format(name, save_path))
        return
    os.makedirs(proj_save_path, exist_ok=True)

    gt_list = []
    render_list = []
    for idx, view in enumerate(tqdm(views, desc="render {}".format(name))):
        if idx == 0:
            time1 = time()
        rendering = render(view, gaussians, pipeline)["render"]
        gt = view.original_image[0:3, :, :]
        gt_list.append(gt)
        render_list.append(rendering)

    time2 = time()
    fps = (len(views) - 1) / (time2 - time1)
    multithread_write(gt_list, proj_save_path, "_gt")
    multithread_write(render_list, proj_save_path, "_render")

    images = torch.concat(render_list, 0).permute(1, 2, 0)
    gt_images = torch.concat(gt_list, 0).permute(1, 2, 0)
    psnr_2d, psnr_2d_projs = metric_proj(gt_images, images, "psnr")
    ssim_2d, ssim_2d_projs = metric_proj(gt_images, images, "ssim")
    eval_dict = {
        "psnr_2d": psnr_2d,
        "ssim_2d": ssim_2d,
        "psnr_2d_projs": psnr_2d_projs,
        "ssim_2d_projs": ssim_2d_projs,
    }
    with open(osp.join(save_path, "eval.yml"), "w") as f:
        yaml.dump(eval_dict, f, default_flow_style=False, sort_keys=False)
    print(
        "{} render complete. psnr_2d: {}. ssim_2d: {}. fps: {}".format(
            name, eval_dict["psnr_2d"], eval_dict["ssim_2d"], fps
        )
    )


def multithread_write(image_list, path, suffix):
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=None)

    def write_image(image, count, path):
        try:
            torchvision.utils.save_image(
                image, osp.join(path, "{0:05d}".format(count) + "{}.png".format(suffix))
            )
            np.save(
                osp.join(path, "{0:05d}".format(count) + "{}.npy".format(suffix)),
                image.cpu().numpy()[0],
            )
            return count, True
        except:
            return count, False

    tasks = []
    for index, image in enumerate(image_list):
        tasks.append(executor.submit(write_image, image, index, path))
    executor.shutdown()
    for index, status in enumerate(tasks):
        if status == False:
            write_image(image_list[index], index, path)


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    opt = OptimizationParams(parser)
    pipeline = PipelineParams(parser)

    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_proj_train", action="store_true", default=False)
    parser.add_argument("--skip_proj_test", action="store_true", default=False)
    parser.add_argument("--skip_vol", action="store_true", default=False)
    args = get_combined_args(parser)

    safe_state(args.quiet)

    testing(
        model.extract(args),
        opt.extract(args),
        pipeline.extract(args),
        args.iteration,
        args.skip_proj_train,
        args.skip_proj_test,
        args.skip_vol,
    )
