"""
Full evaluation script, including PSNR+SSIM evaluation with multi-GPU support.
python eval.py --gpu_id=<gpu list> -n <expname> -c <conf> -D /home/group/data/chairs -F srn
"""
import sys
import os

sys.path.insert(
    0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
)

import torch
import numpy as np
import imageio
import skimage.measure
import util
from data import get_split_dataset
from model import make_model
from render import NeRFRenderer, UnisurfRenderer, common
import cv2
import tqdm
import ipdb
import warnings
import mcubes
import trimesh
import wandb

from torchmetrics import MetricCollection
from metrics import Chamfer_FScore
import matplotlib.cm as cm


def save_pc(PC, PC_color, filename):
    from plyfile import PlyElement, PlyData
    PC = np.concatenate((PC, PC_color), axis=1)
    PC = [tuple(element) for element in PC]
    el = PlyElement.describe(np.array(PC, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]), 'vertex')
    PlyData([el]).write(filename)

#  from pytorch_memlab import set_target_gpu
#  set_target_gpu(9)

def extra_args(parser):
    parser.add_argument(
        "--split",
        type=str,
        default="val",
        help="Split of data to use train | val | test",
    )
    parser.add_argument(
        "--source",
        "-P",
        type=str,
        default="135",
        help="Source view(s) for each object. Alternatively, specify -L to viewlist file and leave this blank.",
    )
    parser.add_argument(
        "--eval_view_list", type=str, default=None, help="Path to eval view list"
    )
    parser.add_argument("--coarse", action="store_true", help="Coarse network as fine")
    parser.add_argument(
        "--no_compare_gt",
        action="store_true",
        help="Skip GT comparison (metric won't be computed) and only render images",
    )
    parser.add_argument(
        "--multicat",
        action="store_true",
        help="Prepend category id to object id. Specify if model fits multiple categories.",
    )
    parser.add_argument(
        "--viewlist",
        "-L",
        type=str,
        default="",
        help="Path to source view list e.g. src_dvr.txt; if specified, overrides source/P",
    )

    parser.add_argument(
        "--output",
        "-O",
        type=str,
        default="eval",
        help="If specified, saves generated images to directory",
    )
    parser.add_argument(
        "--include_src", action="store_true", help="Include source views in calculation"
    )
    parser.add_argument(
        "--same_src_target", action="store_true", help="Use same source and target views"
    )
    parser.add_argument(
        "--scale", type=float, default=1.0, help="Video scale relative to input size"
    )
    parser.add_argument("--write_depth", action="store_true", help="Write depth image")
    parser.add_argument(
        "--write_compare", action="store_true", help="Write GT comparison image"
    )
    parser.add_argument(
        "--write_rgb", action="store_true", help="Write rgb mesh"
    )
    parser.add_argument(
        "--write_seg", action="store_true", help="Write seg mesh"
    )
    parser.add_argument(
        "--free_pose",
        action="store_true",
        help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.",
    )
    parser.add_argument(
        "--background_grayscale", type=float, default=1.0, help="Background grayscale color, -1 to 1 (black to white)"
    )
    parser.add_argument(
        "--batch_size", "-B", type=int, default=4, help="Object batch size ('SB')"
    )
    parser.add_argument(
        "--nviews",
        "-V",
        type=str,
        default="1",
        help="Number of source views (multiview); put multiple (comma delim) to pick randomly per batch ('NV')",
    )
    parser.add_argument(
        "--eval_ray_batch_size", type=int, default=80000, help="Eval ray batch size"
    )
    parser.add_argument("--postprocess", action="store_true", help="Postprocess mesh")
    parser.add_argument(
        "--postprocess_th", type=float, default=1.1, help="Postprocess diag th"
    )
    parser.add_argument(
        "--use_last",
        action="store_true",
        help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.",
    )
    return parser


args, conf = util.args.parse_args(
    extra_args, default_conf="conf/resnet_fine_mv.conf", default_expname="shapenet",
)
# args.resume = True

device = util.get_cuda(args.gpu_id[0])

dset = get_split_dataset(
    args.dataset_format, args.datadir, want_split=args.split, training=False, load_pc=True, level=args.level, category=args.category
)
n_classes = dset.n_classes
conf["renderer"]['n_classes'] = n_classes
conf["model"]['n_classes'] = n_classes

