import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
import json
import os
from torch.utils.data import DataLoader
from PIL import Image
import sys
import io
import pickle
from copy import deepcopy as dp
sys.path.append('')
from myutils.data_utils import (sample_appearance_indices,
                                VideoColorJitter,
                                IdentityTransform,
                                sample_train_layout_indices,
                                get_test_layout_indices,
                                fix_box)
from torchvision.transforms import (
    Compose,
    Normalize,
    RandomCrop,
    Resize,
    ToTensor,
)
from configs.configs import DataConfig
from torchvision.transforms import functional as TF
import math
from natsort import natsorted
import re


class MixAns2(Dataset):
    def __init__(self,name,sample_each_clip=16,node_nums=10,mapping_type=2,train=True):
        super().__init__()
        # print('mix dt:',mapping_type)
        self.sample_rate=sample_each_clip
        self.node_nums=node_nums
        
        self.name=name
        self.train=train
        self.mapping_type=mapping_type
        self.json=json.load(
            open(os.path.join("",name+'.json'),'r')
        )
        if mapping_type==1:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==0:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==2:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==3:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==4:
            self.mapping=json.load(
                open(os.path.join('','type_1','test1.json'),'r')
        )
        elif mapping_type==5:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
        )
        elif mapping_type==6:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
        )
        else:
            raise ModuleNotFoundError
        self.mask=json.load(
            open(os.path.join("",name+'_mask.json'),'r')
        )
        self.bbx=json.load(
            open(os.path.join("",name+'_bbx.json'),'r')
        )
        self.cls=json.load(
            open(os.path.join("",name+'_obj_cls.json'),'r')
        )
        self.rel=json.load(
            open(os.path.join("",name+'_rel.json'),'r')
        )
        # keys=list(self.json.keys())
        # self.keys=[item for item in keys if 'label' not in item]
        # self.label=[i for i in keys if 'label' in i]
        self.video_path=os.path.join('',name+'.hdf5')
        self.video2size=json.load(open('','r'))
        self.num_cls=157
        self.obj_cls_num=38
        self.rel_num=30
    
    def __len__(self):
        return len(self.mapping)
    
    def open_video(self):
        self.videos = h5py.File(
            self.video_path, 
            "r", libver="latest", swmr=True
        )
    #  person relation -> 0
    def __getitem__(self, idx: int):
        if not hasattr(self, "videos"):
            self.open_video()
        # key=self.keys[idx]
        tmp=self.mapping[idx]
        key=tmp['id']
        private_ans=tmp['private']
        common_ans=tmp['common']
        tokens=tmp['token']
        # key -> cls_video_id
        # frame feature videos[key][value[0~]]
        # bbx videos[key][value[0~]bbx]
        # mask list videos[key][value[0~]mask]
        frame_ids=self.json[key]
        indices = sample_appearance_indices(
            self.sample_rate, len(frame_ids),self.train 
        )
        video_size=self.video2size[key.split('.')[0]]
        frames=[torch.from_numpy(np.frombuffer(np.array(self.videos[key][frame_ids[index]]),dtype=np.float16)).reshape(1,11,512) for index in indices]

        bbx=torch.tensor([self.bbx[key][index] for index in indices],dtype=torch.float32)
        # mask=np.array([self.mask[key][index] for index in indices],dtype=np.int64)
        mask=torch.tensor([self.mask[key][index] for index in indices],dtype=torch.long)
        # for gpnn
        mask_=~mask.bool()
        mask=mask[:,1:]
        mask=torch.cat([mask,mask],dim=-1).unsqueeze(-1)


        cls_ids=torch.tensor([self.cls[key][index] for index in indices],dtype=torch.long)
        rel_ids=[self.rel[key][index] for index in indices]

        
        label_name=key+'_label'
        label_idx=self.json[label_name]
        label_idx=[int(x) for x in label_idx]
        label=torch.zeros(self.num_cls,dtype=torch.float32)
        

        private_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        common_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        token_tensor=torch.zeros(self.num_cls,dtype=torch.float32)


        frames=torch.concat(frames,dim=0).float()
        label[label_idx]=1.0
        
        private_label[private_ans]=1.0
        common_label[common_ans]=1.0
        token_tensor[tokens]=1.0
        # token_tensor=torch.tensor(tokens,dtype=torch.long)

        bbx[:,:,0]/=video_size[0]
        bbx[:,:,1]/=video_size[1]
        bbx[:,:,2]/=video_size[0]
        bbx[:,:,3]/=video_size[1]
        zero_tensor=torch.tensor([[0.,0.,1.,1.]]).to(bbx)
        zero_tensor=zero_tensor.unsqueeze(0).repeat(16,1,1)
        bbx=torch.cat([zero_tensor,bbx],dim=-2)
        # mask=torch.concat(mask_,dim=0).long()

        # mask=torch.concat(mask,dim=0).long()
        # mask_tensor_expanded = mask.bool().unsqueeze(-1).expand(-1, -1, 512)
        # frames[~mask_tensor_expanded]=0.
        rel=torch.zeros((self.sample_rate,self.node_nums,self.rel_num),dtype=torch.float32)
        cls_cls=torch.zeros((self.sample_rate,self.node_nums,self.obj_cls_num),dtype=torch.float32)
        cls_cls.scatter_(2,cls_ids.unsqueeze(-1),1.)
        for i in range(self.sample_rate):
            for j in range(self.node_nums):
                rel[i][j][rel_ids[i][j]]=1.
        # breakpoint()
        # cls=torch.zeros(self.obj_cls_num,dtype=torch.float32)
        # rel=torch.zeros(self.rel_num,dtype=torch.float32)
        return frames,bbx,mask,label,cls_ids,cls_cls,rel[:,1:,:],private_label,common_label,token_tensor,mask_

