"""
Full evaluation script, including PSNR+SSIM evaluation with multi-GPU support.

python eval.py --gpu_id=<gpu list> -n <expname> -c <conf> -D /home/group/data/chairs -F srn
"""
import sys
import os

sys.path.insert(
    0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
)

import torch
import numpy as np
import imageio
import skimage.measure
import util
from data import get_split_dataset
from model import make_model
from render import NeRFRenderer, UnisurfRenderer, common
import cv2
import tqdm
import ipdb
import warnings
import wandb

from torchmetrics import MetricCollection, IoU, Accuracy
from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat
import matplotlib.cm as cm

#  from pytorch_memlab import set_target_gpu
#  set_target_gpu(9)

def class_mious(metric):
    """Computes intersection over union (IoU)"""
    return _jaccard_from_confmat(
        metric.confmat, metric.num_classes, metric.ignore_index, -1, "none"
    )

def extra_args(parser):
    parser.add_argument(
        "--split",
        type=str,
        default="val",
        help="Split of data to use train | val | test",
    )
    parser.add_argument(
        "--source",
        "-P",
        type=str,
        default="135",
        help="Source view(s) for each object. Alternatively, specify -L to viewlist file and leave this blank.",
    )
    parser.add_argument(
        "--eval_view_list", type=str, default=None, help="Path to eval view list"
    )
    parser.add_argument("--coarse", action="store_true", help="Coarse network as fine")
    parser.add_argument(
        "--no_compare_gt",
        action="store_true",
        help="Skip GT comparison (metric won't be computed) and only render images",
    )
    parser.add_argument(
        "--multicat",
        action="store_true",
        help="Prepend category id to object id. Specify if model fits multiple categories.",
    )
    parser.add_argument(
        "--viewlist",
        "-L",
        type=str,
        default="",
        help="Path to source view list e.g. src_dvr.txt; if specified, overrides source/P",
    )
    parser.add_argument(
        "--nviews",
        type=str,
        default="",
        help="Number of views during training, not used here",
    )
    parser.add_argument(
        "--batch_size", "-B", type=int, default=1, help="Object batch size ('SB')"
    )
    parser.add_argument(
        "--skip_count", type=int, default=1, help="views every skip_count"
    )
    parser.add_argument(
        "--output",
        "-O",
        type=str,
        default="",
        help="If specified, saves generated images to directory",
    )
    parser.add_argument(
        "--include_src", action="store_true", help="Include source views in calculation"
    )
    parser.add_argument(
        "--write_output", action="store_true", help="Write out files"
    )
    parser.add_argument(
        "--same_src_target", action="store_true", help="Use same source and target views"
    )
    parser.add_argument(
        "--scale", type=float, default=1.0, help="Video scale relative to input size"
    )
    parser.add_argument("--write_depth", action="store_true", help="Write depth image")
    parser.add_argument(
        "--write_compare", action="store_true", help="Write GT comparison image"
    )
    parser.add_argument(
        "--eval_all", action="store_true", help="Eval all target images"
    )
    parser.add_argument(
        "--free_pose",
        action="store_true",
        help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.",
    )
    parser.add_argument(
        "--background_grayscale", type=float, default=1.0, help="Background grayscale color, -1 to 1 (black to white)"
    )
    parser.add_argument(
        "--eval_ray_batch_size", type=int, default=80000, help="Eval ray batch size"
    )
    # parser.add_argument(
    #     "--test",
    #     action="store_true",
    #     help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.",
    # )
    parser.add_argument(
        "--use_last",
        action="store_true",
        help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.",
    )
    return parser


args, conf = util.args.parse_args(
    extra_args, default_conf="conf/resnet_fine_mv.conf", default_expname="shapenet",
)
# args.resume = True

device = util.get_cuda(args.gpu_id[0])

