from nesim.losses.nesim_loss import (
    NesimConfig,
    NeighbourhoodCosineSimilarity,
    LaplacianPyramid,
)
from nesim.utils.json_stuff import load_json_as_dict
import os

config_folder = "./nesim_configs"
os.system(f"mkdir -p {config_folder}")
os.system(f"rm {config_folder}/*.json")

"""
Here are the configs that we want:

1. End topo:
    - apply topo only on last layer of every resnet block
    - tau  = 3 values
2. All topo:
    - apply topo to all conv layers in the model
    - tau = 3 values
3. Baseline
    - we compute topo loss, but we dont backpropagate on it
"""



layer_names = load_json_as_dict("layer_names.json")


shrink_factors = [
    [3.0]
] ## means downsample cortical sheet by 3x3 

configurations = {
    "end_topo": {
        "layer_names": layer_names["last_conv_layers_in_each_block"],
        "loss_scales": [
                    # 1,
                    30,
                    40,
                ]
    },
    "all_topo": {
        "layer_names": layer_names["all_conv_layers_except_first"],
        "loss_scales":[
                    # 1,
                    30,
                    40,
                ]
    },
    # "baseline":{
    #     "layer_names": layer_names["all_conv_layers_except_first"],
    #     "loss_scales": [None]
    # }
}

num_configs = 0

for config_type in configurations:
    for loss_scale in configurations[config_type]["loss_scales"]:
        for shrink_factor in shrink_factors:
            assert len(shrink_factor) == 1, f"Shrink factor is supposed to be a list of length 1. Dont even ask me why. I will fix this interface later."

            layer_wise_configs = []

            for layer_name in configurations[config_type]["layer_names"]:
                # scale is none for neighbourhood cossim because we just want to watch it
                # layer_wise_configs.append(
                #     NeighbourhoodCosineSimilarity(layer_name=layer_name, scale=None)
                # )

                layer_wise_configs.append(
                    LaplacianPyramid(
                        layer_name=layer_name,
                        scale=loss_scale,
                        shrink_factor=shrink_factor,
                    )
                )

            nesim_config = NesimConfig(
                layer_wise_configs=layer_wise_configs,
            )

            filename = os.path.join(
                config_folder,
                f"{config_type}_scale_{loss_scale}_shrink_factor_{shrink_factor[0]}.json"
            ) 

            nesim_config.save_json(filename=filename)
            print(f"Saved: {filename}")
            num_configs += 1

print(f"saved {num_configs} configs")