#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import os

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.gridspec import GridSpec


# -----------------------------
# A. Style utilities (Nature-like)
# -----------------------------
def apply_nature_style():
    """
    Applies a Nature-like plotting style.
    MODIFIED: Increased font sizes for labels and titles for better readability.
    """
    mpl.rcParams.update(
        {
            "figure.dpi": 150,
            "savefig.dpi": 300,
            "pdf.fonttype": 42,  # embeddable fonts
            "ps.fonttype": 42,
            "font.family": "sans-serif",
            "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
            "mathtext.fontset": "stixsans",
            "axes.spines.right": False,
            "axes.spines.top": False,
            "axes.linewidth": 0.8,
            "axes.labelsize": 16,  # Increased from 12
            "axes.titlesize": 16,  # Increased from 12
            "xtick.labelsize": 16,  # Increased from 12
            "ytick.labelsize": 16,  # Increased from 12
            "xtick.direction": "out",
            "ytick.direction": "out",
            "xtick.major.size": 3,
            "ytick.major.size": 3,
            "legend.frameon": False,
            "legend.fontsize": 11,  # Increased from 9
            "axes.grid": False,
            "grid.color": "#e6e6e6",
            "grid.linestyle": "-",
            "grid.linewidth": 0.4,
            "figure.constrained_layout.use": False,
        }
    )


def add_panel_label(ax, label, x=0, y=1.02):
    """
    Adds a panel label (e.g., 'a', 'b') to the top-left of an axes.
    MODIFIED: The label is placed just outside the axes area to prevent
    it from shrinking the plot under a constrained layout.
    """
    ax.text(
        x,
        y,
        label,
        transform=ax.transAxes,
        fontsize=16,
        fontweight="bold",
        va="bottom",
        ha="right",  # Align the right-hand side of the text to the y-axis
    )


