"""
SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-NvidiaProprietary
NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
property and proprietary rights in and to this material, related
documentation and any modifications thereto. Any use, reproduction,
disclosure or distribution of this material and related documentation
without an express license agreement from NVIDIA CORPORATION or
its affiliates is strictly prohibited.
"""

import argparse
import os
import pickle
import sys
from pathlib import Path
import re
from typing import Tuple, Any, Optional
import time
import shutil

import numpy as np
import torch
from easydict import EasyDict as edict
from torch import nn
from torch.utils.data import Dataset
from tqdm import tqdm
from skimage.metrics import structural_similarity
import lpips
from collections import OrderedDict
import torch.nn.functional as F
try:
    import imageio.v2 as imageio
except:
    import imageio
import pandas as pd

sys.path.append(".")
from lib import utils, temporalpoints
from run import render_viewpoints

# best forward scores
lpips_metric_alex = lpips.LPIPS(net='alex').to('cuda')
# # closer to "traditional" perceptual loss, when used for optimization
lpips_metric_vgg = lpips.LPIPS(net='vgg').to('cuda') # used in other papers


def resize_image_area(im: np.array, size: np.array) -> np.array:
    """
    Resizes image with area filter.
    """
    im = torch.from_numpy(im).permute(2, 0, 1)
    im = F.interpolate(im[None, ...], tuple(size[::-1]), mode='area')[0, ...]
    im = im.permute(1, 2, 0).numpy()
    return im


def error_mse(im_pred: np.array, im_gt: np.array, mask: np.array = None):
    """
    Computes MSE metric. Optionally applies mask.
    """
    # Linearize.
    im_pred = im_pred[..., :3].reshape(-1, 3)
    im_gt = im_gt[..., :3].reshape(-1, 3)

    # Mask?
    if mask is not None:
        mask = mask.flatten()
        im_pred = im_pred[mask, :]
        im_gt = im_gt[mask, :]

    mse = (im_pred - im_gt) ** 2
    return mse.mean()


def error_psnr(im_pred: np.array, im_gt: np.array, mask: np.array = None):
    """
    Computes PSNR metric. Optionally applies mask.
    Assumes floats [0,1].
    """
    mse = error_mse(im_pred, im_gt, mask)
    # https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
    return 20 * np.log10(1.0) - 10 * np.log10(mse)


def error_ssim(im_pred: np.array, im_gt: np.array, mask: np.array = None):
    """
    Computes SSIM metric. Optionally applies mask.
    """
    # RGB
    im_pred = im_pred[..., :3]
    im_gt = im_gt[..., :3]

    # Mask?
    if mask is not None:
        mask = mask.reshape(im_pred.shape[0], im_pred.shape[1], 1).repeat(3, axis=2)
        im_pred = im_pred * mask
        im_gt = im_gt * mask

    return structural_similarity(im_pred, im_gt, data_range=1.0, channel_axis=-1)


def error_lpips(im_pred: np.array, im_gt: np.array, mask: np.array = None, flavor='alex'):
    """
    Computes LPIPS metric. Optionally applies mask.
    """
    # RGB
    im_pred = im_pred[..., :3]
    im_gt = im_gt[..., :3]

    # Mask?
    if mask is not None:
        mask = mask.reshape(im_pred.shape[0], im_pred.shape[1], 1).repeat(3, axis=2)
        im_pred = im_pred * mask
        im_gt = im_gt * mask

    if np.max(im_pred.shape) > 5000:
        # Too large (SCU dataset)
        new_size = np.array([im_pred.shape[1], im_pred.shape[0]], int) // 2
        im_pred = resize_image_area(im_pred, new_size)
        im_gt = resize_image_area(im_gt, new_size)

    # To torch.
    device = 'cuda'
    im_pred = torch.from_numpy(im_pred).permute(2, 0, 1)[None, ...].to(device) * 2 - 1
    im_gt = torch.from_numpy(im_gt).permute(2, 0, 1)[None, ...].to(device) * 2 - 1

    # Make metric.       

    # Compute metric.
    if flavor == 'alex':
        loss = lpips_metric_alex(im_pred, im_gt)
    else:
        loss = lpips_metric_vgg(im_pred, im_gt)

    return loss.item()

def error_iou(mask_pred: np.array, mask_gt: np.array):
    """
    Measures area of intersection over union.
    """
    intersection = np.logical_and(mask_pred, mask_gt)
    union = np.logical_or(mask_pred, mask_gt)
    return intersection.sum() / union.sum()

