from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.model_info import convert_number_to_human_readable
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
apply_ratan_matplotlib_thing()

fontsize = 21

# Create two subplots: left for ResNet18, right for ResNet50
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8), sharex=True, sharey=True)

# List of modes and model configurations for ResNet18
modes = [
    {
        "model_prefix": "all_topo",
        "sparsify_layers": "all",
        "label": "All topo (Sparsify all layers)"
    }
]

# Plot ResNet18 data on the left subplot (ax1)
for mode in modes:
    filename = f'results_{mode["model_prefix"]}_sparsify_{mode["sparsify_layers"]}.json'
    results = load_json_as_dict(filename)

    compression_type = "l1"

    for model_name in results:
        single_model_results = results[model_name][compression_type]

        fraction_of_masked_weights = [x["fraction_of_masked_weights"] for x in single_model_results]
        val_accs = [x["val_acc"] for x in single_model_results]

        if model_name.startswith("baseline"):
            tau = "baseline"
        else:
            tau = model_name.split("_")[3]
            tau = float(tau)

        drop_in_acc = [val_accs[0] - acc for acc in val_accs]

        if tau != "baseline":
            label = f"rn18 $\\tau = {tau}$"
        else:
            label = "rn18 Baseline"
        
        ax1.plot(
            fraction_of_masked_weights,
            drop_in_acc,
            marker="o",
            linestyle="-",
            label=label
        )

# Set title for ResNet18 plot
ax1.set_title("Drop in Validation Accuracy L1 Sparsity rn18", fontsize=fontsize)

# Plot ResNet50 data on the right subplot (ax2)
results = load_json_as_dict(
    "/home/XXXX-4/repos/nesim/experiments/imagenet/resnet50/sparsity/results_all_topo_sparsify_all.json"
)

for model_name in results:
    single_model_results = results[model_name][compression_type]

    fraction_of_masked_weights = [x["fraction_of_masked_weights"] for x in single_model_results]
    val_accs = [x["val_acc"] for x in single_model_results]

    if model_name.startswith("baseline"):
        tau = "baseline"
    else:
        tau = model_name.split("_")[3]
        tau = float(tau)

    drop_in_acc = [val_accs[0] - acc for acc in val_accs]

    if tau != "baseline":
        label = f"rn50 $\\tau = {tau}$"
    else:
        label = "rn50 Baseline"
    
    ax2.plot(
        fraction_of_masked_weights,
        drop_in_acc,
        marker="o",
        linestyle="-",
        label=label
    )

# Set title for ResNet50 plot
ax2.set_title("Drop in Validation Accuracy L1 Sparsity rn50", fontsize=fontsize)

# Set y-axis ticks and labels
y_values = [0, 0.25, 0.5, 0.75, 1]
ax1.set_yticks(ticks=y_values)
ax1.set_yticklabels([f"{y}" for y in y_values], fontsize=fontsize)

# Set shared x-axis label
fig.supxlabel("Fraction of Masked Weights", fontsize=fontsize)
fig.supylabel("Drop in Accuracy (Higher = Worse)", fontsize=fontsize)

# Remove top and right spines for both subplots
for ax in [ax1, ax2]:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Add legends to both subplots
ax1.legend(fontsize=fontsize)
ax2.legend(fontsize=fontsize)

# Adjust layout for better fit
plt.tight_layout()

# Save the figure
output_filename = 'assets/l1_compression_drop_in_acc_fraction_weights_dual.pdf'
fig.savefig(output_filename)
print(f"Saved: {output_filename}")
