import json
from argparse import ArgumentParser
from pathlib import Path
from collections import defaultdict

import torch

from experiments.utils import common_parser, set_logger


def generate_splits(
    data_path,
    save_path,
    name="fmnist_splits.json",
):
    splits = ["train", "val", "test"]
    data_split = defaultdict(lambda: defaultdict(list))
    for split in splits:
        print(f"Processing {split} split")
        with open(data_path, "r") as f:
            data = json.load(f)

        paths = [
            (Path(data_path).parent / Path(*Path(di).parts[-2:])).as_posix()
            for di in data[split]
        ]
        labels = [
            torch.load(path, map_location=lambda storage, loc: storage)["label"]
            for path in paths
        ]
        data_split[split]["path"] = paths
        data_split[split]["label"] = labels

        print(f"Finished processing {split} split")

    save_path = Path(save_path) / name
    with open(save_path, "w") as file:
        json.dump(data_split, file)


if __name__ == "__main__":
    parser = ArgumentParser(
        "INR Classification - Fashion MNIST - preprocess data", parents=[common_parser]
    )
    parser.add_argument(
        "--name", type=str, default="fmnist_splits.json", help="json file name"
    )
    parser.set_defaults(
        save_path="dataset",
        data_path="dataset/fmnist_inrs/splits.json",
    )
    args = parser.parse_args()

    set_logger()

    generate_splits(
        args.data_path,
        args.save_path,
        name=args.name,
    )
