import os
import argparse

import cv2
import numpy as np
import torch
from PIL import Image
from numpy import ndarray
from ultralytics import YOLO
import transformers
from transformers import AutoModel, AutoImageProcessor

from config import Config as cfg
from tools import process_calib_json


def parse_args(argv: list[str] = None):
    parser = argparse.ArgumentParser()

    parser.add_argument("left", help="Path of the image from the left stereo camera")
    parser.add_argument("right", help="Path of the image from the right stereo camera")
    parser.add_argument(
        "-o", default="./a.png", help="Path where the output image must be saved"
    )
    parser.add_argument(
        "-v",
        default="",
        help="Path to save the keypoint visualization. "
        + "If not specified, no visualization will be created.",
    )

    return parser.parse_args(argv)


def detect_objs(model: YOLO, imgs: list[ndarray], conf_threshold: float = 0.25):
    results = model(imgs, conf=conf_threshold, verbose=False)
    boxes = []
    for result in results:
        boxes.append([])
        for box in result.boxes:
            boxes[-1].append(tuple(box.xyxy.tolist()[0]))

    assert len(boxes[0]) == 0 or isinstance(boxes[0][0], tuple), "Boxes are not tuples"
    return boxes


def calc_depth(disparity: float, focal_len_pix: int, baseline: float) -> float:
    """
    Calculate depth from the disparity value.

    Args:
        disparity: The disparity value.
        focal_len_pix: Focal length of both the cameras in pixels.
        baseline: Distance between center of both cameras (in any unit).

    Returns:
        Perpendicular distance of the point in the world from the cameras,
        in the same units as `baseline`.
        If `pt_l[0] < pt_r[0]`, then this function will return -1.
    """
    if disparity == 0:
        return float("inf")
    if disparity < 0:
        raise ValueError(f"Disparity ({disparity}) can't be negative")
    return (baseline * focal_len_pix) / disparity


