import json
import yaml
from shutil import copyfile
from pathlib import Path
from argparse import ArgumentParser
import numpy as np

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("-o", "--opts", type=str, help="path to an opts yaml file")
    parser.add_argument(
        "-w", "--write_dir", type=str, help="Where to write mini-dataset"
    )
    parser.add_argument(
        "-n", "--num_samples", type=int, default=4, help="Number of samples"
    )
    parser.add_argument(
        "-s", "--seed", type=int, default=1234, help="Numpy seed for the permutation"
    )
    args = parser.parse_args()

    opts_path = Path(args.opts)
    write_dir = Path(args.write_dir)
    n = args.num_samples
    seed = args.seed

    assert n > 0
    np.random.seed(seed)

    write_dir.mkdir(exist_ok=True, parents=True)

    with opts_path.open("r") as f:
        opts = yaml.safe_load(f)
    files = opts["data"]["files"]

    minijsons = {}
    for k, v in files.items():
        if k == "base":
            continue
        minijsons[k] = {}
        for domain, name in v.items():
            if domain == "kitti":
                continue
            with open(str(Path(files["base"]) / name), "r") as f:
                data = json.load(f)
                perm = np.random.permutation(len(data))[:n]
                minijsons[k][domain] = [data[i] for i in perm]

    iclrjsons = {}
    for k, v in minijsons.items():
        iclrjsons[k] = {}
        for domain, data in v.items():
            iclrjsons[k][domain] = []
            for s, sample in enumerate(data):
                iclrsample = {}
                for task, filepath in sample.items():
                    iclrsample[task] = f"data/{k}/{domain}/{task}/{Path(filepath).name}"
                    dirpath = write_dir / "data" / k / domain / task
                    dirpath.mkdir(exist_ok=True, parents=True)
                    copyfile(filepath, dirpath / Path(filepath).name)
                    print(f"{k}, {domain}, {s}, {task}", end="\r")
                iclrjsons[k][domain].append(iclrsample)

    (write_dir / "jsons").mkdir(exist_ok=True)
    for k, v in iclrjsons.items():
        for domain in v:
            with open(write_dir / "jsons" / f"{k}_{domain}.json", "w") as f:
                json.dump(v[domain], f)
