"""
This code is used to search for misleading frames.
"""
import sys
sys.path.append(r'../Refer_Judge')  # Change this
import glob
import json
import numpy as np
from tqdm import tqdm
from utils.Projection.point_image_project import projection_interface, resize_crop_image
from jury_and_judge.config import ScanNet_Frame_ROOT, ScanNet_ROOT

FRAMES_ROOT = ScanNet_Frame_ROOT + '/{}/{}/*.jpg'
GT_CLOUD_ROOT = ScanNet_ROOT + '/{}/{}'
GT_instance_class_root = ScanNet_ROOT + '/{}/{}.aggregation.json'

def task_for_one_scene(scene_name, messages_for_one_scene):
    cloud_path = GT_CLOUD_ROOT.format(scene_name, '_vert.npy')
    scene_cloud = np.load(cloud_path)
    scene_cloud_label = np.load(GT_CLOUD_ROOT.format(scene_name, '_ins_label.npy'))
    scene_cloud_gt = np.hstack((scene_cloud, scene_cloud_label.reshape(-1, 1)))
    scene_image_root = glob.glob(FRAMES_ROOT.format(scene_name, 'color'))

    for instance_id, mis_leading_list in tqdm(messages_for_one_scene.items(), leave=False):
        if len(mis_leading_list) > 0:
            try:
                fit_rates = []
                fit_frame = []
                for image_path in scene_image_root:
                    frame_name = image_path[-10:]
                    depht_path = image_path.replace('color', 'depth').replace('.jpg', '.png')
                    pose_path = image_path.replace('color', 'pose').replace('.jpg', '.txt')

                    proj_2d_mask, proj_3d_idx, projected_pixels = projection_interface(cloud_path, image_path,
                                                                                       depht_path, pose_path)
                    scene_patch = scene_cloud_gt[proj_3d_idx]
                    is_target = scene_patch[:, -1] == int(instance_id) + 1
                    if np.sum(is_target) == 0:
                        is_in_misleading = np.isin(scene_patch[:, -1], np.asarray(mis_leading_list) + 1)
                        find_instance = np.sum(is_in_misleading)
                        if np.sum(find_instance) > 0:
                            instance_pix = np.sum(find_instance)
                            all_useful_pix = np.sum(proj_2d_mask)
                            fit_frame.append(frame_name)
                            fit_rates.append(instance_pix / all_useful_pix)

                if len(fit_rates) > 3:
                    sorted_indices = sorted(range(len(fit_rates)), key=lambda i: abs(fit_rates[i] - 0.4))
                    sorted_decimals = [fit_rates[i] for i in sorted_indices]
                    sorted_strings = [fit_frame[i] for i in sorted_indices]
                    messages_for_one_scene[instance_id] = sorted_strings
                else:
                    messages_for_one_scene[instance_id] = fit_frame
            except:
                pass

    return messages_for_one_scene

def iteration_mis(data):
    scenes_name_list = []
    for d in data:
        scenes_name_list.append(d['scene_id'])
    scenes_name_list = sorted(list(set(scenes_name_list)))

    for scene_name in tqdm(scenes_name_list, total=len(scenes_name_list), leave=False):
        frames_for_one_scene = {}
        for d in data:
            if d['scene_id'] == scene_name and not 'misleading_frames' in d:
                if not d['object_id'] in frames_for_one_scene:
                    with open(GT_instance_class_root.format(scene_name, scene_name)) as f:
                        anno_json = json.load(f)
                    target_id = d['object_id']
                    target_name = anno_json['segGroups'][int(target_id)]['label']
                    mis_leading_list = []
                    for ins in anno_json['segGroups']:
                        if ins['label'] == target_name and ins['objectId'] != int(target_id):
                            mis_leading_list.append(int(ins['objectId']))
                    frames_for_one_scene.update({d['object_id']: mis_leading_list})

        if len(frames_for_one_scene) > 0:
            frames_for_one_scene = task_for_one_scene(scene_name, frames_for_one_scene)
            for item in data:
                if item['scene_id'] == scene_name:
                    try:
                        item_id = item['object_id']
                        item.update({'misleading_frames': frames_for_one_scene[item_id]})
                    except:
                        pass

    return data

if __name__ == '__main__':
    iteration_mis()