class KPDetectMatcher:
    def __init__(self, name: str, device: str = "cpu"):
        self.name = name
        if self.name == "brisk":
            self.kp_detector = cv2.BRISK_create()
            self.kp_matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        else:
            self.kp_detector = AutoModel.from_pretrained(name).to(device)
            tr_verbosity = transformers.logging.get_verbosity()
            transformers.logging.set_verbosity_error()
            self.img_processor = AutoImageProcessor.from_pretrained(name, use_fast=True)
            transformers.logging.set_verbosity(tr_verbosity)
            self.device = device

    def detect_match(
        self,
        img_l: ndarray,
        img_r: ndarray,
        boxes_xyxy: list[tuple] = None,
    ) -> tuple[list, list, list[list] | list]:
        """
        Performs keypoint detection and matching on `img_l` and `img_r`.

        Args:
            img_l: Left camera BGR image.
            img_r: Right camera BGR image.
            boxes_xyxy: List of bounding box coordinates in `img_l` inside which
                the keypoint detection and matching must be performed.
                If `None`, it will be performed on the whole image.

        Returns:
            1. List of keypoints detected in `img_l`.

            2. List of keypoints detected in `img_r`.

            3. List of list of keypoint matches, one for each box in `boxes_xyxy`.
                If `boxes_xyxy` is `None`, then just the list of matches for the
                whole image.
        """
        if self.name == "brisk":
            return self._brisk_detect_match(img_l, img_r, boxes_xyxy)
        else:
            return self._deep_detect_match(img_l, img_r, boxes_xyxy)

    def _brisk_detect_match(
        self,
        img_l: ndarray,
        img_r: ndarray,
        boxes_xyxy: list[tuple] = None,
    ) -> tuple[list, list, list[list] | list]:
        """
        Keypoint detection using BRISK and matching using efficient object-wise
        brute-force algorithm.
        """
        mask_l, mask_r = None, None
        if boxes_xyxy is not None:
            boxes_xyxy = [[round(x) for x in coords] for coords in boxes_xyxy]
            mask_l, mask_r = [np.zeros(x.shape[:-1], np.uint8) for x in (img_l, img_r)]
            for i, (x0, y0, x1, y1) in enumerate(boxes_xyxy):
                mask_l[y0:y1, x0:x1] = 255  # 255 is easier to visualize
                mask_r[y0:y1, 0:x1] = 255  # Accounting for the left shift

        kps_l, des_l = self.kp_detector.detectAndCompute(img_l, mask_l)
        kps_r, des_r = self.kp_detector.detectAndCompute(img_r, mask_r)
        if des_r is None:
            des_l = None
        if boxes_xyxy is None:
            matches = self.kp_matcher.match(des_l, des_r)
            return kps_l, kps_r, matches

        matches = [None] * len(boxes_xyxy)
        for i, (x0, y0, x1, y1) in enumerate(boxes_xyxy):
            box_dl, idxs_l = [], {}
            if des_l is None:
                des_l = []
            for j, (kp, des) in enumerate(zip(kps_l, des_l)):
                x, y = kp.pt
                if x >= x0 and x < x1 and y >= y0 and y < y1:
                    idxs_l[len(box_dl)] = j
                    box_dl.append(des)

            box_dr, idxs_r = [], {}
            if des_r is None:
                des_r = []
            for j, (kp, des) in enumerate(zip(kps_r, des_r)):
                x, y = kp.pt
                if x < x1 and y >= y0 and y < y1:
                    idxs_r[len(box_dr)] = j
                    box_dr.append(des)

            box_dl = np.stack(box_dl) if box_dl and box_dr else None
            box_dr = np.stack(box_dr) if box_dr else None
            box_matches = self.kp_matcher.match(box_dl, box_dr)
            for match in box_matches:
                match.queryIdx = idxs_l[match.queryIdx]
                match.trainIdx = idxs_r[match.trainIdx]
            matches[i] = box_matches

        return kps_l, kps_r, matches

    def _deep_detect_match(
        self,
        img_l: ndarray,
        img_r: ndarray,
        boxes_xyxy: list[tuple] = None,
    ) -> tuple[list, list, list[list] | list]:
        """
        Keypoint detection and matching using a DL model.
        """
        img_l = cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB)
        img_l = Image.fromarray(img_l)
        img_r = cv2.cvtColor(img_r, cv2.COLOR_BGR2RGB)
        img_r = Image.fromarray(img_r)

        with torch.inference_mode():
            inputs = self.img_processor([img_l, img_r], return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            outputs = self.kp_detector(**inputs)
            img_shapes = [[img_l.size[::-1]] * 2]
            matches = self.img_processor.post_process_keypoint_matching(
                outputs, img_shapes, threshold=0.95
            )
            kps_l = matches[0]["keypoints0"].cpu().numpy()
            kps_r = matches[0]["keypoints1"].cpu().numpy()

        kps_l = [cv2.KeyPoint(float(x), float(y), 0) for x, y in kps_l]
        kps_r = [cv2.KeyPoint(float(x), float(y), 0) for x, y in kps_r]
        matches = [cv2.DMatch(i, i, 0) for i in range(len(kps_l))]

        if boxes_xyxy is None:
            return kps_l, kps_r, matches

        box_matches = [[] for _ in boxes_xyxy]
        for i, (x0, y0, x1, y1) in enumerate(boxes_xyxy):
            for j, kp in enumerate(kps_l):
                x, y = kp.pt
                if x >= x0 and x < x1 and y >= y0 and y < y1:
                    box_matches[i].append(matches[j])

        return kps_l, kps_r, box_matches


kp_detector = KPDetectMatcher(cfg.kp_detector_name, cfg.device)


def contract_boxes(boxes_xyxy: list[tuple], ratio: float) -> list[tuple]:
    if ratio == 1:
        return boxes_xyxy
    boxes_xyxy = boxes_xyxy.copy()
    for i, (x1, y1, x2, y2) in enumerate(boxes_xyxy):
        w = x2 - x1
        h = y2 - y1
        dw = (w - w * ratio) / 2
        dh = (h - h * ratio) / 2
        boxes_xyxy[i] = (x1 + dw, y1 + dh, x2 - dw, y2 - dh)
    return boxes_xyxy


def find_outliers(arr: np.ndarray, num_iqr: float = 1.5) -> np.ndarray:
    """
    Uses Inter-Quartile Range (IQR) to detect outliers in `arr`.
    Any value lower than Q1 - IQR * num_iqr or higher than Q3 + IQR * num_iqr
    is flagged as an outlier.

    Args:
        arr: A 1d-array containing the data.
        num_iqr: Factor which is multiplied by IQR before outlier flagging.

    Returns:
        An outlier mask array. The outliers' positions are set to 1 and other
        positions are set to 0.
    """
    if arr.ndim != 1:
        raise ValueError("`arr.ndim` should exactly be 1")
    q1 = np.percentile(arr, 25)
    q3 = np.percentile(arr, 75)
    iqr = q3 - q1
    mask = np.zeros_like(arr, np.uint8)
    mask[arr < q1 - num_iqr * iqr] = 1
    mask[arr > q3 + num_iqr * iqr] = 1
    return mask


def disparity_each_obj(
    img_l: ndarray,
    img_r: ndarray,
    boxes_xyxy: list[tuple],
    match_dist_thr: float = 50.0,
    kp_reg_factor: float = 1.0,
) -> tuple[list[float | str], list[list]]:
    """
    Calculate disparity value of each object whose bounding boxes are provided.
    Bounding boxes should be provided in XYXY format (not XYWH).

    Args:
        img_l: Left camera image.
        img_r: Right camera image.
        boxes_xyxy: Bounding box coordinates of each object in the image.
        matches_per_box: No. of matches to consider for disparity averaging.
        match_dist_thr: Only the keypoint pairs with match distance at most this
            value will be selected.
        kp_reg_factor: The factor by which the each of `boxes_xyxy` should be
            resized before performing keypoint detection on objects.

    Returns:
        A tuple containing:
            1. A disparity value for each box provided in `boxes_xyxy`.
                If there is an error during the calculation, one the following
                strings will be present instead of the disparity for that box:
                - NKD: No keypoint detected.
                - NEG: Negative average disparity.

            2. List of keypoints detected in `img_l`.

            3. List of keypoints detected in `img_r`.

            4. Filtered lists of matches, that were used to calculate the
                disparity. Each list corresponds to one box in `boxes_xyxy`.
    """
    boxes_xyxy = contract_boxes(boxes_xyxy, kp_reg_factor)
    kps_l, kps_r, matches = kp_detector.detect_match(img_l, img_r, boxes_xyxy)
    avg_disps = [None] * len(matches)
    for i, box_matches in enumerate(matches):
        match_filter = lambda x: x.distance <= match_dist_thr
        box_matches = list(filter(match_filter, box_matches))
        matches[i] = box_matches
        if not box_matches:
            avg_disps[i] = "NKD"
            continue

        disps = []
        for match in box_matches:
            lx = kps_l[match.queryIdx].pt[0]
            rx = kps_r[match.trainIdx].pt[0]
            disps.append(lx - rx)
        disps = np.array(disps)
        outlier_mask = find_outliers(disps)
        disps = disps[outlier_mask == 0]

        mean_disp = disps.mean().item()
        avg_disps[i] = mean_disp if mean_disp >= 0 else "NEG"

    return avg_disps, kps_l, kps_r, matches


def draw_depths(img: ndarray, boxes_xyxy: list[tuple], depths: list[float]) -> ndarray:
    """
    Draw the bounding box along with the depth value for each detected object.

    Args:
        img: An OpenCV image.
        boxes_xyxy: Bounding box coordinates in XYXY format.
        depths: Depth values corresponding to boxes in `boxes_xyxy`.

    Returns:
        The resulting image after drawing.
    """
    color_bg = (171, 0, 0)[::-1]  # in RGB
    text_style = {
        "fontFace": cv2.FONT_HERSHEY_SIMPLEX,
        "fontScale": 0.5,
        "thickness": 1,
    }

    text_draw_params = []
    for (x0, y0, x1, y1), depth in zip(boxes_xyxy, depths):
        x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
        if isinstance(depth, str):
            text = depth
        else:
            text = f"{depth:.2f}m"
        img = cv2.rectangle(img, (x0, y0), (x1, y1), color_bg, 2)
        (w, h), _ = cv2.getTextSize(text, **text_style)
        padding = 2
        if y0 - h >= 0:
            xt0, yt0 = x0, y0 - h - padding
            xt1, yt1 = x0 + w, y0
        else:
            xt0, yt0 = x0, y1
            xt1, yt1 = x0 + w, y1 + h + padding
        img = cv2.rectangle(img, (xt0, yt0), (xt1, yt1), color_bg, cv2.FILLED)
        text_draw_params.append((text, (xt0, yt0 + h)))

    for text_params in text_draw_params:
        img = cv2.putText(img, *text_params, color=(255, 255, 255), **text_style)

    return img


def main(argv: list[str] = None) -> None:
    args = parse_args(argv)

    if not os.path.exists(args.left):
        raise ValueError(f"{args.left} doesn't exist")
    if not os.path.exists(args.right):
        raise ValueError(f"{args.right} doesn't exist")
    img_l = cv2.imread(args.left)
    img_r = cv2.imread(args.right)
    img_shape = img_l.shape[-2::-1]

    if img_r.shape[-2::-1] != img_shape:
        raise ValueError("Both left and right images must of the same size.")

    # Rectification and undistortion
    if cfg.calib_json:
        f, b, maps_L, maps_R, _ = process_calib_json(cfg.calib_json, img_shape)
        cfg.focal_len_pix = f
        cfg.baseline = b
        img_l = cv2.remap(img_l, *maps_L, cv2.INTER_CUBIC)
        img_r = cv2.remap(img_r, *maps_R, cv2.INTER_CUBIC)

    # Object detection
    boxes = detect_objs(YOLO(cfg.yolo_model), [img_l], cfg.conf_threshold)[0]

    # Keypoint detection and matching
    disps, kps_l, kps_r, matches = disparity_each_obj(
        img_l, img_r, boxes, cfg.match_dist_thr, cfg.kp_region_factor
    )

    # Depth calculation
    depths = [None] * len(disps)
    for i, disp in enumerate(disps):
        if isinstance(disp, str):
            depths[i] = disp
        else:
            depths[i] = calc_depth(disp, cfg.focal_len_pix, cfg.baseline)

    if args.v:
        os.makedirs(os.path.dirname(args.v), exist_ok=True)
        match_vis = cv2.drawMatches(
            img_l,
            kps_l,
            img_r,
            kps_r,
            sum(matches, []),
            None,
            matchesThickness=2,
            flags=cv2.DRAW_MATCHES_FLAGS_NOT_DRAW_SINGLE_POINTS,
        )
        cv2.imwrite(args.v, match_vis)

    # Generating final output
    img_out = draw_depths(img_l, boxes, depths)
    os.makedirs(os.path.dirname(args.o), exist_ok=True)
    cv2.imwrite(args.o, img_out)


if __name__ == "__main__":
    main()
