import os

import h5py  # type: ignore
import numpy as np  # type: ignore
import torchmeta  # type: ignore

ROOT = "/home/datasets/"


def main() -> None:
    # make sure the dataset is downlaoded first
    _ = torchmeta.datasets.miniimagenet.MiniImagenet(ROOT, download=True, meta_train=True, num_classes_per_task=5)

    # find the right normalization values and save them
    filepath = os.path.join(ROOT, "miniimagenet", "train_data.hdf5")
    datasets = h5py.File(filepath)["datasets"]

    x = []
    for ds in datasets:
        x.append(datasets[ds])

    x = np.stack(x)

    mu = np.mean(x, axis=(0, 1, 2, 3))
    np.savetxt(os.path.join(ROOT, "miniimagenet", "mu.txt"), mu)

    std = np.std(x, axis=(0, 1, 2, 3))
    np.savetxt(os.path.join(ROOT, "miniimagenet", "std.txt"), std)
    print("DONE")


if __name__ == "__main__":
    main()
