from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
import matplotlib.pyplot as plt
import argparse
from parsing import parse_eval_results
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
from nesim.utils.folder import get_filenames_in_a_folder
import os

RESULTS_FOLDER = "./results"

def load_results(results_folder):
    results = {}

    for filename in get_filenames_in_a_folder(results_folder):
        results[os.path.basename(filename).replace(".json", "")] = load_json_as_dict(filename)

    labels = [
        "topo_1",
        "topo_5",
        "topo_10",
        "topo_50",
        "baseline",
        "untrained"
    ]

    results_rearranged = {}
    for label in labels:
        results_rearranged[label] = results[label]

    results =results_rearranged
    return results

results = load_results(results_folder=RESULTS_FOLDER)

compression_techniques = [
    "l1", "downsampling"
]
compression_technique_labels = {
    "downsampling": "Downsampling",
    "l1": "L1 Sparsity"
}

def plot_results(parsed_data: dict, filename: str, fontsize=18):
    apply_ratan_matplotlib_thing()
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize = (8 , 19))
    """
    make it 2 subplots.
    one for downsampling and one for upsampling
    """

    for col, compression_technique in enumerate(compression_techniques):

        for data in parsed_data:
            if data["label"].split(" ")[-1].lower() == compression_technique:
                ax[col].plot(
                    # data["x"],
                    range(len(parsed_data[0]["x"])),
                    data["y"],
                    label=data["label"].replace(" l1", "").replace("Topo ", "").replace("(", "").replace(")",""),
                    c=data["color"],
                    marker="o"
                )
        # ax[col].set_title(f"{compression_technique_labels[compression_technique]}", fontsize=fontsize)
        # Set the x-ticks to match the x values

        labels = [
                f"{factor}" for factor in parsed_data[0]["x"]
            ]
        
        labels[-1] = labels[-1]+"x"

        ax[col].set_xticks(
            range(len(parsed_data[0]["x"])),
            # data["x"],
            labels=labels,
            fontsize=fontsize+4
        )
        ax[col].set_ylim(ymax=11.5)
        ax[col].tick_params(labelsize=fontsize+4, axis = "y")
        
        ax[col].spines['top'].set_visible(False)
        ax[col].spines['right'].set_visible(False)

        # Adjust the legend to be outside the plot on the right
        if col == 0:
            ax[col].legend(loc='best', fontsize=fontsize)
        # ax[col].grid()

        # ax[col].set_ylabel("Model Loss\n(on openwebtext)", fontsize=fontsize)

        # if col == 1:
        #     ax[col].set_xlabel("Compression in\ntopographic space", fontsize=fontsize)
        # ax.set_xscale("log")
    
    if filename is not None:
        fig.savefig(filename, bbox_inches='tight', format='pdf', dpi=300)


parsed_data = parse_eval_results(results=results, max_num_compression_factors=8)


dict_to_json(dictionary=parsed_data, filename = "parsed_results.json")
plot_results(parsed_data=parsed_data, filename = "compression_vs_loss_openwebtext.pdf", fontsize=21)