import json
import sys

sys.path.append("../../../apebench/")

import apebench  # noqa: E402

# All scenarios are single channel
scene = apebench.scenarios.difficulty.Advection()

param_dict = {}

for net in [
    *[f"Conv;{2**hc_exp:03d};10;relu" for hc_exp in range(9 + 1)],
    *[f"UNet;{2**hc_exp:03d};2;relu" for hc_exp in range(6 + 1)],
    *[f"Res;{2**hc_exp:03d};8;relu" for hc_exp in range(9 + 1)],
    *[f"FNO;12;{2**hc_exp:03d};4;relu" for hc_exp in range(5 + 1)],
    *[f"Dil;2;{2**hc_exp:03d};2;relu" for hc_exp in range(8 + 1)],
]:
    num_params = scene.get_parameter_count(net)

    param_dict[net] = num_params


with open("number_of_params.json", "w") as f:
    json.dump(param_dict, f)
