from eval_utils import (
    convert_from_uvd,
    gen_gt_tracklet,
    get_prediction,
    get_mask_for_first_image,
    get_maskrcnn_model,
    get_iou,
    accumulate_pairwise_pose,
    calc_metric_for_scale,
    gen_center,
    get_pose,
    get_scannet_pose,
    check_box_in_frame,
)
from visualization_utils import (
    get_random_color,
    random_color_masks,
    gen_segmentation_view,
    gen_bounding_box,
    gen_frames,
)

import numpy as np
import cv2
import json
import os
from PIL import Image
import subprocess
import torchvision
from scipy.spatial.transform import Rotation as R
import scipy
from scipy import optimize
import warnings
import argparse
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument("--model", 
                    type=str, 
                    default="ssl_nyuv2",
                    help="the model to run for evaluation script")

parser.add_argument("--scannet_folder", 
                    type=str, 
                    help="the folder where we store scannet data")

parser.add_argument("--eval_metric",
                    type=str,
                    default="PointDis",
                    help="the optimization metric used for evaluation"
)
parser.add_argument("--min_ratio",
                    type=float,
                    default=45,
                    help="min ratio of bounding box lower bound"
)

parser.add_argument("--max_ratio",
                    type=float,
                    default=55,
                    help="max ratio of bounding box upper bound"
)

parser.add_argument("--visualize",
                    help="if set, generates the visualization video",
                    action="store_true")

parser.add_argument("--image_width",
                    type=float,
                    default=480,
                    help="min ratio of bounding box lower bound"
)

parser.add_argument("--image_height",
                    type=float,
                    default=640,
                    help="min ratio of bounding box lower bound"
)
color_banks = [
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 255),
    (127, 127, 127),
]

