import os
import sys
cur_path=os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, cur_path+"/..")

from src.utils.point import fov2focal
from einops import rearrange, repeat
from scripts.point_cloud import PointCloud
from plyfile import PlyData, PlyElement
import torch
import numpy as np
import cv2
import json
from pytorch3d.ops import sample_farthest_points

os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# cuda device
cuda_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

depth_name_list = []
depth_image_list = []
mask_image_list = []
camera_to_world_list = []

FPS_NUM = pow(2, 9)
RANDOM_NUM = pow(10,7)
dataset_path = 'data/nerf/my_synthetic/data'
ply_filename = f'lego_{FPS_NUM}.ply'
ply_output_path = os.path.join(dataset_path, ply_filename)
json_filename = 'transforms_test.json'
json_path = os.path.join(dataset_path, json_filename)
width = 800
height = 800


def load_depth(depth_path):
    img = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
    # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
    mask = img > 1000 # depth = 65535 is invalid
    img[mask] = 0
    mask = ~mask
    return img, mask

def point_random_sample(pc_coords: torch.TensorType, num_points: int):
    # pc_coords: N, 3
    return pc_coords[torch.randperm(pc_coords.shape[0])[:num_points]]
    

def compute_intrinsics(width, height, fov):
    fx = fov2focal(fov, width)
    fy = fx
    cx = width / 2
    cy = height / 2
    intrinsics = torch.tensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float, device=cuda_device)
    print(intrinsics)
    return intrinsics
    
def save_ply(path: str, xyz: np.ndarray):
    l = ['x', 'y', 'z']
    dtype_full = [(attribute, 'f4') for attribute in l]
    elements = np.empty(xyz.shape[0], dtype=dtype_full)
    attributes = xyz
    elements[:] = list(map(tuple, attributes))
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(path)


def ray_sample(cam2world_matrix, intrinsics, resolution, depth_tensor):
    N, M = cam2world_matrix.shape[0], resolution**2
    cam_locs_world = cam2world_matrix[:, :3, 3]
    fx = intrinsics[:, 0, 0]
    fy = intrinsics[:, 1, 1]
    cx = intrinsics[:, 0, 2]
    cy = intrinsics[:, 1, 2]
    sk = intrinsics[:, 0, 1]
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) + 0.5
    
    # * (1.0 / resolution) + (0.5 / resolution)
    uv = repeat(uv, "c h w -> b (h w) c", b=N)
    x_cam = uv[:, :, 0].view(N, -1)
    y_cam = uv[:, :, 1].view(N, -1)
    z_cam = torch.ones((N, M), device=cam2world_matrix.device)

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    
    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam), dim=-1
    )
    # ray_dirs = torch.nn.functional.normalize(cam_rel_points, dim=2)
    
    cam_depth_points = depth_tensor * cam_rel_points
    cam_depth_points = torch.cat(
        (cam_depth_points, torch.ones((N, M, 1), device=cam2world_matrix.device)), dim=-1
    )
    world_rel_depth_points = torch.bmm(
        cam2world_matrix, cam_depth_points.permute(0, 2, 1)
    ).permute(0, 2, 1)[:, :, :3]
    return world_rel_depth_points


with open(json_path, "r") as f:
    transforms = json.load(f)

# # random sample 10 images
# random_idx = np.random.choice(len(transforms["frames"]), 1, replace=False)
# # transforms["frames"] = [transforms["frames"][i] for i in random_idx]
# random_idx = [129,134]
# transforms["frames"] = [transforms["frames"][i] for i in random_idx]

for frame in transforms["frames"]:
    image_name_with_extension = frame["file_path"] + ".png"
    path = "data/nerf/my_synthetic/data/"
    depth_path = os.path.join(path, image_name_with_extension.replace(".png", "_depth_0001.exr"))
    depth_image, mask_image = load_depth(depth_path)
    depth_image_list.append(depth_image)
    mask_image_list.append(mask_image)
    camera_to_world_list.append(frame["transform_matrix"])


camera_to_world = torch.tensor(camera_to_world_list, dtype=torch.float, device=cuda_device) # [N, 4, 4]
depth_image_tensor = torch.tensor(np.array(depth_image_list), dtype=torch.float, device=cuda_device) # [N, 1, H,  W]
depth_image_tensor = repeat(depth_image_tensor, 'n h w c-> n (w h) c')
mask_image_tensor = torch.tensor(np.array(mask_image_list), dtype=torch.float, device=cuda_device) # [N, 1, H,  W]
mask_image_tensor = repeat(mask_image_tensor, 'n h w c-> n (w h) c')
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
camera_to_world[:, :3, 1:3] *= -1
world_to_camera = torch.linalg.inv(camera_to_world).to(torch.float)

intrinsics = compute_intrinsics(width, height, transforms["camera_angle_x"])
intrinsics = repeat(intrinsics, 'a b -> n a b', n=camera_to_world.shape[0])

# ray_origins: N, (H,W), 3
# ray_dirs: N, (H,W), 3
pc_coords = ray_sample(camera_to_world, intrinsics, width, depth_image_tensor)
# pc_coords: N, (H,W), 3
pc_coords = pc_coords[mask_image_tensor.bool()].reshape(-1, 3)
pc_random = point_random_sample(pc_coords, RANDOM_NUM)
pc_fps, _= sample_farthest_points(points=pc_random.view(1, -1, 3), K=FPS_NUM)
pc_fps = pc_fps.view(-1, 3)
save_ply(ply_output_path, pc_fps.detach().cpu().numpy())