color_map = np.concatenate([np.array([[1., 1., 1.]]), cm.Paired_r(np.linspace(0, 1, 51))[:, :3]],axis=0)[::52//n_classes] * 255.

data_loader = torch.utils.data.DataLoader(
    dset, batch_size=1, shuffle=False, num_workers=5, pin_memory=False
)

output_dir = args.output.strip()
has_output = len(output_dir) > 0
wandb.init(config=conf, name=args.name, project=f"SegNerf_3d_dense_{args.split}")
sn = args.name.split('_')
args.name = f'{sn[0]}_{sn[1]}_{sn[2]}'
output_dir = os.path.join(output_dir, args.name)


net = make_model(conf["model"]).to(device=device).load_weights(args, load_best=not args.use_last).eval()
conf["renderer"]['use_rgb_head'] = conf["model"]['use_rgb_head']
conf["renderer"]['use_seg_head'] = conf["model"]['use_seg_head']
renderer_type = conf["renderer"]["type"].lower()

if renderer_type == "nerf":
    renderer = NeRFRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,
        eval_batch_size=args.eval_ray_batch_size).to(
            device=device
        )
elif renderer_type == "unisurf":
    renderer = UnisurfRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,
        eval_batch_size=args.eval_ray_batch_size).to(
            device=device
        )

use_rgb_head = conf["model"]['use_rgb_head']
use_seg_head = conf["model"]['use_seg_head']

if args.coarse:
    net.mlp_fine = None

if renderer.n_coarse < 64:
    # Ensure decent sampling resolution
    renderer.n_coarse = 64
if args.coarse:
    renderer.n_coarse = 64
    renderer.n_fine = 128
    renderer.using_fine = True

render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True).eval()
pts_par = renderer.bind_pts_parallel(net, args.gpu_id).eval()

z_near = dset.z_near
z_far = dset.z_far

use_source_lut = len(args.viewlist) > 0
if use_source_lut:
    print("Using views from list", args.viewlist)
    with open(args.viewlist, "r") as f:
        tmp = [x.strip().split() for x in f.readlines()]
    source_lut = {
        x[0] + "/" + x[1]: torch.tensor(list(map(int, x[2:])), dtype=torch.long)
        for x in tmp
    }
else:
    source = torch.tensor(sorted(list(map(int, args.source.split(',')))), dtype=torch.long)

NV = dset[0]["images"].shape[0]

metrics_3D_rec = MetricCollection({
                                        '3D_L2': Chamfer_FScore.Chamfer_FScore(p=2, dist_sync_on_step=True),
                                        '3D_L1': Chamfer_FScore.Chamfer_FScore(p=1, dist_sync_on_step=True),
                                        })


# all_rays = None
N = 256
t = np.linspace(-1.0, 1.0, N)
query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
sh = query_pts.shape
all_rays_pts = torch.from_numpy(query_pts.reshape([-1,3]))
all_rays_dir = torch.zeros_like(all_rays_pts)
all_rays_z_near = torch.ones((all_rays_pts.shape[0], 1)) * z_near
all_rays_z_far = torch.ones((all_rays_pts.shape[0], 1)) * z_far
all_rays = torch.cat([all_rays_pts, all_rays_dir, all_rays_z_near, all_rays_z_far], dim=-1).to(device=device)
# print(flat.shape)

rays_spl = []

src_view_mask = None
total_objs = len(data_loader)


