import os

import torch

game = "BeamRider"

gameids = os.listdir(f"./datasets/{game}")
gameids = list(map(lambda x: int(x), gameids))
train_gids = gameids[:-100]
dev_gids = gameids[-100:]

# save 
torch.save({
    'train_gids': train_gids,
    'dev_gids': dev_gids
}, f'data_objects/{game}_data_split.tar')

assert len(dev_gids) == 100
assert len(train_gids) + len(dev_gids) == len(gameids)