
import numpy as np
import gzip, pickle
import os, sys

if __name__ == '__main__':

    DATA_DIR = ''
    DATA_FILE = 'real_sph_mnist-no_rotate_train-cz=10-b=30-lmax=10-normalize=avg_sqrt_power-quad_weights=True'
    LABELS = [1, 7] # NB: only two!
    
    with gzip.open(os.path.join(DATA_DIR, DATA_FILE + '.gz'), 'rb') as f:
        dataset = pickle.load(f)

    new_dataset = {}
    for split in dataset:
        new_dataset[split] = {}
        for key in dataset[split]:
            mask = np.logical_or(dataset[split]['labels'] == LABELS[0], dataset[split]['labels'] == LABELS[1])
            new_dataset[split][key] = dataset[split][key][mask]
    
    with gzip.open(os.path.join(DATA_DIR, DATA_FILE + '-labels={}'.format(','.join(list(map(str, LABELS)))) + '.gz'), 'wb') as f:
        pickle.dump(new_dataset, f)
    
