import os
import time
import argparse
from collections import defaultdict

import torch
import numpy as np
import cv2
from ultralytics import YOLO

from config import Config as cfg
from tools import process_calib_json
from image_inf import disparity_each_obj, calc_depth, draw_depths


def parse_args(argv: list[str] = None):
    parser = argparse.ArgumentParser()

    parser.add_argument("left_vid", help="Video file from the left camera")
    parser.add_argument("right_vid", help="Video file from the right camera")
    parser.add_argument(
        "-o", default="./a.mp4", help="Output video file path (default='a.mp4')"
    )
    parser.add_argument(
        "-s",
        "--speed",
        type=float,
        default=1.0,
        help="Speed multiplier for output video",
    )
    parser.add_argument(
        "-f",
        "--frame-diff",
        type=int,
        default=0,
        help=(
            "Number of frames by which the right video is ahead of left video. "
            "Negative values signify left video is ahead. (default=0)"
        ),
    )
    return parser.parse_args(argv)


class VideoDispEstimator:
    """
    Used to calculate disparity of the objects in a video stream.

    Args:
        obj_detector: An ultralytics YOLO instance.
        conf: Confidence threshold to pass to `track` method of `obj_detector`.
        expire_in: Number of frames to wait before deleting an
            object, which is not getting detected, from the state.
        sm_factor: Smoothing factor for the EMA of the disparities.
    """

    def __init__(
        self,
        obj_detector: YOLO,
        conf: float = 0.25,
        expire_in: int = 300,
        sm_factor: float = 0.3,
    ):
        if not 0 <= sm_factor <= 1:
            raise ValueError("`sm_factor` should be in [0, 1] range")
        if expire_in < 1:
            raise ValueError("`expire_in` should be greater than 0")

        self.obj_detector = obj_detector
        self.conf = conf
        self.expire_in = expire_in
        self.sf = sm_factor
        self.state = {}

    def update(
        self,
        frame_l: np.ndarray,
        frame_r: np.ndarray,
        frame_no: int,
    ) -> np.ndarray:
        """
        Update the state of the estimator.
        Also returns the detections and corresponding disparities for the
        current frame.

        Args:
            frame_l: Left camera BGR frame.
            frame_r: Right camera BGR frame.
            frame_no: The current frame no. in the video.

        Returns:
            A tuple containing:
                1. YOLO `Boxes` object that contains the object detections
                    for `img_l`.

                2. A disparity value corresponding to each detected box.
                    If there is an error during the calculation, one the following
                    strings will be present instead of the disparity for that box:
                    - NKD: No keypoint detected.
                    - NEG: Negative average disparity.
        """
        boxes = self.obj_detector.track(
            frame_l,
            conf=self.conf,
            persist=True,
            verbose=False,
        )[0].boxes

        ids = defaultdict(lambda: None)
        for i, obj_id in enumerate(boxes.id if boxes.is_track else []):
            obj_id = int(obj_id.cpu().item())
            ids[obj_id] = i
            if obj_id not in self.state:
                self.state[obj_id] = {"disparity": "NKD", "frame": frame_no}

        # Keypoint detection and matching
        if boxes.is_track:
            disps = disparity_each_obj(
                frame_l,
                frame_r,
                boxes.xyxy.tolist(),
                cfg.match_dist_thr,
                cfg.kp_region_factor,
            )[0]
        else:
            disps = []

        # Syncing the disparities with the state
        for obj_id, state in list(self.state.items()):
            prev_disp = state["disparity"]
            last_frame = state["frame"]

            if last_frame + self.expire_in <= frame_no:
                del self.state[obj_id]
                continue

            i = ids[obj_id]
            if i is None:
                continue
            disp = disps[i]
            if isinstance(prev_disp, str):
                self.state[obj_id] = {"disparity": disp, "frame": frame_no}
                continue

            if isinstance(disp, str):
                disps[i] = prev_disp
                continue
            assert isinstance(disp, float), type(disp)
            assert isinstance(prev_disp, float), type(prev_disp)

            disp_sm = self.sf * disp + (1 - self.sf) * prev_disp
            disps[i] = disp_sm
            self.state[obj_id] = {"disparity": disp_sm, "frame": frame_no}

        return boxes, disps

    def track(
        self,
        frame_l: np.ndarray,
    ) -> np.ndarray:
        """
        Returns the detections and corresponding disparities.
        No disparity computation is done. Only previous disparities are returned.
        The disparity values in the state of the estimator are also not
        modified in any way.

        If an object with a new id is detected, it is ignored and not returned.

        Args:
            frame_l: Left camera BGR frame.

        Returns:
            A tuple containing:
                1. YOLO `Boxes` object that contains the object detections
                    for `img_l`.

                2. A disparity value corresponding to each detected box.
                    If there was an error during the calculation, one the following
                    strings will be present instead of the disparity for that box:
                    - NKD: No keypoint detected.
                    - NEG: Negative average disparity.
        """
        boxes = self.obj_detector.track(
            frame_l,
            conf=self.conf,
            persist=True,
            verbose=False,
        )[0].boxes
        if not boxes.is_track:
            shape, device = boxes.data.shape, boxes.data.device
            boxes.data = torch.empty(0, shape[1], device=device)
            return boxes, []

        new_box_data = []
        disps = []
        for box in boxes:
            obj_id = int(box.id[0].cpu().item())
            if obj_id not in self.state:
                continue
            assert box.data.shape[0] == 1, box.data.shape
            new_box_data.append(box.data)
            disps.append(self.state[obj_id]["disparity"])

        if new_box_data:
            boxes.data = torch.cat(new_box_data)
        else:
            shape, device = boxes.data.shape, boxes.data.device
            boxes.data = torch.empty(0, shape[1], device=device)
        return boxes, disps


