from collections import defaultdict
from typing import overload
from base.base_dataset import TextObjectDataset
import pandas as pd
import os
import json
import nltk
import numpy as np
import random
import torch
import copy
from utils.util import compute_iou

class WebVid(TextObjectDataset):
    """
    WebVid Dataset.
    Assumes webvid data is structured as follows.
    Webvid/
        videos/
            000001_000050/      ($page_dir)
                1.mp4           (videoid.mp4)
                ...
                5000.mp4
            ...
    """
    def _load_metadata(self):
        #metadata_dir = os.path.join(self.metadata_dir, 'meta_data')
        metadata_dir = './meta_data'
        split_files = {            
            'train': 'webvid_training_success_full.tsv',
            'val': 'webvid_validation_success_full.tsv',            # there is no test
            'test': 'webvid_validation_success_full.tsv',
        }
        target_split_fp = split_files[self.split]
        metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
        if self.subsample < 1:
            metadata = metadata.sample(frac=self.subsample)
        # elif self.split == 'val':
        #     metadata = metadata.sample(1000, random_state=0)  # 15k val is unnecessarily large, downsample.

        #metadata['caption'] = metadata['name']
        #del metadata['name']
        self.metadata = metadata
        # TODO: clean final csv so this isn't necessary
        #self.metadata.dropna(inplace=True)
        #self.metadata['caption'] = self.metadata['caption'].str[:350]

    def _get_video_path(self, sample):
        rel_video_fp = sample[1] + '.mp4'
        #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
        full_video_fp = os.path.join(self.data_dir, self.split, rel_video_fp)
        return full_video_fp, rel_video_fp

    def _get_caption(self, sample):
        return sample[0]

    def _get_object_path(self, sample):
        """
        get the object npy path
        Args:
            sample (dict):
        Returns:
            abs path
        """
        rel_object_fp = sample[1]
        full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp)
        return rel_object_fp, full_object_fp

class WebVidObjectSelect(TextObjectDataset):
    """
    WebVid Dataset.
    Assumes webvid data is structured as follows.
    Webvid/
        videos/
            000001_000050/      ($page_dir)
                1.mp4           (videoid.mp4)
                ...
                5000.mp4
            ...
    """
    def __init__(self,
                 dataset_name,
                 text_params,
                 object_params,
                 data_dir,
                 object_dir,
                 metadata_dir=None,
                 split='train',
                 tsfms=None,
                 cut=None,
                 subsample=1,
                 sliding_window_stride=-1,
                 reader='cv2',
                 mask=False
                 ):
        super(WebVidObjectSelect, self).__init__(dataset_name, text_params, object_params, data_dir, object_dir,
                                                metadata_dir, split, tsfms, cut, subsample, sliding_window_stride, reader, mask)
        self.object_num = self.object_params['object_num']
    
    def _load_metadata(self):
        #metadata_dir = os.path.join(self.metadata_dir, 'meta_data')
        metadata_dir = './meta_data'
        split_files = {            
            'train': 'webvid_training_success_full.tsv',
            'val': 'webvid_validation_success_full.tsv',            # there is no test
            'test': 'webvid_validation_success_full.tsv',
        }
        target_split_fp = split_files[self.split]
        metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
        if self.subsample < 1:
            metadata = metadata.sample(frac=self.subsample)
        # elif self.split == 'val':
        #     metadata = metadata.sample(1000, random_state=0)  # 15k val is unnecessarily large, downsample.

        #metadata['caption'] = metadata['name']
        #del metadata['name']
        self.metadata = metadata
        # TODO: clean final csv so this isn't necessary
        #self.metadata.dropna(inplace=True)
        #self.metadata['caption'] = self.metadata['caption'].str[:350]

    def _get_video_path(self, sample):
        rel_video_fp = sample[1] + '.mp4'
        #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
        full_video_fp = os.path.join(self.data_dir, self.split, rel_video_fp)
        return full_video_fp, rel_video_fp

    def _get_caption(self, sample):
        return sample[0]

    def _get_object_path(self, sample):
        """
        get the object npy path
        Args:
            sample (dict):
        Returns:
            abs path
        """
        rel_object_fp = sample[1]
        full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp)
        return rel_object_fp, full_object_fp

    def __getitem__(self, item):
        item = item % len(self.metadata)
        sample = self.metadata.iloc[item]
        object_rel_fp, object_fp = self._get_object_path(sample)
        caption = self._get_caption(sample)
        #######################################################################
        # extract noun and verbs:
        # text = nltk.word_tokenize(caption)
        # pos_tagged = nltk.pos_tag(text)
        # noun_verb = filter(lambda x:x[1].startswith('NN') or x[1].startswith('VB'), pos_tagged)
        # caption = ""
        # for w in noun_verb:
        #     caption += w[0]
        ########################################################################
        if not os.path.exists(os.path.join(object_fp, '0.npz')):
            print("not exist object in: {}, select another sample".format(object_fp))
            item_new = random.randint(1, self.__len__())
            return self.__getitem__(item_new)
        # load object
        object_file_num = len(os.listdir(object_fp))
        if object_file_num < 2:
            item_new = random.randint(1, self.__len__())
            return self.__getitem__(item_new)
        try:
            if self.split == 'train':
                # frame_idxs = np.random.choice(8, 4)  # it may repeat
                if self.segments == object_file_num:
                    frame_idxs = list(range(0, self.segments))
                else:
                    frame_idxs = self._sample_objects(self.segments, object_file_num, sample='rand')
                    frame_idxs = sorted(frame_idxs)
                # frame_idxs = random.sample(range(0, object_num), self.segments) # sorted(frame_idxs)
            else:
                if self.segments == object_file_num:
                    frame_idxs = list(range(0, self.segments))
                else:
                    frame_idxs = self._sample_objects(self.segments, object_file_num, sample='uniform')
                # frame_idxs = [1, 3, 5, 7]
            # print(frame_idxs)
            # frame_idxs = [1]
            object, object_mask, object_len = read_object_from_disk_with_object_select(object_fp, frame_idxs, self.object_num) # [segments, topk, 2054]
        except Exception as e:
            print("Fail to load selected objects {}, object_fp {}".format(e, object_fp))
            item_new = random.randint(1, self.__len__())
            return self.__getitem__(item_new)
        meta_arr = {'raw_captions': caption, 'paths': object_rel_fp, 'dataset': self.dataset_name}
        data = {'object': object, 'text': caption, 'meta': meta_arr, 'object_mask': object_mask, 'object_len': object_len}
        return data

