""" Creates the parameters for the actual training process.  TODO (anonymous): link up to master parameter files """
import json
import os


def get_last_existing_dir(path):
    dirs = os.listdir(path)
    return max([int(d) for d in dirs])


def preview(arg_file):
    print("BEGIN File Preview")
    with open(arg_file, "r") as fp:
        for i in range(25):
            print(next(fp).rstrip())
    print("END File Preview")
    print("Completed Succesfully")


def dump(fn, loc):
    arg_file = os.path.join(CONFIG_DIR, loc)
    runs = fn()
    with open(arg_file, "w") as fp:
        json.dump(runs, fp, indent=4, sort_keys=True)


def build_v1(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v1/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_layers = 1
    for dataset in ["IMDB-BINARY"]:
        # for dataset in ['COLLAB', 'REDDIT-MULTI-5K']:
        count += 1
        runs.append(
            {
                "model_dir": os.path.join(model_dir, str(count)),
                "dataset": dataset,
                "transform": "scatter",
                "model": "ts_net",
                "model_args": {"num_layers": num_layers, "epsilon": 1e-16,},
            }
        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def build_laziness_v2(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v2/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_seeds = 1
    num_layers = 1
    num_epochs = 1000
    for seed in range(num_seeds):
        #for dataset in ["IMDB-BINARY"]:
        for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
            for trainable_laziness in [False, True]:
                count += 1
                runs.append(
                    {
                        "model_dir": os.path.join(model_dir, str(count)),
                        "dataset": dataset,
                        "transform": "scatter",
                        "model": "ts_net",
                        "model_args": {
                            "trainable_laziness": trainable_laziness,
                            "num_layers": num_layers,
                            "epsilon": 1e-16,
                        },
                        "seed": seed,
                        "num_epochs": num_epochs,
                    }
                )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v1/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_seeds = 10
    num_layers = 1
    num_epochs = 1000
    for seed in range(num_seeds):
        #for dataset in ["IMDB-BINARY"]:
        for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
            for splits in [[0.8, 0.1, 0.1], [0.7, 0.1, 0.2], [0.4, 0.1, 0.5], [0.2, 0.1, 0.7]]:
                count += 1
                runs.append(
                    {
                        "model_dir": os.path.join(model_dir, str(count)),
                        "dataset": dataset,
                        "transform": "fast_scatter",
                        "model": "fast_ts_net",
                        "model_args": {
                            "trainable_laziness": False,
                            "num_layers": num_layers,
                            "epsilon": 1e-16,
                        },
                        "seed": seed,
                        "num_epochs": num_epochs,
                        "splits": splits,
                        "early_stopping": True,
                    }
                )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}

def test_train_split_v2(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v2/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_seeds = 10
    num_layers = 1
    num_epochs = 1000
    for seed in range(num_seeds):
        #for dataset in ["IMDB-BINARY"]:
        for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
        #for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
            for splits in [[0.8, 0.1], [0.7, 0.1], [0.4, 0.1], [0.2, 0.1]]:
                count += 1
                runs.append(
                    {
                        "model_dir": os.path.join(model_dir, str(count)),
                        "dataset": dataset,
                        "transform": "fast_scatter",
                        "model": "fast_ts_net",
                        "model_args": {
                            "trainable_laziness": False,
                            "num_layers": num_layers,
                            "epsilon": 1e-16,
                        },
                        "val_seed": seed,
                        "test_seed": 0,
                        "num_epochs": num_epochs,
                        "splits": splits,
                        "early_stopping": True,
                    }
                )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v3(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v3/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_seeds = 5
    num_layers = 1
    num_epochs = 1000
    for seed in range(num_seeds):
        #for dataset in ["IMDB-BINARY"]:
        for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
        #for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
            for splits in [[0.7, 0.2], [0.6, 0.2], [0.3, 0.2], [0.1, 0.2]]:
                count += 1
                runs.append(
                    {
                        "model_dir": os.path.join(model_dir, str(count)),
                        "dataset": dataset,
                        "transform": "fast_scatter",
                        "model": "fast_ts_net",
                        "model_args": {
                            "trainable_laziness": False,
                            "num_layers": num_layers,
                            "epsilon": 1e-16,
                        },
                        "val_seed": seed,
                        "test_seed": 0,
                        "num_epochs": num_epochs,
                        "splits": splits,
                        "early_stopping": True,
                    }
                )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v4(overwrite=True):
    # generated with old_split_dataset
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v4/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_seeds = 5
    num_layers = 1
    num_epochs = 1000
    for seed in range(num_seeds):
        #for dataset in ["IMDB-BINARY"]:
        for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
        #for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
            for splits in [[0.7, 0.2], [0.6, 0.2], [0.3, 0.2], [0.1, 0.2]]:
                count += 1
                runs.append(
                    {
                        "model_dir": os.path.join(model_dir, str(count)),
                        "dataset": dataset,
                        "transform": "fast_scatter",
                        "model": "fast_rbf_net",
                        "model_args": {
                            "trainable_laziness": False,
                            "num_layers": num_layers,
                            "epsilon": 1e-16,
                        },
                        "val_seed": seed,
                        "test_seed": 0,
                        "num_epochs": num_epochs,
                        "splits": splits,
                        "early_stopping": True,
                    }
                )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v5(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v5/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [10, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10, 5, 2, 10]
    all_splits = [[8, 1, 1], [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    for dataset in ["IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]:
        for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
            for test_seed in range(test_seeds):
                for val_seed in range(val_seeds):
                    count += 1
                    runs.append(
                        {
                            "model_dir": os.path.join(model_dir, str(count)),
                            "dataset": dataset,
                            "transform": "fast_scatter",
                            "model": "fast_rbf_net",
                            "model_args": {
                                "trainable_laziness": False,
                                "num_layers": num_layers,
                                "epsilon": 1e-16,
                            },
                            "test_seed": test_seed,
                            "val_seed": val_seed,
                            "num_epochs": num_epochs,
                            "splits": splits,
                            "early_stopping": True,
                        }
                    )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}

def test_train_split_v6(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v6/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [10, 8, 5, 3] # fixed with 0.2 validation set size
    #num_val_seeds = [10, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10, 5, 2, 10]
    all_splits = [[8, 1, 1], [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTIENS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
            for test_seed in range(test_seeds):
                for val_seed in range(val_seeds):
                    count += 1
                    runs.append(
                        {
                            "model_dir": os.path.join(model_dir, str(count)),
                            "dataset": dataset,
                            "transform": "fast_scatter",
                            "model": "fast_rbf_net",
                            "model_args": {
                                "trainable_laziness": False,
                                "num_layers": num_layers,
                                "epsilon": 1e-16,
                            },
                            "test_seed": test_seed,
                            "val_seed": val_seed,
                            "num_epochs": num_epochs,
                            "splits": splits,
                            "early_stopping": True,
                        }
                    )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_vtest(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_vtest/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1] # fixed with 0.2 validation set size
    num_test_seeds = [10]
    all_splits = [[8, 1, 1]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["IMDB-BINARY"]
    for dataset in datasets:
        for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
            for test_seed in range(test_seeds):
                for val_seed in range(val_seeds):
                    count += 1
                    runs.append(
                        {
                            "model_dir": os.path.join(model_dir, str(count)),
                            "dataset": dataset,
                            "transform": "fast_scatter",
                            "model": "fast_rbf_net",
                            "model_args": {
                                "trainable_laziness": False,
                                "num_layers": num_layers,
                                "epsilon": 1e-16,
                            },
                            "test_seed": test_seed,
                            "val_seed": val_seed,
                            "num_epochs": num_epochs,
                            "splits": splits,
                            "early_stopping": True,
                        }
                    )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}




def test_train_split_v7(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v7/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["fast_ts_net"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "fast_scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v8(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v8/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [1]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["IMDB-BINARY"]
    #datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["gcn", "baseline"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v9(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v9/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["fast_sort"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "fast_scatter_sort",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v10(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v10/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["gcn"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v12(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v12/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["graph_sage"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_v11_attention(overwrite=True):
    """
    attention_plain model, attention without sorting
    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_v11_attention_short/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [1]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["attention_plain"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def v13(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v13/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [5, 2, 10]
    all_splits = [[7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["gcn"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}

def v15(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v15/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [5, 2, 10]
    all_splits = [[7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["graph_sage"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}

def v16(overwrite=True):
    """
    attention_plain model, attention without sorting
    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v16/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES"]
    for dataset in datasets:
        for model in ["fast_ts_net"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "fast_scatter_cat",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}
def v14(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v14/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["fast_rbf_sort"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "fast_scatter_sort",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_train_split_baseline(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_split_baseline/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["baseline"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}

def test_casp_regression_v1(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_casp_v1/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [1]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["CASP"]
    #datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["fast_ts_net_regression"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "fast_scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}

def test_casp_regression_v2(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_casp_v2/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [1]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["CASP"]
    #datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["gcn_regression"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "none",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_casp_regression_v3(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_casp_v3/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [1]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["CASP2"]
    #datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["gcn_regression", "fast_ts_net_regression"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "none" if model == "gcn_regression" else "fast_scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_casp_regression_v4(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_casp_v4/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [1]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["CASP2"]
    for dataset in datasets:
        for model in ["graph_sage_regression", "baseline_regression"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "none",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def test_casp_regression_v5(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_casp_5/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [1]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["CASP2"]
    for dataset in datasets:
        for model in ["gcn_regression", "fast_ts_net_regression", "graph_sage_regression", "baseline_regression"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "fast_scatter" if model == "fast_ts_net_regression" else "none",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}
def test_casp_regression_v6(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_casp_6/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [3]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["CASP2"]
    models = ["gcn_regression", "fast_sort_regression", "fast_rbf_regression", "fast_ts_net_regression", "graph_sage_regression", "baseline_regression"]
    transforms = ["scatter", "fast_scatter_sort", "fast_scatter", "fast_scatter", "scatter", "scatter"]
    for dataset in datasets:
        for model, transform in zip(models, transforms):
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": transform,
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}
def test_casp_regression_v7(overwrite=True):
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/tt_casp_7/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [1]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [3]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["CASP2"]
    models = ["gcn_regression", "graph_sage_regression"]
    transforms = ["scatter", "scatter"]
    for dataset in datasets:
        for model, transform in zip(models, transforms):
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": transform,
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def v17(overwrite=True):
    """
    Match Feng's setting as close as possible.

    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v17/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES"]
    for dataset in datasets:
        for model in ["gcn", "graph_sage", "baseline"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter_cat",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


def v18(overwrite=True):
    """
    attention_plain model, attention without sorting
    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v18_attention/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES"]
    for dataset in datasets:
        for model in ["attention_plain"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "scatter",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}
def v19(overwrite=True):
    """
    attention_plain model, attention without sorting
    """
    runs = []
    model_dir = "/home/anonymous/trainable_scattering/models/v19_sort/"
    count = -1 if overwrite else get_last_existing_dir(model_dir)
    num_val_seeds = [9]#, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10]#, 5, 2, 10]
    all_splits = [[8, 1, 1]]#, [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    num_layers = 1
    num_epochs = 1000
    #datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES"]
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    #datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES", "REDDIT-BINARY", "REDDIT-MULTI-12K", "IMDB-BINARY", "IMDB-MULTI", "COLLAB", "REDDIT-MULTI-5K"]
    for dataset in datasets:
        for model in ["fast_sort"]:
            for splits, val_seeds, test_seeds in zip(all_splits, num_val_seeds, num_test_seeds):
                for test_seed in range(test_seeds):
                    for val_seed in range(val_seeds):
                        count += 1
                        runs.append(
                            {
                                "model_dir": os.path.join(model_dir, str(count)),
                                "dataset": dataset,
                                "transform": "fast_scatter_sort",
                                "model": model,
                                "model_args": {
                                    "trainable_laziness": False,
                                    "num_layers": num_layers,
                                    "epsilon": 1e-16,
                                },
                                "test_seed": test_seed,
                                "val_seed": val_seed,
                                "num_epochs": num_epochs,
                                "splits": splits,
                                "early_stopping": True,
                            }
                        )
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data": meta_data, "runs": runs}


TU_DATASETS = [
    "NCI1",
    "NCI109",
    "DD",
    "PROTIENS",
    "PROTEINS",
    "MUTAG",
    "PTC",
    "ENZYMES",
    "REDDIT-BINARY",
    "REDDIT-MULTI-12K",
    "IMDB-BINARY",
    "IMDB-MULTI",
    "COLLAB",
    "REDDIT-MULTI-5K",
]


def train_all_shallow():
    model_dir = "/home/anonymous/trainable_scattering/models/shallow/"
    G_pool = [0.0001,0.001, 0.01, 0.1, 1,10]
    C_pool = [0.01, 0.1, 1, 10,25,50,100]
    num_val_seeds = [9, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10, 5, 2, 10]
    all_splits = [[8, 1, 1], [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    arg_list = []
    count = -1
    for dataset in TU_DATASETS:
        for i in range(4):
            splits = all_splits[i]
            for test_seed in range(num_test_seeds[i]):
                preds = []
                for val_seed in range(num_val_seeds[i]):
                    count += 1
                    args = {
                        "dataset": dataset,
                        "model_dir": os.path.join(model_dir, str(count)),
                        "splits": splits,
                        "test_seed": test_seed,
                        "transform": "fast_scatter",
                        "val_seed": val_seed,
                        "i": count,
                    }
                    arg_list.append(args)
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data" : meta_data, "runs": arg_list}


def train_all_shallow_cat():
    model_dir = "/home/anonymous/trainable_scattering/models/shallow_cat/"
    G_pool = [0.0001,0.001, 0.01, 0.1, 1,10]
    C_pool = [0.01, 0.1, 1, 10,25,50,100]
    num_val_seeds = [9, 8, 5, 3] # fixed with 0.2 validation set size
    num_test_seeds = [10, 5, 2, 10]
    all_splits = [[8, 1, 1], [7, 1, 2], [4, 1, 5], [2, 1, 7]]
    arg_list = []
    count = -1
    datasets = ["NCI1", "NCI109", "DD", "PROTEINS", "MUTAG", "PTC", "ENZYMES"]
    for dataset in datasets:
        for i in range(1):
            splits = all_splits[i]
            for test_seed in range(num_test_seeds[i]):
                preds = []
                for val_seed in range(num_val_seeds[i]):
                    count += 1
                    args = {
                        "dataset": dataset,
                        "model_dir": os.path.join(model_dir, str(count)),
                        "splits": splits,
                        "test_seed": test_seed,
                        "transform": "fast_scatter_cat",
                        "val_seed": val_seed,
                        "i": count,
                    }
                    arg_list.append(args)
    meta_data = {"model_dir": model_dir, "num_runs": count + 1, "search_type": "grid"}
    return {"meta_data" : meta_data, "runs": arg_list}




    #df = pd.read_pickle('shallow_results.pkl')

if __name__ == "__main__":
    CONFIG_DIR = "/home/anonymous/trainable_scattering/config/"
    dump(build_v1, "v1.json")
    dump(build_laziness_v2, "v2.json")
    dump(test_train_split, "tt_split_v1.json")
    dump(test_train_split_v2, "tt_split_v2.json")
    dump(test_train_split_v3, "tt_split_v3.json")
    dump(test_train_split_v4, "tt_split_v4.json")
    dump(test_train_split_v5, "tt_split_v5.json")
    dump(test_train_split_v6, "tt_split_v6.json")
    dump(test_train_split_v7, "tt_split_v7.json")
    dump(test_train_split_v8, "tt_split_v8.json")
    dump(test_train_split_v9, "tt_split_v9.json")
    dump(test_train_split_v10, "tt_split_v10.json")
    dump(test_train_split_v11_attention, "tt_split_v11_attention_short.json")
    dump(test_train_split_v12, "tt_split_v12_sage.json")
    dump(v13, "v13_gcn_low_data.json")
    dump(v14, "v14_sort_rbf.json")
    dump(v15, "v15_graphsage_low_data.json")
    dump(v16, "v16_cat.json")
    dump(v17, "v17_baselines.json")
    dump(v18, "v18_attention.json")
    dump(v19, "v19_sort.json")
    dump(test_train_split_vtest, "tt_split_vtest.json")
    dump(test_train_split_baseline, "tt_split_baseline.json")
    dump(test_casp_regression_v1, "tt_casp_v1.json")
    dump(test_casp_regression_v2, "tt_casp_v2.json")
    dump(test_casp_regression_v3, "casp_v3.json")
    dump(test_casp_regression_v4, "casp_v4.json")
    dump(test_casp_regression_v5, "casp_v5.json")
    dump(test_casp_regression_v6, "casp_v6.json")
    dump(test_casp_regression_v7, "casp_v7.json")
    dump(train_all_shallow, "shallow.json")
    dump(train_all_shallow_cat, "shallow_cat.json")