dset = get_split_dataset(
    args.dataset_format, args.datadir, want_split=args.split, training=False, level=args.level, category=args.category
)
n_classes = dset.n_classes
conf["renderer"]['n_classes'] = n_classes
conf["model"]['n_classes'] = n_classes
color_map = np.concatenate([np.array([[1., 1., 1.]]), cm.Paired_r(np.linspace(0, 1, 51))[:, :3]],axis=0)[::52//n_classes]

data_loader = torch.utils.data.DataLoader(
    dset, batch_size=1, shuffle=False, num_workers=5, pin_memory=False
)

output_dir = args.output.strip()
has_output = len(output_dir) > 0 and args.write_output
wandb.init(config=conf, name=args.name, project=f"SegNerf_2d_{args.split}")
sn = args.name.split('_')
args.name = f'{sn[0]}_{sn[1]}_{sn[2]}'
output_dir = os.path.join(output_dir, args.name)

total_psnr = 0.0
total_ssim = 0.0
cnt = 0

seg_metrics = MetricCollection({
                            'acc': Accuracy(ignore_index=0),
                            'macc': Accuracy(ignore_index=0, num_classes=n_classes, average='macro'),
                            'miou': IoU(ignore_index=0, num_classes=n_classes),
                            }).to(device)


if has_output:
    finish_path = os.path.join(output_dir, "finish.txt")
    os.makedirs(output_dir, exist_ok=True)
    if os.path.exists(finish_path):
        with open(finish_path, "r") as f:
            lines = [x.strip().split() for x in f.readlines()]
        lines = [x for x in lines if len(x) == 4]
        finished = set([x[0] for x in lines])
        total_psnr = sum((float(x[1]) for x in lines))
        total_ssim = sum((float(x[2]) for x in lines))
        cnt = sum((int(x[3]) for x in lines))
        #TODO: add new metrics here
        # total_acc = sum((float(x[4]) for x in lines))
        # total_macc = sum((float(x[5]) for x in lines))
        # total_miou = sum((float(x[6]) for x in lines))
        if cnt > 0:
            print("resume psnr", total_psnr / cnt, "ssim", total_ssim / cnt)
        else:
            total_psnr = 0.0
            total_ssim = 0.0
    else:
        finished = set()

    finish_file = open(finish_path, "a", buffering=1)
    print("Writing images to", output_dir)


net = make_model(conf["model"]).to(device=device).load_weights(args, load_best=not args.use_last).eval()
conf["renderer"]['use_rgb_head'] = conf["model"]['use_rgb_head']
conf["renderer"]['use_seg_head'] = conf["model"]['use_seg_head']
renderer_type = conf["renderer"]["type"].lower()

if renderer_type == "nerf":
    renderer = NeRFRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,
        eval_batch_size=args.eval_ray_batch_size).to(
            device=device
        )
elif renderer_type == "unisurf":
    renderer = UnisurfRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,
        eval_batch_size=args.eval_ray_batch_size).to(
            device=device
        )

# renderer = NeRFRenderer.from_conf(
#     conf["renderer"], lindisp=dset.lindisp, eval_batch_size=args.ray_batch_size
# ).to(device=device)
use_rgb_head = conf["model"]['use_rgb_head']
use_seg_head = conf["model"]['use_seg_head']

if args.coarse:
    net.mlp_fine = None

if renderer.n_coarse < 64:
    # Ensure decent sampling resolution
    renderer.n_coarse = 64
if args.coarse:
    renderer.n_coarse = 64
    renderer.n_fine = 128
    renderer.using_fine = True

render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True).eval()

z_near = dset.z_near
z_far = dset.z_far

use_source_lut = len(args.viewlist) > 0
if use_source_lut:
    print("Using views from list", args.viewlist)
    with open(args.viewlist, "r") as f:
        tmp = [x.strip().split() for x in f.readlines()]
    source_lut = {
        x[0] + "/" + x[1]: torch.tensor(list(map(int, x[2:])), dtype=torch.long)
        for x in tmp
    }
    sources = [source_lut]
else:
    source = torch.tensor(sorted(list(map(int, args.source.split(',')))), dtype=torch.long)
    sources = [source]
# print(source.shape)
if args.eval_all:
    sources = [torch.tensor([i]) for i in range(250)[::args.skip_count]]

