import os
import argparse

import cv2
from ultralytics import YOLO
from numpy import ndarray

from fsrcnn import FSRCNNInference
from config import Config as cfg
from tools import process_calib_json


def parse_args(argv: list[str] = None):
    parser = argparse.ArgumentParser()

    parser.add_argument("left", help="Path of the image from the left stereo camera")
    parser.add_argument("right", help="Path of the image from the right stereo camera")
    parser.add_argument("-o", default="./a.png",
                        help="Path where the output image must be saved")
    parser.add_argument("-v", default="",
                        help="Path to save the keypoint visualization. "
                        + "If not specified, no visualization will be created.")

    return parser.parse_args(argv)


def detect_objs(model: YOLO, imgs: list[ndarray], conf_threshold: float = 0.25):
    results = model(imgs, conf=conf_threshold, verbose=False)
    boxes = []
    for result in results:
        boxes.append([])
        for box in result.boxes:
            boxes[-1].append(tuple(box.xyxy.tolist()[0]))

    assert len(boxes[0]) == 0 or isinstance(boxes[0][0], tuple), \
        "Boxes are not tuples"
    return boxes


def match_keypoints(img_l, img_r) -> tuple:
    """
    Args:
        img_l: Left image
        img_r: Right image

    Returns:
        A tuple (kps_l, kps_r, matches).

        1) kps_l: Keypoints of img_l
        2) kps_r: Keypoints of img_r
        3) matches: Matching keypoints between both images.
    """
    detector = cv2.ORB_create()
    matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    kps_l, des_l = detector.detectAndCompute(img_l, None)
    kps_r, des_r = detector.detectAndCompute(img_r, None)
    matches = matcher.match(des_l, des_r)
    return kps_l, kps_r, matches


def calc_depth(
    pt_l: tuple[float, float],
    pt_r: tuple[float, float],
    focal_len_pix: int,
    baseline: float
) -> float:
    """
    Calculate depth of a point in the world using stereo images.

    Args:
        pt_l: Coordinates of the desired point in the left image.
        pt_r: Coordinates of the desired point in the right image.
        focal_len_pix: Focal length of both the cameras in pixels.
        baseline: Distance between center of both cameras (in any unit).

    Returns:
        Perpendicular distance of the point in the world from the cameras,
        in the same units as `baseline`.
        If `pt_l[0] < pt_r[0]`, then this function will return -1.
    """
    if pt_l[0] == pt_r[0]: return float("inf")
    if pt_l[0] < pt_r[0]:
        return -1
    return (baseline * focal_len_pix) / (pt_l[0] - pt_r[0])


def depth_each_obj(
    boxes_xyxy: list[tuple],
    kps_l: list,
    kps_r: list,
    matches: list,
) -> tuple[list[float | str], list]:
    """
    Calculate depth of each object whose bounding boxes are provided.
    Bounding boxes should be provided in XYXY format (not XYWH).

    Args:
        boxes_xyxy: Bounding box coordinates of each object in the image.
        kps_l: List of keypoints in the left image.
        kps_r: List of keypoints in the right image.
        matches: List of matches performed on `kps_l` and `kps_r` using OpenCV.
            (E.g. Matching performed by BFMatcher).

    Returns:
        A tuple containing:
            1. A depth value for each box provided in `boxes_xyxy`.
                If there is an error in a depth value, one the following strings
                will be added instead of the depth for that particular box:
                - NKD: No keypoint detected.
                - WKM: Wrong keypoint match (b/w left and right image).

            2. A filtered list of matches, one corresponding to each box in
                `boxes_xyxy`, which were used to calculate the depths.
                If there is no keypoint inside one of the boxes, the list will
                contain `None` for that box.
    """
    box_matches = [None] * len(boxes_xyxy)
    for match in matches:
        dist = match.distance
        x, y = kps_l[match.queryIdx].pt
        for i, (x0, y0, x1, y1) in enumerate(boxes_xyxy):
            if x >= x0 and y >= y0 and x <= x1 and y <= y1:
                if box_matches[i] is None or box_matches[i].distance > dist:
                    box_matches[i] = match
                    break

    depths = []
    for match in box_matches:
        if match is None:
            depths.append("NKD")
            continue
        pt_l = kps_l[match.queryIdx].pt
        pt_r = kps_r[match.trainIdx].pt
        depth = calc_depth(pt_l, pt_r, cfg.focal_len_pix, cfg.baseline)
        depths.append(depth if depth != -1 else "WKM")

    return depths, box_matches


