import sys
import os
import glob
from pathlib import Path

import cv2
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoImageProcessor
import transformers

from .utils import image_window, imgs_in_dir
from .corners import BoardConfig, MBCornerDetector
from .metrics import stereo_reproj_error
from .types import StrPath
from .logging import log_no_prefix, log_fmt, logger as log


def multi_board_stereo(
    left_img_paths: StrPath | list[StrPath],
    right_img_paths: StrPath | list[StrPath],
    params_l: dict,
    params_r: dict,
    board: BoardConfig,
    visualize: bool = False,
) -> tuple[bool, dict, dict[str, float]]:
    """
    Calibrate a stereo setup with 2 cameras, using either ChArUco or Chessboards.

    If directory paths are given, this function will ignore all those files in the
    left directory which do not have a corresponding file with the same name in the
    right directory.

    If list of paths are given, then the left and right lists should align to
    correspond the images taken at the same time.

    Args:
        left_img_paths: Path to a directory containing left calibration images or
            a list of paths of the left calibration images.
        right_img_paths: Path to a directory containing right calibration images or
            a list of paths of the right calibration images.
        params_l: Dict containing left camera parameters. This is same as the
            second return value of `calib.multi_board_calibrate()`.
        params_r: Dict containing right camera parameters. This is same as the
            second return value of `calib.multi_board_calibrate()`.
        board: A BoardConfig instance.
        visualize: Whether to show the corners on the images after detection.
            When OpenCV's window is open, you can press 'q' to terminate the
            calibration or 'v' to turn off visualization.
            Press any other key to continue.

    Returns:
        (ret, params_s, img_errs)

        1. ret: Boolean indicating whether the calibration was successful.
        2. params_s: A dict containing the calculated stereo parameters.
        3. img_errs: A dict mapping image paths to their reprojection errors.

        `params_s` has the following structure:
            1. `img_size`: `(width, height)` of the calibration images.
            2. `intrL`: A list of lists representing the 3x3 intrinsic camera matrix
                for the left camera.
            3. `intrR`: A list of lists representing the 3x3 intrinsic camera matrix
                for the right camera.
            4. `distL`: A list of lists containing the distortion coefficients for
                the left camera.
            5. `distR`: A list of lists containing the distortion coefficients for
                the right camera.
            6. `R`: A list of lists representing the 3x3 rotation matrix that
                describes the orientation between the cameras.
            7. `T`: A list of lists representing the translation between the cameras.
    """
    if not isinstance(left_img_paths, type(right_img_paths)):
        raise TypeError("`left_img_paths` and `right_img_paths` must be of same type")

    corners_size = board.corners_size
    square_side = board.square_side
    corner_detector = MBCornerDetector(board)

    if not isinstance(left_img_paths, (list, tuple)):
        left_img_dir = Path(left_img_paths)
        left_img_paths = imgs_in_dir(left_img_dir)
        assert isinstance(left_img_paths[0], Path), type(left_img_paths[0])
        if not left_img_paths:
            raise FileNotFoundError(f"No images found at {left_img_dir}")
        left_img_paths.sort()

        right_img_dir = Path(right_img_paths)
        right_img_paths = [right_img_dir / p.name for p in left_img_paths]
    else:
        if len(left_img_paths) != len(right_img_paths):
            raise ValueError(
                "`left_img_paths` and `right_img_paths` should be of same length"
            )

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.001)

    cw, ch = corners_size
    objp = np.zeros((cw * ch, 3), np.float32)
    grid = np.mgrid[0:cw, 0:ch].T.reshape(-1, 2)
    objp[:, :2] = grid * square_side

    objpoints = []  # 3D points in real world space
    imgpointsL = []  # 2D points for left camera
    imgpointsR = []  # 2D points for right camera
    used_left_paths, used_right_paths = [], []

    cv2_log_level = cv2.getLogLevel()
    cv2.setLogLevel(2)
    for left_path, right_path in zip(left_img_paths, right_img_paths):
        imgL = cv2.imread(left_path)
        imgR = cv2.imread(right_path)
        if imgL is None:
            raise Exception(f"Can't read {left_path}")
        if imgR is None:
            continue

        imgL = cv2.cvtColor(imgL, cv2.COLOR_BGR2GRAY)
        imgR = cv2.cvtColor(imgR, cv2.COLOR_BGR2GRAY)
        if imgL.shape != imgR.shape:
            raise Exception(
                f"{left_path.name}: Both left and right images must have same shape"
            )

        # Normalizing images
        # findChessboardCorners normalize flag is not working as good as this
        imgL = cv2.equalizeHist(imgL)
        imgR = cv2.equalizeHist(imgR)

        retL, cornersL, btypeL = corner_detector.detect(imgL)
        assert btypeL in ("charuco", "chessb"), f"Invalid board type: {btypeL}"
        if retL:
            boardL = "CHESSBOARD" if btypeL == "chessb" else "CHARUCO"
        else:
            boardL = "NOT DETECTED"

        retR, cornersR, btypeR = corner_detector.detect(imgR)
        assert btypeR in ("charuco", "chessb"), f"Invalid board type: {btypeR}"
        if retR:
            boardR = "CHESSBOARD" if btypeR == "chessb" else "CHARUCO"
        else:
            boardR = "NOT DETECTED"

        det_info = f"{left_path.name}: {boardL} | {right_path.name}: {boardR}"
        log.info(det_info)

        if visualize:
            draw_L = cv2.drawChessboardCorners(
                np.stack([imgL] * 3, axis=2), corners_size, cornersL, retL
            )
            draw_R = cv2.drawChessboardCorners(
                np.stack([imgR] * 3, axis=2), corners_size, cornersR, retR
            )
            vis_img = np.concat([draw_L, draw_R], axis=1)
            keycode = image_window(det_info, vis_img)
            if keycode == ord("q"):
                sys.exit()
            elif keycode == ord("v"):
                visualize = False

        if retL and retR:
            objpoints.append(objp)
            cornersL = cv2.cornerSubPix(imgL, cornersL, (11, 11), (-1, -1), criteria)
            cornersR = cv2.cornerSubPix(imgR, cornersR, (11, 11), (-1, -1), criteria)
            imgpointsL.append(cornersL)
            imgpointsR.append(cornersR)
            used_left_paths.append(str(left_path))
            used_right_paths.append(str(right_path))

    cv2.setLogLevel(cv2_log_level)

    if not imgpointsL:
        params_s = {
            "img_size": imgL.shape[::-1],
            "intrL": None,
            "intrR": None,
            "distL": None,
            "distR": None,
            "R": None,
            "T": None,
        }
        return False, params_s, {}

    flags = cv2.CALIB_FIX_INTRINSIC  # Fix individual camera parameters
    _, intrL, distL, intrR, distR, R, T, _, _ = cv2.stereoCalibrate(
        objpoints,
        imgpointsL,
        imgpointsR,
        params_l["intr"],
        params_l["dist"],
        params_r["intr"],
        params_r["dist"],
        imgL.shape[::-1],
        criteria=criteria,
        flags=flags,
    )
    params_s = {
        "img_size": imgL.shape[::-1],
        "intrL": intrL,
        "intrR": intrR,
        "distL": distL,
        "distR": distR,
        "R": R,
        "T": T,
    }
    wc, pc_r, rvecs_l, tvecs_l = [], [], [], []
    final_left_paths = []
    iparams_l, iparams_r = params_l["imgs"], params_r["imgs"]
    for lp, rp in zip(used_left_paths, used_right_paths):
        ipl = iparams_l.get(lp)
        ipr = iparams_r.get(rp)
        if ipl is None or ipr is None:
            continue
        wc.append(ipl["wc"])
        pc_r.append(ipr["pc"])
        rvecs_l.append(ipl["r"])
        tvecs_l.append(ipl["t"])
        final_left_paths.append(lp)

    img_errs = {}
    if wc:
        rerrs = stereo_reproj_error(wc, pc_r, params_s, rvecs_l, tvecs_l)
        img_errs = {p: err.item() for p, err in zip(final_left_paths, rerrs)}
    return True, params_s, img_errs


