# Split dataset
import os
from shutil import copyfile
import random
random.seed(0)



num_action = 6
test_per_action = 30
# base_dir = os.path.join('..', 'data', 'mug64')
base_dir = '../processed_64'
split_dir = 'mug64_split'


if __name__=='__main__':
    if not os.path.exists(split_dir):
        os.mkdir(split_dir)
        os.mkdir(os.path.join(split_dir, 'test'))
        os.mkdir(os.path.join(split_dir, 'train'))

    action_dir = os.listdir(base_dir)
    action_dir.sort()
    for ad in action_dir:
        ad_dir = os.path.join(base_dir,ad)
        if os.path.isdir(ad_dir):
            videos = os.listdir(ad_dir)
            print('Action {} contains {} videos'.format(ad, len(videos))) 
            test = random.sample(videos, test_per_action)
            train = []
            for v in videos:
                if v not in test:
                    train.append(v)
            test_dir = os.path.join(split_dir, 'test', ad)
            train_dir = os.path.join(split_dir, 'train', ad)

            if not os.path.exists(test_dir):
                os.mkdir(test_dir)
            if not os.path.exists(train_dir):
                os.mkdir(train_dir)
            
            for v in test:
                copyfile(os.path.join(ad_dir, v), os.path.join(test_dir, v))
            for v in train:
                copyfile(os.path.join(ad_dir, v), os.path.join(train_dir, v))
