"""
Eval on real images from input/*_normalize.png, output to output/
"""
import sys
import os

ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(ROOT_DIR, "src"))

import util
import torch
import numpy as np
from model import make_model
from render import NeRFRenderer
import torchvision.transforms as T
import tqdm
import imageio
from PIL import Image

color_map = {
        0: (0., 0., 0.),
        1: (174., 199., 232.),
        2: (152., 223., 138.),
        3: (31., 119., 180.),
        4: (255., 187., 120.),
        5: (188., 189., 34.),
        6: (140., 86., 75.),
        7: (255., 152., 150.),
        8: (214., 39., 40.),
        9: (197., 176., 213.),
        10: (148., 103., 189.),
        11: (196., 156., 148.),
        12: (23., 190., 207.),
        13: (100., 85., 144.),
        14: (247., 182., 210.),
        15: (66., 188., 102.),
        16: (219., 219., 141.),
        17: (140., 57., 197.),
        18: (202., 185., 52.),
        19: (51., 176., 203.),
        20: (200., 54., 131.),
        21: (92., 193., 61.),
        22: (78., 71., 183.),
        23: (172., 114., 82.),
        24: (255., 127., 14.),
        25: (91., 163., 138.),
        26: (153., 98., 156.),
        27: (140., 153., 101.),
        28: (158., 218., 229.),
        29: (100., 125., 154.),
        30: (178., 127., 135.),
        32: (146., 111., 194.),
        33: (44., 160., 44.),
        34: (112., 128., 144.),
        35: (96., 207., 209.),
        36: (227., 119., 194.),
        37: (213., 92., 176.),
        38: (94., 106., 211.),
        39: (82., 84., 163.),
        # 40: (100., 85., 144.),
        -1: (255., 0., 0.),
    }
def save_pc(PC, PC_color, filename):
    from plyfile import PlyElement, PlyData
    PC = np.concatenate((PC, PC_color), axis=1)
    PC = [tuple(element) for element in PC]
    el = PlyElement.describe(np.array(PC, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]), 'vertex')
    PlyData([el]).write(filename)

def extra_args(parser):
    parser.add_argument(
        "--input",
        "-I",
        type=str,
        default=os.path.join(ROOT_DIR, "input"),
        help="Image directory",
    )
    parser.add_argument(
        "--output",
        "-O",
        type=str,
        default=os.path.join(ROOT_DIR, "output"),
        help="Output directory",
    )
    parser.add_argument("--size", type=int, default=128, help="Input image maxdim")
    parser.add_argument(
        "--out_size",
        type=str,
        default="128",
        help="Output image size, either 1 or 2 number (w h)",
    )

    parser.add_argument("--focal", type=float, default=131.25, help="Focal length")

    parser.add_argument("--radius", type=float, default=1.3, help="Camera distance")
    parser.add_argument("--z_near", type=float, default=0.8)
    parser.add_argument("--z_far", type=float, default=1.8)

    parser.add_argument(
        "--elevation",
        "-e",
        type=float,
        default=0.0,
        help="Elevation angle (negative is above)",
    )
    parser.add_argument(
        "--num_views",
        type=int,
        default=24,
        help="Number of video frames (rotated views)",
    )
    parser.add_argument("--fps", type=int, default=15, help="FPS of video")
    parser.add_argument("--gif", action="store_true", help="Store gif instead of mp4")
    parser.add_argument(
        "--no_vid",
        action="store_true",
        help="Do not store video (only image frames will be written)",
    )
    return parser


args, conf = util.args.parse_args(
    extra_args, default_expname="srn_car", default_data_format="srn",
)
args.resume = True

device = util.get_cuda(args.gpu_id[0])
net = make_model(conf["model"]).to(device=device).load_weights(args)
n_classes = conf["model"]['n_classes']
renderer = NeRFRenderer.from_conf(
    conf["renderer"], eval_batch_size=args.ray_batch_size
).to(device=device)
render_par = renderer.bind_parallel(net, args.gpu_id, simple_output=True).eval()

z_near, z_far = args.z_near, args.z_far
focal = torch.tensor(args.focal, dtype=torch.float32, device=device)

in_sz = args.size
sz = list(map(int, args.out_size.split()))
if len(sz) == 1:
    H = W = sz[0]
