import sys
import argparse
from pathlib import Path
import shutil
import random

import cv2
from qreader import QReader

import calib
from calib import BoardConfig, StrPath
from config import Config as cfg
from video_inf import main as video_inf


# ---------------------------- Calibration Config --------------------------------
corners_size = (10, 7)
board_type = "chessb"  # "chessb", "charuco" or "both"
square_side = 95e-3  # in meters
marker_side = square_side * 0.75

# Iteration parameters
mono_target_err = 0.5
stereo_target_err = 1.0
reject_per_iter = 1
max_imgs_per_iter = 100
max_iter = 100
# --------------------------------------------------------------------------------


def parse_args(argv: list[str] = None):
    desc = (
        "An end-to-end script which performs QR-based video synchronization, "
        "calibration board image extraction, calibration and depth estimation "
        "using two stereo videos. "
        "The YOLO model used for object detection and depth estimation "
        "parameters, both can be set in `config.py`."
    )
    parser = argparse.ArgumentParser(description=desc)

    parser.add_argument("left_vid", help="Video file from the left camera")
    parser.add_argument("right_vid", help="Video file from the right camera")
    parser.add_argument("-o", default="./", help="Output directory (default='./')")
    return parser.parse_args(argv)


def qr_frame_diff(left_vid: StrPath, right_vid: StrPath) -> int | None:
    if not Path(left_vid).exists():
        raise FileNotFoundError(str(left_vid))
    if not Path(right_vid).exists():
        raise FileNotFoundError(str(right_vid))

    vid_l = cv2.VideoCapture(left_vid)
    if not vid_l.isOpened():
        raise RuntimeError(f"Failure to open {left_vid}")

    vid_r = cv2.VideoCapture(right_vid)
    if not vid_r.isOpened():
        raise RuntimeError(f"Failure to open {right_vid}")

    qreader = QReader()
    qr_frame = {}
    fdiff = None
    frame_no = 1
    while True:
        ret, frame_l = vid_l.read()
        if not ret:
            break
        ret, frame_r = vid_r.read()
        if not ret:
            break

        frame_l = cv2.cvtColor(frame_l, cv2.COLOR_BGR2RGB)
        frame_r = cv2.cvtColor(frame_r, cv2.COLOR_BGR2RGB)
        text_l = qreader.detect_and_decode(frame_l)
        text_r = qreader.detect_and_decode(frame_r)

        if text_l and text_l[0] is not None:
            num = int(text_l[0].replace("Number: ", ""))
            if num in qr_frame and qr_frame[num] < 0:
                fdiff = frame_no + qr_frame[num]
                break
            else:
                qr_frame[num] = frame_no

        if text_r and text_r[0] is not None:
            num = int(text_r[0].replace("Number: ", ""))
            if num in qr_frame and qr_frame[num] > 0:
                fdiff = -frame_no + qr_frame[num]
                break
            else:
                qr_frame[num] = -frame_no

        frame_no += 1

    vid_l.release()
    vid_r.release()
    return fdiff


