import sys
from pathlib import Path

import cv2
import numpy as np

from .utils import image_window, imgs_in_dir
from .corners import BoardConfig, MBCornerDetector
from .metrics import reproj_error
from .types import StrPath
from .logging import log_no_prefix, log_fmt, logger as log


def multi_board_calibrate(
    img_paths: StrPath | list[StrPath],
    board: BoardConfig,
    visualize: bool = False,
    error_func=None,
) -> tuple[bool, dict, dict[str, float]]:
    """
    Calibrate a camera, using either ChArUco or Chessboard images.

    Args:
        img_paths: Path to a directory containing calibration images or a list of
            paths of the calibration images.
        board: A BoardConfig instance.
        visualize: Whether to show the corners on the images after detection.
            When OpenCV's window is open, you can press 'q' to terminate the
            calibration or 'v' to turn off visualization for rest of the images.
            Press any other key to continue.
        error_func: A function that should be called if corners are not detected in
            a particular image. It should take the image path as argument. If None,
            an Exception would be raised instead.

    Returns:
        (ret, camera_params, img_errs)

        1. ret: Boolean indicating whether the calibration was successful.
        2. camera_params: A dict containing the calculated camera parameters.
        3. img_errs: A dict mapping image paths to their reprojection errors.

        `camera_params` has the following structure:
            1. `img_size`: `(width, height)` of the calibration images.
            2. `intr`: Intrinsic matrix.
            3. `dist`: Distortion coefficients.
            4. `imgs`: Dict mapping image paths to their image-specific outputs.
                The dict contains the following keys:
                a. `wc`: World coordinates.
                b. `pc`: Pixel coordinates corresponding to
                        the world coordinates.
                c. `r`: Rotation vector.
                d. `t`: Translation vector.
    """
    corners_size = board.corners_size
    square_side = board.square_side
    corner_detector = MBCornerDetector(board)

    if not isinstance(img_paths, (list, tuple)):
        img_dir = img_paths
        img_paths = imgs_in_dir(img_dir)
        if not img_paths:
            raise FileNotFoundError(f"No images found at {img_dir}")

    elif not img_paths:
        raise ValueError("`img_paths` cannot be an empty")

    img_paths = sorted(img_paths)

    cw, ch = corners_size
    objp = np.zeros((cw * ch, 3), np.float32)
    grid = np.mgrid[0:cw, 0:ch].T.reshape(-1, 2)
    objp[:, :2] = grid * square_side

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.001)

    objpoints = []  # 3D points in real world space
    imgpoints = []  # 2D points for left camera
    used_paths = []
    img_size = None
    for img_path in img_paths:
        img_path = Path(img_path)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        if img is None:
            raise FileNotFoundError(img_path)

        if img_size is None:
            img_size = img.shape[::-1]
        elif img_size != img.shape[::-1]:
            msg = (
                f"All images must be of same shape, but found "
                f"at least two different shapes [{img_size} != {img.shape[::-1]}]"
            )
            raise Exception(msg)

        # Normalizing images
        # findChessboardCorners normalize flag is not working as good as this
        img = cv2.equalizeHist(img)

        ret, corners, btype = corner_detector.detect(img)
        if not ret:
            detect_info = "NOT DETECTED"
        elif btype == "charuco":
            detect_info = "CHARUCO"
        elif btype == "chessb":
            detect_info = "CHESSBOARD"
        else:
            assert False, f"Invalid board type: {btype}"

        log.info(f"{img_path.name}: {detect_info}")

        if visualize:
            vis_img = cv2.drawChessboardCorners(
                np.stack([img] * 3, axis=2), corners_size, corners, ret
            )
            win_name = f"{img_path.name} ({detect_info})"
            keycode = image_window(win_name, vis_img)
            if keycode == ord("q"):
                sys.exit()
            elif keycode == ord("v"):
                visualize = False

        if ret:
            objpoints.append(objp)
            corners = cv2.cornerSubPix(img, corners, (11, 11), (-1, -1), criteria)
            imgpoints.append(corners)
            used_paths.append(str(img_path))
        else:
            if error_func is None:
                error_path = img_path
                detect_info = detect_info.title()
                raise Exception(f"Corners not detected for {error_path}")

            error_func(str(img_path))

    if not imgpoints:
        params = {"img_size": img_size, "intr": None, "dist": None, "imgs": {}}
        return False, params, {}

    objpoints, imgpoints = tuple(objpoints), tuple(imgpoints)
    ret, intr, dist, rvecs, tvecs = cv2.calibrateCamera(
        objpoints, imgpoints, img_size, None, None
    )

    img_params = {}
    img_params_zip = zip(used_paths, objpoints, imgpoints, rvecs, tvecs)
    for p, wc, pc, r, t in img_params_zip:
        img_params[p] = {"wc": wc.copy(), "pc": pc, "r": r, "t": t}
    params = {"img_size": img_size, "intr": intr, "dist": dist, "imgs": img_params}

    rerrs = reproj_error(objpoints, imgpoints, intr, dist, rvecs, tvecs)
    img_errs = {p: err.item() for p, err in zip(used_paths, rerrs)}
    return ret, params, img_errs