else:
    assert len(sz) == 2
    W, H = sz


N = 257
t = np.linspace(-1.0, 1.0, N)
query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
sh = query_pts.shape
all_rays_pts = torch.from_numpy(query_pts.reshape([-1,3]))
all_rays_dir = torch.zeros_like(all_rays_pts)
all_rays_z_near = torch.ones((all_rays_pts.shape[0], 1)) * z_near
all_rays_z_far = torch.ones((all_rays_pts.shape[0], 1)) * z_far
all_rays = torch.cat([all_rays_pts, all_rays_dir, all_rays_z_near, all_rays_z_far], dim=-1).to(device=device)

# render_rays = util.gen_rays(render_poses, W, H, focal, z_near, z_far).to(device=device)
render_rays = all_rays


inputs_all = os.listdir(args.input)
inputs = [
    os.path.join(args.input, x) for x in inputs_all if x.endswith("_normalize.png")
]
os.makedirs(args.output, exist_ok=True)

if len(inputs) == 0:
    if len(inputs_all) == 0:
        print("No input images found, please place an image into ./input")
    else:
        print("No processed input images found, did you run 'scripts/preproc.py'?")
    import sys

    sys.exit(1)

cam_pose = torch.eye(4, device=device)
cam_pose[2, -1] = args.radius
print("SET DUMMY CAMERA")
print(cam_pose)

image_to_tensor = util.get_image_to_tensor_balanced()

with torch.no_grad():
    for i, image_path in enumerate(inputs):
        print("IMAGE", i + 1, "of", len(inputs), "@", image_path)
        image = Image.open(image_path).convert("RGB")
        image = T.Resize(in_sz)(image)
        image = image_to_tensor(image).to(device=device)

        net.encode(
            image.unsqueeze(0), cam_pose.unsqueeze(0), focal,
        )
        print("Rendering", args.num_views * H * W, "rays")
        # all_rgb_fine = []
        all_sigma = []
        all_rgb = []
        all_seg = []
        for rays in tqdm.tqdm(torch.split(render_rays.view(-1, 8), 80000, dim=0)):
            rays = rays.unsqueeze(0)
            out = render_par.net(rays[...,:3], coarse=False, viewdirs=rays[...,3:6])
            sigma = out[0,:,0].cpu()
            all_sigma.append(sigma)
            rgb = out[0,:,1:4].cpu()
            all_rgb.append(rgb)
            seg = out[0,:,-n_classes:].cpu()
            all_seg.append(seg)

        all_rgb = torch.cat(all_rgb, dim=0)
        all_rgb = torch.clamp(
            all_rgb.reshape(N, N, N, 3), 0.0, 1.0
        ).numpy()  # (NV-NS, H, W, 3)
        all_seg = torch.cat(all_seg, dim=0)
        all_seg = all_seg.reshape(N, N, N, conf["model"]['n_classes']).argmax(-1).numpy()  # (NV-NS, H, W)

        all_sigma = torch.cat(all_sigma, dim=0)
        all_sigma = all_sigma.reshape(N, N, N).numpy()

        for th in range(1,10):
            # th = 8
            dense_pts_idx = all_sigma > th
            # seg_pt_idx = all_seg != 0
            # pt_select_idx = np.logical_and(dense_pts_idx, seg_pt_idx)
            pt_select_idx = dense_pts_idx

            all_pts = all_rays[...,:3].reshape(N,N,N,3)[pt_select_idx].cpu().numpy()


            all_seg_cmap = np.array(list(color_map.values()))[all_seg]
            im_name = os.path.basename(os.path.splitext(image_path)[0])
            obj_out_dir = os.path.join(args.output, im_name)
            os.makedirs(obj_out_dir, exist_ok=True)
            out_file = os.path.join(
                obj_out_dir, "rgb_"+str(th)+".ply"
            )
            save_pc(all_pts.reshape(-1,3), all_rgb[pt_select_idx].reshape(-1,3) * 255., out_file)
            out_file = os.path.join(
                obj_out_dir, "seg_"+str(th)+".ply"
            )
            save_pc(all_pts.reshape(-1,3), all_seg_cmap[pt_select_idx].reshape(-1,3), out_file)
