from nesim.losses.nesim_loss import NesimConfig, Ring1D
from nesim.utils.json_stuff import load_json_as_dict
import os

os.system("rm ./nesim_configs/*.json") if os.listdir("./nesim_configs") else None

layer_names = load_json_as_dict("possible_nesim_layers.json")[
    "all_conv_layers_except_conv1"
]
radiuses = [
    # (0, 10),
    # (0, 1),
    # (0, 2),
    # (0, 3),
    # (0, 4),
    # (0, 5),
    # (0, 6),
    # (0, 7),
    # (0, 8),
    (0, 9),
]

for layer_name in layer_names:
    for radius_tuple in radiuses:
        layer_wise_configs_ours = []

        for name in layer_names:
            scale = 1.0 if layer_name == name else None
            layer_wise_configs_ours.extend(
                [
                    Ring1D(
                        layer_name=name,
                        freq_inner=radius_tuple[0],
                        freq_outer=radius_tuple[1],
                        scale=scale,
                    )
                ]
            )
        nesim_config = NesimConfig(
            layer_wise_configs=layer_wise_configs_ours,
        )
        nesim_config.save_json(
            filename=f"./nesim_configs/ring_loss_radius_{radius_tuple[0]}_{radius_tuple[1]}_layer_{layer_name}.json"
        )


layer_wise_configs_baseline = []
for name in layer_names:
    layer_wise_configs_baseline.extend(
        [
            Ring1D(
                layer_name=name,
                freq_inner=radius_tuple[0],
                freq_outer=radius_tuple[1],
                scale=None,
            )
        ]
    )

nesim_config = NesimConfig(
    layer_wise_configs=layer_wise_configs_baseline,
)

nesim_config.save_json(filename="./nesim_configs/baseline.json")
