import argparse
import json
import os
import warnings

import cv2

import numpy as np
import scipy
from eval_utils import calc_metric_for_scale, gen_center, get_pose, check_box_in_frame
from visualization_utils import gen_frames, undistort_img, get_intrinsics

warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument(
    "--model",
    type=str,
    default="ssl_nyuv2",
    help="the model to run for evaluation script",
)
parser.add_argument(
    "--eval_metric",
    type=str,
    default="PointDis",
    help="the optimization metric used for evaluation",
)
parser.add_argument(
    "--img_width", type=int, default=640, help="image width in pixel of the input video"
)
parser.add_argument(
    "--img_height",
    type=int,
    default=480,
    help="image height in pixel of the input video",
)
parser.add_argument(
    "--scenario", type=str, default="kitchen", help="the scenario of the video"
)
parser.add_argument(
    "--visualize", help="if set, generates the visualization video", action="store_true"
)


def extract_tracklet(
    dataset_json, video_id, clip_id, track_id, img_width=640, img_height=480, calib=None,
):
    tracklets = dataset_json["videos"][video_id]["clips"][clip_id]["annotations"][0][
        "query_sets"
    ][track_id]["lt_track"]
    gt_tracklet = {}
    first_frame_ind = np.inf
    for track in tracklets:
        x, y, w, h = track["x"], track["y"], track["width"], track["height"]
        if "original_width" not in track or "original_height" not in track:
            ow = 1920
            oh = 1440
        else:
            ow, oh = track["original_width"], track["original_height"]
        bb = np.array([x, y, x + w, y + h])
        bb[0:3:2] = bb[0:3:2] / ow * img_width
        bb[1:4:2] = bb[1:4:2] / oh * img_height
        if len(calib) > 4:
            K = get_intrinsics(calib)[:3, :3]
            bb = cv2.undistortPoints(bb.reshape((2, 2)), K, calib[4:], P=K).flatten()
        # if check_box_in_frame(bb):
        gt_tracklet[track["frame_number"]] = bb.astype(np.int)
        if track["frame_number"] < first_frame_ind:
            first_frame_ind = track["frame_number"]
    return gt_tracklet, first_frame_ind


def check_moving(attributes):
    is_moving = False
    is_loc_changed = False
    for attribute in attributes:
        is_moving = is_moving or attribute["is_moving"]
        is_loc_changed = is_loc_changed or attribute["is_location_changed"]
    return is_moving or is_loc_changed


SCENARIO_DICT = {
    "kitchen": [
        "1a3273d3-f6a7-4ff3-8339-c5eeca5b6999",
        "5d0a4731-3a90-44ad-b2b4-cd2d6022eeff",
        "9b491af0-67ca-4280-b664-4b1ae78616c1",
        "148833ae-0717-4119-a700-3d04205375a9",
        "dbb00676-b795-4847-93bb-e3bf6117768e",
    ],
    "non-kitchen": [
        "2a3e77c8-05fb-425c-8302-a37006137f12",
        "264f1cfe-da89-4a5e-8a37-fadf451d3091",
        "201283ea-435c-44ec-a0cf-8782f0d6de8c",
        "e8cf3e90-ada3-4350-9f75-cc346349929d",
        "5afaa19d-a01b-44ed-a74b-c52f1c6e5532",
        "cc2e86f0-05db-4d7f-88b2-cbb06cc7620e",
    ],
    "new_selection": [
        "0c9c4d10-b21c-40c6-a048-6f65d3ee9318",
        "1faee988-40b2-4618-92d4-265fd6564a83",
        "2ce4e683-1261-4e81-9182-ce8419f265ec",
        "2cf112fc-e155-4244-9861-73f90ee7a21b",
        "5b7f27f4-ac37-454a-a73a-0209d1cf2308",
        "5de7cee8-24d9-4692-a50c-532fc5f2d416",
        "6d5e75b2-9287-4e75-bfa8-de89b020f997",
        "8d73bfaf-0d29-4691-a57e-8e480bb48f84",
        "043cbd3f-70d5-41de-b44e-a04f6b7a4cf1",
        "49fccf5c-bc35-4af5-919a-1467a638c474",
    ],
}

