import os
from src.data.gobjaverse import MultiViewDataset
from plyfile import PlyData, PlyElement
import numpy as np
from src.models.unet.models.position_map import get_position_map_from_depth
import torch
from einops import rearrange, repeat

def save_ply(path: str, xyz: np.ndarray):
    l = ['x', 'y', 'z']
    dtype_full = [(attribute, 'f4') for attribute in l]
    elements = np.empty(xyz.shape[0], dtype=dtype_full)
    attributes = xyz
    elements[:] = list(map(tuple, attributes))
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(path)

def get_position_map(depth, cam2world_matrix, intrinsics, resolution, sensor_size=1):
    """
    Create batches of rays and return origins and directions.

    cam2world_matrix: (N, 4, 4)
    intrinsics: (N, 3, 3)
    resolution: int

    ray_origins: (N, M, 3)
    ray_dirs: (N, M, 3)
    """
    b, f, _, h, w = depth.shape
    depth = rearrange(depth, "b f 1 h w -> b f h w")
    fx = intrinsics[..., 0, 0].unsqueeze(-1)
    fy = intrinsics[..., 1, 1].unsqueeze(-1)
    cx = intrinsics[..., 0, 2].unsqueeze(-1)
    cy = intrinsics[..., 1, 2].unsqueeze(-1)
    sk = intrinsics[..., 0, 1].unsqueeze(-1)
    uv = torch.stack(
        torch.meshgrid(
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            torch.arange(
                resolution, dtype=torch.float32, device=cam2world_matrix.device
            ),
            indexing="ij",
        )
    ) 

    uv = uv * (1.0 * sensor_size / resolution) 


    uv = repeat(uv, "c w h -> b f h w c", b=b, f=f)
    x_cam = uv[..., 0]
    y_cam = uv[..., 1]
    z_cam = depth

    x_lift = (
        (
            x_cam
            - cx.unsqueeze(-1)
            + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
            - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
        )
        / fx.unsqueeze(-1)
        * z_cam
    )
    y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam

    cam_rel_points = torch.stack(
        (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1
    ).to(cam2world_matrix.dtype) 

    world_rel_points = torch.einsum(
        "b f i j, b f h w j -> b f h w i", cam2world_matrix, cam_rel_points
    )[..., :3]

    return world_rel_points


if __name__ == "__main__":
    dataset = MultiViewDataset(
        root_dir="data/gobjaverse/data",
        instance_file="data/gobjaverse/gobjaverse_280k_Food.json",
        bg_color="white",
        img_wh=(512, 512), relative_pose=True)
    depth = dataset[0]["depths"][None]
    c2w = dataset[0]["c2w"][None] # f x 4 x 4
    intrinsics = dataset[0]["intrinsics"][None]
    mask = dataset[0]["masks"][None]
    mask = rearrange(mask, "b f 1 h w -> b f h w 1")
    position_map = get_position_map(depth, c2w, intrinsics, 512)

    position_map = position_map * mask
    position_map = position_map.cpu().numpy()
    # position_map = position_map / position_map.max()
    points = position_map.reshape(-1, 3)
    save_ply("position_map1.ply", points)

    
    from PIL import Image
    import matplotlib.pyplot as plt
    position_map = rearrange(position_map, "b f h w c -> h (b f w) c") 
    position_map = position_map / position_map.max()
    position_map = (position_map * 255).astype("uint8")
    position_map = Image.fromarray(position_map)
    position_map.save("position_map.png")
    # save depth map
    depth = depth.cpu().numpy()
    depth = (depth / depth.max() * 255).astype("uint8")
    depth = rearrange(depth, "b f 1 h w -> h (b f w)")
    depth = Image.fromarray(depth)
    depth.save("depth.png")
    mask = rearrange(mask, "b f h w 1 -> h (b f w)")
    mask = mask.cpu().numpy()
    mask = (mask * 255).astype("uint8")
    mask = Image.fromarray(mask)
    mask.save("mask.png")