# grid
class MixAns3(Dataset):
    def __init__(self,name,sample_each_clip=16,node_nums=10,mapping_type=1,test_i=1,test_j=1,train=True):
        super().__init__()
        # print('mix dt:',mapping_type)
        self.sample_rate=sample_each_clip
        self.node_nums=node_nums
        self.name=name
        self.train=train
        self.mapping_type=mapping_type
        self.json=json.load(
            open(os.path.join("",name+'.json'),'r')
        )
        if mapping_type==1:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==2:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==3:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==4:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==5:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==7:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        else:
            raise ModuleNotFoundError
        self.mask=json.load(
            open(os.path.join("",name+'_mask.json'),'r')
        )
        self.bbx=json.load(
            open(os.path.join("",name+'_bbx.json'),'r')
        )
        self.cls=json.load(
            open(os.path.join("",name+'_obj_cls.json'),'r')
        )
        self.rel=json.load(
            open(os.path.join("",name+'_rel.json'),'r')
        )
        # keys=list(self.json.keys())
        # self.keys=[item for item in keys if 'label' not in item]
        # self.label=[i for i in keys if 'label' in i]
        self.video_path=os.path.join('',name+'.hdf5')
        self.video2size=json.load(open('','r'))
        self.num_cls=157
        self.obj_cls_num=38
        self.rel_num=30
    
    def __len__(self):
        return len(self.mapping)
    
    def open_video(self):
        self.videos = h5py.File(
            self.video_path, 
            "r", libver="latest", swmr=True
        )
    #  person relation -> 0
    def __getitem__(self, idx: int):
        if not hasattr(self, "videos"):
            self.open_video()
        # key=self.keys[idx]
        tmp=self.mapping[idx]
        key=tmp['id']
        private_ans=tmp['private']
        common_ans=tmp['common']
        tokens=tmp['token']
        # key -> cls_video_id
        # frame feature videos[key][value[0~]]
        # bbx videos[key][value[0~]bbx]
        # mask list videos[key][value[0~]mask]
        frame_ids=self.json[key]
        indices = sample_appearance_indices(
            self.sample_rate, len(frame_ids),self.train 
        )
        video_size=self.video2size[key.split('.')[0]]
        frames=[torch.from_numpy(np.frombuffer(np.array(self.videos[key][frame_ids[index]]),dtype=np.float16)).reshape(1,11,512) for index in indices]

        bbx=torch.tensor([self.bbx[key][index] for index in indices],dtype=torch.float32)
        # mask=np.array([self.mask[key][index] for index in indices],dtype=np.int64)
        mask=torch.tensor([self.mask[key][index] for index in indices],dtype=torch.long)
        cls_ids=torch.tensor([self.cls[key][index] for index in indices],dtype=torch.long)
        rel_ids=[self.rel[key][index] for index in indices]

        
        label_name=key+'_label'
        label_idx=self.json[label_name]
        label_idx=[int(x) for x in label_idx]
        label=torch.zeros(self.num_cls,dtype=torch.float32)
        mask_=torch.zeros(self.num_cls+1,dtype=torch.long)

        private_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        common_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        token_tensor=torch.zeros(self.num_cls,dtype=torch.float32)


        frames=torch.concat(frames,dim=0).float()
        label[label_idx]=1.0
        mask_[label_idx]=1
        private_label[private_ans]=1.0
        common_label[common_ans]=1.0
        token_tensor[tokens]=1.0
        # token_tensor=torch.tensor(tokens,dtype=torch.long)




        bbx[:,:,0]/=video_size[0]
        bbx[:,:,1]/=video_size[1]
        bbx[:,:,2]/=video_size[0]
        bbx[:,:,3]/=video_size[1]
        zero_tensor=torch.tensor([[0.,0.,1.,1.]]).to(bbx)
        zero_tensor=zero_tensor.unsqueeze(0).repeat(16,1,1)
        bbx=torch.cat([zero_tensor,bbx],dim=-2)
        # mask=torch.concat(mask_,dim=0).long()

        # mask=torch.concat(mask,dim=0).long()
        # mask_tensor_expanded = mask.bool().unsqueeze(-1).expand(-1, -1, 512)
        # frames[~mask_tensor_expanded]=0.
        rel=torch.zeros((self.sample_rate,self.node_nums,self.rel_num),dtype=torch.float32)
        cls_cls=torch.zeros((self.sample_rate,self.node_nums,self.obj_cls_num),dtype=torch.float32)
        cls_cls.scatter_(2,cls_ids.unsqueeze(-1),1.)
        for i in range(self.sample_rate):
            for j in range(self.node_nums):
                rel[i][j][rel_ids[i][j]]=1.
        # breakpoint()
        # cls=torch.zeros(self.obj_cls_num,dtype=torch.float32)
        # rel=torch.zeros(self.rel_num,dtype=torch.float32)
        return frames,bbx,mask,label,cls_ids,cls_cls,rel[:,1:,:],private_label,common_label,token_tensor,mask_