class WebVidObjectFuseSelect(TextObjectDataset):
    """
    WebVid Dataset.
    Assumes webvid data is structured as follows.
    Webvid/
        videos/
            000001_000050/      ($page_dir)
                1.mp4           (videoid.mp4)
                ...
                5000.mp4
            ...
    """
    def __init__(self,
                 dataset_name,
                 text_params,
                 object_params,
                 data_dir,
                 object_dir,
                 metadata_dir=None,
                 split='train',
                 tsfms=None,
                 cut=None,
                 subsample=1,
                 sliding_window_stride=-1,
                 reader='cv2',
                 mask=False,
                 ):
        super(WebVidObjectFuseSelect, self).__init__(dataset_name, text_params, object_params, data_dir, object_dir,
                                                metadata_dir, split, tsfms, cut, subsample, sliding_window_stride, reader, mask)
        self.object_num = self.object_params['object_num']
        self.use_itm = self.text_params['use_itm']
    
    def _load_metadata(self):
        #metadata_dir = os.path.join(self.metadata_dir, 'meta_data')
        metadata_dir = './meta_data'
        split_files = {            
            'train': 'webvid_training_success_full.tsv',
            'val': 'webvid_validation_success_full.tsv',            # there is no test
            'test': 'webvid_validation_success_full.tsv',
        }
        target_split_fp = split_files[self.split]
        metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
        if self.subsample < 1:
            metadata = metadata.sample(frac=self.subsample)
        # elif self.split == 'val':
        #     metadata = metadata.sample(1000, random_state=0)  # 15k val is unnecessarily large, downsample.

        #metadata['caption'] = metadata['name']
        #del metadata['name']
        self.metadata = metadata
        # TODO: clean final csv so this isn't necessary
        #self.metadata.dropna(inplace=True)
        #self.metadata['caption'] = self.metadata['caption'].str[:350]

    def _get_video_path(self, sample):
        rel_video_fp = sample[1] + '.mp4'
        #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
        full_video_fp = os.path.join(self.data_dir, self.split, rel_video_fp)
        return full_video_fp, rel_video_fp

    def _get_caption(self, sample):
        if self.use_itm and random.random() < 0.5:
            caption = self.metadata.iloc[random.randint(1, self.__len__())% self.__len__()][0]
            itm_label = 0
        else:
            caption = sample[0]
            itm_label = 1
        return caption, itm_label

    def _get_object_path(self, sample):
        """
        get the object npy path
        Args:
            sample (dict):
        Returns:
            abs path
        """
        rel_object_fp = sample[1]
        full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp)
        return rel_object_fp, full_object_fp

    def __getitem__(self, item):
        item = item % len(self.metadata)
        sample = self.metadata.iloc[item]
        object_rel_fp, object_fp = self._get_object_path(sample)
        caption, itm_label = self._get_caption(sample)
        if not os.path.exists(os.path.join(object_fp, '0.npz')):
            print("not exist object in: {}, select another sample".format(object_fp))
            item_new = random.randint(1, self.__len__())
            return self.__getitem__(item_new)
        # load object
        object_file_num = len(os.listdir(object_fp))
        if object_file_num < 2:
            item_new = random.randint(1, self.__len__())
            return self.__getitem__(item_new)
        try:
            if self.split == 'train':
                # frame_idxs = np.random.choice(8, 4)  # it may repeat
                if self.segments == object_file_num:
                    frame_idxs = list(range(0, self.segments))
                else:
                    frame_idxs = self._sample_objects(self.segments, object_file_num, sample='rand')
                    frame_idxs = sorted(frame_idxs)
                # frame_idxs = random.sample(range(0, object_num), self.segments) # sorted(frame_idxs)
            else:
                if self.segments == object_file_num:
                    frame_idxs = list(range(0, self.segments))
                else:
                    frame_idxs = self._sample_objects(self.segments, object_file_num, sample='uniform')
                # frame_idxs = [1, 3, 5, 7]
            # print(frame_idxs)
            # frame_idxs = [1]
            object, object_mask, object_len = read_object_from_disk_with_object_select(object_fp, frame_idxs, self.object_num) # [segments, topk, 2054]
        except Exception as e:
            print("Fail to load selected objects {}, object_fp {}".format(e, object_fp))
            item_new = random.randint(1, self.__len__())
            return self.__getitem__(item_new)
        meta_arr = {'raw_captions': caption, 'paths': object_rel_fp, 'dataset': self.dataset_name}
        data = {'object': object, 'text': caption, 'meta': meta_arr, 'object_mask': object_mask, 
                'object_len': object_len, 'itm_label': itm_label}
        return data