if __name__ == "__main__":
    np.random.seed(0)
    # Find Camera Intrinsics
    cam_int = np.array(
        [
            [577.87, 0.0, 319.5, 0.0],
            [0.0, 577.87, 239.5, 0.0],
            [0.0, 0.0, 1.0, 0.0],
            [0.0, 0.0, 0.0, 1.0],
        ]
    )
    calib = np.array([577.87, 577.87, 319.5, 239.5])
    args = parser.parse_args()
    log_metric = {}

    for scene_name in [
        "scene0707_00", "scene0708_00", "scene0709_00", "scene0710_00", 
        "scene0711_00", "scene0712_00", "scene0713_00", "scene0714_00", 
        "scene0715_00", "scene0716_00", "scene0717_00", "scene0718_00", 
        "scene0719_00", "scene0720_00", "scene0721_00", "scene0722_00",
        "scene0723_00", "scene0724_00", "scene0725_00", "scene0726_00",
    ]:
        base_img_path = os.path.join(
            args.scannet_folder,
            "{}/color/{}.jpg"
        )
        base_depth_path = os.path.join(
            args.scannet_folder,
            "{}/depth/{}.npy"
        )
        img_num = len(
            os.listdir(
                os.path.join(
                    args.scannet_folder, 
                    "{}/color".format(
                        scene_name
                    )
                )
            )
        )
        model_name = args.model
        metric_used = "PointDis"
        metric_unused = "mIOU"

        gt_poses = get_scannet_pose(scene_name, world2cam=True)

        pred_pose_np = get_pose(scene_name, model_name, world2cam=True, args=args)
        if pred_pose_np is None:
            print("Can't calculate pose for {}".format(scene_name))
            continue
        pred_pose_np = pred_pose_np[:gt_poses.shape[0]]

        model = get_maskrcnn_model()
        frames = None
        metric_agg = []
        for first_frame_ind in range(100, img_num, 1000):
            while first_frame_ind < img_num:
                try:
                    img_path = base_img_path.format(scene_name, first_frame_ind)
                    mask, pred_class = get_mask_for_first_image(img_path, model, visualize=True, mask_score=0.9)
                    if mask.shape[0] > 0:
                        break
                except Exception:
                    print("No object detected for {} frame number {}".format(scene_name, first_frame_ind))
                first_frame_ind += 1
            if first_frame_ind >= img_num:
                continue
            for mask_ind in range(mask.shape[0]):
                # Use ScanNet depth, mask and cam pose to generate tracklets
                depth_path = base_depth_path.format(scene_name, first_frame_ind)
                depth = np.load(depth_path) / 1000.0
                _, obj_vertices = gen_gt_tracklet(
                    cam_int, gt_poses, first_frame_ind, mask[mask_ind], depth
                )
                gt_bounding_box = gen_bounding_box(obj_vertices, min_percentile=args.min_ratio, max_percentile=args.max_ratio, first_frame_ind=first_frame_ind)
                
                img_width = 640
                img_height = 480

                bb_cnt = 0
                for ind, gt_bb in gt_bounding_box.items():
                    if check_box_in_frame(gt_bb):
                        bb_cnt += 1
                if first_frame_ind not in gt_bounding_box:
                    continue
                # Depth scale optimization to find the optimal depth for PD
                print("Start optimization for {}".format(scene_name))
                ref_depth = None
                res_brute = scipy.optimize.brute(
                    lambda x, gt_tracklet=gt_bounding_box, cam_int=cam_int, poses=pred_pose_np, mask=mask[mask_ind], depth=ref_depth, first_frame_ind=first_frame_ind: calc_metric_for_scale(
                        x, gt_tracklet, cam_int, poses, mask, depth, first_frame_ind, metric=metric_used, local_only=True
                    ),
                    ranges=[(0.1, 20.0, 0.05)],
                    # ranges=[(10.0, 600.0, 2.0)],
                    full_output=True,
                    finish=None,
                )
                cnt = 0
                optimal_scale = res_brute[0]
                print("Finished optimization for {}".format(scene_name))


                print("{} is {}, optimal depth scale is {}".format(metric_used, res_brute[1], optimal_scale))
                if metric_used == "PointDis":        
                    pt = (gt_bounding_box[first_frame_ind][:2] + gt_bounding_box[first_frame_ind][2:]) / 2.0
                    depth = optimal_scale 
                    if (pt[0] < 0 or pt[1] < 0 or pt[0] > img_width or pt[1] > img_height):
                        continue
                    obj_centers = gen_center(cam_int, pred_pose_np, first_frame_ind, pt, depth)
                elif metric_used == "mIOU":
                    img_height, img_width = mask[mask_ind].shape[:2]
                    scaled_depth = ref_depth / optimal_scale

                    obj_centers, obj_vertices = gen_gt_tracklet(
                        cam_int, pred_pose_np, first_frame_ind, mask[mask_ind], scaled_depth
                    )
                    bounding_boxes = gen_bounding_box(obj_vertices)

                # changed here /bb_cnt
                metric_pd = res_brute[1]
                metric_agg.append(metric_pd)
                log_metric[scene_name + "_" + str(first_frame_ind) + "_" + str(mask_ind)] = {
                    "OptimalScale": optimal_scale,
                    metric_used: metric_pd,
                    "Class": int(pred_class[mask_ind]),
                    "TrackletLength": bb_cnt,
                }
                if args.visualize:
                    frames = gen_frames(
                        frames,
                        pred_pose_np,
                        gt_bounding_box,
                        obj_centers,
                        base_img_path,
                        scene_name,
                        first_frame_ind,
                        calib=calib,
                        color=color_banks[mask_ind],
                    )

        log_metric[scene_name] = {
            "OptimalScale": optimal_scale,
            metric_used: np.mean(np.array(metric_agg)),
            "Ratio": args.max_ratio - args.min_ratio
        }
        if args.visualize:
            img_indices = sorted(list(frames.keys()))
            output_video_path="./video_samples/{}_{}.webm".format(model_name, scene_name)
            fourcc = cv2.VideoWriter_fourcc(*'vp80')
            out = cv2.VideoWriter(output_video_path, fourcc, 5.0, (640 + 480, 480))
            canvas = np.ones((480, 480, 3)).astype(np.uint8) * 255 
            # for img_ind in range(0, img_num):
            for img_ind in range(0, 100):
                if img_ind in img_indices:
                    local_frame = frames[img_ind]
                else:
                    local_frame = cv2.imread(base_img_path.format(scene_name, img_ind))
                output_frame = np.zeros((480, 640 + 480, 3)).astype(np.uint8)
                output_frame[:, :640, :] = local_frame
                pose = pred_pose_np[i]
                cv2.circle(
                    canvas,
                    (int(pose[0, 3] * 120.0) + 240, int(pose[2, 3] * 120.0) + 240),
                    radius=5,
                    color=(0, 255, 0),
                    thickness=-1,
                )
                output_frame[:, 640:, :] = canvas
                out.write(local_frame)
            out.release()
    resolution = "{}x{}".format(int(args.image_width), int(args.image_height))
    with open("./scannet_results_0603/result_scannet_{}_ratio{}_frontend_only.json".format(args.model, (args.max_ratio-args.min_ratio)), "w") as outfile:
        outfile.write(
            json.dumps(
                log_metric, indent=4
            )
        )