#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import sys
import time
import torch
import pyexr

from utils.loss_utils import ssim
from gaussian_renderer import render
from scene import Scene, GaussianModel

from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams
from torchvision.utils import save_image

from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from lpipsPyTorch import lpips


def tone_map(image: torch.Tensor, mu=5000.0):
    return torch.log(1 + mu * image) / torch.tensor(mu + 1, device=image.device).log()


def testing(dataset, opt, pipe, checkpoint: str,
             gaussian_dim, time_duration, num_pts, num_pts_ratio, rot_4d, force_sh_3d, test: bool = False):
    if dataset.frame_ratio > 1:
        time_duration = [time_duration[0] / dataset.frame_ratio, time_duration[1] / dataset.frame_ratio]

    gaussians = GaussianModel(dataset.sh_degree, gaussian_dim=gaussian_dim, time_duration=time_duration, rot_4d=rot_4d,
                              force_sh_3d=force_sh_3d, sh_degree_t=2 if pipe.eval_shfs_4d else 0)
    scene = Scene(dataset, gaussians, num_pts=num_pts, num_pts_ratio=num_pts_ratio, time_duration=time_duration)
    gaussians.training_setup(opt)

    (model_params, iteration, hist_lum) = torch.load(checkpoint)
    gaussians.restore(model_params, opt)

    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    with torch.no_grad():
        if not test:
            inference(scene, render, (pipe, background), hist_lum, dataset, iteration=iteration)
        else:
            testing_report(scene, render, (pipe, background), hist_lum, dataset, iteration=iteration)


def testing_report(scene: Scene, renderFunc, renderArgs, hist_luminance=None, args=None, iteration=None):
    os.makedirs(os.path.join(args.ckpt_dir, "images"), exist_ok=True)
    validation_configs = (
        {'name': 'train',
        'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]},
        {'name': 'test', 'cameras': [scene.getTestCameras()[idx] for idx in range(len(scene.getTestCameras()))]},)

    for config in validation_configs:
        test_time = 0.0
        if config['cameras'] and len(config['cameras']) > 0:
            psnr_test = 0.0
            ssim_test = 0.0
            lpips_test = 0.0

            psnr_hdr_test = 0.0
            ssim_hdr_test = 0.0
            lpips_hdr_test = 0.0
            for batch_data in tqdm(config['cameras']):
                gt_image, viewpoint = batch_data
                gt_image = gt_image.cuda()
                viewpoint = viewpoint.cuda()
                    
                start = time.time()
                render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs,
                                        hist_luminance=hist_luminance, train=False, iteration=iteration)
                end = time.time()
                test_time += end - start
                image = torch.clamp(render_pkg["render"], 0.0, 1.0)

                gt_image_hdr = viewpoint.image_hdr
                if gt_image_hdr is not None:
                    gt_image_hdr = torch.from_numpy(gt_image_hdr).cuda()
                    gt_image_hdr = tone_map(gt_image_hdr)

                image_hdr = render_pkg["render_hdr"]
                if image_hdr.max() > 0:
                    image_hdr = torch.clamp(image_hdr / image_hdr.max(), 0.0, 1.0)
                image_hdr = tone_map(image_hdr)

                psnr_test += psnr(image, gt_image).mean().double()
                ssim_test += ssim(image, gt_image).mean().double()
                lpips_test += lpips(image[None], gt_image[None]).item()
                
                # save_image(image, os.path.join(args.ckpt_dir, "images", "ldr_{}_{}.png".format(iteration, viewpoint.image_name)))

                if gt_image_hdr is not None:
                    psnr_hdr_test += psnr(image_hdr, gt_image_hdr).mean().double()
                    ssim_hdr_test += ssim(image_hdr, gt_image_hdr).mean().double()
                    lpips_hdr_test += lpips(image_hdr[None], gt_image_hdr[None]).item()
                    
                # pyexr.write(os.path.join(args.ckpt_dir, "images", "hdr_{}_{}.exr".format(iteration, viewpoint.image_name)), image_hdr.permute(1, 2, 0).cpu().numpy())

            psnr_test /= len(config['cameras'])
            ssim_test /= len(config['cameras'])
            lpips_test /= len(config['cameras'])

            psnr_hdr_test /= len(config['cameras'])
            ssim_hdr_test /= len(config['cameras'])
            lpips_hdr_test /= len(config['cameras'])
            fps = len(config['cameras']) / test_time

            print("\n[ITER {}] Evaluating LDR {}: PSNR {} SSIM {} LPIPS {} FPS {:.2f}".format(iteration, config['name'], psnr_test, ssim_test, lpips_test, fps))

            if psnr_hdr_test > 0.0:
                print("[ITER {}] Evaluating HDR {}: PSNR {} SSIM {} LPIPS {}".format(iteration, config['name'], psnr_hdr_test, ssim_hdr_test, lpips_hdr_test))

    torch.cuda.empty_cache()


