import json
import glob
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
from matplotlib.patches import FancyArrowPatch

# Written with significant assistance by DeepSeek and Claude
# Generated plots were manually checked for consistency with the raw data

OBJECTIVE_NAMES = {
    "random": "Random Sampling",
    "dot": "Dot",
    "dist": "Dist"
}

NORMALISER_NAMES = {
    ("MS", 0.5, 0.5): "MS (0.5, 0.5)",
    ("MS", 0, 1): "MS (0, 1)",
    ("F", 0.5): "F (0.5)",
    ("F", 1): "F (1.0)"
}

# Define the parameter space
datasets = ["electricity", "pems-bay"]
n_hidden_values = [50, 100, 150, 250, 350, 450, 550, 650, 750]
objectives = ["random", "dot", "dist"]
normalisers = [("MS", 0.5, 0.5), ("MS", 0, 1), ("F", 0.5), ("F", 1)]

# Create a data structure to store results
results = {}
for dataset in datasets:
    results[dataset] = {}
    for n_hidden in n_hidden_values:
        results[dataset][n_hidden] = {}
        for objective in objectives:
            results[dataset][n_hidden][objective] = {}
            for normaliser in normalisers:
                # Construct filename
                normaliser_str = str(normaliser)
                filename = f"sswim/results_ablation/{dataset}_{n_hidden}_{objective}_{normaliser_str}.json"

                # Try to load the file
                try:
                    with open(filename, 'r') as f:
                        data = json.load(f)

                    # Extract the metrics
                    r2_test = data["averages"]["r2_test"]
                    rse_test = data["averages"]["rse_test"]

                    # Store results
                    results[dataset][n_hidden][objective][normaliser] = {
                        "r2_test": r2_test,
                        "rse_test": rse_test
                    }
                except FileNotFoundError:
                    print(f"File not found: {filename}")
                    results[dataset][n_hidden][objective][normaliser] = None
                except Exception as e:
                    print(f"Error reading {filename}: {e}")
                    results[dataset][n_hidden][objective][normaliser] = None

