from nesim.utils.json_stuff import load_json_as_dict
import matplotlib.pyplot as plt
from nesim.utils.model_info import convert_number_to_human_readable

fontsize = 21

colors = { 
    50.0: "#FEE0D2",
    10.0: "#FF9966", 
    5.0: "#FB6A4A", 
    1.0: "#CB181D", 
    "baseline": "#000000" 
}

modes = [
    {
        "model_prefix": "end_topo",
        "sparsify_layers": "end",
        "label": "End topo\nSparsify end layers only"
    },
    {
        "model_prefix": "all_topo",
        "sparsify_layers": "all",
        "label":  "All topo\nSparsify all layers"
    },
]

# Create a figure for each mode separately
for mode in modes:
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(10, 15))  # Two rows for l1 and downsampling
    filename = f'results_{mode["model_prefix"]}_sparsify_{mode["sparsify_layers"]}.json'
    results = load_json_as_dict(filename)
    
    for row, compression_type in enumerate(["l1", "downsampling"]):
        for model_name in results:
            single_model_results = results[model_name][compression_type]

            downsample_factors = [
                x["downsample_factor"] for x in single_model_results
            ]
            val_accs = [
                x["val_acc"] for x in single_model_results
            ]

            if model_name.startswith("baseline"):
                tau = "baseline"
            else:
                tau = model_name.split("_")[3]
                tau = float(tau)

            color = colors[tau]

            if model_name.startswith("all_topo"):
                # Dashed lines for all topo models
                ax[row].plot(
                    range(len(downsample_factors)),
                    val_accs,
                    marker="o",
                    color=color,
                )
            else:
                if tau != "baseline":
                    label = f"$\\tau = {int(tau)}$"
                else:
                    label = tau
                ax[row].plot(
                    range(len(downsample_factors)),
                    val_accs,
                    marker="o",
                    color=color,
                    label=label
                )

        # Set y-axis ticks and labels
        y_values = [0, 0.25, 0.5, 0.75, 1]
        ax[row].set_yticks(ticks=y_values, labels=[f"{y}" for y in y_values], fontsize=fontsize)
        
        # Set x-axis ticks and labels
        compression_factors_x_axis = [1, 5, 10, 15, 20]
        ax[row].set_xticks([x - 1 for x in compression_factors_x_axis], labels=[f"{x}x" for x in compression_factors_x_axis], fontsize=fontsize)
        
        if row == 1:
            ax[row].set_xlabel("Sparsity", fontsize=fontsize)
        ax[row].set_ylabel("Validation Acc.", fontsize=fontsize)
        
        # Set plot title and additional styling
        ax[row].set_ylim(0, 1)
        ax[row].spines['top'].set_visible(False)
        ax[row].spines['right'].set_visible(False)
        ax[row].grid()
        title = f'{mode["model_prefix"]}_sparsify_{mode["sparsify_layers"]}'.replace("_", " ") + f" ({compression_type})"
        ax[row].set_title(title, fontsize=fontsize)

    plt.legend(fontsize=fontsize)
    plt.tight_layout()

    # Save each plot separately
    output_filename = f'assets/results_{mode["model_prefix"]}_sparsify_{mode["sparsify_layers"]}.pdf'
    fig.savefig(output_filename)
    print(f"Saved: {output_filename}")