def inference(scene: Scene, renderFunc, renderArgs, hist_luminance=None, args=None, iteration=None):
    os.makedirs(os.path.join(args.ckpt_dir, "inference"), exist_ok=True)
    validation_configs = (
        {'name': 'train', 'cameras': scene.getTrainCameras()},
        {'name': 'test', 'cameras': scene.getTestCameras()})

    for config in validation_configs:
        if config['cameras'] and len(config['cameras']) > 0:
            for batch_data in tqdm(config['cameras']):
                gt_image, viewpoint = batch_data
                
                viewpoint.exp_time = args.exp_time
                # gt_image = gt_image.cuda()
                viewpoint = viewpoint.cuda()

                render_pkg = renderFunc(viewpoint, scene.gaussians, *renderArgs,
                                        hist_luminance=hist_luminance, train=False, iteration=iteration)
                image = torch.clamp(render_pkg["render"], 0.0, 1.0)

                save_image(image, os.path.join(args.ckpt_dir, "inference", "ldr_{}_{}_{}.png".format(iteration, viewpoint.image_name, args.exp_time)))

                # image_hdr = render_pkg["render_hdr"]
                # if image_hdr.max() > 0:
                #     image_hdr = torch.clamp(image_hdr / image_hdr.max(), 0.0, 1.0)
                # image_hdr = tone_map(image_hdr)

                # pyexr.write(os.path.join(args.ckpt_dir, "images", "hdr_{}_{}.exr".format(iteration, viewpoint.image_name)), image_hdr.permute(1, 2, 0).cpu().numpy())

            torch.cuda.empty_cache()


def recursive_merge(key, host):
    if isinstance(host[key], DictConfig):
        for key1 in host[key].keys():
            recursive_merge(key1, host[key])
    else:
        if hasattr(args, key):
            setattr(args, key, host[key])
        else:
            print(f"Key {key} not found in args, skipping merge.")
        

if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    parser.add_argument("--config", type=str)
    parser.add_argument("--ckpt_dir", type=str)
    parser.add_argument("--test", action="store_true", default=False)
    parser.add_argument("--exp_time", type=float, default=1.0)
    parser.add_argument("--gaussian_dim", type=int, default=4)
    parser.add_argument("--time_duration", nargs=2, type=float, default=[0.0, 1.0])
    parser.add_argument('--num_pts', type=int, default=100_000)
    parser.add_argument('--num_pts_ratio', type=float, default=1.0)
    parser.add_argument("--rot_4d", action="store_true")
    parser.add_argument("--force_sh_3d", action="store_true")

    args = parser.parse_args(sys.argv[1:])
    
    print(f"Exposure time: {args.exp_time}")
    cfg = OmegaConf.load(args.config)

    for k in cfg.keys():
        recursive_merge(k, cfg)

    print("Loading from " + args.ckpt_dir)
    
    lp = lp.extract(args)
    op = op.extract(args)
    pp = pp.extract(args)
    
    lp.exp_time = args.exp_time
    lp.ckpt_dir = args.ckpt_dir

    ckpt_path = os.path.join(args.ckpt_dir, "chkpnt_best.pth")  # "_hdr.pth"

    testing(lp, op, pp, ckpt_path,
             args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d, test=args.test)

    # All done
    print("\nTesting complete.\n")
