import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os

# Written with significant assistance by Claude
# Generated plots were manually checked for consistency with the raw data

# Create output directory if it doesn't exist
os.makedirs("artefacts", exist_ok=True)

# Define datasets and horizons
datasets = ["Solar", "Metr-la", "Electricity", "Pems-bay"]
horizons = [6, 24, 48]
metrics_list = [
    "best",
    ("FourierMag", "FourierMag"),
    ("L2", "L2"),
    ("BandedFourierHigh", "BandedFourierHigh"),
    ("CosineDistance", "CosineDistance"),
    ("FourierAngle", "FourierAngle"),
    ("BandedFourierLow", "BandedFourierLow")
]

# Colors for different metrics (consistent across plots)
colors = {
    "best": "#1f77b4",
    ("FourierMag", "FourierMag"): "#ff7f0e",
    ("L2", "L2"): "#2ca02c",
    ("BandedFourierHigh", "BandedFourierHigh"): "#d62728",
    ("CosineDistance", "CosineDistance"): "#9467bd",
    ("FourierAngle", "FourierAngle"): "#8c564b",
    ("BandedFourierLow", "BandedFourierLow"): "#e377c2"
}

# Labels for metrics (for legend)
metric_labels = {
    "best": "Entropy Heuristic",
    ("FourierMag", "FourierMag"): "Fourier Magnitude",
    ("L2", "L2"): "L2",
    ("BandedFourierHigh", "BandedFourierHigh"): "High-Frequency",
    ("CosineDistance", "CosineDistance"): "Cosine Distance",
    ("FourierAngle", "FourierAngle"): "Fourier Phase",
    ("BandedFourierLow", "BandedFourierLow"): "Low-Frequency"
}

# Initialize data storage
results = {dataset: {metric: {"r2": [], "rse": []} for metric in metrics_list}
           for dataset in datasets}

# Load all data
base_path = Path("sswim/results_metric")
missing_files = []

for dataset in datasets:
    for H in horizons:
        for metric in metrics_list:
            filename = f"{dataset.lower()}_{H}_{metric}.json"
            filepath = base_path / filename

            try:
                with open(filepath, 'r') as f:
                    data = json.load(f)

                # Extract averages
                if "averages" in data:
                    results[dataset][metric]["r2"].append(data["averages"]["r2_test"])
                    results[dataset][metric]["rse"].append(data["averages"]["rse_test"])
                else:
                    # If averages not found, append NaN
                    results[dataset][metric]["r2"].append(np.nan)
                    results[dataset][metric]["rse"].append(np.nan)

            except FileNotFoundError:
                missing_files.append(filename)
                # Append NaN for missing files
                results[dataset][metric]["r2"].append(np.nan)
                results[dataset][metric]["rse"].append(np.nan)
            except (json.JSONDecodeError, KeyError) as e:
                print(f"Error reading {filename}: {e}")
                results[dataset][metric]["r2"].append(np.nan)
                results[dataset][metric]["rse"].append(np.nan)

if missing_files:
    print(f"Warning: {len(missing_files)} files not found")
    print(f"First few missing: {missing_files[:5]}")

# Create the figure with subplots
fig, axes = plt.subplots(1, 4, figsize=(16, 5))

# Plot RSE scores (bottom row)
for col, dataset in enumerate(datasets):
    ax = axes[col]

    for metric in metrics_list:
        rse_values = results[dataset][metric]["rse"]
        if not all(np.isnan(rse_values)):  # Only plot if we have data
            ax.plot(horizons, rse_values,
                    marker='o',
                    color=colors[metric],
                    label=metric_labels[metric],
                    linewidth=2,
                    markersize=6,
                    linestyle=('dashed' if metric != "best" else 'solid'))

    ax.set_xlabel('Horizon', fontsize=11)
    if col == 0:
        ax.set_ylabel('RSE', fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.set_xticks(horizons)
    ax.set_xlim(horizons[0] - 2, horizons[-1] + 2)

    # Set y-axis limits with some padding
    all_rse = [v for m in metrics_list for v in results[dataset][m]["rse"] if not np.isnan(v)]
    if all_rse:
        y_min, y_max = min(all_rse), max(all_rse)
        y_range = y_max - y_min
        ax.set_ylim(y_min - 0.05 * y_range, y_max + 0.05 * y_range)

# Add a single legend for the entire figure
handles, labels = axes[0].get_legend_handles_labels()
if handles:  # Only add legend if we have data
    fig.legend(handles, labels,
               loc='center',
               bbox_to_anchor=(0.5, 0.05),
               ncol=len(metrics_list),
               fontsize=15,
               frameon=True,
               fancybox=True,
               shadow=True)

# Adjust layout to prevent overlap
plt.tight_layout()
plt.subplots_adjust(bottom=0.2, top=0.99)

# Save the figure
output_path = "artefacts/metric_comp.svg"
plt.savefig(output_path, format='svg', bbox_inches='tight', dpi=150,  pad_inches=0)
print(f"Plot saved to {output_path}")

# Also save as PNG for quick viewing
plt.savefig("artefacts/metric_comp.png", format='png', bbox_inches='tight', dpi=150, pad_inches=0)
print("Also saved as PNG for preview")

plt.show()