# inference dataset
# 1 all right,2 all wrong,3 padding, 4 (1,2,3)dataset
class InfDataset(Dataset):
    def __init__(self,name,sample_each_clip=16,node_nums=10,mapping_type=1,train=True):
        super().__init__()
        # print('mix dt:',mapping_type)
        self.sample_rate=sample_each_clip
        self.node_nums=node_nums
        self.name=name
        self.train=train
        self.mapping_type=mapping_type
        self.json=json.load(
            open(os.path.join("",name+'.json'),'r')
        )
        if mapping_type==1:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==2:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==3:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==4:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
        )
        elif mapping_type==5:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        else:
            raise ModuleNotFoundError
        self.mask=json.load(
            open(os.path.join("",name+'_mask.json'),'r')
        )
        self.bbx=json.load(
            open(os.path.join("",name+'_bbx.json'),'r')
        )
        self.cls=json.load(
            open(os.path.join("",name+'_obj_cls.json'),'r')
        )
        self.rel=json.load(
            open(os.path.join("",name+'_rel.json'),'r')
        )
        # keys=list(self.json.keys())
        # self.keys=[item for item in keys if 'label' not in item]
        # self.label=[i for i in keys if 'label' in i]
        self.video_path=os.path.join('',name+'.hdf5')
        self.video2size=json.load(open('','r'))
        self.num_cls=157
        self.obj_cls_num=38
        self.rel_num=30
    
    def __len__(self):
        # return 64*4
        return len(self.mapping)
    
    def open_video(self):
        self.videos = h5py.File(
            self.video_path, 
            "r", libver="latest", swmr=True
        )
    #  person relation -> 0
    def __getitem__(self, idx: int):

        if not hasattr(self, "videos"):
            self.open_video()
        # key=self.keys[idx]
        tmp=self.mapping[idx]
        key=tmp['id']
        private_ans=tmp['private']
        common_ans=tmp['common']
        tokens=tmp['token']
        # key -> cls_video_id
        # frame feature videos[key][value[0~]]
        # bbx videos[key][value[0~]bbx]
        # mask list videos[key][value[0~]mask]
        frame_ids=self.json[key]
        indices = sample_appearance_indices(
            self.sample_rate, len(frame_ids),self.train 
        )
        video_size=self.video2size[key.split('.')[0]]
        frames=[torch.from_numpy(np.frombuffer(np.array(self.videos[key][frame_ids[index]]),dtype=np.float16)).reshape(1,11,512) for index in indices]

        bbx=torch.tensor([self.bbx[key][index] for index in indices],dtype=torch.float32)
        # mask=np.array([self.mask[key][index] for index in indices],dtype=np.int64)
        mask=torch.tensor([self.mask[key][index] for index in indices],dtype=torch.long)
        cls_ids=torch.tensor([self.cls[key][index] for index in indices],dtype=torch.long)
        rel_ids=[self.rel[key][index] for index in indices]

        
        label_name=key+'_label'
        label_idx=self.json[label_name]
        label_idx=[int(x) for x in label_idx]
        label=torch.zeros(self.num_cls,dtype=torch.float32)
        mask_=torch.zeros(self.num_cls+1,dtype=torch.long)

        private_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        common_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        token_tensor=torch.zeros(self.num_cls,dtype=torch.float32)


        frames=torch.concat(frames,dim=0).float()
        label[label_idx]=1.0
        mask_[label_idx]=1
        private_label[private_ans]=1.0
        common_label[common_ans]=1.0
        token_tensor[tokens]=1.0
        # token_tensor=torch.tensor(tokens,dtype=torch.long)
        bbx[:,:,0]/=video_size[0]
        bbx[:,:,1]/=video_size[1]
        bbx[:,:,2]/=video_size[0]
        bbx[:,:,3]/=video_size[1]
        zero_tensor=torch.tensor([[0.,0.,1.,1.]]).to(bbx)
        zero_tensor=zero_tensor.unsqueeze(0).repeat(16,1,1)
        bbx=torch.cat([zero_tensor,bbx],dim=-2)
        # mask=torch.concat(mask_,dim=0).long()

        # mask=torch.concat(mask,dim=0).long()
        # mask_tensor_expanded = mask.bool().unsqueeze(-1).expand(-1, -1, 512)
        # frames[~mask_tensor_expanded]=0.
        rel=torch.zeros((self.sample_rate,self.node_nums,self.rel_num),dtype=torch.float32)
        cls_cls=torch.zeros((self.sample_rate,self.node_nums,self.obj_cls_num),dtype=torch.float32)
        cls_cls.scatter_(2,cls_ids.unsqueeze(-1),1.)
        for i in range(self.sample_rate):
            for j in range(self.node_nums):
                rel[i][j][rel_ids[i][j]]=1.
        # breakpoint()
        # cls=torch.zeros(self.obj_cls_num,dtype=torch.float32)
        # rel=torch.zeros(self.rel_num,dtype=torch.float32)
        return frames,bbx,mask,label,cls_ids,cls_cls,rel[:,1:,:],private_label,common_label,token_tensor,mask_,torch.tensor(idx)


