import json
import matplotlib.pyplot as plt
import numpy as np

# Load your JSON data
with open('./eval_final_ffa_topo_and_tdann_25_only.json', 'r') as f:
    data_rn = json.load(f)

with open('eval_results_tdann_3.json', 'r') as file:
    data_tdann = json.load(file)

resnet_18_data = data_rn["resnet18"]
resnet_50_data = data_rn["resnet50"]
tdann_data = data_tdann["tdann"]

def plot_combined_best_layers(tdann_dict, resnet18_dict, resnet50_dict, metric="pearson_r"):
    # Prepare data
    combined_best_layers = []
    model_names = []
    legend_labels = []

    # Process TDANN data
    for checkpoint_name, layers_dict in tdann_dict.items():
        best_layer = None
        best_value = float('-inf')

        for layer_name, values in layers_dict.items():
            value = values[metric]
            if value > best_value:
                best_value = value
                best_layer = layer_name

        combined_best_layers.append(best_value)
        model_names.append("TDANN")
        legend_labels.append(f"TDANN: {checkpoint_name} -> Layer: {best_layer}")

    # Process ResNet-18 data
    for topo_mode, checkpoints_dict in resnet18_dict.items():
        for checkpoint_name, layers_dict in checkpoints_dict.items():
            best_layer = None
            best_value = float('-inf')

            for layer_name, values in layers_dict.items():
                value = values[metric]
                if value > best_value:
                    best_value = value
                    best_layer = layer_name

            combined_best_layers.append(best_value)
            model_names.append(f"RN18")
            legend_labels.append(f"RN18 {topo_mode}: {checkpoint_name} -> Layer: {best_layer}")

    # Process ResNet-50 data
    for topo_mode, checkpoints_dict in resnet50_dict.items():
        for checkpoint_name, layers_dict in checkpoints_dict.items():
            best_layer = None
            best_value = float('-inf')

            for layer_name, values in layers_dict.items():
                value = values[metric]
                if value > best_value:
                    best_value = value
                    best_layer = layer_name

            combined_best_layers.append(best_value)
            model_names.append(f"RN50")
            legend_labels.append(f"RN50 {topo_mode}: {checkpoint_name} -> Layer: {best_layer}")

    # Plotting
    plt.figure(figsize=(30, 8))
    
    # Assign numeric labels for x-axis
    x_labels = np.arange(len(combined_best_layers))
    
    # Create a color list based on model names
    color_map = {'TDANN': 'red', 'RN18': 'skyblue', 'RN50': 'green'}
    color_list = [color_map[name] for name in model_names]

    # Creating the bar plot
    plt.bar(x_labels, combined_best_layers, color=color_list)
    plt.grid()
    plt.xlabel("Models and Checkpoints")
    plt.ylabel(f"Best {metric.upper()} Value")
    plt.title(f"Best {metric.upper()} Layer from Each Model and Checkpoint")
    
    # Set x-ticks as placeholder numbers
    plt.xticks(x_labels, [str(i + 1) for i in range(len(combined_best_layers))], rotation=45, ha='right')
    
    # Annotate each bar with the numeric label
    for idx, value in enumerate(combined_best_layers):
        plt.text(idx, value, f"{round(value, 3)}", ha='center', va='bottom' if metric == 'pearson_r' else 'top', fontsize=8)

    # Create a legend with detailed labels using numbers
    handles = [plt.Line2D([0], [0], color='none', lw=0, label=f"{idx + 1}: {legend_labels[idx]}") for idx in range(len(legend_labels))]
    plt.legend(handles=handles, title="Layer Mappings", loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    plt.savefig("combined_best_layer_per_checkpoint_with_legend.png")
    plt.show()

# Call the combined plotting function
plot_combined_best_layers(tdann_data, resnet_18_data, resnet_50_data, metric="pearson_r")
