from functools import partial
from dataclasses import dataclass
from typing import Literal

import cv2
import numpy as np


@dataclass
class BoardConfig:
    """
    Configuration for the calibration board.

    Args:
        corners_size: Number of chessboard corners along (width, height).
        type: Can be one of "chessb", "charuco" or "both".
            1. "chessb": Only perform chessboard calibration.
            2. "charuco": Only perform ChArUco calibration.
            3. "both": On each image, first try ChArUco and then try chessboard
                calibration.
        square_side: Length of the side of the squares of the chessboard.
        marker_side: Side length of the ArUco markers.
            Ignored if `type` is 'chessb'.
        aruco_dict: ArUco dict's opencv enum (for `cv2.aruco.getPredefinedDictionary`).
            Ignored if `type` is 'chessb'.
    """

    corners_size: tuple[int, int]
    type: Literal["chessb", "charuco", "both"] = "both"
    square_side: float = 1.0
    marker_side: float = 0.75
    aruco_dict: int = cv2.aruco.DICT_6X6_250

    def __post_init__(self):
        if self.type not in ("chessb", "charuco", "both"):
            raise ValueError("Invalid board type")
        if self.type == "charuco" and self.marker_side >= self.square_side:
            raise ValueError(
                "For charuco board, marker_side must be smaller than square_side"
            )


class CharucoCornerDetector:
    """
    Detects ChArUco corners in a given image without prior calibration.

    Args:
        corners_size: Number of chessboard corners along (width, height).
        square_side: Square side length for the chessboard.
        marker_side: Side length of the ArUco marker.
        aruco_dict: ArUco dict's opencv enum (for cv2.aruco.getPredefinedDictionary).
    """

    def __init__(
        self,
        corners_size: tuple[int, int],
        square_side: float,
        marker_side: float,
        aruco_dict: int = cv2.aruco.DICT_6X6_250,
    ):
        dictionary = cv2.aruco.getPredefinedDictionary(aruco_dict)
        wc, hc = corners_size
        board = cv2.aruco.CharucoBoard(
            (wc + 1, hc + 1), square_side, marker_side, dictionary
        )
        self.detector = cv2.aruco.CharucoDetector(board)
        self.num_corners = wc * hc

    def detect(self, img: np.ndarray) -> tuple[bool, np.ndarray | None]:
        """
        Detect ChArUco board corners in the given image.

        Args:
            img: An opencv grayscale image.

        Returns:
            (ret, corners)

            1. ret: A boolean indicating whether all the corners were detected.
            2. corners: A numpy array of detected corners.
        """
        corners = self.detector.detectBoard(img)[0]
        ret = corners is not None and corners.shape[0] == self.num_corners
        return ret, corners


class MBCornerDetector:
    """
    Multi-board corner detector to detect corners in both ChArUco and Chessboards.

    Args:
        board: A BoardConfig instance.
    """

    def __init__(self, board: BoardConfig):
        corners_size = board.corners_size
        square_side = board.square_side
        marker_side = board.marker_side
        aruco_dict = board.aruco_dict

        self.charuco_detector = None
        if board.type in ("charuco", "both"):
            self.charuco_detector = CharucoCornerDetector(
                corners_size, square_side, marker_side, aruco_dict
            )

        self.chessb_detector = None
        if board.type in ("chessb", "both"):
            flags = cv2.CALIB_CB_ADAPTIVE_THRESH
            self.chessb_detector = partial(
                cv2.findChessboardCorners, patternSize=corners_size, flags=flags
            )

    def detect(self, img: np.ndarray) -> tuple[bool, np.ndarray, str]:
        """
        Detect calibration board corners in the given image.

        Args:
            img: An opencv grayscale image.

        Returns:
            (ret, corners, board_type)

            1. ret: A boolean indicating whether all the corners were detected.
            2. corners: A numpy array of detected corners.
            3. board_type: Either "charuco" or "chessb", indicating the type of board
                detected.
        """
        ret, corners, board_type = False, np.empty(()), ""
        if self.charuco_detector is not None:
            board_type = "charuco"
            ret, corners = self.charuco_detector.detect(img)
        if not ret and self.chessb_detector is not None:
            board_type = "chessb"
            ret, corners = self.chessb_detector(img)
        return ret, corners, board_type
