

# %%
import math
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.ticker import LogLocator, NullFormatter

from plot_results import load_records

# %%
RESULTS_PATH = Path("../../results/perturbation/perturbations.json")
X_COLUMN = "perturbation_norm"
Y_COLUMN = "perturbed_loss"
RANK_COLUMN = "rank"

NUM_BINS = 30
MIN_SAMPLES_PER_BIN = 3
COLORMAP = "plasma"
FIGSIZE = (11, 6)

# Theoretical transition bounds from Theorem 5.1 (Eq. 24 & 25).
# Set to numeric values (e.g., 120.0) to plot vertical threshold lines.
THEORETICAL_LOWER_BOUND = None
THEORETICAL_UPPER_BOUND = None
LOWER_BOUND_LABEL = "Lower bound (Eq. 24)"
UPPER_BOUND_LABEL = "Upper bound (Eq. 25)"

# %%
records = load_records(RESULTS_PATH)

# %%
df = pd.DataFrame(records)
if RANK_COLUMN not in df.columns:
    raise ValueError(
        f"Expected '{RANK_COLUMN}' in records; found {sorted(df.columns)}")
if X_COLUMN not in df.columns or Y_COLUMN not in df.columns:
    raise ValueError(
        f"Expected '{X_COLUMN}' and '{Y_COLUMN}' in records; found {sorted(df.columns)}")

df[RANK_COLUMN] = df[RANK_COLUMN].astype(int)
df[X_COLUMN] = pd.to_numeric(df[X_COLUMN], errors="coerce")
df[Y_COLUMN] = pd.to_numeric(df[Y_COLUMN], errors="coerce")
df = df.dropna(subset=[X_COLUMN, Y_COLUMN])
df = df[(df[X_COLUMN] > 0) & (df[Y_COLUMN] > 0)]

# %%
sns.set_theme(style="whitegrid")

norm_min = float(df[X_COLUMN].min())
norm_max = float(df[X_COLUMN].max())
if norm_min <= 0:
    raise ValueError(
        "Perturbation norms must be positive for log-scale binning.")

bin_edges = np.logspace(np.log10(norm_min), np.log10(norm_max), NUM_BINS + 1)
df["norm_bin"] = pd.cut(df[X_COLUMN], bins=bin_edges, include_lowest=True)

summary = (
    df.groupby([RANK_COLUMN, "norm_bin"], observed=True)
    .agg(
        mean_value=(Y_COLUMN, "mean"),
        std_value=(Y_COLUMN, "std"),
        count=(Y_COLUMN, "size"),
    )
    .reset_index()
)
summary["std_value"] = summary["std_value"].fillna(0.0)


def _bin_center(interval: pd.Interval) -> float:
    return math.sqrt(interval.left * interval.right)


summary["norm_bin_center"] = summary["norm_bin"].apply(_bin_center)
summary = summary[summary["count"] >= MIN_SAMPLES_PER_BIN]
if summary.empty:
    raise ValueError(
        "No bins have enough samples for plotting. "
        "Try lowering MIN_SAMPLES_PER_BIN or NUM_BINS."
    )

# %%
fig, ax = plt.subplots(figsize=FIGSIZE)

ranks = sorted(summary[RANK_COLUMN].unique())
cmap = plt.get_cmap(COLORMAP)
if len(ranks) > 1:
    color_map = {rank: cmap(i / (len(ranks) - 1))
                 for i, rank in enumerate(ranks)}
else:
    color_map = {ranks[0]: cmap(0.6)}

for rank in ranks:
    subset = summary[summary[RANK_COLUMN] ==
                     rank].sort_values("norm_bin_center")
    ax.plot(
        subset["norm_bin_center"],
        subset["mean_value"],
        color=color_map[rank],
        linewidth=2,
        label=str(rank),
    )
    ax.fill_between(
        subset["norm_bin_center"],
        subset["mean_value"] - subset["std_value"],
        subset["mean_value"] + subset["std_value"],
        color=color_map[rank],
        alpha=0.22,
        linewidth=0,
    )

if "baseline_loss" in df.columns:
    baseline_loss = float(df["baseline_loss"].iloc[0])
    ax.axhline(
        baseline_loss,
        color="gray",
        linestyle="--",
        linewidth=1.0,
        label="Baseline loss",
    )

if THEORETICAL_LOWER_BOUND is not None:
    if THEORETICAL_LOWER_BOUND <= 0:
        raise ValueError(
            "THEORETICAL_LOWER_BOUND must be positive for log-scale axes.")
    ax.axvline(
        THEORETICAL_LOWER_BOUND,
        color="black",
        linestyle="--",
        linewidth=1.2,
        label=LOWER_BOUND_LABEL,
    )

if THEORETICAL_UPPER_BOUND is not None:
    if THEORETICAL_UPPER_BOUND <= 0:
        raise ValueError(
            "THEORETICAL_UPPER_BOUND must be positive for log-scale axes.")
    ax.axvline(
        THEORETICAL_UPPER_BOUND,
        color="black",
        linestyle="--",
        linewidth=1.2,
        label=UPPER_BOUND_LABEL,
    )

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("Perturbation Frobenius norm (||Delta||)")
ax.set_ylabel("Perturbed loss")
ax.set_title("Perturbation Norm vs. Loss (Binned Mean +/- Std)")