all_metrics = ['den_', 'den_seg_']
all_metrics = {k:metrics_3D_rec.clone(prefix=k).to(device) for k in all_metrics}
with torch.no_grad():
    for obj_idx, data in enumerate(data_loader):
        # if obj_idx != 14: continue
        print(
            "OBJECT",
            obj_idx,
            "OF",
            total_objs,
            "PROGRESS",
            obj_idx / total_objs * 100.0,
            "%",
            data["path"][0],
        )
        dpath = data["path"][0]
        obj_basename = os.path.basename(dpath)
        cat_name = os.path.basename(os.path.dirname(dpath))
        obj_name = cat_name + "_" + obj_basename if args.multicat else obj_basename

        images = data["images"][0]  # (NV, 3, H, W)
        labels = data["labels"][0]  # (NV, 1, H, W)
        gt_pts = data["pts"][0].to(device)
        # pts_labels = data["pts_labels"]
        # print(pts_labels.shape)
        # assert(True == False)


        NV, _, H, W = images.shape

        rays_spl = torch.split(all_rays, args.eval_ray_batch_size, dim=0)  # Creates views

        poses = data["poses"][0]  # (NV, 4, 4)
        src_view_mask = torch.zeros(NV, dtype=torch.bool)
        src_view_mask[source] = 1
        src_poses = poses[src_view_mask].to(device=device)  # (NS, 4, 4)

        focal = data["focal"][0]
        if isinstance(focal, float):
            focal = torch.tensor(focal, dtype=torch.float32)
        focal = focal[None].to(device)

        c = data.get("c")
        if c is not None:
            c = c[0].to(device=device).unsqueeze(0)


        if renderer_type == "nerf":
            net.encode(
                images[src_view_mask].to(device=device).unsqueeze(0),
                src_poses.unsqueeze(0),
                focal,
                c=c,
            )
        # net.encode(
        #     images[src_view_mask].to(device=device).unsqueeze(0),
        #     src_poses.unsqueeze(0),
        #     focal,
        #     c=c,
        # )

        all_sigma = []
        if use_rgb_head: 
            all_rgb = []
        if use_seg_head:
            all_seg = []
        for rays in tqdm.tqdm(rays_spl):
            rays = rays.unsqueeze(0)
            if renderer_type == "nerf":
                out = pts_par(rays[...,:3], coarse=False, viewdirs=torch.zeros_like(rays[...,:3]))
            elif renderer_type == "unisurf":
                out = pts_par(rays[...,:3], images[src_view_mask].to(device=device).unsqueeze(0), src_poses.unsqueeze(0),
                                focal, c,
                                coarse=False, viewdirs=torch.zeros_like(rays[...,:3]))
            # out = pts_par(rays[...,:3], coarse=False, viewdirs=rays[...,3:6])
            sigma = out[0,:,0]
            all_sigma.append(sigma)
            if use_rgb_head:
                rgb = out[0,:,1:4]
                all_rgb.append(rgb)
            if use_seg_head:
                seg = out[0,:,-n_classes:]
                all_seg.append(seg)

        if use_rgb_head:
            all_rgb = torch.cat(all_rgb, dim=0)
            all_rgb = torch.clamp(
                all_rgb, 0.0, 1.0
            )  # (NV-NS, H, W, 3)
        if use_seg_head:
            all_seg = torch.cat(all_seg, dim=0)
            all_seg = all_seg.argmax(-1)  # (NV-NS, H, W)

        

        
        all_sigma = torch.cat(all_sigma, dim=0)
        # all_sigma = all_sigma.numpy()

        # th = 10
        # dense_pts_idx = all_sigma > th
        # all_pts = all_rays[...,:3].reshape(N,N,N,3)[dense_pts_idx].cpu().numpy()

        # for th in range(1,10):
        
        if renderer_type == 'nerf':
            # th_range = range(1,21)[::-1]
            th_range = [5]
        else:
            th_range = [0.5]
        # for th in th_range:
        #     for k, v in all_metrics.items():
        #         all_metrics[k][th] = {}
        # import pdb; pdb.set_trace()
        for th in th_range:
            vertices, triangles = mcubes.marching_cubes(all_sigma.reshape(N,N,N).cpu().numpy(), th)
            vertices_norm = (vertices/N - 0.5)*2
            vertices_out = vertices_norm.copy()
            vertices_out[:,0] = vertices_norm[:,1]
            vertices_out[:,1] = vertices_norm[:,0]
            mesh = trimesh.Trimesh(vertices_out, triangles)
            # import pdb; pdb.set_trace()
            if args.postprocess:
                meshes = trimesh.graph.split(mesh, only_watertight=True)
                # mesh = meshes[0]
                mesh_v = np.array([len(m.vertices) for m in meshes]).argsort()[::-1]
                mesh_idcs = np.array([mesh_v[0]])
                for i in range(len(mesh_v)-1):
                    cur_mesh_idcs = np.append(mesh_idcs, mesh_v[i+1])
                    cur_meshes = meshes[cur_mesh_idcs]
                    all_v = [m.vertices for m in cur_meshes]
                    all_v = np.concatenate(all_v)
                    if np.linalg.norm(all_v.max(axis=0) - all_v.min(axis=0)) <= args.postprocess_th:
                        mesh_idcs = cur_mesh_idcs.copy()
                cur_meshes = meshes[mesh_idcs]
                all_v = [m.vertices for m in cur_meshes]
                all_v = np.concatenate(all_v)
                all_f = [m.faces for m in cur_meshes]
                all_f = np.concatenate(all_f)
                mesh = trimesh.Trimesh(all_v, all_f)

            if args.write_compare:
                obj_out_dir = os.path.join(output_dir, obj_name)
                os.makedirs(obj_out_dir, exist_ok=True)
                out_file = os.path.join(
                    obj_out_dir, "den_"+str(th)+".ply"
                )
                mesh.export(out_file)
                if args.write_rgb or args.write_seg:

                    pts = torch.from_numpy(vertices_out[None,...,:3]).to(device).float()
                    # import pdb; pdb.set_trace()
                    if renderer_type == "nerf":
                        out = pts_par(pts, coarse=False, viewdirs=torch.zeros_like(pts))
                    elif renderer_type == "unisurf":
                        out = pts_par(pts, images[src_view_mask].to(device=device).unsqueeze(0), src_poses.unsqueeze(0),
                                        focal, c,
                                        coarse=False, viewdirs=torch.zeros_like(pts))
                    # out = pts_par(rays[...,:3], coarse=False, viewdirs=rays[...,3:6])
                    sigma = out[0,:,0]
                    if args.write_rgb:
                        rgb = out[0,:,1:4]
                        rgb = torch.clamp(
                            rgb, 0.0, 1.0
                        )  # (NV-NS, H, W, 3)
                        out_file = os.path.join(
                            obj_out_dir, "rgb_den_"+str(th)+".ply"
                        )
                        mesh_rgb = trimesh.Trimesh(vertices_out, triangles, vertex_colors=rgb.cpu().numpy()*255.)
                        mesh_rgb.export(out_file)
                    if args.write_seg:
                        seg = out[0,:,-n_classes+1:].argmax(-1)
                        seg_cmap = color_map[seg.cpu().numpy()+1]
                        out_file = os.path.join(
                            obj_out_dir, "seg_den_"+str(th)+".ply"
                        )
                        mesh_seg = trimesh.Trimesh(vertices_out, triangles, vertex_colors=seg_cmap)
                        mesh_seg.export(out_file)

            
            try:
                # sampled_verts, sampled_faces = trimesh.sample.sample_surface_even(mesh, count=10000, radius=None)
                sampled_verts, sampled_faces = trimesh.sample.sample_surface(mesh, count=10000)
                # import pdb; pdb.set_trace()
                metrics_3D_rec_out = all_metrics['den_']([torch.from_numpy(sampled_verts).float().to(device)], [gt_pts])
                # print('den: ', metrics_3D_rec_out)
                # metrics_3D_rec_out = {k_t: v_t.item() for (k, v) in metrics_3D_rec_out.items() for k_t, v_t in zip([k+'_'+m for m in metrics_3D_rec['3D_L1'].returned_metrics()], v)}
                # all_metrics['den'][th][obj_idx] = metrics_3D_rec_out
            except:
                print(f"Error sampling mesh for {obj_name}")
                pass

            if use_seg_head:
                sigma_seg = all_sigma * (all_seg != 0)
                vertices, triangles = mcubes.marching_cubes(sigma_seg.reshape(N,N,N).cpu().numpy(), th)
                vertices_norm = (vertices/N - 0.5)*2
                vertices_out = vertices_norm.copy()
                vertices_out[:,0] = vertices_norm[:,1]
                vertices_out[:,1] = vertices_norm[:,0]
                mesh = trimesh.Trimesh(vertices_out, triangles)
                if args.postprocess:
                    meshes = trimesh.graph.split(mesh, only_watertight=True)
                    # mesh = meshes[0]
                    mesh_v = np.array([len(m.vertices) for m in meshes]).argsort()[::-1]
                    mesh_idcs = np.array([mesh_v[0]])
                    for i in range(len(mesh_v)-1):
                        cur_mesh_idcs = np.append(mesh_idcs, mesh_v[i+1])
                        cur_meshes = meshes[cur_mesh_idcs]
                        all_v = [m.vertices for m in cur_meshes]
                        all_v = np.concatenate(all_v)
                        if np.linalg.norm(all_v.max(axis=0) - all_v.min(axis=0)) <= args.postprocess_th:
                            # print(np.linalg.norm(all_v.max(axis=0) - all_v.min(axis=0)))
                            mesh_idcs = cur_mesh_idcs.copy()
                    cur_meshes = meshes[mesh_idcs]
                    all_v = [m.vertices for m in cur_meshes]
                    all_v = np.concatenate(all_v)
                    all_f = [m.faces for m in cur_meshes]
                    all_f = np.concatenate(all_f)
                    mesh = trimesh.Trimesh(all_v, all_f)

                if args.write_compare:
                    obj_out_dir = os.path.join(output_dir, obj_name)
                    os.makedirs(obj_out_dir, exist_ok=True)
                    out_file = os.path.join(
                        obj_out_dir, "den_seg_"+str(th)+".ply"
                    )
                    mesh.export(out_file)
                    if args.write_rgb or args.write_seg:
                        pts = torch.from_numpy(vertices_out[None,...,:3]).to(device).float()
                        if renderer_type == "nerf":
                            out = pts_par(pts, coarse=False, viewdirs=torch.zeros_like(pts))
                        elif renderer_type == "unisurf":
                            out = pts_par(pts, images[src_view_mask].to(device=device).unsqueeze(0), src_poses.unsqueeze(0),
                                            focal, c,
                                            coarse=False, viewdirs=torch.zeros_like(pts))
                        # out = pts_par(rays[...,:3], coarse=False, viewdirs=rays[...,3:6])
                        sigma = out[0,:,0]
                        if args.write_rgb:
                            rgb = out[0,:,1:4]
                            rgb = torch.clamp(rgb, 0.0, 1.0)  # (NV-NS, H, W, 3)
                            out_file = os.path.join(
                                obj_out_dir, "rgb_den_seg_"+str(th)+".ply"
                            )
                            mesh_rgb = trimesh.Trimesh(vertices_out, triangles, vertex_colors=rgb.cpu().numpy()*255.)
                            mesh_rgb.export(out_file)
                        if args.write_seg:
                            seg = out[0,:,-n_classes+1:].argmax(-1)
                            seg_cmap = color_map[seg.cpu().numpy()+1]
                            out_file = os.path.join(
                                obj_out_dir, "seg_den_seg_"+str(th)+".ply"
                            )
                            mesh_seg = trimesh.Trimesh(vertices_out, triangles, vertex_colors=seg_cmap)
                            mesh_seg.export(out_file)

                try:
                    # sampled_verts, sampled_faces = trimesh.sample.sample_surface_even(mesh, count=10000, radius=None)
                    sampled_verts, sampled_faces = trimesh.sample.sample_surface(mesh, count=10000)
                    # import pdb; pdb.set_trace()
                    metrics_3D_rec_out = all_metrics['den_seg_']([torch.from_numpy(sampled_verts).float().to(device)], [gt_pts])
                    # print('den_seg: ', metrics_3D_rec_out)
                    # metrics_3D_rec_out = {k_t: v_t.item() for (k, v) in metrics_3D_rec_out.items() for k_t, v_t in zip([k+'_'+m for m in metrics_3D_rec['3D_L1'].returned_metrics()], v)}
                    # all_metrics['den_seg'][th][obj_idx] = metrics_3D_rec_out
                except:
                    print(f"Error sampling mesh for {obj_name}")
                    pass
            
        if args.write_compare:
            images_0to1 = images * 0.5 + 0.5  # (NV, 3, H, W)
            rgb_src_all = (images_0to1[src_view_mask].permute(0, 2, 3, 1).contiguous().numpy())  # (NV-NS, H, W, 3)
            out_im = [im for im in rgb_src_all]
            out_im = np.hstack(out_im)
            out_file = os.path.join(
                obj_out_dir, "src.png"
            )
            imageio.imwrite(out_file, (out_im * 255).astype(np.uint8))
    # for rec_type, v1 in all_metrics.items():
    #     print(k, ": ")
    #     for th, v2 in v1.items():
    #         print(th, ':')
    #         for obj_idx, m in v2.items()
    #             for k_2, v_2 in m.items():
    #                 if k_2 in ['3D_L2_f_tau', '3D_L2_p_tau', '3D_L2_r_tau', '3D_L2_chamfer']:
    #                     print(k_2, ': ', v_2)

    for k, m in all_metrics.items():
        metrics_3D_rec_out = m.compute()
        metrics_3D_rec_out = {k_t: v_t.item() for (k, v) in metrics_3D_rec_out.items() for k_t, v_t in zip([k+'_'+m for m in m['3D_L1'].returned_metrics()], v)}
        wandb.log(metrics_3D_rec_out)
        print(metrics_3D_rec_out)
        # assert(True == False)
