import numpy as np
import torch
import cv2


def yaw_to_rotation_matrix(yaw: float) -> np.ndarray:
    """
    Convert yaw angle (radians) to rotation matrix in camera coordinate system.
    Camera convention: x-right, y-down, z-forward.
    """
    c = np.cos(yaw)
    s = np.sin(yaw)
    R = np.array([
        [ c, 0,  -s],
        [ 0, 1,  0],
        [s, 0,  c]
    ])
    return R

def rotation_matrix_to_yaw(R: np.ndarray) -> float:
    """
    Convert rotation matrix back to yaw angle (radians, clockwise positive)
    in camera coordinate system (x-right, y-down, z-forward).
    """
    yaw = np.arctan2(R[2,0], R[0,0])
    return yaw


def normalize(v):
    norm = np.linalg.norm(v)
    if norm == 0:
        return v
    return v / norm


def rotation_matrix_from_vectors(vec1, vec2):
    vec1 = normalize(vec1)
    vec2 = normalize(vec2)
    axis = np.cross(vec1, vec2)
    cos_theta = np.dot(vec1, vec2)
    skew_symmetric = np.array([
        [0, -axis[2], axis[1]],
        [axis[2], 0, -axis[0]],
        [-axis[1], axis[0], 0]
    ])
    rotation_matrix = np.eye(3) + skew_symmetric + np.dot(skew_symmetric, skew_symmetric) * (1 - cos_theta) / (np.linalg.norm(axis) ** 2)
    return rotation_matrix

def depthmap_to_pts3d(depth, focal, cu=None, cv=None, mask=None):
    H, W = depth.shape
    if cu is None:
        cu = (W - 1) * 0.5
    if cv is None:
        cv = (H - 1) * 0.5

    grid_x, grid_y = np.meshgrid(*[np.arange(0, s) for s in (W, H)])
    pts3d = np.zeros((H, W, 3))
    if mask is not None:
        pts3d[mask, 0] = depth[mask] * (grid_x[mask] - cu) / focal     # (u - cu) / f * d
        pts3d[mask, 1] = depth[mask] * (grid_y[mask] - cv) / focal     # (v - cv) / f * d
        pts3d[mask, 2] = depth[mask]
        return pts3d[pts3d[..., 2] != 0]
    else:
        pts3d[..., 0] = depth * (grid_x - cu) / focal     # (u - cu) / f * d
        pts3d[..., 1] = depth * (grid_y - cv) / focal     # (v - cv) / f * d
        pts3d[..., 2] = depth
        return pts3d
    

def meshgrid2d(x, y, device="cpu"):
    grid_y = torch.linspace(0.0, y-1, y, device=device).reshape((y, 1))
    grid_y = grid_y.repeat(1, x)

    grid_x = torch.linspace(0.0, x-1, x, device=device).reshape((1, x))
    grid_x = grid_x.repeat(y, 1)

    x, y = grid_x.reshape(-1), grid_y.reshape(-1)

    return torch.stack([x, y], dim=1).float()


def rigid_points_registration_numpy(src_pts: np.ndarray, tgt_pts: np.ndarray, weights: np.ndarray = None, compute_scaling: bool = False):
    if weights is None:
        weights = np.ones(src_pts.shape[0])
    
    src_pts_mean = np.mean(weights[..., None] * src_pts, axis=0)
    tgt_pts_mean = np.mean(weights[..., None] * tgt_pts, axis=0)

    src_pts_centered = src_pts - src_pts_mean
    tgt_pts_centered = tgt_pts - tgt_pts_mean
    weights /= (weights.sum() + 1e-12)

    cov = (weights[:, None] * src_pts_centered).T @ tgt_pts_centered
    U, S, Vh = np.linalg.svd(cov)
    R = Vh.T @ U.T
    if np.linalg.det(R) < 0:
        Vh[2, :] *= -1
        R = Vh.T @ U.T

    if compute_scaling:
        scale = np.sum(S) / np.trace((weights[:, None] * src_pts_centered).T @ src_pts_centered)
        t = tgt_pts_mean - scale * (src_pts_mean @ R.T)
        return R, t, scale
    else:
        t = tgt_pts_mean - (src_pts_mean @ R.T)
        return R, t, 1
    

def find_relative_pose(kpts1, kpts2, K, ransac_threshold=0.5):
    # normalize keypoints
    kpts1 = (kpts1 - K[[0, 1], [2, 2]][None]) / K[[0, 1], [0, 1]][None]
    kpts2 = (kpts2 - K[[0, 1], [2, 2]][None]) / K[[0, 1], [0, 1]][None]

    # normalize ransac threshold
    ransac_thr = ransac_threshold / np.mean([K[0, 0], K[1, 1], K[1, 1], K[0, 0]])

    E, mask = cv2.findEssentialMat(
        kpts1,
        kpts2,
        np.eye(3),
        method=cv2.RANSAC,
        prob=0.999,
        threshold=ransac_thr,
        maxIters=1000
    )
    mask = mask.ravel() == 1
    _, R, t, _ = cv2.recoverPose(E, kpts1[mask], kpts2[mask], np.eye(3), 1e9)
    return R, t.squeeze()



def geotrf(Trf, pts, ncol=None, norm=False):
    """ Apply a geometric transformation to a list of 3-D points.

    H: 3x3 or 4x4 projection matrix (typically a Homography)
    p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)

    ncol: int. number of columns of the result (2 or 3)
    norm: float. if != 0, the resut is projected on the z=norm plane.

    Returns an array of projected 2d points.
    """
    assert Trf.ndim >= 2
    if isinstance(Trf, np.ndarray):
        pts = np.asarray(pts)
    elif isinstance(Trf, torch.Tensor):
        pts = torch.as_tensor(pts, dtype=Trf.dtype)

    # adapt shape if necessary
    output_reshape = pts.shape[:-1]
    ncol = ncol or pts.shape[-1]

    # optimized code
    if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
            Trf.ndim == 3 and pts.ndim == 4):
        d = pts.shape[3]
        if Trf.shape[-1] == d:
            pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
        elif Trf.shape[-1] == d + 1:
            pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
        else:
            raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
    else:
        if Trf.ndim >= 3:
            n = Trf.ndim - 2
            assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
            Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])

            if pts.ndim > Trf.ndim:
                # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
                pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
            elif pts.ndim == 2:
                # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
                pts = pts[:, None, :]

        if pts.shape[-1] + 1 == Trf.shape[-1]:
            Trf = Trf.swapaxes(-1, -2)  # transpose Trf
            pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
        elif pts.shape[-1] == Trf.shape[-1]:
            Trf = Trf.swapaxes(-1, -2)  # transpose Trf
            pts = pts @ Trf
        else:
            pts = Trf @ pts.T
            if pts.ndim >= 2:
                pts = pts.swapaxes(-1, -2)

    if norm:
        pts = pts / pts[..., -1:]  # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
        if norm != 1:
            pts *= norm

    res = pts[..., :ncol].reshape(*output_reshape, ncol)
    return res
