from nesim.utils.json_stuff import dict_to_json, load_json_as_dict
import matplotlib.pyplot as plt
from tqdm import tqdm
from pprint import pprint

intermediate_layer_configs = load_json_as_dict("intermediate_layer_configs.json")

method_names = {"TDANN": "tdann", "Ours": "ours"}

final_plot_results = {"tdann": [], "ours": []}

for name in method_names:

    configs = intermediate_layer_configs[method_names[name]]

    for config in tqdm(configs, desc=f"Plotting {name}"):
        result_json_filename = (
            f"results/{method_names[name]}_{config['layer_name']}.json"
        )
        result = load_json_as_dict(result_json_filename)

        fold_wise_results = {
            "train_loss": [],
            "validation_loss": [],
            "validation_correlation_score": [],
        }

        for fold in range(len(result)):
            fold_wise_results["train_loss"].append(result[fold]["train_losses"][-1])
            fold_wise_results["validation_loss"].append(
                result[fold]["validation_losses"][-1]
            )
            fold_wise_results["validation_correlation_score"].append(
                result[fold]["correlation_scores"][-1]
            )

        mean_across_fold_results = {
            "train_loss": sum(fold_wise_results["train_loss"])
            / len(fold_wise_results["train_loss"]),
            "validation_loss": sum(fold_wise_results["validation_loss"])
            / len(fold_wise_results["validation_loss"]),
            "validation_correlation_score": sum(
                fold_wise_results["validation_correlation_score"]
            )
            / len(fold_wise_results["validation_correlation_score"]),
        }
        final_plot_results[method_names[name]].append(
            {"layer_name": config["layer_name"], "results": mean_across_fold_results}
        )

pprint(final_plot_results)
dict_to_json(final_plot_results, "results.json")

import numpy as np


def plot_correlation_scores(data_dict, filename=None):
    # Extract the data for "tdann" and "ours" layers
    tdann_data = data_dict["tdann"]
    ours_data = data_dict["ours"]

    # Extract layer names and correlation scores for each layer
    layer_names_tdann = [
        entry["layer_name"].replace("base_model.", "") for entry in tdann_data
    ]
    correlation_scores_tdann = [
        entry["results"]["validation_correlation_score"] for entry in tdann_data
    ]

    layer_names_ours = [entry["layer_name"] for entry in ours_data]
    correlation_scores_ours = [
        entry["results"]["validation_correlation_score"] for entry in ours_data
    ]

    # Set the width of the bars
    bar_width = 0.35

    # Generate x-axis positions for the bars
    x = np.arange(len(layer_names_tdann))

    # Create the bar plot
    plt.bar(
        x - bar_width / 2, correlation_scores_tdann, bar_width, label="tdann", alpha=0.7
    )
    plt.bar(
        x + bar_width / 2,
        correlation_scores_ours,
        bar_width,
        label="ours",
        alpha=0.7,
        color="green",
    )

    # Set the x-axis labels and rotate them for better visibility
    plt.xticks(x, layer_names_tdann, rotation=45, ha="right")

    # Add labels and title
    plt.xlabel("Layer Name")
    plt.ylabel("Correlation Score")
    plt.title("Brain model performance\n(murty185 + resnet18 + convmapper)")

    # Add a legend with an offset
    plt.legend(loc="upper left", bbox_to_anchor=(0.0, 1.5))

    # Add grid lines
    plt.grid()

    # Show the plot
    plt.tight_layout()

    # Save the figure if filename is provided
    if filename:
        plt.savefig(filename, bbox_inches="tight")

    # Display the plot or save it based on the presence of filename
    if filename:
        plt.close()
    else:
        plt.show()


plot_correlation_scores(final_plot_results, filename="figures/correlation_scores.png")
