import os
import sys
cur_path=os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, cur_path+"/..")

from src.utils.point import fov2focal
from einops import rearrange, repeat
from scripts.point_cloud import PointCloud
from plyfile import PlyData, PlyElement
import torch
import numpy as np
import cv2
import json

os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# cuda device
cuda_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

depth_name_list = []
depth_image_list = []
mask_image_list = []
camera_to_world_list = []
split = "train"
path = 'data/demo/0e0a38f2d5424c00b7ce9a121a4d4017'
ply_output_path = 'data/demo/0e0a38f2d5424c00b7ce9a121a4d4017/demo.ply'
width = 512
height = 512




def load_depth(depth_path):
    img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
    # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
    mask = img > 1000 # depth = 65535 is invalid
    img[mask] = 0
    mask = ~mask
    return img[None, ...], mask[None, ...]

def compute_intrinsics(width, height, fov):
    fx = fov2focal(fov, width)
    fy = fx
    cx = width / 2
    cy = height / 2
    intrinsics = torch.tensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float64, device=cuda_device)
    print(intrinsics)
    return intrinsics
    
def save_ply(path: str, xyz: np.ndarray):
    l = ['x', 'y', 'z']
    dtype_full = [(attribute, 'f4') for attribute in l]
    elements = np.empty(xyz.shape[0], dtype=dtype_full)
    attributes = xyz
    elements[:] = list(map(tuple, attributes))
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(path)


def ray_sample(cam2world_matrix, intrinsics, resolution, depth_tensor):
    N, M = cam2world_matrix.shape[0], resolution**2
    cam_locs_world = cam2world_matrix[:, :3, 3]
    fx = intrinsics[:, 0, 0]
    fy = intrinsics[:, 1, 1]
    cx = intrinsics[:, 0, 2]
    cy = intrinsics[:, 1, 2]
    sk = intrinsics[:, 0, 1]
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float64, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float64, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) + 0.5
    
    # * (1.0 / resolution) + (0.5 / resolution)
    uv = repeat(uv, "c h w -> b (h w) c", b=N)
    x_cam = uv[:, :, 0].view(N, -1)
    y_cam = uv[:, :, 1].view(N, -1)
    z_cam = torch.ones((N, M), device=cam2world_matrix.device)

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    
    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam), dim=-1
    )
    ray_dirs = torch.nn.functional.normalize(cam_rel_points, dim=2)
    
    cam_depth_points = depth_tensor * ray_dirs
    cam_depth_points = torch.cat(
        (cam_depth_points, torch.ones((N, M, 1), device=cam2world_matrix.device)), dim=-1
    )
    world_rel_depth_points = torch.bmm(
        cam2world_matrix, cam_depth_points.permute(0, 2, 1)
    ).permute(0, 2, 1)[:, :, :3]
    return world_rel_depth_points


# with open(os.path.join(path, "transforms_{}.json".format(split)), "r") as f:
with open(os.path.join(path, "meta.json"), "r") as f:
    meta = json.load(f)

# random sample 10 images

for locations in meta["locations"][:1]:
    frames = locations["frames"]
    for frame in frames:
        if frame["type"] == "depth":
            depth_path = os.path.join(path, frame["name"])
            depth_name_list.append(depth_path)
    depth_image, mask_image = load_depth(depth_path)
    depth_image_list.append(depth_image)
    mask_image_list.append(mask_image)
    camera_to_world_list.append(locations["transform_matrix"])


camera_to_world = torch.tensor(camera_to_world_list, dtype=torch.float64, device=cuda_device) # [N, 4, 4]
depth_image_tensor = torch.tensor(np.array(depth_image_list), dtype=torch.float64, device=cuda_device) # [N, 1, H,  W]
depth_image_tensor = repeat(depth_image_tensor, 'n c h w -> n (w h c) d', d=3)
mask_image_tensor = torch.tensor(np.array(mask_image_list), dtype=torch.float64, device=cuda_device) # [N, 1, H,  W]
mask_image_tensor = repeat(mask_image_tensor, 'n c h w -> n (w h c) d', d=3)
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
camera_to_world[:, :3, 1:3] *= -1
world_to_camera = torch.linalg.inv(camera_to_world).to(torch.float64)

intrinsics = torch.tensor(meta["K_matrix"], dtype=torch.float64, device=cuda_device)
intrinsics = repeat(intrinsics, 'a b -> n a b', n=camera_to_world.shape[0])

# ray_origins: N, (H,W), 3
# ray_dirs: N, (H,W), 3
pc_coords = ray_sample(camera_to_world, intrinsics, width, depth_image_tensor)
# pc_coords: N, (H,W), 3
pc_coords = pc_coords[mask_image_tensor.bool()].reshape(-1, 3)
# pc_coords = pc_coords.reshape(-1, 3)

pc_class = PointCloud(pc_coords.detach().cpu().numpy(), {})
pc_subsampled = pc_class.random_sample(1000000)
pc_fps = pc_subsampled.farthest_point_sample(10000)
save_ply(ply_output_path, pc_fps.coords)