def extract_calib_imgs(
    left_vid: StrPath,
    right_vid: StrPath,
    frame_diff: int,
    board: BoardConfig,
    left_savedir: StrPath,
    right_savedir: StrPath,
) -> None:
    if not Path(left_vid).exists():
        raise FileNotFoundError(str(left_vid))
    if not Path(right_vid).exists():
        raise FileNotFoundError(str(right_vid))

    vid_l = cv2.VideoCapture(left_vid)
    if not vid_l.isOpened():
        raise RuntimeError(f"Failure to open {left_vid}")

    vid_r = cv2.VideoCapture(right_vid)
    if not vid_r.isOpened():
        raise RuntimeError(f"Failure to open {right_vid}")

    frames_l = int(vid_l.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_r = int(vid_r.get(cv2.CAP_PROP_FRAME_COUNT))
    if frame_diff > 0:
        vid_l.set(cv2.CAP_PROP_POS_FRAMES, frame_diff)
        frames_l -= frame_diff
    elif frame_diff < 0:
        frame_diff = -frame_diff
        vid_r.set(cv2.CAP_PROP_POS_FRAMES, frame_diff)
        frames_r -= frame_diff

    corner_detector = calib.MBCornerDetector(board)

    left_savedir = Path(left_savedir)
    right_savedir = Path(right_savedir)
    left_savedir.mkdir(parents=True, exist_ok=True)
    right_savedir.mkdir(parents=True, exist_ok=True)

    frame_count = 0
    found = 0
    frames_total = min(frames_l, frames_r)
    print("Extracting calibration images...")
    while True:
        percent_done = (frame_count / frames_total) * 100
        print(f"{percent_done:>5.1f}% completed | {found} pairs found", end="\r")

        ret, frame_l = vid_l.read()
        if not ret:
            break
        ret, frame_r = vid_r.read()
        if not ret:
            break

        norm_frame_l = cv2.cvtColor(frame_l, cv2.COLOR_BGR2GRAY)
        norm_frame_l = cv2.equalizeHist(norm_frame_l)
        retL, _, _ = corner_detector.detect(norm_frame_l)

        norm_frame_r = cv2.cvtColor(frame_r, cv2.COLOR_BGR2GRAY)
        norm_frame_r = cv2.equalizeHist(norm_frame_r)
        retR, _, _ = corner_detector.detect(norm_frame_r)

        if retL and retR:
            cv2.imwrite(left_savedir / f"f_{frame_count:06d}.png", frame_l)
            cv2.imwrite(right_savedir / f"f_{frame_count:06d}.png", frame_r)
            found += 1

        frame_count += 1

    print()
    vid_l.release()
    vid_r.release()


def pair_randomize_files(left_dir: StrPath, right_dir: StrPath):
    left_dir = Path(left_dir)
    right_dir = Path(right_dir)
    left_files = [p for p in left_dir.iterdir() if p.is_file()]
    right_files = [p for p in right_dir.iterdir() if p.is_file()]

    left_file_map = {f.name: f for f in left_files}
    right_file_map = {f.name: f for f in right_files}

    left_file_names = set(left_file_map.keys())
    right_file_names = set(right_file_map.keys())
    common_filenames = list(set(left_file_names.intersection(right_file_names)))

    unique_left_filenames = [
        f for f in left_file_map.keys() if f not in common_filenames
    ]
    unique_right_filenames = [
        f for f in right_file_map.keys() if f not in common_filenames
    ]

    all_paths_to_shuffle = []
    for common_name in common_filenames:
        all_paths_to_shuffle.append(
            (left_file_map[common_name], right_file_map[common_name])
        )

    for unique_name in unique_left_filenames:
        all_paths_to_shuffle.append(left_file_map[unique_name])

    for unique_name in unique_right_filenames:
        all_paths_to_shuffle.append(right_file_map[unique_name])

    random.shuffle(all_paths_to_shuffle)

    for i, item in enumerate(all_paths_to_shuffle):
        if isinstance(item, tuple):
            original_left_path, original_right_path = item
            left_new_name = f"{i:06d}{original_left_path.suffix}"
            right_new_name = f"{i:06d}{original_right_path.suffix}"

            new_left_path = left_dir / left_new_name
            new_right_path = right_dir / right_new_name

            original_left_path.rename(new_left_path)
            original_right_path.rename(new_right_path)

        else:
            original_path = item
            new_name = f"{i:06d}{original_path.suffix}"
            new_path = original_path.parent / new_name
            original_path.rename(new_path)


def main(argv: list[str] = None) -> None:
    args = parse_args(argv)

    left_vid = Path(args.left_vid)
    right_vid = Path(args.right_vid)
    out_dir = Path(args.o)
    fdiff_path = out_dir / "frame_diff.txt"
    left_dir = out_dir / "left"
    right_dir = out_dir / "right"
    calib_path = out_dir / "calib.json"
    report_path = out_dir / "calib_report.txt"
    de_path = out_dir / "de.mp4"

    replace = calib.confirm_replace(
        [fdiff_path, left_dir, right_dir, calib_path, report_path, de_path]
    )
    if replace is not None and not replace:
        sys.exit()
    elif replace:
        print()

    print("Running QR frame synchronization...")
    fdiff = qr_frame_diff(left_vid, right_vid)
    if fdiff is None:
        raise RuntimeError("QR frame synchronization failed!!")
    print("Frame diff:", fdiff)
    fdiff_path.parent.mkdir(parents=True, exist_ok=True)
    with open(fdiff_path, "w") as f:
        f.write(str(fdiff))
    print("Frame diff saved to", fdiff_path, "\n")

    board = BoardConfig(corners_size, board_type, square_side, marker_side)

    shutil.rmtree(left_dir, ignore_errors=True)
    shutil.rmtree(right_dir, ignore_errors=True)
    extract_calib_imgs(left_vid, right_vid, fdiff, board, left_dir, right_dir)
    print("Left calibration images saved at", left_dir)
    print("Right calibration images saved at", right_dir, "\n")

    print("Randomizing the order of calibration images...")
    pair_randomize_files(left_dir, right_dir)
    print("Done")

    print("\nCalibrating left camera...")
    calib.log_fmt.prefix += "Left "
    ret, params_l, rerr_l = calib.iter_mb_calibrate(
        left_dir,
        board,
        mono_target_err,
        reject_per_iter,
        max_imgs_per_iter,
        max_iter,
    )
    if not ret:
        raise Exception("Left camera calibration failed!")

    calib.log_fmt.prefix = calib.log_fmt.prefix[:-5]
    rerr_l = sum(rerr_l.values()) / len(rerr_l)

    print("\nCalibrating right camera...")
    calib.log_fmt.prefix += "Right "
    ret, params_r, rerr_r = calib.iter_mb_calibrate(
        right_dir,
        board,
        mono_target_err,
        reject_per_iter,
        max_imgs_per_iter,
        max_iter,
    )
    if not ret:
        raise Exception("Right camera calibration failed!")

    calib.log_fmt.prefix = calib.log_fmt.prefix[:-6]
    rerr_r = sum(rerr_r.values()) / len(rerr_r)

    print("\nCalibrating stereo setup...")
    calib.log_fmt.prefix += "Stereo "
    ret, params_s, rerr_s = calib.iter_mb_stereo(
        left_dir,
        right_dir,
        params_l,
        params_r,
        board,
        stereo_target_err,
        reject_per_iter,
        max_imgs_per_iter,
        max_iter,
    )
    if not ret:
        raise Exception("Stereo camera calibration failed!")

    calib.log_fmt.prefix = calib.log_fmt.prefix[:-7]
    rerr_s = sum(rerr_s.values()) / len(rerr_s)

    print()
    print("Left Camera Reprojection Error: ", rerr_l)
    print("Right Camera Reprojection Error:", rerr_r)
    print("Stereo Reprojection Error:      ", rerr_s)

    calib_path.parent.mkdir(parents=True, exist_ok=True)
    calib.save_params(params_s, calib_path)
    print("\nCalibrated parameters saved at", calib_path)

    report = calib.validation_report(
        params_s["R"], params_s["T"], rerr_l, rerr_r, rerr_s
    )
    with open(report_path, "w") as f:
        f.write(report)
    print("Validation report saved at", report_path, "\n")

    cfg.calib_json = str(calib_path)
    video_inf([str(left_vid), str(right_vid), "-o", str(de_path), "-f", str(fdiff)])
    print("Depth estimated video saved as", de_path)


if __name__ == "__main__":
    main()
