"""
Full evaluation script, including PSNR+SSIM evaluation with multi-GPU support.

python eval.py --gpu_id=<gpu list> -n <expname> -c <conf> -D /home/group/data/chairs -F srn
"""
import sys
import os

sys.path.insert(
    0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))
)

import torch
import numpy as np
import imageio
import skimage.measure
import util
from data import get_split_dataset
from model import make_model
from render import NeRFRenderer, UnisurfRenderer, common
import cv2
import tqdm
import ipdb
import warnings

from torchmetrics import MetricCollection, IoU
from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat
import wandb
import matplotlib.cm as cm
# from metrics import MeanAccuracy, MeanIoU, Accuracy


def class_mious(metric):
    """Computes intersection over union (IoU)"""
    return _jaccard_from_confmat(
        metric.confmat, metric.num_classes, metric.ignore_index, -1, "none"
    )

def save_pc(PC, PC_color, filename):
    from plyfile import PlyElement, PlyData
    PC = np.concatenate((PC, PC_color), axis=1)
    PC = [tuple(element) for element in PC]
    el = PlyElement.describe(np.array(PC, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]), 'vertex')
    PlyData([el]).write(filename)

#  from pytorch_memlab import set_target_gpu
#  set_target_gpu(9)


def extra_args(parser):
    parser.add_argument(
        "--split",
        type=str,
        default="val",
        help="Split of data to use train | val | test",
    )
    parser.add_argument(
        "--nviews",
        type=str,
        default="",
        help="Number of views during training, not used here",
    )
    parser.add_argument(
        "--batch_size", "-B", type=int, default=1, help="Object batch size ('SB')"
    )
    parser.add_argument(
        "--source",
        "-P",
        type=str,
        default="135",
        help="Source view(s) for each object. Alternatively, specify -L to viewlist file and leave this blank.",
    )
    parser.add_argument("--coarse", action="store_true", help="Coarse network as fine")
    parser.add_argument(
        "--no_compare_gt",
        action="store_true",
        help="Skip GT comparison (metric won't be computed) and only render images",
    )
    parser.add_argument(
        "--multicat",
        action="store_true",
        help="Prepend category id to object id. Specify if model fits multiple categories.",
    )
    parser.add_argument(
        "--viewlist",
        "-L",
        type=str,
        default="",
        help="Path to source view list e.g. src_dvr.txt; if specified, overrides source/P",
    )

    parser.add_argument(
        "--output",
        "-O",
        type=str,
        default="eval",
        help="If specified, saves generated images to directory",
    )
    parser.add_argument(
        "--include_src", action="store_true", help="Include source views in calculation"
    )
    parser.add_argument(
        "--scale", type=float, default=1.0, help="Video scale relative to input size"
    )
    parser.add_argument("--write_depth", action="store_true", help="Write depth image")
    parser.add_argument(
        "--write_compare", action="store_true", help="Write GT comparison image"
    )
    parser.add_argument(
        "--all_src", action="store_true", help="Write GT comparison image"
    )
    parser.add_argument(
        "--free_pose",
        action="store_true",
        help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.",
    )
    parser.add_argument(
        "--use_last",
        action="store_true",
        help="Set to indicate poses may change between objects. In most of our datasets, the test set has fixed poses.",
    )
    parser.add_argument(
        "--background_grayscale", type=float, default=1.0, help="Background grayscale color, -1 to 1 (black to white)"
    )
    parser.add_argument(
        "--eval_ray_batch_size", type=int, default=80000, help="Eval ray batch size"
    )
    return parser


args, conf = util.args.parse_args(
    extra_args, default_conf="conf/resnet_fine_mv.conf", default_expname="shapenet",
)
# args.resume = True

device = util.get_cuda(args.gpu_id[0])

dset = get_split_dataset(
    args.dataset_format, args.datadir, want_split=args.split, training=False, load_pc=True, level=args.level, category=args.category
)
data_loader = torch.utils.data.DataLoader(
    dset, batch_size=1, shuffle=False, num_workers=5, pin_memory=False
)

output_dir = args.output.strip()
has_output = len(output_dir) > 0
wandb.init(config=conf, name=args.name, project=f"SegNerf_3d_{args.split}")
sn = args.name.split('_')
args.name = f'{sn[0]}_{sn[1]}_{sn[2]}'
output_dir = os.path.join(output_dir, args.name)




