import matplotlib.pyplot as plt
from nesim.utils.json_stuff import load_json_as_dict
import os

results_json_folder = "./results"
figures_folder = "./figures"

model_names = [
    'baseline',
    'topo'
]

layer_names = [
    "layer1.0.conv1",
    "layer1.0.conv2",
    "layer1.1.conv1",
    "layer1.1.conv2",
    "layer2.0.conv1",
    "layer2.0.conv2",
    "layer2.1.conv1",
    "layer2.1.conv2",
    "layer3.0.conv1",
    "layer3.0.conv2",
    "layer3.1.conv1",
    "layer3.1.conv2",
    "layer4.0.conv1",
    "layer4.0.conv2",
    "layer4.1.conv1",
    "layer4.1.conv2"
]

size_animacy_corr_scores = {
    "baseline": [],
    "topo": []
}
for model_name in model_names:
    for layer_name in layer_names:
        result = load_json_as_dict(
            os.path.join(
                results_json_folder,
                f"{model_name}_{layer_name}.json"
            )
        )

        size_animacy_corr_scores[model_name].append(result["pearson_correlation"])

fig = plt.figure()
plt.title("Correlation between Size and Animacy\n Small -> Animate\nBig -> Inanimate")
for key in size_animacy_corr_scores:
    plt.plot(size_animacy_corr_scores[key], label = key)
plt.legend()
plt.grid()
plt.xlabel("Layers ->")
plt.ylabel("Correlation")
plt.xticks(
    ticks = [i for i in range(len(layer_names))],
    labels =  layer_names,
    rotation = 90
)
plt.ylim(-0.5,0.5)
plt.tight_layout()
pdf_file_path = os.path.join(
    figures_folder,
    f"correlation_across_layers.pdf"
)
fig.savefig(pdf_file_path, format="pdf")