def iter_mb_calibrate(
    img_paths: StrPath | list[StrPath],
    board: BoardConfig,
    target_err: float,
    reject_per_iter: int = 1,
    max_imgs_per_iter: int = 100,
    max_iter: int = None,
    visualize: bool = False,
    error_func=None,
) -> tuple[bool, dict, dict[str, float]]:
    """
    Iterative multi-board calibration.

    Iteratively improves the calibration by rejecting the calibration images
    with high reprojection error.

    Args:
        img_paths: Path to a directory containing calibration images or a list of
            paths of the calibration images.
        board: A BoardConfig instance.
        target_err: The calibration process will stop when error falls at or below
            this value.
        reject_per_iter: Number of calibration images with highest error values to
            reject per iteration. Increasing this value will speed up the
            calibration process at the cost of precision of image selection.
        max_imgs_per_iter: Maximum number of images to use for calibration per
            iteration. This limits the memory usage during the calibration process.
        max_iter: Maximum number of iterations after which the function returns.
        visualize: Whether to show the corners on the images after detection.
            When OpenCV's window is open, you can press 'q' to terminate the
            calibration or 'v' to turn off visualization for that specific iteration.
            Press any other key to continue.
        error_func: A function that should be called if corners are not detected in
            a particular image. It should take the image path as argument. If None,
            an Exception would be raised instead.

    Returns:
        (ret, camera_params, img_errs)

        1. ret: Boolean indicating whether the calibration was successful.
        2. camera_params: A dict containing the calculated camera parameters.
        3. img_errs: A dict mapping image paths to their reprojection errors.

        `camera_params` has the following structure:
            1. `img_size`: `(width, height)` of the calibration images.
            2. `intr`: Intrinsic matrix.
            3. `dist`: Distortion coefficients.
            4. `imgs`: Dict mapping image paths to their image-specific outputs.
                The dict contains the following keys:
                a. `wc`: World coordinates.
                b. `pc`: Pixel coordinates corresponding to
                        the world coordinates.
                c. `r`: Rotation vector.
                d. `t`: Translation vector.
    """
    if target_err < 0:
        raise ValueError("`target_err` has to be non-negative")

    if not isinstance(img_paths, list):
        img_dir = Path(img_paths)
        img_paths = imgs_in_dir(img_dir)
        if not img_paths:
            raise FileNotFoundError(f"No images found at {img_dir}")

    img_paths = sorted(img_paths)
    rem_paths = img_paths
    img_paths = rem_paths[:max_imgs_per_iter]
    rem_paths = rem_paths[max_imgs_per_iter:]

    rerr = float("inf")
    max_iter = float("inf") if max_iter is None else max_iter
    img_errs = None
    iter_no = 1
    while rerr > target_err and iter_no <= max_iter:
        if iter_no > 1:
            log_no_prefix()
        iter_prefix = f"Iter {iter_no}, "
        log_fmt.prefix += iter_prefix

        if img_errs is not None:
            sorted_img_errs = sorted(img_errs.items(), key=lambda x: x[1])
            img_paths = [Path(x[0]) for x in sorted_img_errs]
            img_paths = img_paths[:-reject_per_iter]
            if rem_paths:
                add_count = max_imgs_per_iter - len(img_paths)
                img_paths += rem_paths[:add_count]
                rem_paths = rem_paths[add_count:]

        ret, params, img_errs = multi_board_calibrate(
            img_paths, board, visualize, error_func
        )
        if not img_errs and not rem_paths:
            break
        rerr = sum(img_errs.values()) / len(img_errs)
        log.info(f"Reprojection Error: {rerr}")

        assert log_fmt.prefix.endswith(iter_prefix), log_fmt.prefix
        log_fmt.prefix = log_fmt.prefix[: -len(iter_prefix)]
        iter_no += 1

    if rerr > target_err and iter_no > max_iter:
        if log_fmt.prefix:
            log_fmt.prefix += "| "
        log_fmt.prefix = "\n" + log_fmt.prefix

        log.info("Maximum iterations reached. Calibration stopped!")

        if log_fmt.prefix.endswith("| "):
            log_fmt.prefix = log_fmt.prefix[:-2]
        log_fmt.prefix = log_fmt.prefix[1:]

    return ret, params, img_errs