def draw_depths(img: ndarray, boxes_xyxy: list[tuple], depths: list[float]) -> ndarray:
    """
    Draw the bounding box along with the depth value for each detected object.

    Args:
        img: An OpenCV image.
        boxes_xyxy: Bounding box coordinates in XYXY format.
        depths: Depth values corresponding to boxes in `boxes_xyxy`.

    Returns:
        The resulting image after drawing.
    """
    color_bg = (171, 0, 0)[::-1] # in RGB
    text_style = {
        "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
        "fontScale": 0.5,
        "thickness": 1,
    }

    text_draw_params = []
    for (x0, y0, x1, y1), depth in zip(boxes_xyxy, depths):
        x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
        if isinstance(depth, str):
            text = depth
        else:
            text = f"{depth:.2f}m"
        img = cv2.rectangle(img, (x0, y0), (x1, y1), color_bg, 2)
        (w, h), _ = cv2.getTextSize(text, **text_style)
        padding = 2
        if y0-h >= 0:
            xt0, yt0 = x0, y0-h - padding
            xt1, yt1 = x0+w, y0
        else:
            xt0, yt0 = x0, y1
            xt1, yt1 = x0+w, y1+h + padding
        img = cv2.rectangle(img, (xt0, yt0), (xt1, yt1), color_bg, cv2.FILLED)
        text_draw_params.append((text, (xt0, yt0+h)))

    for text_params in text_draw_params:
        img = cv2.putText(img, *text_params, color=(255, 255, 255), **text_style)

    return img


def main(argv: list[str] = None) -> None:
    args = parse_args(argv)

    if not os.path.exists(args.left):
        raise ValueError(f"{args.left} doesn't exist")
    if not os.path.exists(args.right):
        raise ValueError(f"{args.right} doesn't exist")
    img_l = cv2.imread(args.left)
    img_r = cv2.imread(args.right)
    img_shape = img_l.shape[-2::-1]

    if img_r.shape[-2::-1] != img_shape:
        raise ValueError("Both left and right images must of the same size.")

    # Rectification and undistortion
    if cfg.calib_json:
        f, b, maps_L, maps_R, _ = process_calib_json(cfg.calib_json, img_shape)
        cfg.focal_len_pix = f
        cfg.baseline = b
        img_l = cv2.remap(img_l, *maps_L, cv2.INTER_CUBIC)
        img_r = cv2.remap(img_r, *maps_R, cv2.INTER_CUBIC)

    # Object detection
    boxes = detect_objs(YOLO(cfg.yolo_model), [img_l], cfg.conf_threshold)[0]

    # Super-resolution and downscaling
    img_l_orig = img_l
    img_r_orig = img_r
    img_l = cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB)
    img_r = cv2.cvtColor(img_r, cv2.COLOR_BGR2RGB)
    sr_model = FSRCNNInference(cfg.fsrcnn_model, cfg.scaling_factor, cfg.device)
    img_l, img_r = sr_model([img_l, img_r])
    img_l = cv2.cvtColor(img_l, cv2.COLOR_RGB2BGR)
    img_r = cv2.cvtColor(img_r, cv2.COLOR_RGB2BGR)
    img_l = cv2.resize(img_l, img_l_orig.shape[1::-1],
                       interpolation=cv2.INTER_CUBIC)
    img_r = cv2.resize(img_r, img_l_orig.shape[1::-1],
                       interpolation=cv2.INTER_CUBIC)

    # Keypoint detection and matching
    kps_l, kps_r, matches = match_keypoints(img_l, img_r)
    if not matches:
        print("No keypoint matches found between both the images.")
        return

    # Depth calculation
    depths, box_matches = depth_each_obj(boxes, kps_l, kps_r, matches)
    if args.v:
        os.makedirs(os.path.dirname(args.v), exist_ok=True)
        box_matches = list(filter(lambda x: x is not None, box_matches))

        match_vis = cv2.drawMatches(
            img_l_orig, kps_l, img_r_orig, kps_r, box_matches,
            None, matchesThickness=2,
            flags=cv2.DRAW_MATCHES_FLAGS_NOT_DRAW_SINGLE_POINTS,
        )
        cv2.imwrite(args.v, match_vis)

    # Generating final output
    img_out = draw_depths(img_l_orig, boxes, depths)
    os.makedirs(os.path.dirname(args.o), exist_ok=True)
    cv2.imwrite(args.o, img_out)


if __name__ == "__main__":
    main()
