processed_data_dir = r"./mnist/processed/"

digits_of_interest = [0, 1, 3, 6]

# simply extract data and format it like below
if __name__ == "__main__":
    import os

    import torch
    import torch.nn as nn

    from data_utils import MNIST

    save_dir = os.path.join(processed_data_dir, f"digits{digits_of_interest}")
    os.makedirs(save_dir)

    dataset = MNIST(
        root="./mnist",
        train_valid_split_ratio=[0.9, 0.1],
        digits_of_interest=digits_of_interest,
        n_test_samples=3000,
        n_train_samples=5000,
    )

    dataflow = dict()
    for split in dataset:
        sampler = torch.utils.data.RandomSampler(dataset[split])
        dataflow[split] = torch.utils.data.DataLoader(
            dataset[split],
            batch_size=len(dataset[split]),
            sampler=sampler,
            num_workers=0,
        )

    train_data = next(iter(dataflow["train"]))
    val_data = next(iter(dataflow["valid"]))
    test_data = next(iter(dataflow["test"]))

    data = {
        "train": {
            "images": train_data["image"],
            "digits": train_data["digit"],
            "encoder_params": [torch.nan for _ in range(len(train_data["digit"]))],
        },
        "valid": {
            "images": val_data["image"],
            "digits": val_data["digit"],
            "encoder_params": [torch.nan for _ in range(len(val_data["digit"]))],
        },
        "test": {
            "images": test_data["image"],
            "digits": test_data["digit"],
            "encoder_params": [torch.nan for _ in range(len(test_data["digit"]))],
        },
    }  # only 25M, acceptable

    for split in data.keys():
        print(f"Saving {split}_set to disk")
        torch.save(data[split], os.path.join(save_dir, f"mnist_{split}.pt"))
