from nesim.utils.json_stuff import load_json_as_dict
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
apply_ratan_matplotlib_thing()
results = load_json_as_dict("results.json")
acc_results = load_json_as_dict("/home/XXXX-4/repos/nesim/experiments/imagenet/resnet18/validation_acc/result_simple.json")
acc_results["eshed"] = 0.439

layer_groups  = {
    "all": [
        "layer1.0.conv1",
        "layer1.0.conv2",
        "layer1.1.conv1",
        "layer1.1.conv2",
    
        "layer2.0.conv1", 
        "layer2.0.conv2",
        "layer2.0.downsample.0",
        "layer2.1.conv1", 
        "layer2.1.conv2", 
    
        "layer3.0.conv1", 
        "layer3.0.conv2",
        "layer3.0.downsample.0",
        "layer3.1.conv1", 
        "layer3.1.conv2", 
    
        "layer4.0.conv1", 
        "layer4.0.conv2",
        "layer4.0.downsample.0",
        "layer4.1.conv1", 
        "layer4.1.conv2"
    ]
}
def process_model_data(results, layer_names, acc_results, prefix = "thing"):
    """
    Processes the model data to calculate effective dimensionality, 
    tau values, and accuracies for specified models.
    
    Args:
        results (list): List of dictionaries containing model data.
        layer_names (list): List of valid layer names to include in the calculations.
        acc_results (dict): Dictionary mapping model names to their accuracy values.
    
    Returns:
        all_effective_dims (list): List of mean effective dimensionalities for each model.
        all_taus (list): List of tau values for each model.
        accuracies (list): List of accuracy values for each model.
        labels (list): List of labels for each model.
    """
    all_effective_dims = []
    all_taus = []
    accuracies = []
    labels = []

    # Loop through each model's data
    for data in results:
        model_name = data["model_name"]

        # Process specific model names starting with "all_topo", "baseline", or "eshed"
        if model_name.startswith("all_topo") or model_name.startswith("baseline") or model_name == "eshed":
            
            # Determine the tau value for the model
            if model_name != "eshed":
                if model_name.startswith("all_topo"):
                    tau = float(model_name.split("_")[3])
                    # Convert tau to an integer if it's greater than 1
                    if tau > 1:
                        tau = int(tau)
                else:
                    assert model_name.startswith("baseline")
                    tau = 0
            else:
                tau = "eshed"  # Special case for "eshed"

            single_model_effective_dims = []

            # Loop through the layers of the model to extract effective dimensionalities
            for single_layer_data in data["effective_dims"]:
                layer_name = single_layer_data["layer_name"]

                # Remove base_model prefix for "eshed" model layers
                if model_name == "eshed":
                    layer_name = layer_name.replace("base_model.", "")

                # Only include layers that match the specified layer_names
                if layer_name in layer_names:
                    effective_dim = single_layer_data["ed"]
                    single_model_effective_dims.append(effective_dim)

            # Calculate the mean effective dimensionality for the model
            mean_effective_dim = sum(single_model_effective_dims) / len(single_model_effective_dims)
            all_effective_dims.append(mean_effective_dim)
            all_taus.append(tau)
            accuracies.append(acc_results[model_name])
            labels.append(f"{prefix} tau {tau}")

            # Print the results for debugging
            print(f"Tau: {tau} ED: {mean_effective_dim} Accuracy: {acc_results[model_name]}")

    return all_effective_dims, all_taus, accuracies, labels


all_effective_dims, all_taus, accuracies, labels = process_model_data(
    results=results,
    layer_names=layer_groups["all"],
    acc_results=acc_results,
    prefix = "rn18"
)

## add resnet18
"""
Model: baseline_scale_None_shrink_factor_3.0 Val acc Top 1: 0.74288
Model: all_topo_scale_1_shrink_factor_3.0 Val acc Top 1: 0.7129
Model: all_topo_scale_30_shrink_factor_3.0 Val acc Top 1: 0.62244
"""
layer_names_rn50 =  load_json_as_dict("/home/XXXX-4/repos/nesim/training/imagenet/resnet50/layer_names.json")["all_conv_layers_except_first"]

all_effective_dims_rn50, all_taus_rn50, accuracies_rn50, labels_rn50 = process_model_data(
    results=load_json_as_dict("/home/XXXX-4/repos/nesim/experiments/imagenet/resnet50/effective_dimensionality/results.json"),
    layer_names=layer_names_rn50,
    acc_results=load_json_as_dict("/home/XXXX-4/repos/nesim/experiments/imagenet/resnet50/validation_acc/result_simple.json"),
    prefix = "rn50"
)
accuracies.extend(accuracies_rn50)
all_effective_dims.extend(all_effective_dims_rn50)
labels.extend(labels_rn50)

fig = plt.figure()
plt.scatter(accuracies, all_effective_dims)

# Add tau values as text labels next to each data point
for i, label in enumerate(labels):
    plt.text(accuracies[i], all_effective_dims[i], label, fontsize=9, ha='right')

plt.xlabel("Val accuracy")
plt.ylabel("Mean Effective dim")
fig.savefig("acc_vs_ed.pdf")