import os
import sys
cur_path=os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, cur_path+"/..")

from src.utils.point import fov2focal
from src.utils.project import ray_sample
from einops import rearrange, repeat
from scripts.point_cloud import PointCloud
from plyfile import PlyData, PlyElement
import torch
import numpy as np
import cv2
import json

os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"

# cuda device
cuda_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

depth_name_list = []
depth_image_list = []
mask_image_list = []
camera_to_world_list = []

ply_output_path = 'lego.ply'
width = 800
height = 800

def load_depth(depth_path):
    img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
    # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
    mask = img > 1000 # depth = 65535 is invalid
    img[mask] = 0
    mask = ~mask
    return img[None, ...], mask[None, ...]
    
def get_intrinsics_by_fov(fov):
    focal_length = 0.5 * 1 / np.tan(0.5 * fov)
    intrinsics = np.array(
        [
            [focal_length, 0, 0.5],
            [0, focal_length, 0.5],
            [0, 0, 1],
        ],
        dtype=np.float32
    )
    intrinsics = torch.from_numpy(intrinsics)
    return intrinsics

def compute_intrinsics(width, height, fov):
    fx = fov2focal(fov, width)
    fy = fx
    cx = width / 2
    cy = height / 2
    intrinsics = torch.tensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float64, device=cuda_device)
    print(intrinsics)
    return intrinsics

def depth2pcd(depth, intrinsics, pose):
    inv_K = np.linalg.inv(intrinsics)
    inv_K[2, 2] = -1
    depth = np.flipud(depth) # 将矩阵进行上下翻转
    y, x = np.where(depth < 65504) # 返回索引
    # image coordinates -> camera coordinates
    points = np.dot(inv_K, np.stack([x, y, np.ones_like(x)] * depth[y, x], 0))
    # camera coordinates -> world coordinates
    points = np.dot(pose, np.concatenate([points, np.ones((1, points.shape[1]))], 0)).T[:, :3]
    return torch.from_numpy(points).to(torch.float32)

def ray_sample(cam2world_matrix, intrinsics, resolution, sensor_size=1, depth=None):
    N, M = cam2world_matrix.shape[0], resolution**2
    cam_locs_world = cam2world_matrix[:, :3, 3]
    fx = intrinsics[:, 0, 0]
    fy = intrinsics[:, 1, 1]
    cx = intrinsics[:, 0, 2]
    cy = intrinsics[:, 1, 2]
    sk = intrinsics[:, 0, 1]
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) 

    uv = uv * (1.0 * sensor_size / resolution) + (0.5 * sensor_size / resolution)

    uv = repeat(uv, "c h w -> b (h w) c", b=N)
    x_cam = uv[:, :, 0].view(N, -1)
    y_cam = uv[:, :, 1].view(N, -1)
    z_cam = torch.ones((N, M), device=cam2world_matrix.device)

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1
    )

    cam_rel_points[:, :, :3] = cam_rel_points[:, :, :3] * depth
    cam_rel_points = cam_rel_points.to(torch.float32)
    world_rel_points = torch.bmm(
        cam2world_matrix, cam_rel_points.permute(0, 2, 1)
    ).permute(0, 2, 1)[:, :, :3]

    return world_rel_points

def save_ply(path: str, xyz: np.ndarray):
    l = ['x', 'y', 'z']
    dtype_full = [(attribute, 'f4') for attribute in l]
    elements = np.empty(xyz.shape[0], dtype=dtype_full)
    attributes = xyz
    elements[:] = list(map(tuple, attributes))
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(path)


# with open(os.path.join(path, "transforms_{}.json".format(split)), "r") as f:
with open("data/nerf/my_synthetic/data/transforms_test.json", "r") as f:
    transforms = json.load(f)

# random sample 10 images
random_idx = [i for i in range(10)]
transforms["frames"] = [transforms["frames"][i] for i in random_idx]

for frame in transforms["frames"]:
    image_name_with_extension = frame["file_path"] + ".png"
    path = "data/nerf/my_synthetic/data/"
    depth_path = os.path.join(path, image_name_with_extension.replace(".png", "_depth_0001.exr"))
    depth_image, mask_image = load_depth(depth_path)
    depth_image_list.append(depth_image)
    mask_image_list.append(mask_image)
    camera_to_world_list.append(frame["transform_matrix"])


camera_to_world = torch.tensor(camera_to_world_list, dtype=torch.float, device=cuda_device) # [N, 4, 4]
depth_image_tensor = torch.tensor(np.array(depth_image_list), dtype=torch.float, device=cuda_device) # [N, 1, H,  W]

depth_image_tensor = repeat(depth_image_tensor, 'n c h w -> n (w h c) d', d=3)
# depth_im = rearrange(depth_image_tensor, 'n 1 h w -> n h w')
# depth_image_tensor = torch.ones_like(depth_image_tensor)
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
camera_to_world[:, :, 1:3] *= -1
world_to_camera = torch.linalg.inv(camera_to_world).to(torch.float)

# intrinsics = get_intrinsics_by_fov(transforms["camera_angle_x"])
intrinsics = compute_intrinsics(width, height, transforms["camera_angle_x"])
intrinsics = repeat(intrinsics, 'a b -> n a b', n=camera_to_world.shape[0])

# ray_origins: N, (H,W), 3
# ray_dirs: N, (H,W), 3
# camera_to_world = torch.tensor(
#     [
#         [1., 0, 0, 0],
#         [0, 1., 0, 0],
#         [0, 0, 1., 0],
#         [0, 0, 0, 1.],
#     ]
# ).unsqueeze(0)
# print(camera_to_world.shape)
# depth_image_tensor = torch.ones_like(depth_image_tensor)
pc_coords = ray_sample(camera_to_world, intrinsics, width, sensor_size=width, depth=depth_image_tensor)
# pc_coords: N, (H,W), 3
# pc_coords = depth2pcd(depth_im, intrinsics, camera_to_world)
# pc_coords = pc_coords[mask_image_tensor.bool()].reshape(-1, 3)
pc_coords = pc_coords.reshape(-1, 3)

pc_class = PointCloud(pc_coords.detach().cpu().numpy(), {})
pc_subsampled = pc_class.random_sample(1000000)
save_ply(ply_output_path, pc_subsampled.coords)