if __name__ == "__main__":  # noqa: C901
    # visualization wouldn't work for the sample version
    BASE_IMG_PATH = "./"
    # Find Camera Intrinsics
    log_metric = {}
    with open("./egotracks_with_additional_annotation.json", "r") as f:
        dataset_json = json.load(f)
    with open('./video_meta.json', 'r') as f:
        metadata_json = json.load(f)

    args = parser.parse_args()
    img_width = args.img_width
    img_height = args.img_height
    color_banks = [
        (255, 0, 0),
        (0, 255, 0),
        (0, 0, 255),
        (255, 255, 0),
        (255, 0, 255),
        (0, 255, 255),
        (127, 127, 127),
    ]
    clip_uids = ["1a3273d3-f6a7-4ff3-8339-c5eeca5b6999"]
    with open('./video_meta.json', 'r') as openfile:
        metadata = json.load(openfile)
    for clip_uid in clip_uids:
        calib_file = metadata_json[clip_uid]        
        calib_filepath = "./calib/{}.txt".format(calib_file)
        calib = np.loadtxt(calib_filepath, delimiter=" ")

        fx, fy, cx, cy = calib[:4]
        cam_int = np.eye(4)
        cam_int[0,0] = fx
        cam_int[0,2] = cx
        cam_int[1,1] = fy
        cam_int[1,2] = cy
        for i in range(len(dataset_json["videos"])):
            for j in range(len(dataset_json["videos"][i]["clips"])):
                tmp_clip_uid = dataset_json["videos"][i]["clips"][j]["source_clip_uid"]
                if tmp_clip_uid == clip_uid:
                    video_id = i
                    clip_id = j
        img_num = 1500
        model_name = args.model
        metric_used = "PointDis"
        metric_unused = "mIOU"
        try:
            pred_pose_np = get_pose(
                clip_uid,
                model_name,
                dataset="ego4d",
                world2cam=True,
                scenario=args.scenario,
            )
            assert img_num == pred_pose_np.shape[0]
        except Exception as e:
            print(e)
            continue
        if pred_pose_np is None:
            print("Can't calculate pose for {}".format(clip_uid))
            continue
        query_sets = dataset_json["videos"][video_id]["clips"][clip_id]["annotations"][
            0
        ]["query_sets"]
        pd_agg = []
        frames = None
        color_id = 0
        for track_id in query_sets:
            if "attributes_occurrence" not in query_sets[track_id]:
                continue
            if check_moving(query_sets[track_id]["attributes_occurrence"]):
                continue
            if "lt_track" not in query_sets[track_id]:
                continue
            gt_bounding_box, first_frame_ind = extract_tracklet(
                dataset_json,
                video_id,
                clip_id,
                track_id,
                img_width=img_width,
                img_height=img_height,
                calib=calib
            )
            print("Start optimization for {}, tracklet {}".format(clip_uid, track_id))
            ref_depth = None
            bb_cnt = len(gt_bounding_box)
            res_brute1 = scipy.optimize.brute(
                lambda x, gt_tracklet=gt_bounding_box, cam_int=cam_int, poses=pred_pose_np, mask=None, depth=None, first_frame_ind=first_frame_ind: calc_metric_for_scale(
                    x,
                    gt_tracklet,
                    cam_int,
                    poses,
                    mask,
                    depth,
                    first_frame_ind,
                    metric=metric_used,
                    local_only=True,
                ),
                ranges=[(0.1, 3.0, 0.05)],
                full_output=True,
                finish=None,
            )
            res_brute2 = scipy.optimize.brute(
                lambda x, gt_tracklet=gt_bounding_box, cam_int=cam_int, poses=pred_pose_np, mask=None, depth=None, first_frame_ind=first_frame_ind: calc_metric_for_scale(
                    x,
                    gt_tracklet,
                    cam_int,
                    poses,
                    mask,
                    depth,
                    first_frame_ind,
                    metric=metric_used,
                    local_only=True,
                ),
                ranges=[(1.0, 200.0, 2.0)],
                full_output=True,
                finish=None,
            )
            cnt = 0
            optimal_scale = res_brute1[0] if res_brute1[1] < res_brute2[1] else res_brute2[0]
            metric_pd = res_brute1[1] if res_brute1[1] < res_brute2[1] else res_brute2[1]
            print("Finished simulation for {}".format(clip_uid))
            print(
                "ORE is {}, optimal depth scale is {}".format(
                    metric_pd, optimal_scale
                )
            )
            if metric_used == "PointDis":
                pt = (
                    gt_bounding_box[first_frame_ind][:2]
                    + gt_bounding_box[first_frame_ind][2:]
                ) / 2.0
                depth = optimal_scale
                obj_centers = gen_center(
                    cam_int, pred_pose_np, first_frame_ind, pt, depth
                )
            pd_agg.append(metric_pd)
            if args.visualize:
                frames = gen_frames(
                    frames,
                    pred_pose_np,
                    gt_bounding_box,
                    obj_centers,
                    BASE_IMG_PATH,
                    clip_uid,
                    0,
                    color=color_banks[color_id],
                    calib=calib,
                )
            color_id += 1
            log_metric[clip_uid + "_" + str(track_id)] = {
                "OptimalScale": optimal_scale,
                metric_used: metric_pd,
            }

        log_metric["{}".format(clip_uid)] = {
            metric_used: np.mean(np.array(pd_agg)),
        }
        print("Average ORE is {}".format(np.mean(np.array(pd_agg))))
        if args.visualize:
            img_indices = sorted(frames.keys())
            output_video_path = "./video_samples/{}_{}.webm".format(
                model_name, clip_uid
            )
            fourcc = cv2.VideoWriter_fourcc(*"vp80")
            out = cv2.VideoWriter(output_video_path, fourcc, 5.0, (640, 480))
            for img_ind in range(0, img_num):
                if img_ind in img_indices:
                    out.write(frames[img_ind])
                else:
                    local_frame = cv2.imread(BASE_IMG_PATH.format(clip_uid, img_ind))
                    out.write(local_frame)
            out.release()