def read_object_from_disk_with_object_select(object_path, frame_idxs, object_num):
    """
    load object features and bounding box localization
    Args:
        object_path(str): absoulte path
        frame_idx: list[int]
        object_num(int): num of region
    Returns:
        feat: b x N x [2048+6]; 6 means two points and s_h, s_w
    """
    object_info = read_all_object_from_disk(object_path, frame_idxs)
    # object_selected = object_select(object_info, object_num)
    # object_selected = object_select_v1(object_info, object_num)
    # len_list = object_selected[2]
    # len_list = len(len_list) * [sum(len_list) // len(len_list)]
    len_list = None
    object_selected = object_select_random(object_info, object_num, len_list)
    return object_selected


def object_select_random(object_info, object_num, object_num_list=None):
    """
    select loaded object
    Args:
        object_info: dict: {'frame_idx': {'feat': ndarray, 'objects_conf': ndarray, 'objects_id': ndarray, 'bbox': ndarray, 'spatial_feature': ndarray},
                            'frame_idx': ...}
        object_num: int: total number of object to load
        object_num_list: list, number of object to load for each frame
    Select rules:
        1. if object_num_list is given, choose object randomly according to the list, 
        else choose around 0.6 x object_num in each frame, if objects in current frame < 0.6 x object_num, choose all objects
    """
    o_num = int(object_num * 1.0)
    idxs = sorted(object_info.keys())
    if object_num_list is not None:
        for i, idx in enumerate(idxs):
            obj_num = object_num_list[i]
            selected_idxs = [random.choice(list(range(len(object_info[idx]['objects_id'])))) for _ in range(obj_num)]
            object_info[idx]['feat'] = object_info[idx]['feat'][selected_idxs]
            object_info[idx]['objects_conf'] = object_info[idx]['objects_conf'][selected_idxs]
            object_info[idx]['objects_id'] = object_info[idx]['objects_id'][selected_idxs]
            object_info[idx]['bbox'] = object_info[idx]['bbox'][selected_idxs]
            object_info[idx]['spatial_feature'] = object_info[idx]['spatial_feature'][selected_idxs]
            object_info[idx]['object_len'] = obj_num

            res = object_num - obj_num
            object_info[idx]['feat'] = np.pad(object_info[idx]['feat'], ((0, res), (0, 0)), 'edge')
            object_info[idx]['bbox'] = np.pad(object_info[idx]['bbox'], ((0, res), (0, 0)), 'edge')
            object_info[idx]['spatial_feature'] = np.pad(object_info[idx]['spatial_feature'], ((0, res), (0, 0)), 'edge')
    else:
        for i, idx in enumerate(idxs):
            if len(object_info[idx]['objects_id']) > o_num:
                object_info[idx]['feat'] = object_info[idx]['feat'][:o_num]
                object_info[idx]['objects_conf'] = object_info[idx]['objects_conf'][:o_num]
                object_info[idx]['objects_id'] = object_info[idx]['objects_id'][:o_num]
                object_info[idx]['bbox'] = object_info[idx]['bbox'][:o_num]
                object_info[idx]['spatial_feature'] = object_info[idx]['spatial_feature'][:o_num]
                object_info[idx]['object_len'] = o_num

                res = object_num - o_num
            else:
                res = object_num - len(object_info[idx]['objects_id'])
                object_info[idx]['object_len'] = len(object_info[idx]["objects_id"])
            object_info[idx]['feat'] = np.pad(object_info[idx]['feat'], ((0, res), (0, 0)), 'edge')
            object_info[idx]['bbox'] = np.pad(object_info[idx]['bbox'], ((0, res), (0, 0)), 'edge')
            object_info[idx]['spatial_feature'] = np.pad(object_info[idx]['spatial_feature'], ((0, res), (0, 0)), 'edge')

    feat_list = [object_info[i]['feat'] for i in idxs]
    spatial_feat_list = [object_info[i]['spatial_feature'] for i in idxs]
    object_len_list = [object_info[i]['object_len'] for i in idxs]
    object_mask = np.zeros((len(idxs), object_num))
    for i, length in enumerate(object_len_list):
        object_mask[i, :length] = 1
    feat = np.stack(feat_list, axis=0)
    spatial_feat = np.stack(spatial_feat_list, axis=0)
    feat_tensor = torch.from_numpy(feat) 
    spatial_feat_tensor = torch.from_numpy(spatial_feat)
    object_feat = torch.cat([feat_tensor, spatial_feat_tensor], dim=-1)

    return object_feat, object_mask, object_len_list


