import numpy as np
import cv2


def reproj_error(
    wcoords: tuple[np.ndarray],
    pcoords: tuple[np.ndarray],
    cam_mat: np.ndarray,
    dist_coeff: np.ndarray,
    rvecs: tuple[np.ndarray],
    tvecs: tuple[np.ndarray],
) -> np.ndarray:
    """
    Calculates reprojection error for each image using the given camera parameters.

    Args:
        wcoords: Set of points in world coordinates, one for each image.
        pcoords: Pixel coordinates corresponding to `wcoords`.
        cam_mat: Camera intrinsic matrix.
        dist_coeff: Distortion coefficients of the camera.
        rvecs: Rotation vectors, one for each calibration image.
        tvecs: Translation vectors, one for each calibration image.

    Returns:
        An array of errors, one for each set of points in `wcoords`.
    """
    errors = []
    for wc, pc, r, t in zip(wcoords, pcoords, rvecs, tvecs):
        proj_pc, _ = cv2.projectPoints(wc, r, t, cam_mat, dist_coeff)
        residuals = np.sqrt(((pc - proj_pc) ** 2).sum(-1)).squeeze()
        error = np.mean(residuals)
        errors.append(error)
    return np.stack(errors)


def stereo_reproj_error(
    wcoords: tuple[np.ndarray],
    pcoords_r: tuple[np.ndarray],
    params_s: dict,
    rvecs_l: tuple[np.ndarray],
    tvecs_l: tuple[np.ndarray],
) -> np.ndarray:
    """
    Calculates stereo reprojection error for a pair of calibrated cameras.

    Args:
        wcoords: Set of points in world coordinates, one for each image pair.
        pcoords_r: Right camera pixel coordinates corresponding to `wcoords`.
        params_s: Dict of all stereo parameters. This is same as the dict
            returned by `calib.multi_board_stereo()`.
        rvecs_l: Left camera rotation vectors, one for each image pair.
        tvecs_l: Left camera translation vectors, one for each image pair.

    Returns:
        An array of errors, one for each set of points in `wcoords`.
    """
    errors = []
    k_r, dist_r = params_s["intrR"], params_s["distR"]
    R, T = params_s["R"], params_s["T"]
    for wc, pc, r, t in zip(wcoords, pcoords_r, rvecs_l, tvecs_l):
        r = cv2.Rodrigues(r)[0]
        t = t.squeeze()
        cc_l = wc @ r.T + t
        proj_pc, _ = cv2.projectPoints(cc_l, R, T, k_r, dist_r)
        residuals = np.sqrt(((pc - proj_pc) ** 2).sum(-1)).squeeze()
        error = np.mean(residuals)
        errors.append(error)
    return np.stack(errors)


def validation_report(
    R: np.ndarray,
    T: np.ndarray,
    rerr_l: float | None = None,
    rerr_r: float | None = None,
    rerr_s: float | None = None,
) -> str:
    """
    Generates a validation report for the stereo parameters R and T.

    Args:
        R: Stereo Rotation matrix.
        T: Stereo Translation vector.
        rerr_l: Left camera reprojection error. If None, not included in the report.
        rerr_r: Right camera reprojection error. If None, not included in the report.
        rerr_s: Stereo camera reprojection error. If None, not included in the report.

    Returns:
        The validation report text.
    """
    R, T = R.squeeze(), T.squeeze()

    out = ""
    if rerr_l is not None:
        out += f"Left Camera Reprojection Error:  {rerr_l}\n"
    if rerr_r is not None:
        out += f"Right Camera Reprojection Error: {rerr_r}\n"
    if rerr_s is not None:
        out += f"Stereo Reprojection Error:       {rerr_s}\n"

    if out:
        out += "\n"

    E = np.cross(np.eye(3), T) @ R
    out += f"Essential Matrix, E:\n{E}\n\n"

    out += f"Singular Values of E: {np.linalg.svd(E)[1]}\n"
    out += f"det(E):  {np.linalg.det(E)}\n"
    out += f"norm(E): {np.linalg.norm(E)}\n"
    out += f"rank(E): {np.linalg.matrix_rank(E)}\n\n"

    constraint_val = 2 * (E @ E.T @ E) - np.linalg.trace(E @ E.T) * E
    out += f"2E(E.T)E - tr(E(E.T))E:\n{constraint_val}"
    return out