def _drawMatches_on_PIL(
    img_l: Image, img_r: Image, kps_l: np.ndarray, kps_r: np.ndarray
) -> np.ndarray:
    """
    Draw matches on PIL images using arrays of corresponding points.
    """
    cv_kps_l = [cv2.KeyPoint(float(x), float(y), 0) for x, y in kps_l]
    cv_kps_r = [cv2.KeyPoint(float(x), float(y), 0) for x, y in kps_r]
    cv_matches = [cv2.DMatch(i, i, 0) for i in range(len(cv_kps_l))]
    cv_img_l = np.array(img_l)[..., ::-1]
    cv_img_r = np.array(img_r)[..., ::-1]
    vis_img = cv2.drawMatches(
        cv_img_l,
        cv_kps_l,
        cv_img_r,
        cv_kps_r,
        cv_matches,
        None,
        matchesThickness=2,
        flags=cv2.DRAW_MATCHES_FLAGS_NOT_DRAW_SINGLE_POINTS,
    )
    return vis_img


def deep_stereo(
    left_img_dir: StrPath,
    right_img_dir: StrPath,
    params_l: dict,
    params_r: dict,
    hf_model: str = "magic-leap-community/superglue_outdoor",
    device: str = "cpu",
    match_thr: float = 0.95,
    visualize: bool = False,
) -> dict:
    """
    Calibrate a stereo setup with 2 cameras, via keypoint matching using a
    huggingface model.

    This function will ignore all those files in `left_img_dir` which do not have
    a corresponding file with the same name in `right_img_dir`.

    Args:
        left_img_dir: Directory containing chessboard images from the left camera.
        right_img_dir: Directory containing chessboard images from the right camera.
            Names of the corresponding left and right images should be the same
            in both directories.
        params_l: Dict containing left camera parameters. This is same as the
            second return value of `calib.multi_board_calibrate()`.
        params_r: Dict containing right camera parameters. This is same as the
            second return value of `calib.multi_board_calibrate()`.
        hf_model: Value to put in `AutoModel.from_pretrained` method.
        device: Device on which the model will be ran.
        match_thr: Match confidence threshold. Matches with confidence values lower
            than this will not be used for calibration.
        visualize: Whether to show the detected points on the images.
            When OpenCV's window is open, you can press 'q' to terminate the
            calibration or 'v' to turn off visualization.
            Press any other key to continue.

    Returns:
        A dict containing the intrinsic camera parameters and extrinsic stereo
        parameters. The dict's structure will be as follows:

        1. `img_size`: `[width, height]` of the calibration images.
        2. `intrL`: A list of lists representing the 3x3 intrinsic camera matrix
            for the left camera.
        3. `intrR`: A list of lists representing the 3x3 intrinsic camera matrix
            for the right camera.
        4. `distL`: A list of lists containing the distortion coefficients for
            the left camera.
        5. `distR`: A list of lists containing the distortion coefficients for
            the right camera.
        6. `R`: A list of lists representing the 3x3 rotation matrix that
            describes the orientation between the cameras.
        7. `T`: A list of lists representing the translation between the cameras.
    """
    img_exts = (".png", ".jpg", ".jpeg")
    img_exts += tuple(x.upper() for x in img_exts)

    left_img_dir, right_img_dir = str(left_img_dir), str(right_img_dir)
    left_img_paths = glob.glob(os.path.join(left_img_dir, "*"))
    left_img_paths = list(filter(lambda x: x.endswith(img_exts), left_img_paths))
    right_img_dir = Path(right_img_dir)
    assert left_img_paths, f"No images found at {left_img_dir}"

    tr_verbosity = transformers.logging.get_verbosity()
    transformers.logging.set_verbosity_error()
    processor = AutoImageProcessor.from_pretrained(hf_model, use_fast=True)
    transformers.logging.set_verbosity(tr_verbosity)

    model = AutoModel.from_pretrained(hf_model).to(device)
    log.info(f"[ Running on {device.upper()} ]")

    kpL = []
    kpR = []
    img_size = None
    with torch.inference_mode():
        for left_path in left_img_paths:
            left_path = Path(left_path)
            right_path = right_img_dir / left_path.name
            try:
                img_r = Image.open(right_path)
            except FileNotFoundError:
                continue
            img_l = Image.open(left_path)

            log.info(left_path.name)

            if img_l.size != img_r.size:
                raise Exception(
                    f"{left_path.name}: Both left and right images must have same size"
                )
            if img_size is None:
                img_size = img_l.size
            elif img_size != img_l.size:
                msg = (
                    f"All images must be of same shape, but found "
                    f"at least two different shapes [{img_size} != {img_l.size}]"
                )
                raise Exception(msg)

            inputs = processor([img_l, img_r], return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs)
            img_shapes = [[img_l.size[::-1]] * 2]
            matches = processor.post_process_keypoint_matching(
                outputs, img_shapes, threshold=match_thr
            )
            kps_l = matches[0]["keypoints0"].cpu().numpy()
            kps_r = matches[0]["keypoints1"].cpu().numpy()

            if visualize:
                vis_img = _drawMatches_on_PIL(img_l, img_r, kps_l, kps_r)
                keycode = image_window(left_path.name, vis_img)
                if keycode == ord("q"):
                    sys.exit()
                elif keycode == ord("v"):
                    visualize = False

            kpL.append(kps_l)
            kpR.append(kps_r)

    kpL = np.concat(kpL).astype(np.float32)
    kpR = np.concat(kpR).astype(np.float32)
    assert kpL.ndim == 2 and kpL.shape[1] == 2, kpL.shape
    assert kpR.ndim == 2 and kpR.shape[1] == 2, kpR.shape
    assert kpL.shape == kpR.shape, f"{kpL.shape} != {kpR.shape}"

    ret, _, R, T, _ = cv2.recoverPose(
        kpL,
        kpR,
        params_l["intr"],
        params_l["dist"],
        params_r["intr"],
        params_r["dist"],
        threshold=1.0,
    )
    if not ret:
        raise Exception("Camera pose cannot be computed!")

    calib_dict = {
        "img_size": img_size,
        "intrL": params_l["intr"],
        "intrR": params_r["intr"],
        "distL": params_l["dist"],
        "distR": params_r["dist"],
        "R": R,
        "T": T,
    }
    return calib_dict