class VisualizeDataset(Dataset):
    def __init__(self,name,sample_each_clip=16,node_nums=10,mapping_type=1,test_i=1,test_j=1):
        super().__init__()
        # print('mix dt:',mapping_type)
        self.sample_rate=sample_each_clip
        self.node_nums=node_nums
        self.name=name
        self.mapping_type=mapping_type
        self.json=json.load(
            open(os.path.join("",name+'.json'),'r')
        )
        if mapping_type==1:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==2:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==3:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==4:
            self.mapping=json.load(
                open(os.path.join('',str(test_i),str(test_j),name+'.json'))
            )
        elif mapping_type==5:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        else:
            raise ModuleNotFoundError
        self.mask=json.load(
            open(os.path.join("",name+'_mask.json'),'r')
        )
        self.bbx=json.load(
            open(os.path.join("",name+'_bbx.json'),'r')
        )
        self.cls=json.load(
            open(os.path.join("",name+'_obj_cls.json'),'r')
        )
        self.rel=json.load(
            open(os.path.join("",name+'_rel.json'),'r')
        )
        # keys=list(self.json.keys())
        # self.keys=[item for item in keys if 'label' not in item]
        # self.label=[i for i in keys if 'label' in i]
        self.video_path=os.path.join('',name+'.hdf5')
        self.video2size=json.load(open('','r'))
        self.num_cls=157
        self.obj_cls_num=38
        self.rel_num=30
    
    def __len__(self):
        return len(self.mapping)
    
    def open_video(self):
        self.videos = h5py.File(
            self.video_path, 
            "r", libver="latest", swmr=True
        )
    #  person relation -> 0
    def __getitem__(self, idx: int):
        if not hasattr(self, "videos"):
            self.open_video()
        # key=self.keys[idx]
        tmp=self.mapping[idx]
        key=tmp['id']
        private_ans=tmp['private']
        common_ans=tmp['common']
        tokens=tmp['token']
        # key -> cls_video_id

        frame_ids=self.json[key]
        indices = sample_appearance_indices(
            self.sample_rate, len(frame_ids),False 
        )
        video_size=self.video2size[key.split('.')[0]]
        frames=[torch.from_numpy(np.frombuffer(np.array(self.videos[key][frame_ids[index]]),dtype=np.float16)).reshape(1,11,512) for index in indices]

        bbx=torch.tensor([self.bbx[key][index] for index in indices],dtype=torch.float32)
        # mask=np.array([self.mask[key][index] for index in indices],dtype=np.int64)
        
        cls_ids=torch.tensor([self.cls[key][index] for index in indices],dtype=torch.long)
        rel_ids=[self.rel[key][index] for index in indices]
        mask=torch.tensor([self.mask[key][index] for index in indices],dtype=torch.long)
        # for gpnn
        mask_=~mask.bool()
        mask=mask[:,1:]
        mask=torch.cat([mask,mask],dim=-1).unsqueeze(-1)

        
        label_name=key+'_label'
        label_idx=self.json[label_name]
        label_idx=[int(x) for x in label_idx]
        label=torch.zeros(self.num_cls,dtype=torch.float32)


        private_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        common_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        token_tensor=torch.zeros(self.num_cls,dtype=torch.float32)


        frames=torch.concat(frames,dim=0).float()
        label[label_idx]=1.0

        private_label[private_ans]=1.0
        common_label[common_ans]=1.0
        token_tensor[tokens]=1.0


        bbx_=dp(bbx)
        bbx[:,:,0]/=video_size[0]
        bbx[:,:,1]/=video_size[1]
        bbx[:,:,2]/=video_size[0]
        bbx[:,:,3]/=video_size[1]
        zero_tensor=torch.tensor([[0.,0.,1.,1.]]).to(bbx)
        zero_tensor=zero_tensor.unsqueeze(0).repeat(16,1,1)
        bbx=torch.cat([zero_tensor,bbx],dim=-2)

        rel=torch.zeros((self.sample_rate,self.node_nums,self.rel_num),dtype=torch.float32)
        cls_cls=torch.zeros((self.sample_rate,self.node_nums,self.obj_cls_num),dtype=torch.float32)
        cls_cls.scatter_(2,cls_ids.unsqueeze(-1),1.)
        for i in range(self.sample_rate):
            for j in range(self.node_nums):
                rel[i][j][rel_ids[i][j]]=1.

        return frames,bbx,bbx_,key,indices,\
            label,cls_ids,cls_cls,rel[:,1:,:],private_label,torch.tensor(common_ans,dtype=torch.long),token_tensor,mask,mask_