# Create plots for each dataset
for dataset in datasets:
    all_handles = {}

    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    #fig.suptitle(f'{dataset.capitalize()}', fontsize=16)

    # Define colors and markers for different objectives
    colors = {'random': 'blue', 'dot': 'green', 'dist': 'red'}
    markers = {
        str(normalisers[0]): 'o',
        str(normalisers[1]): 's',
        str(normalisers[2]): 'o',
        str(normalisers[3]): 's'
    }

    # Set y-axis limits for all plots
    y_limits = (0, 1)


    ax = axes[0]
    ax.set_ylim(y_limits)
    for objective in objectives:
        for normaliser in [n for n in normalisers if n[0] == "MS"]:
            if normaliser in NORMALISER_NAMES:
                label = f"{OBJECTIVE_NAMES[objective]} - {NORMALISER_NAMES[normaliser]}"
            else:
                label = f"{objective} - {normaliser}"

            # Extract data for this line
            x_vals = []
            y_vals = []
            for n_hidden in n_hidden_values:
                if (results[dataset][n_hidden][objective][normaliser] is not None and
                        results[dataset][n_hidden][objective][normaliser] != "File not found"):
                    x_vals.append(n_hidden)
                    y_val = results[dataset][n_hidden][objective][normaliser]["rse_test"]
                    y_vals.append(y_val)

                    # Add arrow for out-of-bounds values
                    if y_val < 0 or y_val > 1:
                        marker = "v" if y_val < 0 else "^"
                        arrow_y = 0.05 if y_val < 0 else 0.95

                        # Arrowhead as a scatter triangle
                        ax.scatter(
                            [n_hidden],
                            [arrow_y],
                            marker=marker,  # triangle marker, rotated
                            s=50,  # size of arrowhead
                            facecolors=colors[objective],  # interior color
                            edgecolors='black',  # border color
                            linewidths=1.5,
                            zorder=5
                        )

            if x_vals:
                ln, = ax.plot(x_vals, y_vals,
                              marker=markers[str(normaliser)],
                              color=colors[objective],
                              label=label,
                              linewidth=2,
                              markersize=8,
                              markeredgecolor='black',  # recommended so shapes render clearly
                              markeredgewidth=1.2,
                              markerfacecolor=colors[objective])
                # store handle (overwrite is fine if same label would appear multiple times;
                # if you want the first one, check 'if label not in all_handles:' before storing)
                if label not in all_handles:
                    all_handles[label] = ln

    ax.set_title('MS Normalisers')
    ax.set_xlabel('Number of Neurons')
    ax.set_ylabel('RSE')
    ax.grid(True, linestyle='--', alpha=0.7)

    ax = axes[1]
    ax.set_ylim(y_limits)
    for objective in objectives:
        for normaliser in [n for n in normalisers if n[0] == "F"]:
            if normaliser in NORMALISER_NAMES:
                label = f"{OBJECTIVE_NAMES[objective]} - {NORMALISER_NAMES[normaliser]}"
            else:
                label = f"{objective} - {normaliser}"

            # Extract data for this line
            x_vals = []
            y_vals = []
            for n_hidden in n_hidden_values:
                if (results[dataset][n_hidden][objective][normaliser] is not None and
                        results[dataset][n_hidden][objective][normaliser] != "File not found"):
                    x_vals.append(n_hidden)
                    y_val = results[dataset][n_hidden][objective][normaliser]["rse_test"]
                    y_vals.append(y_val)

                    # Add arrow for out-of-bounds values
                    if y_val < 0 or y_val > 1:
                        marker = "v" if y_val < 0 else "^"
                        arrow_y = 0.05 if y_val < 0 else 0.95

                        # Arrowhead as a scatter triangle
                        ax.scatter(
                            [n_hidden],
                            [arrow_y],
                            marker=marker,  # triangle marker, rotated
                            s=50,  # size of arrowhead
                            facecolors=colors[objective],  # interior color
                            edgecolors='black',  # border color
                            linewidths=1.5,
                            zorder=5
                        )

            if x_vals:
                ln, = ax.plot(x_vals, y_vals,
                              marker=markers[str(normaliser)],
                              color=colors[objective],
                              label=label,
                              linewidth=2,
                              markersize=8,
                              markeredgecolor='black',  # recommended so shapes render clearly
                              markeredgewidth=1.2,
                              markerfacecolor=colors[objective])
                # store handle (overwrite is fine if same label would appear multiple times;
                # if you want the first one, check 'if label not in all_handles:' before storing)
                if label not in all_handles:
                    all_handles[label] = ln

    ax.set_title('F Normalisers')
    ax.set_xlabel('Number of Neurons')
    ax.set_ylabel('RSE')
    ax.grid(True, linestyle='--', alpha=0.7)


    plt.draw()
    # Create a 2x3 grid layout for the legend
    ncols = 3  # Number of objectives
    nrows = 2  # Number of normaliser types (MS and F)

    reordered_handles = []
    reordered_labels = []
    for objective in objectives:
        for params in [
            ("(0.5, 0.5)", "(0.5)", 'o'), ("(0, 1)", "(1)", 's')
        ]:
            legend_marker = params[2]
            label = f"{OBJECTIVE_NAMES[objective]} - MS{params[0]} / Fl{params[1]}"

            color = colors[objective]

            # create proxy handle with your chosen marker
            handle = Line2D([0], [0],
                            marker=legend_marker,
                            linestyle='-',
                            color=color,
                            markersize=8,
                            markerfacecolor=color,
                            markeredgecolor='black',
                            markeredgewidth=1.2)
            reordered_handles.append(handle)
            reordered_labels.append(label)

    fig.legend(reordered_handles, reordered_labels,
               loc='lower center',
               bbox_to_anchor=(0.5, -0.01),
               ncol=3,
               fontsize=15,
               frameon=True,
               fancybox=True,
               shadow=True)

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)  # Make room for the legend
    plt.savefig(f"artefacts/{dataset}_ablation_study.svg", format="svg", bbox_inches="tight", pad_inches=0)
    plt.savefig(f"artefacts/{dataset}_ablation_study.png", format="png", bbox_inches="tight", pad_inches=0)
    plt.show()
    plt.close()

print("Plots saved as electricity_ablation_study.svg and pems-bay_ablation_study.svg")