from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
import matplotlib.pyplot as plt
import argparse
from parsing import parse_eval_results
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
from nesim.utils.folder import get_filenames_in_a_folder
import os

RESULTS_FOLDER = "./results"

def load_results(results_folder):
    results = {}

    for filename in get_filenames_in_a_folder(results_folder):
        results[os.path.basename(filename).replace(".json", "")] = load_json_as_dict(filename)

    labels = [
        "topo_1",
        "topo_5",
        "topo_10",
        "topo_50",
        "baseline",
        "untrained"
    ]

    results_rearranged = {}
    for label in labels:
        results_rearranged[label] = results.get(label, [])

    return results_rearranged

results = load_results(results_folder=RESULTS_FOLDER)

# Function to plot the results showing the "rise in loss"
def plot_results(results):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))  # Two subplots: 1 row, 2 columns
    
    l1_types = ["baseline", "topo_1", "topo_5", "topo_10", "topo_50"]
    downsampling_types = ["topo_1", "topo_5", "topo_10", "topo_50", "baseline"]

    # L1 compression plot (left subplot)
    for topo_type, entries in results.items():
        if topo_type in l1_types:
            factors = []
            results_list = []
            for entry in entries:
                if entry['compression_type'] == 'l1' or topo_type == 'untrained':
                    if entry["factor"] == 0: 
                        continue
                    factors.append(entry['factor'])
                    results_list.append(entry['result'])

            if results_list:
                # Calculate "rise in loss" relative to the first value
                initial_loss = results_list[0]
                rise_in_loss = [loss - initial_loss for loss in results_list]

                axs[0].plot(range(len(factors)), rise_in_loss, marker='o', label=f"{topo_type}")

    axs[0].set_title("L1 (Rise in Loss)")
    axs[0].set_xlabel("Fraction of masked weights in topo layers")
    axs[0].set_ylabel("Rise in Loss (compared to no sparsity)")
    axs[0].legend()

    factors = [0, 0.2, 0.4, 0.6, 0.8]
    axs[0].set_xticks(range(len(factors)))
    labels = ["no sparsity"]
    labels.extend([round(x,3) for x in factors[1:]])
    axs[0].set_xticklabels(labels)

    # Downsampling compression plot (right subplot)
    for topo_type, entries in results.items():
        if topo_type in downsampling_types:
            factors = []
            results_list = []
            for entry in entries:
                if entry['compression_type'] == 'downsampling' or topo_type == 'untrained':
                    if entry["factor"] == 0:
                        continue
                    factors.append(entry['factor'])
                    results_list.append(entry['result'])

            if results_list:
                # Calculate "rise in loss" relative to the first value
                initial_loss = results_list[0]
                rise_in_loss = [loss - initial_loss for loss in results_list]

                axs[1].plot(range(len(factors)), rise_in_loss, marker='o', label=f"{topo_type}")

    axs[1].set_title("Downsampling ")
    axs[1].set_ylabel("Rise in Loss (compared to no sparsity)")
    axs[1].legend()

    axs[1].set_xlabel("Fraction of weights removed from topo layers")

    factors = [0, 0.2, 0.4, 0.6, 0.8]
    axs[1].set_xticks(range(len(factors)))
    labels = ["no sparsity"]
    labels.extend([round(x,3) for x in factors[1:]])
    axs[1].set_xticklabels(labels)


    axs[0].spines['top'].set_visible(False)
    axs[0].spines['right'].set_visible(False)
    axs[1].spines['top'].set_visible(False)
    axs[1].spines['right'].set_visible(False)

    # Show the plot
    plt.tight_layout()
    plt.show()

    # Save the figure
    fig.savefig("perplexity_vs_fraction_of_masked_weights.pdf")

# Call the plotting function
plot_results(results)
