import argparse
import os

import torch


def data_gen(
    n_nodes: int,
    batch_size: int,
    flag: str,
) -> None:
    # seedの値に特に意味はない
    flag2seed = {
        "training": 0,
        "validation": 1,
        "testing": 2,
    }
    torch.manual_seed(flag2seed[flag])

    def check_folder_and_make_folder(flag: str):
        folder_name = flag + "_data"
        if not os.path.isdir(folder_name):
            os.mkdir(folder_name)

    check_folder_and_make_folder(flag)
    if flag == "training":
        for i in range(10):
            torch.save(
                torch.rand(size=[batch_size, n_nodes, 2]),
                f"./training_data/training_data_{str(n_nodes)}_{str(batch_size)}_{str(i)}",
            )
    else:
        torch.save(
            torch.rand(size=[batch_size, n_nodes, 2]),
            f"./{flag}_data/{flag}_data_{str(n_nodes)}_{str(batch_size)}",
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n-nodes", type=int, default=400, help="number of nodes")
    parser.add_argument("--batch-size", type=int, default=3, help="batchsize")
    parser.add_argument(
        "--flag", type=str, choices=["training", "validation", "testing", "all"], help="flag", required=True
    )

    args = parser.parse_args()

    if args.flag == "all":
        data_gen(args.n_nodes, args.batch_size, "training")
        data_gen(args.n_nodes, args.batch_size, "validation")
        data_gen(args.n_nodes, args.batch_size, "testing")
    else:
        data_gen(args.n_nodes, args.batch_size, args.flag)