def imwritef(filename, im):
    """
    Saves float image.
    """
    imageio.imwrite(filename, (np.clip(im, 0, 1) * 255).astype(np.uint8))

@torch.no_grad()
def compute_metrics(render_poses, HW, Ks, test_times, render_viewpoints_kwargs, imgs, output_path, iteration, frame_index, view_index):
    """
    Measures error of 2D projection for single view.
    """
    # Raytrace view.
    start_time = time.time()
    rgb, disp, weight, _ = render_viewpoints(
        render_poses=render_poses,
        HW=HW,
        Ks=Ks,
        test_times=test_times,
        verbose=False,
        **render_viewpoints_kwargs
        )
    rt_time = time.time() - start_time
    
    im_gt = imgs
    im_pred = rgb[0]
    mask_gt = (im_gt == 1.).all(dim=-1)
    mask_pred = disp[0] > 0.5

    # Reshape.
    im_gt = im_gt.cpu().numpy()
    # im_pred = im_pred.permute(1,2,0).cpu().numpy()
    mask_gt = mask_gt.cpu().numpy()
    # mask_pred = mask_pred.cpu().numpy()

    # GT.
    if output_path is not None:
        output_path.mkdir(0o777, True, True)
        imwritef(output_path / f'f{frame_index:06d}_v{view_index:03d}_color_gt.png', im_gt)
        imwritef(output_path / f'f{frame_index:06d}_v{view_index:03d}_mask_gt.png', mask_gt)
        imwritef(output_path / f'f{frame_index:06d}_v{view_index:03d}_color_pred.png', im_pred)
        imwritef(output_path / f'f{frame_index:06d}_v{view_index:03d}_mask_pred.png', mask_pred)

        # # Mask using GT mask and save for presentation purposes.
        # masked_path = output_path / 'masked_gt'
        # masked_path.mkdir(0o777, True, True)
        # im_pred_masked = mask_image(im_pred, mask_gt, (1, 1, 1), close_hole_size=0)
        # imwritef(masked_path / f'f{frame_index:06d}_v{view_index:03d}_masked.png', im_pred_masked)

        # # Mask using RT mask and save for presentation purposes.
        # masked_path = output_path / 'masked_rt'
        # masked_path.mkdir(0o777, True, True)
        # im_pred_masked = render['viz']['color_masked']
        # imwritef(masked_path / f'f{frame_index:06d}_v{view_index:03d}_masked.png', im_pred_masked)

    # Error.
    return OrderedDict([
        ('iteration', iteration),
        ('frame_id', frame_index),
        ('view_id', view_index),
        #('loss', loss),
        ('render_time', rt_time),
        ('width', im_gt.shape[1]),
        ('height', im_gt.shape[0]),
        ('psnr', error_psnr(im_pred, im_gt)),
        ('ssim', error_ssim(im_pred, im_gt)),
        ('lpips', error_lpips(im_pred, im_gt, flavor='alex')),
        ('lpips_vgg', error_lpips(im_pred, im_gt, flavor='vgg')),
        #('psnr_mask_pred', error_psnr(im_pred, im_gt, mask_pred)),
        #('ssim_mask_pred', error_ssim(im_pred, im_gt, mask_pred)),
        ('psnr_mask_gt', error_psnr(im_pred, im_gt, mask_gt)),
        ('ssim_mask_gt', error_ssim(im_pred, im_gt, mask_gt)),
        ('lpips_mask_gt', error_lpips(im_pred, im_gt, mask_gt, flavor='alex')),
        ('lpips_vgg_mask_gt', error_lpips(im_pred, im_gt, mask_gt, flavor='vgg')),
        #('mask_valid_pred', mask_pred.sum()),
        #('mask_valid_gt', mask_gt.sum()),
        ('mask_iou', error_iou(mask_pred, mask_gt)),
    ])