ax.xaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))
ax.xaxis.set_minor_formatter(NullFormatter())
ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10) * 0.1))
ax.yaxis.set_minor_formatter(NullFormatter())
ax.grid(True, which="major", linewidth=0.6, alpha=0.6)
ax.grid(True, which="minor", linewidth=0.3, alpha=0.3)


def _geom_interp(low: float, high: float, frac: float) -> float:
    return math.exp(math.log(low) + frac * (math.log(high) - math.log(low)))


xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()

stable_y = _geom_interp(ymin, ymax, 0.18)
phase_y = _geom_interp(ymin, ymax, 0.55)
collapsed_y = _geom_interp(ymin, ymax, 0.82)

if THEORETICAL_LOWER_BOUND is not None and THEORETICAL_UPPER_BOUND is not None:
    phase_x = math.sqrt(THEORETICAL_LOWER_BOUND * THEORETICAL_UPPER_BOUND)
    if THEORETICAL_LOWER_BOUND > xmin:
        stable_x = _geom_interp(xmin, THEORETICAL_LOWER_BOUND, 0.6)
    else:
        stable_x = _geom_interp(xmin, xmax, 0.12)
    if THEORETICAL_UPPER_BOUND < xmax:
        collapsed_x = _geom_interp(THEORETICAL_UPPER_BOUND, xmax, 0.6)
    else:
        collapsed_x = _geom_interp(xmin, xmax, 0.78)
else:
    stable_x = _geom_interp(xmin, xmax, 0.12)
    phase_x = _geom_interp(xmin, xmax, 0.45)
    collapsed_x = _geom_interp(xmin, xmax, 0.78)

label_box = dict(boxstyle="round,pad=0.25", facecolor="white",
                 alpha=0.7, edgecolor="none")
ax.text(stable_x, stable_y, "Stable Regime",
        fontsize=10, fontweight="bold", bbox=label_box)
ax.text(phase_x, phase_y, "Phase Transition",
        fontsize=10, fontweight="bold", bbox=label_box)
ax.text(collapsed_x, collapsed_y, "Collapsed Regime",
        fontsize=10, fontweight="bold", bbox=label_box)

ax.legend(title="rank", frameon=False, loc="upper left")
fig.tight_layout()
plt.savefig("perturbation_norm_vs_loss_binned.pdf")
plt.show()

# %%
unique_norms = df[X_COLUMN].nunique()

if unique_norms <= 1:
    df["norm_bin"] = pd.Categorical(
        ["all norms"] * len(df), categories=["all norms"], ordered=True
    )
else:
    desired_bins = min(6, unique_norms)
    df["norm_bin"] = pd.qcut(df[X_COLUMN], q=desired_bins, duplicates="drop")

# %%
group_columns = ["norm_bin", RANK_COLUMN]
norm_rank_summary = (
    df.groupby(group_columns, observed=True)
    .agg(
        mean_value=(Y_COLUMN, "mean"),
        std_value=(Y_COLUMN, "std"),
        count=(Y_COLUMN, "size"),
    )
    .reset_index()
)

print("Loss grouped by perturbation_norm bins and rank:")
print(
    norm_rank_summary.assign(
        norm_bin=norm_rank_summary["norm_bin"].astype(str),
    )
)

# %%
min_samples = 10
valid_summary = norm_rank_summary[norm_rank_summary["count"] >= min_samples]

if valid_summary.empty:
    print(
        f"No perturbation_norm and rank groups have at least {min_samples} samples. "
        "Skipping plot."
    )
else:
    filtered_df = df.merge(
        valid_summary[group_columns], on=group_columns, how="inner")
    filtered_df["norm_bin"] = filtered_df["norm_bin"].cat.remove_unused_categories()

    norm_rank_plot = sns.catplot(
        data=filtered_df,
        x=RANK_COLUMN,
        y=Y_COLUMN,
        col="norm_bin",
        kind="box",
        col_wrap=3,
        sharey=False,
        height=4,
    )
    norm_rank_plot.set_axis_labels("Rank", "Perturbed loss")
    norm_rank_plot.set(yscale="log")
    norm_rank_plot.fig.subplots_adjust(top=0.9)
    norm_rank_plot.fig.suptitle(
        "Loss distribution grouped by perturbation norm bins and rank"
    )

# %%
if {"target_norm", "delta_loss"}.issubset(df.columns):
    df["target_norm"] = df["target_norm"].astype(float)
    rank_numeric = df[RANK_COLUMN].astype(int)
    df["rank_numeric"] = rank_numeric

    delta_summary = (
        df.groupby(["target_norm", RANK_COLUMN], observed=True)
        .agg(
            mean_delta_loss=("delta_loss", "mean"),
            std_delta_loss=("delta_loss", "std"),
            trials=("delta_loss", "size"),
        )
        .reset_index()
    )

    print("Delta loss summary by target_norm and rank:")
    print(
        delta_summary.assign(
            rank=delta_summary[RANK_COLUMN].astype(str),
            target_norm=delta_summary["target_norm"].map("{:.4f}".format),
        )
    )

    delta_plot = sns.relplot(
        data=df,
        x="target_norm",
        y="delta_loss",
        hue=RANK_COLUMN,
        kind="scatter",
        height=5,
        aspect=1.4,
    )
    delta_plot.set(xscale="log")

plt.show()

# %%
