import matplotlib.pyplot as plt
from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing


import argparse

# Initialize parser
parser = argparse.ArgumentParser(description="Example script with --marchenko flag")

# Add --marchenko argument
parser.add_argument('--marchenko', action='store_true', help='Use Marchenko flag', required = False)

# Parse arguments
args = parser.parse_args()

colors = { 
    50.0: "#66B2A6",
    10: "#FEE0D2", 
    5: "#FF9966", 
    1: "#FB6A4A", 
    0.5: "#CB181D", 
    "baseline": "#000000" 
}
# colors = { 
#     50.0: "#66B2A6",  # Stronger pastel green
#     10: "#4DA6D8",    # Rich teal
#     5: "#0099CC",     # Vibrant blue
#     1: "#007ACC",     # Deep pastel blue
#     0.5: "#005DA2",   # Darker blue
#     "baseline": "#888888"  # Neutral grey for baseline
# }

all_layer_names = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/layer_names.json"
)
def plot_effective_dimensionality(data, filename="result.png"):
    apply_ratan_matplotlib_thing()
    fontsize = 18
    
    fig, axs = plt.subplots(2, 1, figsize=(10, 12), sharex=True)  # Two subplots, stacked vertically
    modes = ['all_topo', 'end_topo']
    
    lines = []  # For storing the line objects to make a common legend
    labels = []  # For storing the labels of each line for the legend

    for idx, mode in enumerate(modes):
        ax = axs[idx]
        for model in data:
            model_name = model["model_name"]
            layer_names = [entry["layer_name"] for entry in model["effective_dims"]]
            effective_dims = [entry["ed"] for entry in model["effective_dims"]]
            
            if mode in model_name or "baseline" in model_name:
                
                line, = ax.plot(
                    layer_names, 
                    effective_dims, 
                    label=model_name, 
                    marker='o', 
                    c=colors[model["topo_scale"]]
                )

                if "end_topo" in mode:
                    topo_layer_names = all_layer_names["last_conv_layers_in_each_block"]
                    for layer_index in range(len(layer_names)):
                        if layer_names[layer_index] in topo_layer_names:
                            axs[idx].axvline(
                                x = layer_index,
                                c = "#D3D3D3",
                                linestyle = "--"
                            )

                if idx ==1:
                    lines.append(line)
                    if model["topo_scale"] == "baseline":
                        labels.append("baseline")
                    else:
                        # labels.append(f'\$$Tau = {model["topo_scale"]}$$')
                        labels.append(f'$\\tau = {model["topo_scale"]}$')

                

        
        ax.set_ylabel('Effective Dimensionality', fontsize=fontsize)
        ax.set_ylim(0, 90)
        ax.set_title(f'{mode.replace("_", " ").title()}', fontsize=fontsize)
        ax.tick_params(axis='x', rotation=90)
        axs[idx].spines['top'].set_visible(False)
        axs[idx].spines['right'].set_visible(False)
    
    axs[-1].set_xlabel('Layer')
    
    # Create a single legend for both subplots
    # fig.legend(lines, labels, loc='lower right', bbox_to_anchor=(1.2, 1), title='Model Name')
    fig.legend(lines, labels, loc='lower right', bbox_to_anchor=(1.0, 0.2), title='Model Name')
    
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout to fit the legend
    plt.savefig(filename)
    plt.show()


def remove_explained_var(data):
    for d in data:
        for p in d["effective_dims"]:
            del p["num_components_explaining_most_of_variance"]

    return data

import scipy
if args.marchenko:
    data = load_json_as_dict("results_marchenko.json")
    scipy.io.savemat(
        "results_marchenko.mat",
        {"data": remove_explained_var(data)}
    )
    plot_effective_dimensionality(data=data, filename="assets/effective_dim_marchenko.png")
else:
    data = load_json_as_dict("results.json")
    scipy.io.savemat(
        "results.mat",
        {"data": remove_explained_var(data)}
    )
    plot_effective_dimensionality(data=data, filename="assets/effective_dim.png")
