from pkgutil import get_data
import pickle
import argparse
import torch
import path
import sys
folder_path= (path.Path(__file__).abspath()).parent.parent
sys.path.append(folder_path)

from data.pytorch_datasets import get_dataset

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='svhn', type=str)

args = parser.parse_args()
print(args)

trainset = get_dataset(args)[0]

if args.data == 'svhn':
    val_size = 20000
elif args.data == 'fmnist':
    val_size = 10000
else:
    raise NotImplementedError
train_size = len(trainset) - val_size

train_ds, val_ds = torch.utils.data.random_split(trainset, [train_size, val_size])

d = {}
d["train_ds"] = train_ds
d["val_ds"] = val_ds

pkl_path = f'dataset_{args.data}_split.pkl'
with open(pkl_path, 'wb') as f:
    pickle.dump(d, f)