import os
import time
import argparse
from functools import partial

import numpy as np
import cv2
from ultralytics import YOLO

from image_inf import detect_objs, match_keypoints, depth_each_obj, draw_depths
from fsrcnn import FSRCNNInference
from config import Config as cfg
from tools import process_calib_json


def parse_args(argv: list[str] = None):
    parser = argparse.ArgumentParser()

    parser.add_argument("left_vid", help="Video file from the left camera")
    parser.add_argument("right_vid", help="Video file from the right camera")
    parser.add_argument("-o", default="./a.mp4",
                        help="Output video file path (default='a.mp4')")
    parser.add_argument("-s", "--speed", type=float, default=1.,
                        help="Speed multiplier for output video")

    return parser.parse_args(argv)


def inference_frame(
    img_l: np.ndarray,
    img_r: np.ndarray,
    yolo: YOLO,
    sr_model: FSRCNNInference,
) -> np.ndarray:
    """
    Run depth estimation on a single image.

    Args:
        img_l: Left camera BGR image.
        img_r: Right camera BGR image.
        yolo: A trained ultralytics YOLO model.
        sr_model: A super-resolution model.

    Returns:
        `img_l`, with bounding boxes and depths drawn on it.
    """
    boxes = detect_objs(
        partial(yolo.track, persist=True),
        [img_l], cfg.conf_threshold
    )[0]

    # Super-resolution and downscaling
    img_l_orig = img_l
    img_l = cv2.cvtColor(img_l, cv2.COLOR_BGR2RGB)
    img_r = cv2.cvtColor(img_r, cv2.COLOR_BGR2RGB)
    img_l, img_r = sr_model([img_l, img_r])
    img_l = cv2.cvtColor(img_l, cv2.COLOR_RGB2BGR)
    img_r = cv2.cvtColor(img_r, cv2.COLOR_RGB2BGR)
    img_l = cv2.resize(img_l, img_l_orig.shape[1::-1],
                       interpolation=cv2.INTER_CUBIC)
    img_r = cv2.resize(img_r, img_l_orig.shape[1::-1],
                       interpolation=cv2.INTER_CUBIC)

    # Keypoint detection and matching
    kps_l, kps_r, matches = match_keypoints(img_l, img_r)
    if not matches:
        print("No keypoint matches found between both the images.")
        return

    # Depth calculation
    depths, _ = depth_each_obj(boxes, kps_l, kps_r, matches)
    img_out = draw_depths(img_l_orig, boxes, depths)
    return img_out


def main(argv: list[str] = None) -> None:
    args = parse_args(argv)
    if args.speed < 1:
        raise ValueError("Speed multiplier cannot be smaller than 1")

    if not os.path.exists(args.left_vid):
        raise FileNotFoundError(f"{args.left_vid} doesn't exist!")
    if not os.path.exists(args.right_vid):
        raise FileNotFoundError(f"{args.right_vid} doesn't exist!")
    vid_l = cv2.VideoCapture(args.left_vid)
    vid_r = cv2.VideoCapture(args.right_vid)
    w = int(vid_l.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(vid_l.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = vid_l.get(cv2.CAP_PROP_FPS)

    if cfg.calib_json:
        f, b, maps_L, maps_R, _ = process_calib_json(cfg.calib_json, (w, h))
        cfg.focal_len_pix = f
        cfg.baseline = b

    os.makedirs(os.path.dirname(args.o), exist_ok=True)
    vid_writer = cv2.VideoWriter(args.o, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))

    yolo = YOLO(cfg.yolo_model)
    fsrcnn = FSRCNNInference(cfg.fsrcnn_model, cfg.scaling_factor, cfg.device)

    frame_count = 0
    frames_written = 0
    frames_total = min(int(vid_l.get(cv2.CAP_PROP_FRAME_COUNT)),
                       int(vid_r.get(cv2.CAP_PROP_FRAME_COUNT)))
    start_time = time.time()
    avg_fps = 0
    print("Generating output video...")

    while vid_l.isOpened():
        if frame_count % 30 == 0:
            avg_fps = frame_count / (time.time() - start_time)
        percent_done = (frame_count/frames_total) * 100
        print(f"{percent_done:>5.1f}% completed | Avg. FPS: {avg_fps:<9.2f}", end="\r")

        ret, frame_l = vid_l.read()
        if not ret: break
        ret, frame_r = vid_r.read()
        if not ret: break

        if frames_written > frame_count/args.speed:
            frame_count += 1
            continue

        if cfg.calib_json:
            frame_l = cv2.remap(frame_l, *maps_L, cv2.INTER_CUBIC)
            frame_r = cv2.remap(frame_r, *maps_R, cv2.INTER_CUBIC)
        out_frame = inference_frame(frame_l, frame_r, yolo, fsrcnn)
        vid_writer.write(out_frame)
        frame_count += 1
        frames_written += 1

    vid_l.release()
    vid_r.release()
    vid_writer.release()
    print()


if __name__ == "__main__":
    main()