def object_select_v1(object_info, object_num):
    """
    select loaded object
    Args:
        object_info: dict: {'frame_idx': {'feat': ndarray, 'objects_conf': ndarray, 'objects_id': ndarray, 'bbox': ndarray, 'spatial_feature': ndarray},
                            'frame_idx': ...}
        object_num: int: total number of object to load
    Select rules:
        1. if objects in the first frame > object_num, just choose object_num objects in the first frame, else choose all objects in the frame
        2. the objects appear in the previous frame should be tracked in the current frame, if these tracked objects are in the nearly same 
        spatial position as in the previous frame, they should be deliminated, else hold on them
        3. if all tracked objects are selected, try to choose new objects in each frames
    """
    idxs = sorted(object_info.keys())
    new_object_info = defaultdict(dict) 
    for i, idx in enumerate(idxs):
        if i == 0:
            if len(object_info[idx]['objects_id']) > object_num:
                new_object_info[idx]['feat'] = object_info[idx]['feat'][:object_num]
                new_object_info[idx]['objects_conf'] = object_info[idx]['objects_conf'][:object_num]
                new_object_info[idx]['objects_id'] = object_info[idx]['objects_id'][:object_num]
                new_object_info[idx]['bbox'] = object_info[idx]['bbox'][:object_num]
                new_object_info[idx]['spatial_feature'] = object_info[idx]['spatial_feature'][:object_num]
                object_info[idx]['object_len'] = [object_num]
                new_object_info[idx]['object_len'] = [object_num]
            else:
                res = object_num - len(object_info[idx]['objects_id'])
                new_object_info[idx]['feat'] = np.pad(object_info[idx]['feat'], ((0, res), (0, 0)), 'edge')
                new_object_info[idx]['objects_conf'] = object_info[idx]['objects_conf']
                new_object_info[idx]['objects_id'] = object_info[idx]['objects_id']
                new_object_info[idx]['bbox'] = np.pad(object_info[idx]['bbox'], ((0, res), (0, 0)), 'edge')
                new_object_info[idx]['spatial_feature'] = np.pad(object_info[idx]['spatial_feature'], ((0, res), (0, 0)), 'edge')
                object_info[idx]['object_len'] = [len(object_info[idx]['objects_id'])]
                new_object_info[idx]['object_len'] = [len(object_info[idx]['objects_id'])]

        else:
            pre_idx = idxs[i - 1]
            pre_info = object_info[pre_idx]
            curr_info = object_info[idx]
            pre_feat = pre_info['feat'][:pre_info['object_len'][0]]
            pre_obj_ids = pre_info['objects_id']
            pre_bbox = pre_info['bbox']
            curr_feat = curr_info['feat']
            curr_obj_ids = curr_info['objects_id']
            curr_bbox = curr_info['bbox']
            maintain_idx = list(range(len(curr_obj_ids)))
            # remove redundant tracked objects
            # try to track objects in the pre frame
            sim = np.dot(pre_feat, curr_feat.T)/np.linalg.norm(pre_feat)/np.linalg.norm(curr_feat)
            track_candidates = np.argmax(sim, axis=1)

            for i in range(len(track_candidates)):
                if pre_obj_ids[i] == curr_obj_ids[track_candidates[i]]:
                    iou = compute_iou(list(pre_bbox[i]), list(curr_bbox[track_candidates[i]]))
                    if iou > 0.5:
                        maintain_idx[track_candidates[i]] = -1


            # remove redundant objects in current frame
            for i in range(len(curr_obj_ids) - 1):
                for j in range(i+1, len(curr_obj_ids)):
                    if curr_obj_ids[i] == curr_obj_ids[j]:
                        iou = compute_iou(list(curr_bbox[i]), list(curr_bbox[j]))
                        if iou > 0.5:
                            maintain_idx[j] = -1

            maintain_idx_set = set(maintain_idx)
            if -1 in maintain_idx_set:
                maintain_idx_set.remove(-1)
            maintain_idx = list(maintain_idx_set)
                    
            if len(maintain_idx) > object_num:
                maintain_idx = maintain_idx[:object_num]
                new_object_info[idx]['feat'] = object_info[idx]['feat'][maintain_idx]
                new_object_info[idx]['objects_id'] = object_info[idx]['objects_id'][maintain_idx]
                new_object_info[idx]['bbox'] = object_info[idx]['bbox'][maintain_idx]
                new_object_info[idx]['spatial_feature'] = object_info[idx]['spatial_feature'][maintain_idx]
                object_info[idx]['object_len'] = [len(maintain_idx)]
                new_object_info[idx]['object_len'] = [len(maintain_idx)]
            else:
                if len(maintain_idx) == 0:
                    maintain_idx = [random.choice(list(range(len(object_info[idx]['objects_id'])))) for _ in range(object_num)]
                res = object_num - len(maintain_idx)
                new_object_info[idx]['feat'] = np.pad(object_info[idx]['feat'][maintain_idx], ((0, res), (0, 0)), 'edge')
                new_object_info[idx]['objects_id'] = object_info[idx]['objects_id'][maintain_idx]
                new_object_info[idx]['bbox'] = np.pad(object_info[idx]['bbox'][maintain_idx], ((0, res), (0, 0)), 'edge')
                new_object_info[idx]['spatial_feature'] = np.pad(object_info[idx]['spatial_feature'][maintain_idx], ((0, res), (0, 0)), 'edge')
                object_info[idx]['object_len'] = [len(object_info[idx]['objects_id'])]
                new_object_info[idx]['object_len'] = [len(new_object_info[idx]['objects_id'])]

                # new_object_info[idx]['feat'] = object_info[idx]['feat'][maintain_idx]
                # print("ori_obj_id: ", object_info[idx]['objects_id'])
                # new_object_info[idx]['objects_id'] = object_info[idx]['objects_id'][maintain_idx]
                # new_object_info[idx]['bbox'] = object_info[idx]['bbox'][maintain_idx]
                # new_object_info[idx]['spatial_feature'] = object_info[idx]['spatial_feature'][maintain_idx]
                # object_info[idx]['object_len'] = [len(maintain_idx)]
                # new_object_info[idx]['object_len'] = [len(maintain_idx)]

    feat_list = [new_object_info[i]['feat'] for i in idxs]
    spatial_feat_list = [new_object_info[i]['spatial_feature'] for i in idxs]
    object_len_list = [new_object_info[i]['object_len'][0] for i in idxs]
    object_mask = np.zeros((len(idxs), object_num))
    for i, length in enumerate(object_len_list):
        object_mask[i, :length] = 1
    feat = np.stack(feat_list, axis=0)
    spatial_feat = np.stack(spatial_feat_list, axis=0)
    feat_tensor = torch.from_numpy(feat) 
    spatial_feat_tensor = torch.from_numpy(spatial_feat)
    object_feat = torch.cat([feat_tensor, spatial_feat_tensor], dim=-1)

    return object_feat, object_mask, object_len_list
    

