import os
import os.path as osp
import numpy as np
from rekognition_online_action_detection.utils.parser import load_cfg
all_class_name = ['BaseballPitch',
                  'BasketballDunk',
                  'Billiards',
                  'CleanAndJerk',
                  'CliffDiving',
                  'CricketBowling',
                  'CricketShot',
                  'Diving',
                  'FrisbeeCatch',
                  'GolfSwing',
                  'HammerThrow',
                  'HighJump',
                  'JavelinThrow',
                  'LongJump',
                  'PoleVault',
                  'Shotput',
                  'SoccerPenalty',
                  'TennisSwing',
                  'ThrowDiscus',
                  'VolleyballSpiking']

def main(args,phase="train"):
    args.new_data_root="data/THUMOS_only_replace_train"
    args.data_root="data/THUMOS"
    num_tasks=args.num_tasks
    num_frames=[]
    for i in range(20):
        num_frames.append([])
    for i in range(20) :
        root=osp.join(args.new_data_root,all_class_name[i])
        os.mkdir(root)
    save_name=".npy"
    sessions = getattr(args.DATA, phase.upper() + '_SESSION_SET') 
    args.training = phase == 'train'
    num_videos = np.ones([20])
    for session in sessions:
        target = np.load(osp.join(args.data_root, 'target_perframe', session + '.npy')) 
        camera_inputs = np.load(
            osp.join(args.data_root, args.INPUT.VISUAL_FEATURE, session + '.npy'), mmap_mode='r') 
        motion_inputs = np.load(
            osp.join(args.data_root, args.INPUT.MOTION_FEATURE, session + '.npy'), mmap_mode='r') 
        
        height,width=target.shape
        target_jdg=np.copy(target)
        for i in range(height): 
            target_jdg[i][0]=0
            target_jdg[i][21]=0
        sum_target_jdg=target_jdg.sum(axis=0)
        action_classes=np.count_nonzero(sum_target_jdg)
        if action_classes==1:
            for i in range(1,21):
                if sum_target_jdg[i]>0:
                    np.save(osp.join(args.new_data_root,all_class_name[i-1],all_class_name[i-1]+"_num_"+
                                     str(int(num_videos[int(i-1)])) + "_rgb" + save_name), camera_inputs)
                    np.save(osp.join(args.new_data_root,all_class_name[i-1],all_class_name[i-1]+"_num_"+
                                     str(int(num_videos[int(i-1)])) + "_flow" + save_name), motion_inputs)
                    np.save(osp.join(args.new_data_root,all_class_name[i-1],all_class_name[i-1]+"_num_"+
                                     str(int(num_videos[int(i-1)])) + "_target" + save_name), target)
                    num_videos[int(i-1)] = num_videos[int(i-1)] + 1
                    num_frames[i-1].append(height)
        else:
            action_cls=np.nonzero(sum_target_jdg)[0]
            for i in range(action_cls.shape[0]):
                cls=action_cls[i]
                store_target=np.copy(target)
                for frame in range(height):
                    for j in range(1,21):
                        if store_target[frame][j]==1 and j!=cls:
                            store_target[frame][j]=0
                    if np.count_nonzero(store_target[frame])==0:
                        store_target[frame][0]=1
                np.save(osp.join(args.new_data_root, all_class_name[cls-1], all_class_name[cls-1] + "_num_" +
                                 str(int(num_videos[int(cls-1)])) + "_rgb" + save_name), camera_inputs)
                np.save(osp.join(args.new_data_root, all_class_name[cls-1], all_class_name[cls-1] + "_num_" +
                                 str(int(num_videos[int(cls-1)])) + "_flow" + save_name), motion_inputs)
                np.save(osp.join(args.new_data_root, all_class_name[cls-1], all_class_name[cls-1] + "_num_" +
                                 str(int(num_videos[int(cls-1)])) + "_target" + save_name), store_target)
                num_videos[int(cls-1)] = num_videos[int(cls-1)] + 1
                num_frames[cls-1].append(height)
    
    
    
if __name__ == '__main__':
        main(load_cfg())