# Copyright (c) 2021 Copyright holder of the paper "Test-Time Adaptation to Distribution Shifts by Confidence Maximization and Input Transformation" submitted to NeurIPS 2021 for review
# All rights reserved.

"""Write a dataset to a h5 file."""
import argparse
import os

import numpy as np
import h5py
import tqdm
import yaml
import torch.utils.data

import sys
sys.path.append('..')
from data import get_datasets


def to_dataloader(dataset):
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=64,
        num_workers=12,
        shuffle=False,
        pin_memory=True,
    )


def dataset_to_h5(dataset, outpath):
    """Writes a given dataset to a .h5 file under `outpath`"""
    n_samples = len(dataset)
    f = h5py.File(outpath, mode="w")
    f.create_dataset("imgs", (n_samples, 224, 224, 3), np.uint8)
    f.create_dataset("labels", (n_samples,), int)

    loader = to_dataloader(dataset)
    i = 0
    for imgs, labels in tqdm.tqdm(loader):
        bs = imgs.shape[0]
        f["imgs"][i : i + bs] = imgs
        f["labels"][i : i + bs] = labels
        i += bs

    f.close()
    os.chmod(outpath, 0o777)
    with open("written_h5", "a") as f:
        f.write(f"{outpath}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=argparse.FileType("r"))
    parser.add_argument("outpath")
    args = parser.parse_args()
    config = yaml.load(args.config, Loader=yaml.SafeLoader)

    datasets = get_datasets(config)
    dataset_to_h5(list(datasets.values())[0], args.outpath)