n_classes = dset.n_classes
conf["renderer"]['n_classes'] = n_classes
conf["model"]['n_classes'] = n_classes
color_map = np.concatenate([np.array([[1., 1., 1.]]), cm.Paired_r(np.linspace(0, 1, 51))[:, :3]],axis=0)[::52//n_classes] * 255.
seg_metrics = MetricCollection({
                            'miou': IoU(num_classes=n_classes-1),
                            }).to(device)
# seg_metrics = MetricCollection({
#                                         'acc': Accuracy.Accuracy(dist_sync_on_step=True),
#                                         'macc': MeanAccuracy.MeanAccuracy(num_classes=n_classes-1, dist_sync_on_step=True),
#                                         'miou': MeanIoU.MeanIoU(num_classes=n_classes-1, dist_sync_on_step=True),
#                                         }).to(device)


net = make_model(conf["model"]).to(device=device).load_weights(args, load_best=not args.use_last).eval()

conf["renderer"]['use_rgb_head'] = conf["model"]['use_rgb_head']
conf["renderer"]['use_seg_head'] = conf["model"]['use_seg_head']
renderer_type = conf["renderer"]["type"].lower()

if renderer_type == "nerf":
    renderer = NeRFRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,
        eval_batch_size=args.eval_ray_batch_size).to(
            device=device
        )
elif renderer_type == "unisurf":
    renderer = UnisurfRenderer.from_conf(conf["renderer"], lindisp=dset.lindisp,
        eval_batch_size=args.eval_ray_batch_size).to(
            device=device
        )

# renderer = NeRFRenderer.from_conf(
#     conf["renderer"], lindisp=dset.lindisp
# ).to(device=device)
use_rgb_head = conf["model"]['use_rgb_head']
use_seg_head = conf["model"]['use_seg_head']

if args.coarse:
    net.mlp_fine = None

if renderer.n_coarse < 64:
    # Ensure decent sampling resolution
    renderer.n_coarse = 64
if args.coarse:
    renderer.n_coarse = 64
    renderer.n_fine = 128
    renderer.using_fine = True

render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True)
pts_par = renderer.bind_pts_parallel(net, args.gpu_id)

z_near = dset.z_near
z_far = dset.z_far

use_source_lut = len(args.viewlist) > 0
if use_source_lut:
    print("Using views from list", args.viewlist)
    with open(args.viewlist, "r") as f:
        tmp = [x.strip().split() for x in f.readlines()]
    source_lut = {
        x[0] + "/" + x[1]: torch.tensor(list(map(int, x[2:])), dtype=torch.long)
        for x in tmp
    }
else:
    source = torch.tensor(sorted(list(map(int, args.source.split(',')))), dtype=torch.long)



rays_spl = []

src_view_mask = None
total_objs = len(data_loader)