def object_select(object_info, object_num):
    """
    select loaded object
    Args:
        object_info: dict: {'frame_idx': {'feat': ndarray, 'objects_conf': ndarray, 'objects_id': ndarray, 'bbox': ndarray, 'spatial_feature': ndarray},
                            'frame_idx': ...}
        object_num: int: total number of object to load
    Select rules:
        1. if objects in the first frame > object_num, just choose object_num objects in the first frame
        2. the objects appear in the first frame should be tracked in next several frames
        3. if all tracked objects are selected, try to choose new objects in each frames
    """
    object_set = set()
    idxs = sorted(object_info.keys())
    for i, idx in enumerate(idxs):
        if i == 0:
            if len(object_info[idx]['objects_id']) > object_num:
                object_info[idx]['feat'] = object_info[idx]['feat'][:object_num]
                object_info[idx]['objects_conf'] = object_info[idx]['objects_conf'][:object_num]
                object_info[idx]['objects_id'] = object_info[idx]['objects_id'][:object_num]
                object_info[idx]['bbox'] = object_info[idx]['bbox'][:object_num]
                object_info[idx]['spatial_feature'] = object_info[idx]['spatial_feature'][:object_num]
                object_info[idx]['object_len'] = object_num
            else:
                res = object_num - len(object_info[idx]['objects_id'])
                object_info[idx]['feat'] = np.pad(object_info[idx]['feat'], ((0, res), (0, 0)), 'edge')
                object_info[idx]['bbox'] = np.pad(object_info[idx]['bbox'], ((0, res), (0, 0)), 'edge')
                object_info[idx]['spatial_feature'] = np.pad(object_info[idx]['spatial_feature'], ((0, res), (0, 0)), 'edge')
                object_info[idx]['object_len'] = len(object_info[idx]['objects_id'])

            object_set.update(list(object_info[idx]['objects_id']))
        else:
            pre_idx = idxs[i - 1]
            pre_info = object_info[pre_idx]
            curr_info = object_info[idx]
            pre_feat = pre_info['feat'][:pre_info['object_len']]
            pre_obj_ids = pre_info['objects_id']
            curr_feat = curr_info['feat']
            curr_obj_ids = curr_info['objects_id']
            # try to track objects in the pre frame
            sim = np.dot(pre_feat, curr_feat.T)/np.linalg.norm(pre_feat)/np.linalg.norm(curr_feat)
            track_candidates = np.argmax(sim, axis=1)
            track_object_idx = []
            for i in range(len(track_candidates)):
                if pre_obj_ids[i] == curr_obj_ids[track_candidates[i]]:
                    track_object_idx.append(track_candidates[i])
            if len(track_object_idx) < object_num:
                diff_object_set = set(curr_obj_ids) - object_set
                supp_object_idx = [np.where(curr_obj_ids==i)[0] for i in diff_object_set]
                if len(supp_object_idx) > 0: 
                    supp_object_idx = np.concatenate(supp_object_idx, axis=0)
                    supp_object_idx = list(supp_object_idx)
                track_object_idx.extend(supp_object_idx)

            if len(track_object_idx) > object_num:
                track_object_idx = track_object_idx[:object_num]
                object_info[idx]['feat'] = object_info[idx]['feat'][track_object_idx]
                object_info[idx]['objects_id'] = object_info[idx]['objects_id'][track_object_idx]
                object_info[idx]['bbox'] = object_info[idx]['bbox'][track_object_idx]
                object_info[idx]['spatial_feature'] = object_info[idx]['spatial_feature'][track_object_idx]
                object_info[idx]['object_len'] = len(track_object_idx)
            else:
                if len(track_object_idx) == 0:
                    track_object_idx = [random.choice(list(range(len(object_info[idx]['objects_id'])))) for _ in range(10)]
                res = object_num - len(track_object_idx)
                object_info[idx]['feat'] = np.pad(object_info[idx]['feat'][track_object_idx], ((0, res), (0, 0)), 'edge')
                object_info[idx]['objects_id'] = object_info[idx]['objects_id'][track_object_idx]
                object_info[idx]['bbox'] = np.pad(object_info[idx]['bbox'][track_object_idx], ((0, res), (0, 0)), 'edge')
                object_info[idx]['spatial_feature'] = np.pad(object_info[idx]['spatial_feature'][track_object_idx], ((0, res), (0, 0)), 'edge')
                object_info[idx]['object_len'] = len(track_object_idx)

            object_set.update(list(object_info[idx]['objects_id']))
            
    feat_list = [object_info[i]['feat'] for i in idxs]
    spatial_feat_list = [object_info[i]['spatial_feature'] for i in idxs]
    object_len_list = [object_info[i]['object_len'] for i in idxs]
    object_mask = np.zeros((len(idxs), object_num))
    for i, length in enumerate(object_len_list):
        object_mask[i, :length] = 1
    feat = np.stack(feat_list, axis=0)
    spatial_feat = np.stack(spatial_feat_list, axis=0)
    feat_tensor = torch.from_numpy(feat) 
    spatial_feat_tensor = torch.from_numpy(spatial_feat)
    object_feat = torch.cat([feat_tensor, spatial_feat_tensor], dim=-1)

    return object_feat, object_mask, object_len_list