class MixLocal(Dataset):
    def __init__(self,name,sample_each_clip=16,node_nums=10,mapping_type=2,train=True):
        super().__init__()
        # print('mix dt:',mapping_type)
        self.sample_rate=sample_each_clip
        self.node_nums=node_nums
        
        self.name=name
        self.train=train
        self.mapping_type=mapping_type
        self.json=json.load(
            open(os.path.join("",name+'.json'),'r')
        )
        if mapping_type==1:
            self.mapping=json.load(
                open('')
            )
        elif mapping_type==0:
            self.mapping=json.load(
                open('')
            )
        elif mapping_type==2:
            self.mapping=json.load(
                open('')
            )
        elif mapping_type==3:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
            )
        elif mapping_type==4:
            self.mapping=json.load(
                open(os.path.join('','type_1','test1.json'),'r')
        )
        elif mapping_type==5:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
        )
        elif mapping_type==6:
            self.mapping=json.load(
                open(os.path.join('',name+'.json'))
        )
        else:
            raise ModuleNotFoundError
        self.mask=json.load(
            open(os.path.join("",name+'_mask.json'),'r')
        )
        self.bbx=json.load(
            open(os.path.join("",name+'_bbx.json'),'r')
        )
        self.cls=json.load(
            open(os.path.join("",name+'_obj_cls.json'),'r')
        )
        self.rel=json.load(
            open(os.path.join("",name+'_rel.json'),'r')
        )
        self.ioufile=json.load(open('','r'))
        # keys=list(self.json.keys())
        # self.keys=[item for item in keys if 'label' not in item]
        # self.label=[i for i in keys if 'label' in i]
        self.video_path=os.path.join('',name+'.hdf5')
        self.video2size=json.load(open('','r'))
        self.num_cls=157
        self.obj_cls_num=38
        self.rel_num=30
    
    def __len__(self):
        return len(self.mapping)
    
    def open_video(self):
        self.videos = h5py.File(
            self.video_path, 
            "r", libver="latest", swmr=True
        )
    #  person relation -> 0
    def __getitem__(self, idx: int):
        if not hasattr(self, "videos"):
            self.open_video()
        # key=self.keys[idx]
        tmp=self.mapping[idx]
        key=tmp['id']
        private_ans=tmp['private']
        common_ans=tmp['common']
        tokens=tmp['token']
        # key -> cls_video_id
        # frame feature videos[key][value[0~]]
        # bbx videos[key][value[0~]bbx]
        # mask list videos[key][value[0~]mask]
        frame_ids=self.json[key]
        indices = sample_appearance_indices(
            self.sample_rate, len(frame_ids),self.train 
        )
        video_size=self.video2size[key.split('.')[0]]
        frames=[torch.from_numpy(np.frombuffer(np.array(self.videos[key][frame_ids[index]]),dtype=np.float16)).reshape(1,11,512) for index in indices]

        bbx=torch.tensor([self.bbx[key][index] for index in indices],dtype=torch.float32)
        # mask=np.array([self.mask[key][index] for index in indices],dtype=np.int64)
        mask=torch.tensor([self.mask[key][index] for index in indices],dtype=torch.long)
        # for gpnn
        mask_=~mask.bool()
        mask=mask[:,1:]
        mask=torch.cat([mask,mask],dim=-1).unsqueeze(-1)


        cls_ids=torch.tensor([self.cls[key][index] for index in indices],dtype=torch.long)
        rel_ids=[self.rel[key][index] for index in indices]

        frame_flag=torch.zeros(1,dtype=torch.long)
        label_name=key+'_label'
        label_idx=self.json[label_name]
        label_idx=[int(x) for x in label_idx]
        label=torch.zeros(self.num_cls,dtype=torch.float32)
        frame_ans=torch.zeros(16,dtype=torch.long)
        frame_dict=self.ioufile[key]
        frame_ids_local=[]
        for l in common_ans:
            if frame_dict.get(str(l)) is not None:
                frame_ids_local.extend(frame_dict[str(l)])
        frame_ids_local=list(set(frame_ids_local))
        if len(frame_ids_local)==0:
            frame_flag[0]=0
        else:
            frame_flag[0]=1
        frame_ans[frame_ids_local]=1
        private_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        common_label=torch.zeros(self.num_cls+1,dtype=torch.float32)
        token_tensor=torch.zeros(self.num_cls,dtype=torch.float32)


        frames=torch.concat(frames,dim=0).float()
        label[label_idx]=1.0
        
        private_label[private_ans]=1.0
        common_label[common_ans]=1.0
        token_tensor[tokens]=1.0
        # token_tensor=torch.tensor(tokens,dtype=torch.long)

        bbx[:,:,0]/=video_size[0]
        bbx[:,:,1]/=video_size[1]
        bbx[:,:,2]/=video_size[0]
        bbx[:,:,3]/=video_size[1]
        zero_tensor=torch.tensor([[0.,0.,1.,1.]]).to(bbx)
        zero_tensor=zero_tensor.unsqueeze(0).repeat(16,1,1)
        bbx=torch.cat([zero_tensor,bbx],dim=-2)
        # mask=torch.concat(mask_,dim=0).long()

        # mask=torch.concat(mask,dim=0).long()
        # mask_tensor_expanded = mask.bool().unsqueeze(-1).expand(-1, -1, 512)
        # frames[~mask_tensor_expanded]=0.
        rel=torch.zeros((self.sample_rate,self.node_nums,self.rel_num),dtype=torch.float32)
        cls_cls=torch.zeros((self.sample_rate,self.node_nums,self.obj_cls_num),dtype=torch.float32)
        cls_cls.scatter_(2,cls_ids.unsqueeze(-1),1.)
        for i in range(self.sample_rate):
            for j in range(self.node_nums):
                rel[i][j][rel_ids[i][j]]=1.
        # breakpoint()
        # cls=torch.zeros(self.obj_cls_num,dtype=torch.float32)
        # rel=torch.zeros(self.rel_num,dtype=torch.float32)
        return frames,bbx,mask,label,cls_ids,cls_cls,rel[:,1:,:],private_label,torch.tensor([common_ans],dtype=torch.long),token_tensor,mask_,frame_ans,frame_flag