with torch.no_grad():
    with tqdm.tqdm(total=len(data_loader)) as t:
        for obj_idx, data in enumerate(data_loader):
            # print(
            #     "OBJECT",
            #     obj_idx,
            #     "OF",
            #     total_objs,
            #     "PROGRESS",
            #     obj_idx / total_objs * 100.0,
            #     "%",
            #     data["path"][0],
            # )
            dpath = data["path"][0]
            obj_basename = os.path.basename(dpath)
            cat_name = os.path.basename(os.path.dirname(dpath))
            obj_name = cat_name + "_" + obj_basename if args.multicat else obj_basename

            images = data["images"][0]  # (NV, 3, H, W)
            labels = data["labels"][0]  # (NV, 1, H, W)
            all_rays_pts = data["pts"][0] # (P, 3)
            pts_labels = data["pts_labels"][0] # (P)
            all_rays_dir = torch.zeros_like(all_rays_pts)
            all_rays_z_near = torch.ones((all_rays_pts.shape[0], 1)) * z_near
            all_rays_z_far = torch.ones((all_rays_pts.shape[0], 1)) * z_far
            all_rays = torch.cat([all_rays_pts, all_rays_dir, all_rays_z_near, all_rays_z_far], dim=-1).to(device=device)
            #print(pts.shape, pts_labels.shape)
            #assert(True == False)
            
            NV, _, H, W = images.shape

            rays_spl = torch.split(all_rays, args.eval_ray_batch_size, dim=0)  # Creates views

            poses = data["poses"][0]  # (NV, 4, 4)
            src_view_mask = torch.zeros(NV, dtype=torch.bool)
            src_view_mask[source] = 1
            if args.all_src:
                src_view_mask = torch.ones(NV, dtype=torch.bool)
                src_view_mask[::2] = 0
            src_poses = poses[src_view_mask].to(device=device)  # (NS, 4, 4)

            focal = data["focal"][0]
            if isinstance(focal, float):
                focal = torch.tensor(focal, dtype=torch.float32)
            focal = focal[None].to(device)

            c = data.get("c")
            if c is not None:
                c = c[0].to(device=device).unsqueeze(0)

            if renderer_type == "nerf":
                net.encode(
                    images[src_view_mask].to(device=device).unsqueeze(0),
                    src_poses.unsqueeze(0),
                    focal,
                    c=c,
                )

            all_sigma = []
            all_seg = []
            for rays in rays_spl:
                rays = rays.unsqueeze(0)
                if renderer_type == "nerf":
                    out = pts_par(rays[...,:3], coarse=False, viewdirs=torch.zeros_like(rays[...,:3]))
                elif renderer_type == "unisurf":
                    out = pts_par(rays[...,:3], images[src_view_mask].to(device=device).unsqueeze(0), src_poses.unsqueeze(0),
                                    focal, c,
                                    coarse=False, viewdirs=torch.zeros_like(rays[...,:3]))
                # out = pts_par(rays[...,:3], coarse=False, viewdirs=rays[...,3:6])
                sigma = out[0,:,0].cpu()
                all_sigma.append(sigma)
                seg = out[0,:,-n_classes:].cpu()
                all_seg.append(seg)

            all_seg = torch.cat(all_seg, dim=0)[:,1:].argmax(-1)
            all_sigma = torch.cat(all_sigma, dim=0)
            # print(all_seg.shape, conf["model"]['n_classes'], pts_labels.max())
            # print(pts_labels.max(), pts_labels.min())
            curr_seg_metrics = seg_metrics(all_seg.to(device), pts_labels.to(device).short()-1)

            if not args.no_compare_gt:
                running_seg_metrics = seg_metrics.compute()
                curr_seg_metrics = {'c_'+k: f'{v.item()*100:.2f}%' for (k, v) in curr_seg_metrics.items()}
                running_seg_metrics = {k: f'{v.item()*100:.2f}%' for (k, v) in running_seg_metrics.items()}
                curr_seg_metrics.update(running_seg_metrics)
                # print(
                #     curr_seg_metrics,
                #     running_seg_metrics
                # )
                t.postfix = curr_seg_metrics
                # print(seg_metrics['miou'].class_ious())
            # print(curr_seg_metrics)

            if args.write_compare:
                # print(all_seg, all_seg.shape, pts_labels)
                obj_out_dir = os.path.join(output_dir, obj_name)
                os.makedirs(obj_out_dir, exist_ok=True)
                all_seg_cmap = color_map[all_seg+1]
                out_file = os.path.join(
                    obj_out_dir, "pred_seg.ply"
                )
                save_pc(all_rays_pts.reshape(-1,3), all_seg_cmap.reshape(-1,3), out_file)
                # print(pts_labels.max(), pts_labels.min(), pts_labels[:100])
                gt_seg_cmap = color_map[pts_labels]
                out_file = os.path.join(
                    obj_out_dir, "gt_seg.ply"
                )
                save_pc(all_rays_pts.reshape(-1,3), gt_seg_cmap.reshape(-1,3), out_file)
                images_0to1 = images * 0.5 + 0.5  # (NV, 3, H, W)
                rgb_src_all = (images_0to1[src_view_mask].permute(0, 2, 3, 1).contiguous().numpy())  # (NV-NS, H, W, 3)
                out_im = [im for im in rgb_src_all]
                out_im = np.hstack(out_im)
                out_file = os.path.join(
                    obj_out_dir, "src.png"
                )
                imageio.imwrite(out_file, (out_im * 255).astype(np.uint8))
            
            t.update()
        wandb.log(seg_metrics.compute())
        print("class ious", class_mious(seg_metrics['miou']))
        wandb.log({'class_ious: ', class_mious(seg_metrics['miou'])})




