from vgn.perception import *


def pixel2cam(u, v, depth, intrinsic) -> np.ndarray:
    intrinsic = _get_intrinsic_matrix(intrinsic)
    pos = np.linalg.inv(intrinsic) @ np.array([u, v, 1]) * depth
    return pos


def cam2pixel(x, y, z, intrinsic) -> np.ndarray:
    intrinsic = _get_intrinsic_matrix(intrinsic)
    pos = intrinsic @ np.array([x, y, z])
    pos = pos / pos[2]
    return pos[:2].astype(np.int)


def cam2world(x, y, z, extrinsic) -> np.ndarray:
    extrinsic = _get_extrinsic_matrix(extrinsic)
    pos = np.array([x, y, z, 1])
    pos = np.linalg.inv(extrinsic) @ pos
    return pos[:3]


def world2cam(U, V, W, extrinsic) -> np.ndarray:
    extrinsic = _get_extrinsic_matrix(extrinsic)
    pos = np.array([U, V, W, 1])
    pos = extrinsic @ pos
    return pos[:3]


def pixel2world(u, v, depth, intrinsic, extrinsic) -> np.ndarray:
    pos = pixel2cam(u, v, depth, intrinsic)
    pos = cam2world(*pos, extrinsic)
    return pos


def world2pixel(U, V, W, intrinsic, extrinsic) -> np.ndarray:
    pos = world2cam(U, V, W, extrinsic)
    pos = cam2pixel(*pos, intrinsic)
    return pos


def _get_intrinsic_matrix(intrinsic):
    if isinstance(intrinsic, CameraIntrinsic):
        return intrinsic.K
    elif isinstance(intrinsic, np.ndarray) and intrinsic.shape == (3, 3):
        return intrinsic
    else:
        raise ValueError("intrinsic error")


def _get_extrinsic_matrix(extrinsic):
    if isinstance(extrinsic, Transform):
        return extrinsic.as_matrix()
    elif isinstance(extrinsic, np.ndarray) and extrinsic.shape == (4, 4):
        return extrinsic
    else:
        raise ValueError("extrinsic error")
