#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms.functional as tf
from utils.loss_utils import ssim
from lpipsPyTorch import lpips, LPIPS
import json
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser
from pytorch_msssim import ms_ssim
from multiprocessing.pool import ThreadPool
def readImages(renders_dir, gt_dir):
    renders = []
    gts = []
    image_names = []
    for fname in os.listdir(renders_dir):
        render = Image.open(renders_dir / fname)
        gt = Image.open(gt_dir / fname)
        renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
        gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
        image_names.append(fname)
    return renders, gts, image_names

def readImages2(renders_dir, gt_dir):
    renders = []
    gts = []
    image_names = []
    for fname in os.listdir(renders_dir):
        render = Image.open(renders_dir / fname)
        cam, id, _ = fname.split('_')
        fname = f'{cam}_{id}.png'
        gt = Image.open(gt_dir / fname)
        renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
        gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
        image_names.append(fname)
    return renders, gts, image_names

def evaluate(model_paths, target_method=[]):

    full_dict = {}
    per_view_dict = {}
    full_dict_polytopeonly = {}
    per_view_dict_polytopeonly = {}
    per_cam_dict = {}
    print("")

    lpips_vgg = LPIPS(net_type='vgg').cuda()
    lpips_alex = LPIPS(net_type='alex').cuda()
    for scene_dir in model_paths:
        try:
            print("Scene:", scene_dir)
            full_dict[scene_dir] = {}
            per_view_dict[scene_dir] = {}
            full_dict_polytopeonly[scene_dir] = {}
            per_view_dict_polytopeonly[scene_dir] = {}
            per_cam_dict[scene_dir] = {}

            test_dir = Path(scene_dir) / "test"

            for method in os.listdir(test_dir):
                if len(target_method) > 0 and not method in target_method:
                    continue
                print("Method:", method)

                full_dict[scene_dir][method] = {}
                per_view_dict[scene_dir][method] = {}
                full_dict_polytopeonly[scene_dir][method] = {}
                per_view_dict_polytopeonly[scene_dir][method] = {}
                per_cam_dict[scene_dir][method] = {}

                method_dir = test_dir / method
                gt_dir = method_dir/ "gt"
                renders_dir = method_dir / "renders"
                if not renders_dir.exists():
                    renders_dir = method_dir / "renders_opt"
                # renders, gts, image_names = readImages(renders_dir, gt_dir)
                # renders, gts, image_names = readImages2(renders_dir, gt_dir)

                ssims = []
                psnrs = []
                lpipss = []
                lpipsa = []
                ms_ssims = []
                Dssims = []

                image_names = []
                fnames = sorted(os.listdir(renders_dir))

                for fname in tqdm(fnames):
                # tbar = tqdm(range(len(fnames)))
                # def metric_fn(fname):
                #     tbar.update(1)
                    cam, id, _ = fname.split('_')
                    render = Image.open(renders_dir / fname)
                    render = tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()

                    fname = f'{cam}_{id}.png'
                    gt = Image.open(gt_dir / fname)
                    gt = tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()

                    ssim_ = ssim(render, gt)
                    psnr_ = psnr(render, gt)
                    lpips_ = lpips_vgg(render, gt)
                    # ms_ssim_ = ms_ssim(render, gt,data_range=1, size_average=True)
                    # lpipsa_ = lpips(render, gt, net_type='alex')
                    # Dssim_ = (1-ms_ssim_)/2

                    # return fname, ssim_, psnr_, lpips_, cam

                    image_names.append(fname)
                    ssims.append(ssim_)
                    psnrs.append(psnr_)
                    lpipss.append(lpips_)
                    # ms_ssims.append(ms_ssim_)
                    # lpipsa.append(lpipsa_)
                    # Dssims.append(Dssim_)
                    if not cam in per_cam_dict[scene_dir][method]:
                        per_cam_dict[scene_dir][method][cam] = {"SSIM": [ssim_], "PSNR": [psnr_], "LPIPS-vgg": [lpips_]}
                    else:
                        per_cam_dict[scene_dir][method][cam]["SSIM"].append(ssim_)
                        per_cam_dict[scene_dir][method][cam]["PSNR"].append(psnr_)
                        per_cam_dict[scene_dir][method][cam]["LPIPS-vgg"].append(lpips_)

                # with ThreadPool() as pool:
                #     metric_infos = pool.map(metric_fn, fnames)
                #     pool.close()
                #     pool.join()
                # for metric_info in tqdm(metric_infos):
                #     fname, ssim_, psnr_, lpips_, cam = metric_info
                #     image_names.append(fname)
                #     ssims.append(ssim_)
                #     psnrs.append(psnr_)
                #     lpipss.append(lpips_)
                #     if not cam in per_cam_dict[scene_dir][method]:
                #         per_cam_dict[scene_dir][method][cam] = {"SSIM": [ssim_], "PSNR": [psnr_], "LPIPS-vgg": [lpips_]}
                #     else:
                #         per_cam_dict[scene_dir][method][cam]["SSIM"].append(ssim_)
                #         per_cam_dict[scene_dir][method][cam]["PSNR"].append(psnr_)
                #         per_cam_dict[scene_dir][method][cam]["LPIPS-vgg"].append(lpips_)




                # for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
                #     ssims.append(ssim(renders[idx], gts[idx]))
                #     psnrs.append(psnr(renders[idx], gts[idx]))
                #     lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
                #     # ms_ssims.append(ms_ssim(renders[idx], gts[idx],data_range=1, size_average=True ))
                #     # lpipsa.append(lpips(renders[idx], gts[idx], net_type='alex'))
                #     # Dssims.append((1-ms_ssims[-1])/2)

                # tbar = tqdm(range(len(renders)), desc="Metric evaluation progress")
                # def eval_metrics(eval_data):
                #     gt = eval_data[0]
                #     render = eval_data[1]
                #     msssim = ms_ssim(render, gt, data_range=1., size_average=True)

                #     ret = {
                #         'ssim': ssim(render, gt),
                #         'psnr': psnr(render, gt),
                #         'lpipsv': lpips(render, gt, net_type='vgg'),
                #         'ms_ssim': msssim,
                #         'lpipsa': lpips(render, gt, net_type='alex'),
                #         'dssim': (1-msssim)/2
                #     }
                #     tbar.update(1)
                #     return ret
                # with ThreadPool() as pool:
                #     metrics = pool.map(eval_metrics, zip(gts, renders))
                #     pool.close()
                #     pool.join()

                # ssims = [metric['ssim'] for metric in metrics if metric is not None]
                # psnrs = [metric['psnr'] for metric in metrics if metric is not None]
                # lpipss = [metric['lpipsv'] for metric in metrics if metric is not None]
                # ms_ssims = [metric['ms_ssim'] for metric in metrics if metric is not None]
                # lpipsa = [metric['lpipsa'] for metric in metrics if metric is not None]
                # Dssims = [metric['dssim'] for metric in metrics if metric is not None]
                # tbar.close()

                print("Scene: ", scene_dir,  "SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
                print("Scene: ", scene_dir,  "PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
                print("Scene: ", scene_dir,  "LPIPS-vgg: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
                # print("Scene: ", scene_dir,  "LPIPS-alex: {:>12.7f}".format(torch.tensor(lpipsa).mean(), ".5"))
                # print("Scene: ", scene_dir,  "MS-SSIM: {:>12.7f}".format(torch.tensor(ms_ssims).mean(), ".5"))
                # print("Scene: ", scene_dir,  "D-SSIM: {:>12.7f}".format(torch.tensor(Dssims).mean(), ".5"))

                full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
                                                        "PSNR": torch.tensor(psnrs).mean().item(),
                                                        "LPIPS-vgg": torch.tensor(lpipss).mean().item(),
                                                        # "LPIPS-alex": torch.tensor(lpipsa).mean().item(),
                                                        # "MS-SSIM": torch.tensor(ms_ssims).mean().item(),
                                                        # "D-SSIM": torch.tensor(Dssims).mean().item()
                                                    },

                                                    )
                # per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
                #                                             "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
                #                                             "LPIPS-vgg": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)},
                #                                             # "LPIPS-alex": {name: lp for lp, name in zip(torch.tensor(lpipsa).tolist(), image_names)},
                #                                             # "MS-SSIM": {name: lp for lp, name in zip(torch.tensor(ms_ssims).tolist(), image_names)},
                #                                             # "D-SSIM": {name: lp for lp, name in zip(torch.tensor(Dssims).tolist(), image_names)},

                #                                             }
                #                                         )

                for cam in per_cam_dict[scene_dir][method]:
                    per_cam_dict[scene_dir][method][cam]["SSIM"] = torch.tensor(per_cam_dict[scene_dir][method][cam]["SSIM"]).mean().item()
                    per_cam_dict[scene_dir][method][cam]["PSNR"] = torch.tensor(per_cam_dict[scene_dir][method][cam]["PSNR"]).mean().item()
                    per_cam_dict[scene_dir][method][cam]["LPIPS-vgg"] = torch.tensor(per_cam_dict[scene_dir][method][cam]["LPIPS-vgg"]).mean().item()

                with open(str(method_dir) + "/results.json", 'w') as fp:
                    json.dump(full_dict[scene_dir][method], fp, indent=True)
                # with open(scene_dir + "/per_view.json", 'w') as fp:
                #     json.dump(per_view_dict[scene_dir], fp, indent=True)
                with open(str(method_dir) + "/per_cam.json", 'w') as fp:
                    json.dump(per_cam_dict[scene_dir][method], fp, indent=True)
        except Exception as e:

            print("Unable to compute metrics for model", scene_dir)
            raise e

if __name__ == "__main__":
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
    parser.add_argument('--target_methods', '-t', nargs="+", type=str, default=["ours_30000"])
    args = parser.parse_args()
    evaluate(args.model_paths, args.target_methods)