def _remove_unavailable(left_paths, right_paths, dict_l, dict_r):
    l, r = [], []
    for lp, rp in zip(left_paths, right_paths):
        if str(lp) in dict_l and str(rp) in dict_r:
            l.append(lp)
            r.append(rp)
    return l, r


def iter_mb_stereo(
    left_img_paths: StrPath | list[StrPath],
    right_img_paths: StrPath | list[StrPath],
    params_l: dict,
    params_r: dict,
    board: BoardConfig,
    target_err: float,
    reject_per_iter: int = 1,
    max_imgs_per_iter: int = 100,
    max_iter: int = None,
    visualize: bool = False,
) -> tuple[bool, dict, dict[str, float]]:
    """
    Iterative multi-board stereo calibration.

    Iteratively improves the calibration by rejecting the calibration image pairs
    with high stereo reprojection error.

    Args:
        left_img_paths: Path to a directory containing left calibration images or
            a list of paths of the left calibration images.
        right_img_paths: Path to a directory containing right calibration images or
            a list of paths of the right calibration images.
        params_l: Dict containing left camera parameters. This is same as the
            second return value of `calib.multi_board_calibrate()`.
        params_r: Dict containing right camera parameters. This is same as the
            second return value of `calib.multi_board_calibrate()`.
        board: A BoardConfig instance.
        target_err: The calibration process will stop when error falls at or below
            this value.
        reject_per_iter: Number of calibration images with highest error values to
            reject per iteration. Increasing this value will speed up the
            calibration process at the cost of precision of image selection.
        max_imgs_per_iter: Maximum number of images to use for calibration per
            iteration. This limits the memory usage during the calibration process.
        max_iter: Maximum number of iterations after which the function returns.
        visualize: Whether to show the corners on the images after detection.
            When OpenCV's window is open, you can press 'q' to terminate the
            calibration or 'v' to turn off visualization for that iteration.
            Press any other key to continue.

    Returns:
        (ret, params_s, img_errs)

        1. ret: Boolean indicating whether the calibration was successful.
        2. params_s: A dict containing the calculated stereo parameters.
        3. img_errs: A dict mapping image paths to their reprojection errors.

        `params_s` has the following structure:
            1. `img_size`: `(width, height)` of the calibration images.
            2. `intrL`: A list of lists representing the 3x3 intrinsic camera matrix
                for the left camera.
            3. `intrR`: A list of lists representing the 3x3 intrinsic camera matrix
                for the right camera.
            4. `distL`: A list of lists containing the distortion coefficients for
                the left camera.
            5. `distR`: A list of lists containing the distortion coefficients for
                the right camera.
            6. `R`: A list of lists representing the 3x3 rotation matrix that
                describes the orientation between the cameras.
            7. `T`: A list of lists representing the translation between the cameras.
    """
    if target_err < 0:
        raise ValueError("`target_err` has to be non-negative")

    if not isinstance(left_img_paths, type(right_img_paths)):
        raise TypeError("`left_img_paths` and `right_img_paths` must be of same type")

    if not isinstance(left_img_paths, (list, tuple)):
        left_img_dir = Path(left_img_paths)
        left_img_paths = imgs_in_dir(left_img_dir)
        assert isinstance(left_img_paths[0], Path), type(left_img_paths[0])
        if not left_img_paths:
            raise FileNotFoundError(f"No images found at {left_img_dir}")
        left_img_paths.sort()

        right_img_dir = Path(right_img_paths)
        right_img_paths = [right_img_dir / p.name for p in left_img_paths]
    else:
        if len(left_img_paths) != len(right_img_paths):
            raise ValueError(
                "`left_img_paths` and `right_img_paths` should be of same length"
            )

    left_img_paths, right_img_paths = _remove_unavailable(
        left_img_paths, right_img_paths, params_l["imgs"], params_r["imgs"]
    )
    if not left_img_paths:
        raise Exception(
            "No image pair in left and right image lists is present in "
            + "`params_l['imgs']` and `params_r['imgs']`"
        )
    lr_map = dict(zip(left_img_paths, right_img_paths))
    del right_img_paths

    rem_left_paths = left_img_paths
    left_img_paths = rem_left_paths[:max_imgs_per_iter]
    rem_left_paths = rem_left_paths[max_imgs_per_iter:]

    rerr = float("inf")
    max_iter = float("inf") if max_iter is None else max_iter
    img_errs = None
    iter_no = 1
    while rerr > target_err and iter_no <= max_iter:
        if iter_no > 1:
            log_no_prefix()
        iter_prefix = f"Iter {iter_no}, "
        log_fmt.prefix += iter_prefix

        if img_errs is not None:
            sorted_img_errs = sorted(img_errs.items(), key=lambda x: x[1])
            left_img_paths = [Path(x[0]) for x in sorted_img_errs]
            left_img_paths = left_img_paths[:-reject_per_iter]
            if rem_left_paths:
                add_count = max_imgs_per_iter - len(left_img_paths)
                left_img_paths += rem_left_paths[:add_count]
                rem_left_paths = rem_left_paths[add_count:]

        right_img_paths = [lr_map[p] for p in left_img_paths]
        ret, params, img_errs = multi_board_stereo(
            left_img_paths, right_img_paths, params_l, params_r, board, visualize
        )
        if not img_errs and not rem_left_paths:
            break
        rerr = sum(img_errs.values()) / len(img_errs)
        log.info(f"Reprojection Error: {rerr}")

        assert log_fmt.prefix.endswith(iter_prefix), log_fmt.prefix
        log_fmt.prefix = log_fmt.prefix[: -len(iter_prefix)]
        iter_no += 1

    if rerr > target_err and iter_no > max_iter:
        if log_fmt.prefix:
            log_fmt.prefix += "| "
        log_fmt.prefix = "\n" + log_fmt.prefix

        log.info("Maximum iterations reached. Calibration stopped!")

        if log_fmt.prefix.endswith("| "):
            log_fmt.prefix = log_fmt.prefix[:-2]
        log_fmt.prefix = log_fmt.prefix[1:]

    return ret, params, img_errs
