from nesim.utils.json_stuff import load_json_as_dict
import matplotlib.pyplot as plt

results = load_json_as_dict("results.json")
layer_groups  = {
    "late": [
        "layer4.0.conv1", 
        "layer4.0.conv2",
        "layer4.0.downsample.0",
        "layer4.1.conv1", 
        "layer4.1.conv2"
    ],
    "mid": [
        "layer2.0.conv1", 
        "layer2.0.conv2",
        "layer2.0.downsample.0",
        "layer2.1.conv1", 
        "layer2.1.conv2", 
    
        "layer3.0.conv1", 
        "layer3.0.conv2",
        "layer3.0.downsample.0",
        "layer3.1.conv1", 
        "layer3.1.conv2", 
    ],
    "early": [
        "layer1.0.conv1",
        "layer1.0.conv2",
        "layer1.1.conv1",
        "layer1.1.conv2",
    ],
    "all": [
        "layer1.0.conv1",
        "layer1.0.conv2",
        "layer1.1.conv1",
        "layer1.1.conv2",
    
        "layer2.0.conv1", 
        "layer2.0.conv2",
        "layer2.0.downsample.0",
        "layer2.1.conv1", 
        "layer2.1.conv2", 
    
        "layer3.0.conv1", 
        "layer3.0.conv2",
        "layer3.0.downsample.0",
        "layer3.1.conv1", 
        "layer3.1.conv2", 
    
        "layer4.0.conv1", 
        "layer4.0.conv2",
        "layer4.0.downsample.0",
        "layer4.1.conv1", 
        "layer4.1.conv2"
    ]
}

fig = plt.figure()
for group_name, layer_names in layer_groups.items():
    all_taus = []
    all_effective_dims = []

    for data in results:
        model_name = data["model_name"]
        if model_name.startswith("all_topo") or model_name.startswith("baseline"):

            if model_name.startswith("all_topo"):
                tau = float(model_name.split("_")[3])
                if tau> 1:
                    tau = int(tau)
            else:
                assert model_name.startswith("baseline")
                tau = 0

            single_model_effective_dims=  []
            for single_layer_data in data["effective_dims"]:
                layer_name = single_layer_data["layer_name"]
                if layer_name in layer_names:
                    effective_dim = single_layer_data["ed"]
                    single_model_effective_dims.append(effective_dim)
            
            mean_effective_dim = sum(single_model_effective_dims)/len(single_model_effective_dims)
            all_effective_dims.append(mean_effective_dim)
            all_taus.append(tau)
            print(f"Tau: {tau} ED: {mean_effective_dim}")

    
    plt.plot(range(len(all_taus)), all_effective_dims, label = group_name, marker = "o")

all_taus[0] = "Baseline"
plt.xticks(
    ticks=range(len(all_taus)),
    labels=all_taus
)
plt.xlabel("Tau")
plt.ylabel("Effective dim")
plt.legend()
filename =f"assets/grouped_layers.png"
fig.savefig(filename)
print(filename)

############ MODEL ACC VS EFFECTIVE DIM