def read_all_object_from_disk(object_path, frame_idxs):
    """
    load all object info 
    Returns:
        dict: {'frame_idx': {'feat': ndarray, 'objects_conf': ndarray, 'objects_id': ndarray, 'bbox': ndarray, 'spatial_feature': ndarray},
               'frame_idx': ...}
    """
    object_info = {}
    for index in frame_idxs:
        # print("index is: {}".format(index))
        frame_object_info = {}
        try:
            frame1 = np.load(os.path.join(object_path, '{}.npz'.format(index)), allow_pickle=True)
            features = frame1['x']
            boxes = frame1['bbox']
            confident = frame1['info'].item()['objects_conf']
            object_ids = frame1['info'].item()['objects_id']
            confident_indices = np.argsort(confident)[::-1]
            confident = confident[confident_indices]
            boxes = boxes[confident_indices]
            features = features[confident_indices]
            object_ids = object_ids[confident_indices]

            image_width = frame1['info'].item()['image_w']
            image_height = frame1['info'].item()['image_h']
            box_width = boxes[:, 2] - boxes[:, 0]
            box_height = boxes[:, 3] - boxes[:, 1]
            scaled_width = box_width / image_width
            scaled_height = box_height / image_height
            scaled_x = boxes[:, 0] / image_width
            scaled_y = boxes[:, 1] / image_height
            scaled_width = scaled_width[..., np.newaxis]
            scaled_height = scaled_height[..., np.newaxis]
            scaled_x = scaled_x[..., np.newaxis]
            scaled_y = scaled_y[..., np.newaxis]
            spatial_features = np.concatenate(
                (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height),
                axis=1)

            frame_object_info['feat'] = features
            frame_object_info['objects_conf'] = confident
            frame_object_info['objects_id'] = object_ids
            frame_object_info['bbox'] = boxes
            frame_object_info['spatial_feature'] = spatial_features
        except OSError:
            # if not found or npz error， return full 1 matrix
            print("object is wrong or not existed in : {}".format(os.path.join(object_path, '{}.npz'.format(index))))
        # print(feat.size())
        object_info[index] = frame_object_info
    return object_info


