import numpy as np
import cv2
from tqdm import tqdm
import os
from pathlib import Path
import torch
import main_numcc
import util.misc as misc
from src.engine.engine import prepare_data_udf
from src.engine.engine_viz import generate_html_udf
from src.fns import *
from src.model.nu_mcc import NUMCC
import open3d as o3d

def run_viz_udf(model, samples, device, args, prefix):
    model.eval()

    seen_xyz, valid_seen_xyz, query_xyz, unseen_rgb, labels, seen_images, gt_fps_xyz, seen_xyz_hr, valid_seen_xyz_hr = prepare_data_udf(samples, device, is_train=False, is_viz=True, args=args)
    seen_images_no_preprocess = seen_images.clone()


    with torch.no_grad():
        seen_images_hr = None
        
        if args.hr == 1:
            seen_images_hr = preprocess_img(seen_images.clone(), res=args.xyz_size)
            seen_xyz_hr = shrink_points_beyond_threshold(seen_xyz_hr, args.shrink_threshold)

        seen_images = preprocess_img(seen_images)
        query_xyz = shrink_points_beyond_threshold(query_xyz, args.shrink_threshold)
        seen_xyz = shrink_points_beyond_threshold(seen_xyz, args.shrink_threshold)


        latent, up_grid_fea = model.encoder(seen_images, seen_xyz, valid_seen_xyz, up_grid_bypass=seen_images_hr)
        fea = model.decoderl1(latent)
        centers_xyz = fea['anchors_xyz']
    
    max_n_queries_fwd = args.n_query_udf if not args.hr else int(args.n_query_udf * (args.xyz_size/args.xyz_size_hr)**2)

    # Filter query based on centers xyz # (1, 200, 3)
    offset = 0.3
    min_xyz = torch.min(centers_xyz, dim=1)[0][0] - offset
    max_xyz = torch.max(centers_xyz, dim=1)[0][0] + offset

    mask = (torch.rand(1, query_xyz.size()[1]) >= 0).to(args.device)
    mask = mask & (query_xyz[:,:,0] > min_xyz[0]) & (query_xyz[:,:,1] > min_xyz[1]) & (query_xyz[:,:,2] > min_xyz[2])
    mask = mask & (query_xyz[:,:,0] < max_xyz[0]) & (query_xyz[:,:,1] < max_xyz[1]) & (query_xyz[:,:,2] < max_xyz[2])
    query_xyz = query_xyz[mask].unsqueeze(0)

    total_n_passes = int(np.ceil(query_xyz.shape[1] / max_n_queries_fwd))

    pred_points = np.empty((0,3))
    pred_colors = np.empty((0,3))


    for param in model.parameters():
        param.requires_grad = False


    for p_idx in tqdm(range(total_n_passes)):
        p_start = p_idx     * max_n_queries_fwd
        p_end = (p_idx + 1) * max_n_queries_fwd
        cur_query_xyz = query_xyz[:, p_start:p_end]

        with torch.no_grad():
            if args.hr != 1:
                seen_points = seen_xyz
                valid_seen = valid_seen_xyz
            else:
                seen_points = seen_xyz_hr
                valid_seen = valid_seen_xyz_hr


            pred = model.decoderl2(cur_query_xyz, seen_points, valid_seen, fea, up_grid_fea, custom_centers = None)
            pred = model.fc_out(pred)

        max_dist = 0.5
        pred_udf = F.relu(pred[:,:,:1]).reshape((-1, 1)) # nQ, 1
        pred_udf = torch.clamp(pred_udf, max=max_dist) 

        # Candidate points
        t = args.udf_threshold
        pos = (pred_udf < t).squeeze(-1) # (nQ, )
        points = cur_query_xyz.squeeze(0) # (nQ, 3)
        points = points[pos].unsqueeze(0) # (1, n, 3)

        if torch.sum(pos) > 0:
            points = move_points(model, points, seen_points, valid_seen, fea, up_grid_fea, args, n_iter=args.udf_n_iter)

            # predict final color
            with torch.no_grad():
                pred = model.decoderl2(points, seen_points, valid_seen, fea, up_grid_fea)
                pred = model.fc_out(pred)

            cur_color_out = pred[:,:,1:].reshape((-1, 3, 256)).max(dim=2)[1] / 255.0
            cur_color_out = cur_color_out.detach().squeeze(0).cpu().numpy()
            if len(cur_color_out.shape) == 1:
                cur_color_out = cur_color_out[None,...]
            pts = points.detach().squeeze(0).cpu().numpy()
            pred_points = np.append(pred_points, pts, axis = 0)
            pred_colors = np.append(pred_colors, cur_color_out, axis = 0)
        
    img = (seen_images_no_preprocess[0].permute(1, 2, 0) * 255).cpu().numpy().copy().astype(np.uint8)

    fn_pc = None
    fn_pc_seen = None
    epoch = None
    if args.save_pc == 1:
        out_folder_ply = os.path.join('experiments/', f'{args.exp_name}', 'ply', 'epoch'+str(epoch).zfill(3))
        Path(out_folder_ply).mkdir(parents= True, exist_ok=True)
        prefix_pc = os.path.join(out_folder_ply, 'demo_udf')
        fn_pc = prefix_pc + '.ply'

        # seen
        out_folder_ply = os.path.join('experiments/', f'{args.exp_name}', 'ply_seen', 'epoch'+str(epoch).zfill(3))
        Path(out_folder_ply).mkdir(parents= True, exist_ok=True)
        prefix_pc = os.path.join(out_folder_ply, 'demo_udf')
        fn_pc_seen = prefix_pc +'_seen' +'.ply'

    with open(prefix + '.html', 'a') as f:
        generate_html_udf(
            img,
            seen_xyz, seen_images_no_preprocess,
            pred_points,
            pred_colors,
            query_xyz,
            f,
            centers = centers_xyz,
            fn_pc=fn_pc,
            fn_pc_seen = fn_pc_seen,
            pointcloud_marker_size=3
        )


