"""
This code is used to search for positive frames.
"""
import sys
sys.path.append(r'../Refer_Judge')  # Change this
import glob
import numpy as np
from tqdm import tqdm
from utils.Projection.point_image_project import projection_interface, resize_crop_image
from jury_and_judge.config import ScanNet_Frame_ROOT, ScanNet_ROOT

FRAMES_ROOT = ScanNet_Frame_ROOT + '/{}/{}/*.jpg'
GT_CLOUD_ROOT = ScanNet_ROOT + '/{}/{}'

def task_for_one_scene(scene_name, messages_for_one_scene):
    cloud_path = GT_CLOUD_ROOT.format(scene_name, '_vert.npy')
    scene_cloud = np.load(cloud_path)
    scene_cloud_label = np.load(GT_CLOUD_ROOT.format(scene_name, '_ins_label.npy'))
    scene_cloud_gt = np.hstack((scene_cloud, scene_cloud_label.reshape(-1, 1)))
    scene_image_root = glob.glob(FRAMES_ROOT.format(scene_name, 'color'))

    for instance_id, v in tqdm(messages_for_one_scene.items(), leave=False):
        try:
            fit_rates = []
            fit_frame = []
            for image_path in scene_image_root:
                frame_name = image_path[-10:]
                depht_path = image_path.replace('color', 'depth').replace('.jpg', '.png')
                pose_path = image_path.replace('color', 'pose').replace('.jpg', '.txt')

                proj_2d_mask, proj_3d_idx, projected_pixels = projection_interface(cloud_path, image_path,
                                                                                   depht_path, pose_path)
                scene_patch = scene_cloud_gt[proj_3d_idx]
                find_instance = scene_patch[:, -1] == int(instance_id) + 1
                if np.sum(find_instance) > 0:
                    instance_pix = np.sum(find_instance)
                    all_useful_pix = np.sum(proj_2d_mask)
                    fit_frame.append(frame_name)
                    fit_rates.append(instance_pix / all_useful_pix)

            if len(fit_rates) > 0:
                sorted_indices = sorted(range(len(fit_rates)), key=lambda i: abs(fit_rates[i] - 0.4))
                sorted_decimals = [fit_rates[i] for i in sorted_indices]
                sorted_strings = [fit_frame[i] for i in sorted_indices]
                # 更新信息
                messages_for_one_scene[instance_id] = sorted_strings
        except:
            pass

    return messages_for_one_scene

def iteration_pos(data):

    scenes_name_list = []
    for d in data:
        scenes_name_list.append(d['scene_id'])
    scenes_name_list = sorted(list(set(scenes_name_list)))

    for scene_name in tqdm(scenes_name_list, total=len(scenes_name_list), leave=False):
        frames_for_one_scene = {}
        for d in data:
            if d['scene_id'] == scene_name and not 'positive_frames' in d:
                if not d['object_id'] in frames_for_one_scene:
                    frames_for_one_scene.update({d['object_id']: None})

        if len(frames_for_one_scene) > 0:
            frames_for_one_scene = task_for_one_scene(scene_name, frames_for_one_scene)
            for item in data:
                if item['scene_id'] == scene_name:
                    try:
                        item_id = item['object_id']
                        item.update({'positive_frames': frames_for_one_scene[item_id]})
                    except:
                        pass

    return data

if __name__ == '__main__':
    iteration_pos()