from nesim.losses.nesim_loss import (
    NesimConfig,
    NeighbourhoodCosineSimilarity,
    Ring,
)
from nesim.utils.json_stuff import load_json_as_dict
import os

# from nesim.configs.exponential_scale_scheduler import ExponentialDecayScale

config_folder = "./nesim_configs"
os.system(f"rm {config_folder}/*.json")

ring_loss_scales = [
    10,
    50,
    100,
    # 300,
]

posible_nesim_layers = load_json_as_dict("possible_nesim_layers.json")

## watch only all conv layers for baseline
posible_nesim_layers_baseline = {
    "all_conv_layers_except_conv1": load_json_as_dict("possible_nesim_layers.json")[
        "all_conv_layers_except_conv1"
    ]
}

radiuses = [
    # (0, 1),
    (1, 2),
    (2, 3),
    (3, 4),
    (4, 5),
    (5, 6),
    # (6, 7),
    # (7, 8),
    # (8, 9),
    # (3, 5),
    # (5, 7),
    # (7, 9),
    # (3, 6),
    # (6, 9),
]

num_configs = 0

for radius_tuple in radiuses:
    for loss_scale in ring_loss_scales:
        for key, nesim_layer_names in posible_nesim_layers.items():
            layer_wise_configs = []

            for layer_name in nesim_layer_names:
                # scale is none for neighbourhood cossim because we just want to watch it
                layer_wise_configs.append(
                    NeighbourhoodCosineSimilarity(layer_name=layer_name, scale=None)
                )

                layer_wise_configs.append(
                    Ring(
                        layer_name=layer_name,
                        radius_inner=radius_tuple[0],
                        radius_outer=radius_tuple[1],
                        scale=loss_scale,
                    )
                )

            nesim_config = NesimConfig(
                layer_wise_configs=layer_wise_configs,
            )

            filename = f"radius_{radius_tuple[0]}_{radius_tuple[1]}_loss_scale_{loss_scale}_layers_{key}.json"

            nesim_config.save_json(filename=os.path.join(config_folder, filename))
            num_configs += 1


# baseline run
key = "all_conv_layers_except_conv1"
nesim_layer_names = posible_nesim_layers[key]
for layer_name in nesim_layer_names:
    # scale is none for neighbourhood cossim because we just want to watch it
    # layer_wise_configs.append(
    #     NeighbourhoodCosineSimilarity(layer_name=layer_name, scale=None)
    # )

    layer_wise_configs.append(
        Ring(
            layer_name=layer_name,
            radius_inner=radius_tuple[0],
            radius_outer=radius_tuple[1],
            scale=None,
        )
    )

nesim_config = NesimConfig(
    layer_wise_configs=layer_wise_configs,
)

filename = f"baseline_radius_{radius_tuple[0]}_{radius_tuple[1]}_loss_scale_{loss_scale}_layers_{key}.json"

nesim_config.save_json(filename=os.path.join(config_folder, filename))
num_configs += 1

print(f"saved {num_configs} configs")
