import json
import glob
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
from matplotlib.patches import FancyArrowPatch

# Written with significant assistance by ChatGPT
# Generated plots were manually checked for consistency with the raw data


# Define the parameter space
datasets = ["solar", "metr-la", "electricity", "pems-bay"]
horizons = [6, 24, 48, 96]
kernel = "hat"

# Create a data structure to store results
results = {}
for dataset in datasets:
    results[dataset] = {}
    for horizon in horizons:
        results[dataset][horizon] = {}
        filename_sswim = f"sswim/results_timing/{dataset}_{kernel}_{horizon}.json"
        with open(filename_sswim, 'r') as f:
            data = json.load(f)
        time = data["averages"]["time"]
        results[dataset][horizon]["sswim"] = time

        filename_sgd_time = f"sgd/results_timing/{dataset}_{kernel}_{horizon}.json"
        try:
            with open(filename_sgd_time, 'r') as f:
                data = json.load(f)
            time = data["averages"]["time"]
            results[dataset][horizon]["sgd_single"] = time
            results[dataset][horizon]["sgd_total"] = time * data["averages"]["best_epoch"]

            filename_sgd_full = f"sgd/results/{dataset}_{kernel}_{horizon}.json"
            with open(filename_sgd_full, 'r') as f:
                data = json.load(f)
            results[dataset][horizon]["sgd_total"] = time * data["averages"]["best_epoch"]
        except FileNotFoundError as e:
            print(f"Warning: {e}")
            results[dataset][horizon]["sgd_total"] = 0
            results[dataset][horizon]["sgd_single"] = 0

width = 0.15            # bar width
inner_gap = 2.5 * width        # gap between bars inside the same dataset group
group_gap = 4 * width         # extra gap between dataset groups (in data units)

positions = []          # x position for each individual bar
horizon_labels = []     # label "H=..." for each bar
series_sswim = []
series_sgd_single = []
series_sgd_total = []
dataset_label_positions = {}  # center position where we'll show the dataset name


datasets = list(results.keys())  # preserve insertion order
# For each dataset keep its own sorted horizon list (not necessarily the same across datasets)
dataset_horizons = {d: sorted(results[d].keys()) for d in datasets}

x = 0.0
for d in datasets:
    hs = dataset_horizons[d]
    start_x = x
    # positions for bars within this dataset
    for i, h in enumerate(hs):
        pos = x + i * (width + inner_gap)
        positions.append(pos)
        horizon_labels.append(f"{h}")
        series_sswim.append(results[d][h]['sswim'])
        series_sgd_single.append(results[d][h]['sgd_single'])
        series_sgd_total.append(results[d][h]['sgd_total'])
    # choose center index: middle bar, or first of the two middle if even count
    n = len(hs)
    center_idx = start_x + ((n - 1) // 2) * (width + inner_gap) if n > 0 else start_x
    # compute actual center position value:
    if n > 0:
        center_pos = start_x + ((n - 1) // 2) * (width + inner_gap)
    else:
        center_pos = start_x
    dataset_label_positions[d] = center_pos
    # advance x past this group (last bar pos + width) and add group_gap
    x = start_x + n * (width + inner_gap) + group_gap

positions = np.array(positions)

# Create plot
fig, ax = plt.subplots(figsize=(16, 6.3))

# Plot three series as grouped bars
bars_sswim = ax.bar(positions - width, series_sswim, width, label=r'\textit{S-SWIM}')
bars_sgd_single = ax.bar(positions, series_sgd_single, width, label=r'\textit{SGD} $1$ Epoch')
bars_sgd_total = ax.bar(positions + width, series_sgd_total, width, label=r'\textit{SGD} Total')

# Log scale on y-axis
ax.set_yscale('log')

# X-axis: show horizon labels under every bar, and dataset label only at the computed center positions.
ax.set_xticks(positions)
# Construct composite ticklabels: put dataset name above H=... only at the chosen center positions
composite_labels = []
# Build a reverse mapping from position -> dataset name if position is one of the dataset centers
# (due to floating arithmetic, compare with a tolerance)
tol = 1e-8
center_positions = list(dataset_label_positions.values())
center_names = list(dataset_label_positions.keys())

for pos, h_label in zip(positions, horizon_labels):
    # find if pos is (approximately) equal to any center_pos
    matched_name = None
    for name, cpos in dataset_label_positions.items():
        if abs(pos - cpos) < 1e-6:
            matched_name = name
            break
    if matched_name is not None:
        composite_labels.append(f"{h_label}\n           {matched_name.capitalize()} (H)")
    else:
        composite_labels.append(h_label)

ax.set_xticklabels(composite_labels, rotation=0, ha='center', fontsize=14)

# Labels, title, legend
ax.set_ylabel(r"Time (seconds) — $\log_{10}$", fontsize=14)
with plt.rc_context({"text.usetex": True, "font.family": "serif"}):
    ax.legend(fontsize="16")
ax.grid(axis='y', linestyle='--', alpha=0.7, which='both')

# Annotate bars with numeric values (two decimal places)
def annotate(bars, fmt="{:.2f}", fontsize=11, sswim_hs=None):
    bbox = dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray", alpha=0.85)
    hs  = []
    for idx, bar in enumerate(bars):
        h = bar.get_height()
        hs.append(h)
        if sswim_hs is not None:
            sswim_h = sswim_hs[idx]
            if np.abs(h - sswim_h) < 75:
                if h > sswim_h:
                    shift = h * 0.15
                else:
                    shift = -h * 0.3
            else:
                shift = 0
        else:
            shift = 0
        ax.annotate(fmt.format(h),
                    xy=(bar.get_x() + bar.get_width() / 2, h + shift),
                    xytext=(0, 6), textcoords="offset points",
                    ha='center', va='bottom', fontsize=fontsize,
                    bbox=bbox)
    return hs


sswim_h = annotate(bars_sswim, "{:.1f}")
annotate(bars_sgd_single, "{:.1f}", sswim_hs=sswim_h)
annotate(bars_sgd_total, "{:.1f}", sswim_hs=None)

plt.tight_layout()
plt.savefig("artefacts/benchmark_time.svg", format="svg", bbox_inches="tight", pad_inches=0)
plt.show()