from parsing import parse_eval_results
import argparse
from nesim.utils.json_stuff import load_json_as_dict
import numpy as np
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
from nesim.utils.folder import get_filenames_in_a_folder
import os
from parsing import color_map

RESULTS_FOLDER = "./results"

def hex_to_rgb(hex_color: str) -> tuple:
    # Remove the '#' if present
    hex_color = hex_color.lstrip('#')
    
    # Convert short form hex (e.g., #RGB) to full form (e.g., #RRGGBB)
    if len(hex_color) == 3:
        hex_color = ''.join([c*2 for c in hex_color])
    
    # Convert the hex to an RGB tuple
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def load_results(results_folder):
    results = {}

    for filename in get_filenames_in_a_folder(results_folder):
        results[os.path.basename(filename).replace(".json", "")] = load_json_as_dict(filename)

    labels = [
        "topo_1",
        "topo_5",
        "topo_10",
        "topo_50",
        "baseline",
        "untrained"
    ]

    results_rearranged = {}
    for label in labels:
        results_rearranged[label] = results[label]

    results =results_rearranged
    return results

results = load_results(results_folder=RESULTS_FOLDER)

x_labels = {
    "baseline": "Baseline",
    "topo_1": "$\\tau$ = 1",
    "topo_5": "$\\tau$ = 5",
    "topo_10": "$\\tau$ = 10",
    "topo_50": "$\\tau$ = 50"
}

compression_type_labels = {
    "downsampling": "Downsampling",
    "l1": "L1 Sparsity"
}

def plot_loss_increase(data, factor=9, filename: str = "figure.png", fontsize = 23):
    # apply_ratan_matplotlib_thing()
    model_types = ["untrained", 'baseline', 'topo_1', 'topo_5', 'topo_10', 'topo_50']
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize = (10,20))
    count = 0
    for compression_types in [["l1"], ["downsampling"]]:
    
        loss_increases = {model: {comp: 0 for comp in compression_types} for model in model_types}
        
        for model in model_types:
            base_loss = next(item['result'] for item in data[model] if item['compression_type'] is None)
            
            for comp in compression_types:
                comp_losses = [item['result'] for item in data[model] if item['compression_type'] == comp and item["factor"] == factor]
                if comp_losses:
                    avg_comp_loss = sum(comp_losses) / len(comp_losses)
                    loss_increases[model][comp] = avg_comp_loss - base_loss
        
        x = np.arange(len(model_types)-1)
        
        rects1 = ax[count].bar(
            [x_labels[key] for key in model_types[1:]], 
            [loss_increases[model][compression_types[0]] for model in model_types[1:]], 
            label=compression_types[0], 
            color = [color_map[key]["downsampling"] for key in model_types[1:]]
        )
        # rects2 = ax.bar(x + width/2, [loss_increases[model]['downsampling'] for model in model_types[1:]], width, label='Downsampling')
        
        # Calculate and plot untrained model loss comparison
        untrained_model_loss_no_compression = []
        for model_type in model_types:
            if model_type != "untrained":
                untrained_losses  = [
                    x['result'] for x in data["untrained"]
                ]

                mean_untrained_loss = sum(untrained_losses)/len(untrained_losses)
                untrained_loss_in_delta_space = mean_untrained_loss - data[model_type][0]["result"]
                
                untrained_model_loss_no_compression.append(untrained_loss_in_delta_space)
             
        ax[count].plot(x, untrained_model_loss_no_compression, '--', label='As good as untrained', color="#808080")
        
        # Add markers for each point on the dashed line
        ax[count].scatter(x, untrained_model_loss_no_compression, color='#808080', zorder=3)
        
        # Add value labels for the dashed line
        # for i, value in enumerate(untrained_model_loss_no_compression):
        #     ax.annotate(
        #         f'{value:.2f}', 
        #         (x[i], value), 
        #         textcoords="offset points", 
        #         xytext=(0,10), 
        #         ha='center', 
        #         fontsize = fontsize
        #     )

        # ax[count].set_ylabel('$\\Delta$ Loss', fontsize=fontsize)
        # ax[count].set_xlabel('Model Type', fontsize=fontsize)
        # ax[count].set_title(f'{compression_type_labels[compression_types[0]]}', fontsize=fontsize)
        # ax[count].set_xticks(x)
        # ax[count].set_xticklabels(model_types[1:], fontsize = fontsize)
        ax[count].tick_params(axis='both', labelsize=fontsize)

        # ax[count].legend(fontsize=fontsize)
        
        ax[count].bar_label(rects1, padding=3, fmt='%.2f', fontsize=fontsize)
        # ax[count].set_ylim(0.0, 4.)
        # ax[count].axhline(0.0, c = "black", linestyle = "--")

        ax[count].spines['top'].set_visible(False)
        ax[count].spines['right'].set_visible(False)
        ax[count].set_yticks(
            ticks = [0,1,2,3,4],
            labels = [0,1,2,3,4]
        )
        # ax.bar_label(rects2, padding=3, fmt='%.2f')
        count += 1
    
    fig.tight_layout()
    plt.show()
    fig.savefig(filename, dpi = 300)

for factor in [3, 6, 9, 18, 36, 45, 81]:
    plot_loss_increase(
        data=results, 
        factor=factor, 
        filename = f"assets/delta_loss_factor_{factor}.pdf", 
        fontsize=29
    )