def main(argv: list[str] = None) -> None:
    args = parse_args(argv)
    if args.speed < 1:
        raise ValueError("Speed multiplier cannot be smaller than 1")

    if not os.path.exists(args.left_vid):
        raise FileNotFoundError(f"{args.left_vid} doesn't exist!")
    if not os.path.exists(args.right_vid):
        raise FileNotFoundError(f"{args.right_vid} doesn't exist!")

    vid_l = cv2.VideoCapture(args.left_vid)
    if not vid_l.isOpened():
        raise RuntimeError(f"Failure to open {args.left_vid}")

    vid_r = cv2.VideoCapture(args.right_vid)
    if not vid_r.isOpened():
        raise RuntimeError(f"Failure to open {args.right_vid}")

    w = int(vid_l.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(vid_l.get(cv2.CAP_PROP_FRAME_HEIGHT))
    w_r = int(vid_r.get(cv2.CAP_PROP_FRAME_WIDTH))
    h_r = int(vid_r.get(cv2.CAP_PROP_FRAME_HEIGHT))
    if w != w_r or h != h_r:
        raise ValueError("left and right videos must have equal frame size")

    frames_l = int(vid_l.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_r = int(vid_r.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = vid_l.get(cv2.CAP_PROP_FPS)

    if args.frame_diff > 0:
        vid_l.set(cv2.CAP_PROP_POS_FRAMES, args.frame_diff)
        frames_l -= args.frame_diff
    elif args.frame_diff < 0:
        args.frame_diff = -args.frame_diff
        vid_r.set(cv2.CAP_PROP_POS_FRAMES, args.frame_diff)
        frames_r -= args.frame_diff

    if cfg.calib_json:
        f, b, maps_L, maps_R, _ = process_calib_json(cfg.calib_json, (w, h))
        cfg.focal_len_pix = f
        cfg.baseline = b

    os.makedirs(os.path.dirname(args.o), exist_ok=True)
    vid_writer = cv2.VideoWriter(args.o, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))

    yolo = YOLO(cfg.yolo_model)
    disp_estimator = VideoDispEstimator(yolo, cfg.conf_threshold, fps * 10)

    frame_count = 0
    frames_written = 0
    frames_total = min(frames_l, frames_r)
    start_time = time.time()
    avg_fps = 0
    print("Generating output video...")

    while True:
        if frame_count % 30 == 0:
            avg_fps = frame_count / (time.time() - start_time + 1e-10)
        percent_done = (frame_count / frames_total) * 100
        print(f"{percent_done:>5.1f}% completed | Avg. FPS: {avg_fps:<9.2f}", end="\r")

        ret, frame_l = vid_l.read()
        if not ret:
            break
        ret, frame_r = vid_r.read()
        if not ret:
            break

        if frames_written > frame_count / args.speed:
            frame_count += 1
            continue

        if cfg.calib_json:
            frame_l = cv2.remap(frame_l, *maps_L, cv2.INTER_CUBIC)
            frame_r = cv2.remap(frame_r, *maps_R, cv2.INTER_CUBIC)

        if frame_count % 10 == 0:
            boxes, disps = disp_estimator.update(frame_l, frame_r, frame_count + 1)
        else:
            boxes, disps = disp_estimator.track(frame_l)

        f, b = cfg.focal_len_pix, cfg.baseline
        depths = [d if isinstance(d, str) else calc_depth(d, f, b) for d in disps]
        out_frame = draw_depths(frame_l, boxes.xyxy.tolist(), depths)
        vid_writer.write(out_frame)

        frame_count += 1
        frames_written += 1

    vid_l.release()
    vid_r.release()
    vid_writer.release()
    print()


if __name__ == "__main__":
    main()
