import sys
sys.path.append(".")

from dataset import create_dataloader, create_ssl_loader
from tqdm import tqdm
import h5py
import numpy as np

# dataset = "SumMe"
dataset = "TVSum"

for split_id,(train_loader,test_loader) in enumerate(create_dataloader(dataset)):
    for feature,gtscore,dataset_name,video_num in train_loader:
        """
        video_num: 在h5文件中的视频id索引, 比如video_10
        feature: (1,3,T,D) GoogleNet 提取的特征,T为采样后的frame数值,D为输出纬度,默认是1024
        gtscore: (T) 每个frame重要性得分,取值在0到1之间,gtscore.sum()可大于1
        """
        print(dataset)
        print("feature",feature.shape)
        with h5py.File(f'./data/eccv16_dataset_{dataset_name.lower()}_google_pool5.h5','r') as hdf:
            """
            user_summary: (num_user,n_frames)
            all_shot_bound: (2,-1) 对不同shot的分割,[[start_id,end_id],...]
            all_positions: (T) 采样的frame对应完整n_frames的位置
            """
            user_summary = np.array(hdf[video_num]['user_summary']) 
            all_shot_bound = np.array(hdf[f"{video_num}/change_points"])
            n_frames = np.array(hdf[f"{video_num}/n_frames"])
            all_positions = np.array(hdf[f"{video_num}/picks"])
            print("user_summary",user_summary.shape,user_summary.max(),user_summary.mean(),user_summary.sum())
            print("all_shot_bound",all_shot_bound.shape,all_shot_bound.max())
            print("n_frames",n_frames)
            print("all_positions",all_positions.shape)
        break
    break