def pad_image(im, value):
    if im.shape[0] > im.shape[1]:
        diff = im.shape[0] - im.shape[1]
        return torch.cat([im, (torch.zeros((im.shape[0], diff, im.shape[2])) + value)], dim=1)
    else:
        diff = im.shape[1] - im.shape[0]
        return torch.cat([im, (torch.zeros((diff, im.shape[1], im.shape[2])) + value)], dim=0)


def normalize(seen_xyz):
    seen_xyz = seen_xyz / (seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].var(dim=0) ** 0.5).mean()
    seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0)
    return seen_xyz


def main_demo(args):

    misc.init_distributed_mode(args)

    model = NUMCC(args=args)
    model = model.to(args.device)
    misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)

    rgb = cv2.imread(args.image)

    H = rgb.shape[0]
    W = rgb.shape[1]

    eye = [1,0,0]
    up = [0, 1, 0]
    center = [0,0,0]

    depth = cv2.imread(args.depth, cv2.IMREAD_UNCHANGED).astype(np.float32)
    resized = cv2.resize(depth, (W, H), interpolation = cv2.INTER_LINEAR)
    resized = resized * -1

    rays = o3d.t.geometry.RaycastingScene.create_rays_pinhole(
        fov_deg=args.fov,
        center=center,
        eye=eye,
        up=up,
        width_px=W,
        height_px=H,
    )

    mask1 = np.random.rand(H, W) > 0
    points = rays[mask1][:,:3] + rays[mask1][:,3:]*resized.reshape((-1,1))
    pcd = o3d.t.geometry.PointCloud(points)
    obj = np.asarray(pcd.to_legacy().points)
    obj = torch.from_numpy(obj).float()

    seen_rgb = (torch.tensor(rgb).float() / 255)[..., [2, 1, 0]]

    
    seen_rgb = torch.nn.functional.interpolate(
        seen_rgb.permute(2, 0, 1)[None],
        size=[H, W],
        mode="bilinear",
        align_corners=False,
    )[0].permute(1, 2, 0)

    seen_xyz = obj.reshape(H, W, 3)
    seen_xyz
    seg = cv2.imread(args.seg, cv2.IMREAD_UNCHANGED)

    mask = torch.tensor(cv2.resize(seg, (W, H))).bool()
    seen_xyz[~mask] = float('inf')
    seen_xyz = normalize(seen_xyz)

    bottom, right = mask.nonzero().max(dim=0)[0]
    top, left = mask.nonzero().min(dim=0)[0]

    margin = 10
    bottom = bottom + margin
    right = right + margin
    top = max(top - margin, 0)
    left = max(left - margin, 0)

    seen_xyz = seen_xyz[top:bottom+1, left:right+1]
    seen_rgb = seen_rgb[top:bottom+1, left:right+1]

    seen_xyz = pad_image(seen_xyz, float('inf'))
    seen_rgb = pad_image(seen_rgb, 0)

    seen_rgb = torch.nn.functional.interpolate(
        seen_rgb.permute(2, 0, 1)[None],
        size=[800, 800],
        mode="bilinear",
        align_corners=False,
    )

    seen_xyz_ori = seen_xyz.clone()

    seen_xyz = torch.nn.functional.interpolate(
        seen_xyz.permute(2, 0, 1)[None],
        size=[112, 112],
        mode="bilinear",
        align_corners=False,
    ).permute(0, 2, 3, 1)

    seen_xyz_hr = torch.nn.functional.interpolate(
        seen_xyz_ori.permute(2, 0, 1)[None],
        size=[args.xyz_size, args.xyz_size],
        mode="bilinear",
        align_corners=False,
    ).permute(0, 2, 3, 1)

    samples = [
        [seen_xyz, seen_rgb, seen_xyz_hr],
        [torch.zeros((20000, 3)), torch.zeros((20000, 3))],
    ]
    run_viz_udf(model, samples, "cuda", args, prefix=args.output)


if __name__ == '__main__':
    parser = main_numcc.get_args_parser()
    parser.add_argument('--image', default='demo/web/gudetama/im.png', type=str, help='input image file')
    parser.add_argument('--depth', default='demo/web/gudetama/depth.png', type=str, help='input depth file')
    parser.add_argument('--seg', default='demo/web/gudetama/seg.png', type=str, help='input seg file')
    parser.add_argument('--output', default='demo/output', type=str, help='output path')
    parser.add_argument('--checkpoint', default='pretrained/udf-ep99.pth', type=str, help='model checkpoint')
    parser.add_argument('--batch_query', default=48000, type=int, help='batch query for repulsive UDF')
    parser.add_argument('--fov', default=35, type=int, help='Field of view for depth to point cloud')

    parser.set_defaults(eval=True)

    args = parser.parse_args()
    args.resume = args.checkpoint
    args.n_query_udf = args.batch_query
    main_demo(args)