def test(data_dict: dict, render_viewpoints_kwargs: dict, save_dir: str, ckpt_dir: str, mode: str = "test") -> None:
    """

    Args:
        config_path:
        default_config:
        mode:

    Returns:

    """
    assert mode in ["test", "novel_pose"]
    
    test_cams = np.unique(data_dict['img_to_cam'][data_dict['i_test']])
    test_data_num_frames = len(data_dict['images'][data_dict['i_test']]) // len(test_cams)
    num_test_frames = min(10, test_data_num_frames)
    frame_ids = np.arange(test_data_num_frames)
    frames_for_test = np.linspace(0, test_data_num_frames - 1, num_test_frames).round().astype(int)
    all_test_times = data_dict['times'][data_dict['i_test']]
    view_test_times = all_test_times[::len(test_cams)]
    test_times = view_test_times[frames_for_test]

    os.makedirs(save_dir, exist_ok=True)
    res_dir = Path(save_dir) / 'test'
    shutil.rmtree(res_dir, True) # WARNING!!!! WILL DELETE FOLDER!!!
    res_dir.mkdir(0o777, True, True)
    save_frames_dir = res_dir / 'frames'

    # Scan snapshots.
    iterations = []
    for filename in Path(ckpt_dir).iterdir():
        m = re.match('.*\_(\d+)\.tar', filename.name)
        if not m:
            continue
        iterations += [int(m.group(1))]
    iterations = np.array(sorted(iterations))
    print(f'Loaded {len(iterations)} iterations from {iterations[0]} to {iterations[-1]}.')
    num_keep_iters = min(30, len(iterations))
    sub_iter_idsx = np.linspace(0, len(iterations) - 1, num_keep_iters).round().astype(int)
    # iterations = iterations[sub_iter_idsx]
    iterations = iterations[-1:]
    print(f'Evaluation on {len(iterations)} iterations from {iterations[0]} to {iterations[-1]}.')

    num_keep_imgs = min(10, len(iterations))
    save_imgs_idsx = np.linspace(0, len(iterations) - 1, num_keep_imgs).round().astype(int)
    save_imgs_iters = iterations
    # save_imgs_iters = iterations[-3:]
    print(f'Will save imgs from iterations {save_imgs_iters}...')

    res = OrderedDict()

    for iteration in tqdm(iterations, desc='Iterations', position=0, leave=True):

        ckpt_path = os.path.join(ckpt_dir, f'temporalpoints_{iteration}.tar')
        model_class = temporalpoints.TemporalPoints
        
        model = utils.load_model(model_class, ckpt_path).to('cuda')
        render_viewpoints_kwargs['model'] = model

        tqdm.write(f'Loading snapshot from {ckpt_path}...')

        iter_save_img_dir = None
        if iteration in save_imgs_iters:
            iter_save_img_dir = save_frames_dir / f'iter_{iteration:08d}'

        for time in tqdm(test_times, desc='Frames', position=1, leave=False):
            video_frame_idx = torch.where(view_test_times == time)[0].item()
            for cam_idx in test_cams:
                fra_idx = torch.where((data_dict['times'] == time).cpu() & (data_dict['img_to_cam'] == cam_idx))[0].item()

                test_img = data_dict['images'][fra_idx]
                test_pose = data_dict['poses'][data_dict['img_to_cam'][fra_idx]][None,...]
                test_Ks = data_dict['Ks'][data_dict['img_to_cam'][fra_idx]][None,...]
                test_HW = data_dict['HW'][fra_idx][None,...]
                test_time = data_dict['times'][fra_idx][None,...]

                row = compute_metrics(test_pose, test_HW, test_Ks, test_time, render_viewpoints_kwargs, test_img, iter_save_img_dir, iteration, video_frame_idx, cam_idx)
                
                for k,v in row.items():
                    if k not in res:
                        res[k] = []
                    res[k] += [v]

                tqdm.write(f'Iter = {iteration} | Frame = {video_frame_idx} | View = {cam_idx} | PSNR = {row["psnr"]:3f}')

        # for fra_idx in tqdm(frames_for_test, desc='Frames', position=1, leave=False):
        #     for cam_idx in camera_for_test:
        #         row = compute_metrics(test_poses, test_HW, test_Ks, test_times, render_viewpoints_kwargs, test_imgs, iter_save_img_dir, iteration, fra_idx, cam_idx)

        #         for k,v in row.items():
        #             if k not in res:
        #                 res[k] = []
        #             res[k] += [v]

        #         tqdm.write(f'Iter = {iteration} | Frame = {fra_idx} | View = {cam_idx} | PSNR = {row["psnr"]:3f}')
                

    res = pd.DataFrame(res)
    res.to_excel(res_dir / 'stats.xlsx')
    print(res.describe())
       


# if __name__ == "__main__":
    # evaluate novel view and novel pose reconstruction
    # novel view -> learned pose, new camera
    # novel pose -> novel pose, all camera. Requires smpl regression
    # parser = argparse.ArgumentParser(description='Save reconstructed images')
    # parser.add_argument('--exp_name', action='append', required=True)
    # args = parser.parse_args()
    # exp_names = args.exp_name

    # default_config = "confs/default.yml"
    # for exp_name in exp_names:
    #     config_path = f"confs/{exp_name}.yml"
    #     test(config_path, default_config, mode="test")  # novel view
    #     #test(config_path, default_config, mode="novel_pose")