NV = dset[0]["images"].shape[0]

if args.eval_view_list is not None:
    with open(args.eval_view_list, "r") as f:
        eval_views = torch.tensor(list(map(int, f.readline().split())))
    target_view_mask = torch.zeros(NV, dtype=torch.bool)
    target_view_mask[eval_views] = 1
else:
    if args.same_src_target:
        target_view_mask = source
    else:
        target_view_mask = torch.zeros(NV, dtype=torch.bool)
        target_view_mask[::args.skip_count] = 1
        # target_view_mask = torch.ones(NV, dtype=torch.bool)
target_view_mask_init = target_view_mask

all_rays = None
rays_spl = []

src_view_mask = None
total_objs = len(data_loader)



with torch.no_grad():
    for obj_idx, data in enumerate(data_loader):
        print(
            "OBJECT",
            obj_idx,
            "OF",
            total_objs,
            "PROGRESS",
            obj_idx / total_objs * 100.0,
            "%",
            data["path"][0],
        )
        dpath = data["path"][0]
        obj_basename = os.path.basename(dpath)
        cat_name = os.path.basename(os.path.dirname(dpath))
        obj_name = cat_name + "_" + obj_basename if args.multicat else obj_basename
        if has_output and obj_name in finished:
            print("(skip)")
            continue
        images = data["images"][0]  # (NV, 3, H, W)
        labels = data["labels"][0]  # (NV, 1, H, W)

        if args.background_grayscale != 1.0:
            # print(images.shape, labels.shape, images.max(), images.min())
            bckgd_mask = labels == 0
            # print(bckgd_mask.shape)
            images[bckgd_mask.repeat(1,3,1,1)] = args.background_grayscale
            # print(images.shape, images.max(), images.min())
            # cv2.imwrite('test_img.png', (images[0].permute(1,2,0).contiguous().cpu().numpy() * 0.5 + 0.5) * 255)
            # assert(True == False)


        for source in sources:
            NV, _, H, W = images.shape
            if args.scale != 1.0:
                Ht = int(H * args.scale)
                Wt = int(W * args.scale)
                if abs(Ht / args.scale - H) > 1e-10 or abs(Wt / args.scale - W) > 1e-10:
                    warnings.warn(
                        "Inexact scaling, please check {} times ({}, {}) is integral".format(
                            args.scale, H, W
                        )
                    )
                H, W = Ht, Wt

            if all_rays is None or use_source_lut or args.free_pose or args.eval_all:
                if use_source_lut:
                    obj_id = cat_name + "/" + obj_basename
                    source = source_lut[obj_id]

                NS = len(source)
                src_view_mask = torch.zeros(NV, dtype=torch.bool)
                src_view_mask[source] = 1

                focal = data["focal"][0]
                if isinstance(focal, float):
                    focal = torch.tensor(focal, dtype=torch.float32)
                focal = focal[None]

                c = data.get("c")
                if c is not None:
                    c = c[0].to(device=device).unsqueeze(0)

                poses = data["poses"][0]  # (NV, 4, 4)
                src_poses = poses[src_view_mask].to(device=device)  # (NS, 4, 4)


                if args.same_src_target:
                    target_view_mask = src_view_mask
                else:
                    target_view_mask = target_view_mask_init.clone()
                    if not (args.include_src or args.same_src_target):
                        target_view_mask *= ~src_view_mask


                novel_view_idxs = target_view_mask.nonzero(as_tuple=False).reshape(-1)

                

                if renderer_type == "nerf":
                    poses = poses[target_view_mask]  # (NV[-NS], 4, 4)
                    all_rays = (
                        util.gen_rays(
                            poses.reshape(-1, 4, 4),
                            W,
                            H,
                            focal * args.scale,
                            z_near,
                            z_far,
                            c=c * args.scale if c is not None else None,
                        )
                        .reshape(-1, 8)
                        .to(device=device)
                    )  # ((NV[-NS])*H*W, 8)
                    poses = None

                elif renderer_type == "unisurf":
                    net.target_poses = poses

                    
                    mesh = torch.meshgrid(torch.arange(NV), torch.arange(H), torch.arange(W))
                    all_rays = torch.stack(mesh, axis=-1).to(device)
                    all_rays = all_rays[target_view_mask].reshape(-1,3)
                    # import pdb; pdb.set_trace()
                

                # poses = None
                focal = focal.to(device=device)

            rays_spl = torch.split(all_rays, args.eval_ray_batch_size, dim=0)  # Creates views

            n_gen_views = len(novel_view_idxs)

            if renderer_type == "nerf":
                net.encode(
                    images[src_view_mask].to(device=device).unsqueeze(0),
                    src_poses.unsqueeze(0),
                    focal,
                    c=c,
                )
            all_depth = []
            if use_rgb_head: 
                all_rgb = []
            if use_seg_head:
                all_seg = []
            # import pdb; pdb.set_trace()
            for rays in tqdm.tqdm(rays_spl):
                if renderer_type == "nerf":
                    ret_dict = render_par(rays[None])
                elif renderer_type == "unisurf":
                    ret_dict = render_par(rays[None], images[src_view_mask].to(device=device).unsqueeze(0),
                                           src_poses.unsqueeze(0),focal, c)
                if use_rgb_head:
                    rgb = ret_dict["rgb"][0].cpu()
                    all_rgb.append(rgb)
                if use_seg_head:
                    seg = ret_dict["seg"][0].cpu()
                    all_seg.append(seg)
                depth = ret_dict["depth"][0].cpu()
                all_depth.append(depth)

            if use_rgb_head:
                all_rgb = torch.cat(all_rgb, dim=0)
                all_rgb = torch.clamp(
                    all_rgb.reshape(n_gen_views, H, W, 3), 0.0, 1.0
                ).numpy()  # (NV-NS, H, W, 3)
                if renderer_type == "unisurf":
                    all_rgb = all_rgb.transpose(0,2,1,3)
                    all_rgb = all_rgb[:,::-1]
            if use_seg_head:
                all_seg = torch.cat(all_seg, dim=0)
                all_seg = all_seg.reshape(n_gen_views, H, W, n_classes).argmax(-1).numpy()  # (NV-NS, H, W)
                if renderer_type == "unisurf":
                    all_seg = all_seg.transpose(0,2,1)
                    all_seg = all_seg[:,::-1].copy()

            all_depth = torch.cat(all_depth, dim=0)
            all_depth = (all_depth - z_near) / (z_far - z_near)
            all_depth = all_depth.reshape(n_gen_views, H, W).numpy()

            if has_output:
                obj_out_dir = os.path.join(output_dir, obj_name)
                os.makedirs(obj_out_dir, exist_ok=True)
                for i in range(n_gen_views):
                    if use_rgb_head:
                        out_file = os.path.join(
                            obj_out_dir, "{:06}.png".format(novel_view_idxs[i].item())
                        )
                        imageio.imwrite(out_file, (all_rgb[i] * 255.).astype(np.uint8))
                    if use_seg_head:
                        out_seg_file = os.path.join(
                            obj_out_dir, "{:06}_seg.png".format(novel_view_idxs[i].item())
                        )
                        seg_cmap = color_map[all_seg[i].astype(np.uint32)] * 255.
                        imageio.imwrite(out_seg_file, seg_cmap.astype(np.uint8))

                    if args.write_depth:
                        out_depth_file = os.path.join(
                            obj_out_dir, "{:06}_depth.exr".format(novel_view_idxs[i].item())
                        )
                        out_depth_norm_file = os.path.join(
                            obj_out_dir,
                            "{:06}_depth_norm.png".format(novel_view_idxs[i].item()),
                        )
                        depth_cmap_norm = util.cmap(all_depth[i])
                        imageio.imwrite(out_depth_file, all_depth[i].astype(np.uint8))
                        imageio.imwrite(out_depth_norm_file, depth_cmap_norm.astype(np.uint8))

            curr_ssim = 0.0
            curr_psnr = 0.0
            if not args.no_compare_gt:
                images_0to1 = images * 0.5 + 0.5  # (NV, 3, H, W)
                images_gt = images_0to1[target_view_mask]
                rgb_gt_all = (
                    images_gt.permute(0, 2, 3, 1).contiguous().numpy()
                )  # (NV-NS, H, W, 3)
                rgb_src_all = (
                    images_0to1[src_view_mask].permute(0, 2, 3, 1).contiguous().numpy()
                )  # (NV-NS, H, W, 3)
                if use_seg_head:
                    seg_gt = labels[target_view_mask]
                    seg_gt_all = (
                        seg_gt.permute(0, 2, 3, 1).contiguous().numpy()
                    ).squeeze(-1)  # (NV-NS, H, W)
                # import pdb; pdb.set_trace()
                for view_idx in range(n_gen_views):
                    if use_rgb_head:
                        # ssim = skimage.measure.compare_ssim(
                        ssim = skimage.metrics.structural_similarity(
                            all_rgb[view_idx],
                            rgb_gt_all[view_idx],
                            multichannel=True,
                            data_range=1,
                        )
                        # psnr = skimage.measure.compare_psnr(
                        psnr = skimage.metrics.peak_signal_noise_ratio(
                            all_rgb[view_idx], rgb_gt_all[view_idx], data_range=1
                        )
                        curr_ssim += ssim
                        curr_psnr += psnr

                    if use_seg_head:
                        curr_seg_metrics = seg_metrics(torch.from_numpy(all_seg[view_idx]).to(device), torch.from_numpy(seg_gt_all[view_idx]).to(device).short())

                    if args.write_compare:
                        out_file = os.path.join(
                            obj_out_dir,
                            "{:06}_compare.png".format(novel_view_idxs[view_idx].item()),
                        )
                        # print(images, images.shape)
                        # print(images.max(), images.min())
                        out_im = [im for im in rgb_src_all]
                        out_im.append(rgb_gt_all[view_idx])
                        if use_rgb_head:
                            out_im.extend([all_rgb[view_idx]])
                        if use_seg_head:
                            cur_seg_cmap = color_map[all_seg[view_idx].astype(np.uint32)]
                            cur_seg_gt_cmap = color_map[seg_gt_all[view_idx].astype(np.uint32)]
                            out_im.extend([cur_seg_cmap, cur_seg_gt_cmap])
                        out_im = np.hstack(out_im)
                        imageio.imwrite(out_file, (out_im * 255).astype(np.uint8))

            if use_rgb_head:
                curr_psnr /= n_gen_views
                curr_ssim /= n_gen_views
                curr_cnt = 1
                total_psnr += curr_psnr
                total_ssim += curr_ssim
                cnt += curr_cnt
            if not args.no_compare_gt:
                if use_rgb_head:
                    print(
                        "curr psnr",
                        curr_psnr,
                        "ssim",
                        curr_ssim,
                        "running psnr",
                        total_psnr / cnt,
                        "running ssim",
                        total_ssim / cnt,
                    )
                    if has_output:
                        finish_file.write(
                            "{} {} {} {}\n".format(obj_name, curr_psnr, curr_ssim, curr_cnt)
                        )
                if use_seg_head:
                    running_seg_metrics = seg_metrics.compute()
                    curr_seg_metrics = {k: v.item() for (k, v) in curr_seg_metrics.items()}
                    running_seg_metrics = {k: v.item() for (k, v) in running_seg_metrics.items()}
                    print(
                        curr_seg_metrics,
                        running_seg_metrics
                    )
if use_rgb_head:
    wandb.log({'psnr': total_psnr / cnt, 'ssim': total_ssim / cnt})
    print("final psnr", total_psnr / cnt, "ssim", total_ssim / cnt)
if use_seg_head:
    wandb.log(running_seg_metrics)
    print("final seg metrics", running_seg_metrics)
    print("class ious", class_mious(seg_metrics['miou']))
    wandb.log({'class_ious: ', class_mious(seg_metrics['miou'])})
