import os
import sys
cur_path=os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, cur_path+"/..")

from src.utils.point import fov2focal
from einops import rearrange, repeat
from scripts.point_cloud import PointCloud
from plyfile import PlyData, PlyElement
import torch
import numpy as np
import cv2
import json

os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# cuda device
cuda_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

depth_name_list = []
depth_image_list = []
mask_image_list = []
camera_to_world_list = []
split = "train"
path = 'data/gamepad'
ply_output_path = 'data/gamepad/demo.ply'
width = 256
height = 256




def load_depth(depth_path):
    img = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
    # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
    mask = img > 1000 # depth = 65535 is invalid
    img[mask] = 0
    mask = ~mask
    return img, mask

def compute_intrinsics(width, height, fov):
    fx = fov2focal(fov, width)
    fy = fx
    cx = width / 2
    cy = height / 2
    intrinsics = torch.tensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float64, device=cuda_device)
    return intrinsics
    
def save_ply(path: str, xyz: np.ndarray):
    l = ['x', 'y', 'z']
    dtype_full = [(attribute, 'f4') for attribute in l]
    elements = np.empty(xyz.shape[0], dtype=dtype_full)
    attributes = xyz
    elements[:] = list(map(tuple, attributes))
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(path)


def ray_sample(cam2world_matrix, intrinsics, resolution, depth_tensor):
    N, M = cam2world_matrix.shape[0], resolution**2
    cam_locs_world = cam2world_matrix[:, :3, 3]
    fx = intrinsics[:, 0, 0]
    fy = intrinsics[:, 1, 1]
    cx = intrinsics[:, 0, 2]
    cy = intrinsics[:, 1, 2]
    sk = intrinsics[:, 0, 1]
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float64, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float64, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) + 0.5
    
    # * (1.0 / resolution) + (0.5 / resolution)
    uv = repeat(uv, "c h w -> b (h w) c", b=N)
    x_cam = uv[:, :, 0].view(N, -1)
    y_cam = uv[:, :, 1].view(N, -1)
    z_cam = torch.ones((N, M), device=cam2world_matrix.device)

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    
    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam), dim=-1
    )
    # ray_dirs = torch.nn.functional.normalize(cam_rel_points, dim=2)
    
    cam_depth_points = depth_tensor * cam_rel_points
    cam_depth_points = torch.cat(
        (cam_depth_points, torch.ones((N, M, 1), device=cam2world_matrix.device)), dim=-1
    )
    world_rel_depth_points = torch.bmm(
        cam2world_matrix, cam_depth_points.permute(0, 2, 1)
    ).permute(0, 2, 1)[:, :, :3]
    return world_rel_depth_points


# with open(os.path.join(path, "transforms_{}.json".format(split)), "r") as f:
with open(os.path.join(path, "meta.json"), "r") as f:
    meta = json.load(f)

# random sample 10 images
random_idx = np.random.choice(len(meta["locations"]), 10, replace=False)
locations = [meta["locations"][i] for i in random_idx]

for l in locations:
    frames = l["frames"]
    for frame in frames:
        idx = frame["name"].split(".")[0][7:]
        depth_path = os.path.join(path, f'depth_{idx}.exr')
        depth_name_list.append(depth_path)
    depth_image, mask_image = load_depth(depth_path)
    depth_image_list.append(depth_image)
    mask_image_list.append(mask_image)
    camera_to_world_list.append(l["transform_matrix"])


camera_to_world = torch.tensor(camera_to_world_list, dtype=torch.float64, device=cuda_device) # [N, 4, 4]
depth_image_tensor = torch.tensor(np.array(depth_image_list), dtype=torch.float64, device=cuda_device) # [N, 1, H,  W]
depth_image_tensor = repeat(depth_image_tensor, 'n h w c-> n (w h) c')
mask_image_tensor = torch.tensor(np.array(mask_image_list), dtype=torch.float64, device=cuda_device) # [N, 1, H,  W]
mask_image_tensor = repeat(mask_image_tensor, 'n h w c-> n (w h) c')
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
camera_to_world[:, :3, 1:3] *= -1
world_to_camera = torch.linalg.inv(camera_to_world).to(torch.float64)

intrinsics = compute_intrinsics(width, height, meta["camera_angle_x"])
intrinsics = repeat(intrinsics, 'a b -> n a b', n=camera_to_world.shape[0])

# ray_origins: N, (H,W), 3
# ray_dirs: N, (H,W), 3
pc_coords = ray_sample(camera_to_world, intrinsics, width, depth_image_tensor)
# pc_coords: N, (H,W), 3
pc_coords = pc_coords[mask_image_tensor.bool()].reshape(-1, 3)
# pc_coords = pc_coords.reshape(-1, 3)

pc_class = PointCloud(pc_coords.detach().cpu().numpy(), {})
pc_subsampled = pc_class.random_sample(1000000)
pc_fps = pc_subsampled.farthest_point_sample(10000)
save_ply(ply_output_path, pc_fps.coords)