# -----------------------------
# B. IO helpers
# -----------------------------
def read_csv_or_die(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing CSV: {path}")
    return pd.read_csv(path)


def read_json_or_die(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing JSON: {path}")
    with open(path, "r") as f:
        return json.load(f)


# -----------------------------
# C. Combined Figure Function
# -----------------------------


def _get_array(d, keys):
    for k in keys:
        if k in d and isinstance(d[k], (list, tuple)):
            return np.array(d[k], dtype=float)
    return np.array([])


def fig_combined_circuits(
    summary_csv, robust_json, out_pdf="fig_circuits_combined.pdf"
):
    # --- Part 1: Data loading ---
    df = read_csv_or_die(summary_csv).copy()
    needed = [
        "layer",
        "head",
        "pointer_ring",
        "pointer_paren",
        "delta_margin_ring",
        "delta_margin_paren",
    ]
    for c in needed:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    L, H = int(df["layer"].max()) + 1, int(df["head"].max()) + 1

    data = read_json_or_die(robust_json)
    ring_pm_curve = _get_array(
        data, ["ring_pointer_mass", "ring_pointer", "pointer_mass_ring"]
    )
    ring_dM_curve = _get_array(
        data, ["ring_delta_margin", "ring_dM", "delta_margin_ring"]
    )
    par_pm_curve = _get_array(
        data, ["paren_pointer_mass", "paren_pointer", "pointer_mass_paren"]
    )
    par_dM_curve = _get_array(
        data, ["paren_delta_margin", "paren_dM", "delta_margin_paren"]
    )
    ring_name, par_name = data.get("ring_head", "ring"), data.get("paren_head", "paren")
    ring_bins = _get_array(data, ["ring_bins", "span_bins", "distance"])
    par_bins = _get_array(data, ["paren_bins", "depth_bins", "depth"])

    # --- Part 2: Data processing ---
    def to_matrix(value_col):
        mat = np.full((H, L), np.nan, dtype=float)
        for _, r in df.iterrows():
            mat[int(r["head"]), int(r["layer"])] = float(r[value_col])
        return mat

    ring_pm, parn_pm = to_matrix("pointer_ring"), to_matrix("pointer_paren")
    ring_dm, parn_dm = to_matrix("delta_margin_ring"), to_matrix("delta_margin_paren")

    def bin_labels(arr):
        if arr.size >= 2 and np.all(np.diff(arr) > 0):
            return [f"{int(a)}–{int(b)}" for a, b in zip(arr[:-1], arr[1:])]
        return [str(int(x)) for x in arr] if arr.size else []

    ring_labels, par_labels = bin_labels(ring_bins), bin_labels(par_bins)

    def align(x, y):
        n = int(min(len(x), len(y)))
        return x[:n], y[:n]

    ring_pm_curve, ring_dM_curve = align(ring_pm_curve, ring_dM_curve)
    par_pm_curve, par_dM_curve = align(par_pm_curve, par_dM_curve)
    if len(ring_labels) != len(ring_pm_curve):
        ring_labels = [str(i + 1) for i in range(len(ring_pm_curve))]
    if len(par_labels) != len(par_pm_curve):
        par_labels = [str(i + 1) for i in range(len(par_pm_curve))]

    # --- Part 3: Plotting Setup using a single, unified GridSpec ---
    fig = plt.figure(figsize=(12.5, 6.5), constrained_layout=True)
    # MODIFICATION: Changed width_ratios to make the first two columns square.
    gs = GridSpec(
        2,
        4,
        figure=fig,
        width_ratios=[1.8, 1.8, 0.1, 1.8],
        hspace=0.05,
        wspace=0.07,
    )

    ax_r_pm = fig.add_subplot(gs[0, 0])
    ax_p_pm = fig.add_subplot(gs[0, 1])
    cax_pm = fig.add_subplot(gs[0, 2])
    ax_curve_r = fig.add_subplot(gs[0, 3])

    ax_r_dm = fig.add_subplot(gs[1, 0])
    ax_p_dm = fig.add_subplot(gs[1, 1])
    cax_dm = fig.add_subplot(gs[1, 2])
    ax_curve_p = fig.add_subplot(gs[1, 3])

    # --- Part 4: Plotting Heatmaps (Left Side) ---
    pm_vmax = float(np.nanmax(np.stack([ring_pm, parn_pm])))
    dm_vals = np.stack([ring_dm, parn_dm])
    dm_vmin = (
        float(np.nanmin(dm_vals[np.isfinite(dm_vals)]))
        if np.isfinite(dm_vals).any()
        else 0.0
    )
    dm_vmax = (
        float(np.nanmax(dm_vals[np.isfinite(dm_vals)]))
        if np.isfinite(dm_vals).any()
        else 1.0
    )
    if dm_vmin > 0:
        dm_vmin = 0.0
    extent = [-0.5, L - 0.5, H - 0.5, -0.5]

    im_r_pm = ax_r_pm.imshow(
        np.nan_to_num(ring_pm, nan=0.0),
        cmap="viridis",
        vmin=0.0,
        vmax=pm_vmax,
        aspect="auto",
        extent=extent,
    )
    ax_p_pm.imshow(
        np.nan_to_num(parn_pm, nan=0.0),
        cmap="viridis",
        vmin=0.0,
        vmax=pm_vmax,
        aspect="auto",
        extent=extent,
    )
    im_r_dm = ax_r_dm.imshow(
        np.nan_to_num(ring_dm, nan=0.0),
        cmap="magma",
        vmin=dm_vmin,
        vmax=dm_vmax,
        aspect="auto",
        extent=extent,
    )
    ax_p_dm.imshow(
        np.nan_to_num(parn_dm, nan=0.0),
        cmap="magma",
        vmin=dm_vmin,
        vmax=dm_vmax,
        aspect="auto",
        extent=extent,
    )

    # MODIFICATION START: Remove redundant y-axis labels
    # Apply shared settings to all four heatmap axes
    for ax in [ax_r_pm, ax_p_pm, ax_r_dm, ax_p_dm]:
        ax.set_xlabel("Layer")
        ax.set_xticks(np.arange(0, L, max(1, L // 8)))
        ax.set_yticks(np.arange(0, H, max(1, H // 8)))

    # Apply y-axis labels only to the first column
    ax_r_pm.set_ylabel("Head")
    ax_r_dm.set_ylabel("Head")

    # Remove y-tick labels from the second column to save space
    ax_p_pm.set_yticklabels([])
    ax_p_dm.set_yticklabels([])
    # MODIFICATION END

    ax_r_pm.set_title("Rings: pointer mass")
    ax_p_pm.set_title("Branches: pointer mass")
    ax_r_dm.set_title("Rings: ablation")
    ax_p_dm.set_title("Branches: ablation")

    # MODIFIED: Removed x and y args to use new function defaults
    add_panel_label(ax_r_pm, "a")
    add_panel_label(ax_p_pm, "b")
    add_panel_label(ax_r_dm, "c")
    add_panel_label(ax_p_dm, "d")

    fig.colorbar(im_r_pm, cax=cax_pm).set_label("Pointer mass")
    fig.colorbar(im_r_dm, cax=cax_dm).set_label("Δ-margin")

    def top_positions(mat, k=2):
        arr = np.array(mat, copy=True)
        arr = np.where(np.isfinite(arr), arr, -np.inf)
        flat = arr.ravel()
        valid = np.isfinite(flat) & (flat > -np.inf)
        k = int(min(k, np.sum(valid)))
        if k <= 0:
            return []
        idxs = np.argpartition(flat, -k)[-k:]
        idxs = idxs[np.argsort(flat[idxs])][::-1]
        return [np.unravel_index(idx, arr.shape)[::-1] for idx in idxs]

    def draw_white_rect(ax, lh):
        l, h = lh
        rect = mpl.patches.Rectangle(
            (l - 0.5, h - 0.5), 1, 1, fill=False, edgecolor="#FFFFFF", linewidth=1.8
        )
        ax.add_patch(rect)

    for lh in top_positions(ring_pm, k=2):
        draw_white_rect(ax_r_pm, lh)
        draw_white_rect(ax_r_dm, lh)
    for lh in top_positions(parn_pm, k=2):
        draw_white_rect(ax_p_pm, lh)
        draw_white_rect(ax_p_dm, lh)

    # --- Part 5: Plotting Robustness Curves (Right Side) ---
    def plot_curve_panel(ax, labels, y1, y2, xlabel, title, panel_char):
        x = np.arange(len(labels))
        ax.plot(
            x,
            y1,
            marker="o",
            markersize=4,
            linewidth=1.6,
            color="#1f77b4",
            label="Pointer mass",
        )
        ax.set_ylabel("Pointer mass", color="#1f77b4")
        ax.tick_params(axis="y", colors="#1f77b4")
        ax2 = ax.twinx()
        ax2.plot(
            x,
            y2,
            marker="s",
            markersize=4,
            linewidth=1.6,
            linestyle="--",
            color="#d62728",
            label="Δ-margin",
        )
        ax2.set_ylabel("Δ-margin", color="#d62728")
        ax2.tick_params(axis="y", colors="#d62728")
        ax.set_xlabel(xlabel)
        ax.set_xticks(x)
        ax.set_xticklabels(labels, rotation=0)
        ax.grid(True, axis="y", linestyle="-", linewidth=0.4, color="#e6e6e6")
        ax.spines["left"].set_color("#1f77b4")
        ax2.spines["right"].set_color("#d62728")
        ax.set_title(title)
        # MODIFIED: Removed x and y args to use new function defaults
        add_panel_label(ax, panel_char)
        return ax, ax2

    ax_r_pm_curve, ax_r_dm_curve = plot_curve_panel(
        ax_curve_r,
        ring_labels,
        ring_pm_curve,
        ring_dM_curve,
        "Opener–closer span",
        f"Rings: {ring_name}",
        "e",
    )
    ax_r_pm_curve.set_ylim(0, 1)
    ax_r_dm_curve.set_ylim(0, 2.5)
    ax_p_pm_curve, ax_p_dm_curve = plot_curve_panel(
        ax_curve_p,
        par_labels,
        par_pm_curve,
        par_dM_curve,
        "Branch depth",
        f"Branches: {par_name}",
        "f",
    )
    ax_p_pm_curve.set_ylim(0, 1)
    ax_p_dm_curve.set_ylim(1, 2)

    # --- Part 6: Finalize and Save ---
    fig.savefig(out_pdf, bbox_inches="tight")
    plt.close(fig)
    print(f"[ok] wrote {out_pdf}")


# -----------------------------
# E. Main
# -----------------------------
def main(base_dir="./"):
    apply_nature_style()
    suite_dir = os.path.join(base_dir, "suite")
    out_dir = os.path.join(base_dir, "figures")
    os.makedirs(out_dir, exist_ok=True)

    summary_csv = os.path.join(suite_dir, "pointer_suite_summary.csv")
    robust_json = os.path.join(suite_dir, "robustness_curves.json")

    fig_combined_circuits(
        summary_csv,
        robust_json,
        out_pdf=os.path.join(out_dir, "fig_circuits_combined.pdf"),
    )


if __name__ == "__main__":
    try:
        main("./")
        print("[done] All figures generated.")
    except Exception as e:
        print(f"[error] {e}")
