import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, MaxNLocator
import numpy as np


# -------------------------------
# Data utilities
# -------------------------------


TRAIN_DIR_RE = re.compile(r"train(\d+)_")


def parse_train_size_from_path(path: Path) -> int:
    m = TRAIN_DIR_RE.search(path.name)
    if not m:
        # try parent
        m = TRAIN_DIR_RE.search(path.parent.name)
    if not m:
        raise ValueError(f"train size not found in path: {path}")
    return int(m.group(1))


def count_encoder_depth_from_dirname(dirname: str) -> Optional[int]:
    """
    Infer the encoder depth from a directory name.
    - Pattern 1: 'enc<units>_<units>_...' → number of underscores + 1
    - Pattern 2: 'hidden<units>_<units>_...' → treat the list length as the depth
      e.g., hidden512_256_128_64 → 4
    Returns None if the pattern is not recognized.
    """
    m_enc = re.search(r"enc([0-9_]+)", dirname)
    if m_enc:
        return len(m_enc.group(1).split("_"))
    m_hidden = re.search(r"hidden([0-9_]+)", dirname)
    if m_hidden:
        return len(m_hidden.group(1).split("_"))
    # CNN-style: e.g., ch32_64_128_128
    m_ch = re.search(r"ch([0-9_]+)", dirname)
    if m_ch:
        return len(m_ch.group(1).split("_"))
    return None


def read_aggregated_results(path: Path) -> Dict:
    with path.open("r") as f:
        return json.load(f)


def compute_abs_gap_stats_from_aggregated(agg: Dict, gap_key: str = "gap_recon_loss") -> Tuple[float, float]:
    """
    From aggregated_results.json, take absolute values of 'gap_key' over individual_results
    and return the mean and standard deviation.
    gap_key is expected to be one of {"gap_loss", "gap_recon_loss", "gap_kl_loss"}.
    """
    individuals = agg.get("individual_results") or agg.get("individual", [])
    if not individuals:
        # Fallback: if only average_metrics are available, take abs(mean) and std
        avg_key = f"avg_{gap_key}"
        std_key = f"std_{gap_key}"
        avg = float(agg["average_metrics"][avg_key])  # KeyError if missing
        std = float(agg["average_metrics"].get(std_key, 0.0))
        return abs(avg), std
    vals = [abs(float(item[gap_key])) for item in individuals if gap_key in item]
    if not vals:
        avg_key = f"avg_{gap_key}"
        std_key = f"std_{gap_key}"
        avg = float(agg["average_metrics"][avg_key])  # KeyError if missing
        std = float(agg["average_metrics"].get(std_key, 0.0))
        return abs(avg), std
    return float(np.mean(vals)), float(np.std(vals, ddof=0))


def read_avg_gap_from_aggregated(agg: Dict, gap_key: str = "gap_recon_loss") -> float:
    """
    For scatter plots, use signed avg_{gap_key}.
    If unavailable, fall back to the mean over individual_results.
    """
    avg = agg.get("average_metrics", {}).get(f"avg_{gap_key}")
    if avg is not None:
        return float(avg)
    individuals = agg.get("individual_results") or []
    vals = [float(item[gap_key]) for item in individuals if gap_key in item]
    if not vals:
        raise ValueError(f"avg_{gap_key} not found and individual_results empty")
    return float(np.mean(vals))


def read_combined_upper_from_mi(mi_path: Path) -> float:
    """
    Read the combined upper bound from mi_both_encoder.json.
    Priority:
      1) data["combined_summary"]["mean_upper_bound"]
      2) mean of data["combined_per_split"][i]["combined_upper"]
      3) data["if_upper_bounds_summary"]["mean_upper_bound"]
      4) mean of data["if_upper_bounds_per_split"][i]["upper_bound"]
    """
    with mi_path.open("r") as f:
        data = json.load(f)
    if isinstance(data, dict):
        cs = data.get("combined_summary")
        if isinstance(cs, dict) and cs.get("mean_upper_bound") is not None:
            return float(cs["mean_upper_bound"])
        cps = data.get("combined_per_split")
        if isinstance(cps, list) and cps:
            vals = [x.get("combined_upper") for x in cps if x.get("combined_upper") is not None]
            if vals:
                return float(np.mean([float(v) for v in vals]))
        if data.get("if_upper_bounds_summary"):
            s = data["if_upper_bounds_summary"].get("mean_upper_bound")
            if s is not None:
                return float(s)
        if data.get("if_upper_bounds_per_split"):
            vals = [x.get("upper_bound") for x in data["if_upper_bounds_per_split"] if x.get("upper_bound") is not None]
            if vals:
                return float(np.mean([float(v) for v in vals]))
    raise ValueError(f"combined upper not found in {mi_path}")


# -------------------------------
# Search logic
# -------------------------------


@dataclass
class SplitPoint:
    enc_depth: int
    gap_recon_loss: float
    combined_upper: float


@dataclass
class IfZuGapPoint:
    enc_depth: int
    gap_value: float
    if_params_u_upper: float
    zu_upper: float


def read_gap_per_split_from_aggregated(agg: Dict, gap_key: str = "gap_recon_loss") -> List[float]:
    individuals = agg.get("individual_results") or []
    vals = [float(item[gap_key]) for item in individuals if gap_key in item]
    if vals:
        return vals
    # Fallback: if only averages exist, treat as a single point
    avg = agg.get("average_metrics", {}).get(f"avg_{gap_key}")
    return [float(avg)] if avg is not None else []


def read_combined_upper_per_split(mi_path: Path, expected_len: Optional[int] = None) -> List[float]:
    with mi_path.open("r") as f:
        data = json.load(f)
    # 1) combined_per_split
    cps = data.get("combined_per_split")
    if isinstance(cps, list) and cps:
        vals = [x.get("combined_upper") for x in cps]
        vals = [float(v) for v in vals if v is not None]
        if vals:
            return vals
    # 2) if_upper_bounds_per_split (upper bound for a single component only)
    iups = data.get("if_upper_bounds_per_split")
    if isinstance(iups, list) and iups:
        vals = [x.get("upper_bound") for x in iups]
        vals = [float(v) for v in vals if v is not None]
        if vals:
            return vals
    # 3) Use summary, repeated
    s = None
    if data.get("combined_summary") and data["combined_summary"].get("mean_upper_bound") is not None:
        s = float(data["combined_summary"]["mean_upper_bound"])
    elif data.get("if_upper_bounds_summary") and data["if_upper_bounds_summary"].get("mean_upper_bound") is not None:
        s = float(data["if_upper_bounds_summary"]["mean_upper_bound"])
    if s is not None and expected_len:
        return [s] * expected_len
    return []


def collect_train30000_split_points(base: Path, algo: str, gap_key: str = "gap_recon_loss") -> List[SplitPoint]:
    """
    For train30000 runs, collect per-split points per architecture (estimated encoder depth).
    x: gap_recon_loss (per split)
    y: combined_upper (per split)
    color: encoder depth
    Required files: aggregated_results.json, mi_both_encoder.json
    """
    assert algo in {"vae", "iwae"}
    points: List[SplitPoint] = []
    for agg_path in base.rglob("aggregated_results.json"):
        # Only train30000 runs
        if "train30000_" not in agg_path.parent.name:
            continue
        model_dir = agg_path.parent.parent
        # Check only the algo prefix; allow naming variations
        if not model_dir.name.startswith(f"{algo}_"):
            continue
        depth = count_encoder_depth_from_dirname(model_dir.name)
        if depth is None:
            continue
        # Location of MI file (prefer under the train directory)
        mi_path = model_dir / agg_path.parent.name / "mi_both_encoder.json"
        if not mi_path.exists():
            mi_path = model_dir / "mi_both_encoder.json"
        if not mi_path.exists():
            continue
        try:
            agg = read_aggregated_results(agg_path)
            gaps = read_gap_per_split_from_aggregated(agg, gap_key=gap_key)
            uppers = read_combined_upper_per_split(mi_path, expected_len=len(gaps))
            n = min(len(gaps), len(uppers))
            for i in range(n):
                points.append(SplitPoint(enc_depth=depth, gap_recon_loss=float(gaps[i]), combined_upper=float(uppers[i])))
        except Exception:
            continue
    return points


def collect_train30000_if_zu_points(base: Path, algo: str, gap_key: str = "gap_recon_loss") -> List[IfZuGapPoint]:
    """
    For train30000 runs, collect per-split IF/ZU and gap points per architecture.
    x: gap_key (per split; absolute value applied when computing correlations)
    y1: if_params_u_upper (per split)
    y2: zu_upper (per split)
    color: encoder depth
    Required files: aggregated_results.json, mi_both_encoder.json
    """
    assert algo in {"vae", "iwae"}
    points: List[IfZuGapPoint] = []
    for agg_path in base.rglob("aggregated_results.json"):
        # Only train30000 runs
        if "train30000_" not in agg_path.parent.name:
            continue
        model_dir = agg_path.parent.parent
        # Check only the algo prefix; allow naming variations
        if not model_dir.name.startswith(f"{algo}_"):
            continue
        depth = count_encoder_depth_from_dirname(model_dir.name)
        if depth is None:
            continue
        # Location of MI file (prefer under the train directory)
        mi_path = model_dir / agg_path.parent.name / "mi_both_encoder.json"
        if not mi_path.exists():
            mi_path = model_dir / "mi_both_encoder.json"
        if not mi_path.exists():
            continue
        try:
            agg = read_aggregated_results(agg_path)
            gaps = read_gap_per_split_from_aggregated(agg, gap_key=gap_key)
            ifs, zus = _read_if_zu_per_split(mi_path)
            n = min(len(gaps), len(ifs), len(zus))
            for i in range(n):
                points.append(
                    IfZuGapPoint(
                        enc_depth=depth,
                        gap_value=float(gaps[i]),
                        if_params_u_upper=float(ifs[i]),
                        zu_upper=float(zus[i]),
                    )
                )
        except Exception:
            continue
    return points


@dataclass
class GapSeries:
    num_train: np.ndarray
    mean_abs_gap: np.ndarray
    std_abs_gap: np.ndarray


@dataclass
class UpperSeries:
    num_train: np.ndarray
    mean_upper: np.ndarray
    std_upper: np.ndarray


@dataclass
class BoundComponentsSeries:
    num_train: np.ndarray
    mean_if: np.ndarray
    std_if: np.ndarray
    mean_zu: np.ndarray
    std_zu: np.ndarray


def find_abs_gap_series_for_4layer(base: Path, algo: str, gap_key: str = "gap_recon_loss") -> GapSeries:
    """
    From 4-layer encoder architectures (including hidden512_256_128_64),
    aggregate mean/std of |gap_recon_loss| across train* runs.
    """
    assert algo in {"vae", "iwae"}
    # 4-layer candidates: MLP(hidden512_256_128_64) or CNN(ch*_ _ _ _)
    candidates = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if not d.name.startswith(f"{algo}_"):
            continue
        depth = count_encoder_depth_from_dirname(d.name)
        if depth == 4:
            candidates.append(d)
    model_dirs = candidates
    if not model_dirs:
        raise RuntimeError(f"4-layer model dir not found under {base} for {algo}")
    model_dir = model_dirs[0]

    rows: List[Tuple[int, float, float]] = []
    for train_dir in model_dir.iterdir():
        if not train_dir.is_dir():
            continue
        if not train_dir.name.startswith("train"):
            continue
        is_algo = (algo == "vae" and "elbo" in train_dir.name) or (algo == "iwae" and "iwae" in train_dir.name)
        if not is_algo:
            continue
        agg_path = train_dir / "aggregated_results.json"
        if not agg_path.exists():
            continue
        try:
            agg = read_aggregated_results(agg_path)
            mean_abs, std_abs = compute_abs_gap_stats_from_aggregated(agg, gap_key=gap_key)
            n = parse_train_size_from_path(train_dir)
            rows.append((n, float(mean_abs), float(std_abs)))
        except Exception:
            continue
    if not rows:
        raise RuntimeError(f"No aggregated_results.json found for 4-layer series under {model_dir}")

    rows.sort(key=lambda x: x[0])
    num_train = np.array([r[0] for r in rows], dtype=np.int64)
    mean_abs = np.array([r[1] for r in rows], dtype=np.float64)
    std_abs = np.array([r[2] for r in rows], dtype=np.float64)
    return GapSeries(num_train=num_train, mean_abs_gap=mean_abs, std_abs_gap=std_abs)


def find_upper_series_for_4layer(base: Path, algo: str, transform: str = "none") -> UpperSeries:
    """
    For a 4-layer encoder (hidden512_256_128_64), aggregate mean/std of combined_upper
    across train* sizes (over splits).
    """
    assert algo in {"vae", "iwae"}
    candidates = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if not d.name.startswith(f"{algo}_"):
            continue
        depth = count_encoder_depth_from_dirname(d.name)
        if depth == 4:
            candidates.append(d)
    model_dirs = candidates
    if not model_dirs:
        raise RuntimeError(f"4-layer model dir not found under {base} for {algo}")
    model_dir = model_dirs[0]

    rows: List[Tuple[int, float, float]] = []
    for train_dir in model_dir.iterdir():
        if not train_dir.is_dir():
            continue
        if not train_dir.name.startswith("train"):
            continue
        is_algo = (algo == "vae" and "elbo" in train_dir.name) or (algo == "iwae" and "iwae" in train_dir.name)
        if not is_algo:
            continue
        mi_path = train_dir / "mi_both_encoder.json"
        if not mi_path.exists():
            mi_path = model_dir / "mi_both_encoder.json"
            if not mi_path.exists():
                continue
        try:
            # If per-split uppers exist, compute mean/std from them
            vals = read_combined_upper_per_split(mi_path)
            mean_upper: float
            std_upper: float
            if vals:
                arr = np.asarray(vals, dtype=float)
                if transform == "sqrt":
                    arr = np.sqrt(np.clip(arr, a_min=0.0, a_max=None))
                mean_upper = float(np.mean(arr))
                std_upper = float(np.std(arr, ddof=0))
            else:
                # Fallback to summary
                with mi_path.open("r") as f:
                    data = json.load(f)
                if data.get("combined_summary"):
                    mu = float(data["combined_summary"].get("mean_upper_bound"))
                    su = float(data["combined_summary"].get("std_upper_bound", 0.0))
                    if transform == "sqrt":
                        eps = 1e-12
                        mean_upper = float(np.sqrt(max(mu, 0.0)))
                        # delta method: Var(sqrt(X)) ≈ Var(X) / (4 * mu)
                        std_upper = float(su / (2.0 * max(np.sqrt(mu + eps), eps))) if mu > 0 else 0.0
                    else:
                        mean_upper, std_upper = mu, su
                elif data.get("if_upper_bounds_summary"):
                    mu = float(data["if_upper_bounds_summary"].get("mean_upper_bound"))
                    su = float(data["if_upper_bounds_summary"].get("std_upper_bound", 0.0))
                    if transform == "sqrt":
                        eps = 1e-12
                        mean_upper = float(np.sqrt(max(mu, 0.0)))
                        std_upper = float(su / (2.0 * max(np.sqrt(mu + eps), eps))) if mu > 0 else 0.0
                    else:
                        mean_upper, std_upper = mu, su
                else:
                    continue
            n = parse_train_size_from_path(train_dir)
            rows.append((n, mean_upper, std_upper))
        except Exception:
            continue

    if not rows:
        raise RuntimeError(f"No MI files found for 4-layer series under {model_dir}")

    rows.sort(key=lambda x: x[0])
    num_train = np.array([r[0] for r in rows], dtype=np.int64)
    mean_upper = np.array([r[1] for r in rows], dtype=np.float64)
    std_upper = np.array([r[2] for r in rows], dtype=np.float64)
    return UpperSeries(num_train=num_train, mean_upper=mean_upper, std_upper=std_upper)


def _read_if_zu_per_split(mi_path: Path) -> Tuple[List[float], List[float]]:
    with mi_path.open("r") as f:
        data = json.load(f)
    ifs: List[float] = []
    zus: List[float] = []
    cps = data.get("combined_per_split")
    if isinstance(cps, list) and cps:
        ifs = [float(x.get("if_params_u_upper")) for x in cps if x.get("if_params_u_upper") is not None]
        zus = [float(x.get("zu_upper")) for x in cps if x.get("zu_upper") is not None]
    # Fallbacks
    if not ifs and data.get("if_upper_bounds_per_split"):
        ifs = [float(x.get("upper_bound")) for x in data["if_upper_bounds_per_split"] if x.get("upper_bound") is not None]
    if not zus and data.get("mi_zu_upper_per_split"):
        zus = [float(x.get("mi_upper_bound")) for x in data["mi_zu_upper_per_split"] if x.get("mi_upper_bound") is not None]
    return ifs, zus


def find_if_zu_series_for_4layer(base: Path, algo: str, normalize: bool = True) -> BoundComponentsSeries:
    assert algo in {"vae", "iwae"}
    # 4-layer model dir detection (same as others)
    candidates = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if not d.name.startswith(f"{algo}_"):
            continue
        depth = count_encoder_depth_from_dirname(d.name)
        if depth == 4:
            candidates.append(d)
    if not candidates:
        raise RuntimeError(f"4-layer model dir not found under {base} for {algo}")
    model_dir = candidates[0]

    # First pass: collect all raw values to compute global min/max for min-max normalization
    all_ifs: List[float] = []
    all_zus: List[float] = []
    per_train_paths: List[Tuple[int, Path]] = []
    for train_dir in model_dir.iterdir():
        if not train_dir.is_dir() or not train_dir.name.startswith("train"):
            continue
        is_algo = (algo == "vae" and "elbo" in train_dir.name) or (algo == "iwae" and "iwae" in train_dir.name)
        if not is_algo:
            continue
        mi_path = train_dir / "mi_both_encoder.json"
        if not mi_path.exists():
            mi_path = model_dir / "mi_both_encoder.json"
            if not mi_path.exists():
                continue
        try:
            ifs, zus = _read_if_zu_per_split(mi_path)
            if ifs:
                all_ifs.extend(ifs)
            if zus:
                all_zus.extend(zus)
            n = parse_train_size_from_path(train_dir)
            per_train_paths.append((n, mi_path))
        except Exception:
            continue
    if not per_train_paths or not all_ifs or not all_zus:
        raise RuntimeError(f"No IF/ZU values found for 4-layer series under {model_dir}")

    if_min, if_max = float(np.min(all_ifs)), float(np.max(all_ifs))
    zu_min, zu_max = float(np.min(all_zus)), float(np.max(all_zus))
    if_denom = if_max - if_min if if_max > if_min else 1.0
    zu_denom = zu_max - zu_min if zu_max > zu_min else 1.0

    # Second pass: compute per-train mean/std after global min-max normalization
    rows: List[Tuple[int, float, float, float, float]] = []
    for n, mi_path in per_train_paths:
        try:
            ifs, zus = _read_if_zu_per_split(mi_path)
            ifs_arr = np.asarray(ifs, dtype=float)
            zus_arr = np.asarray(zus, dtype=float)
            if normalize:
                ifs_arr = (ifs_arr - if_min) / if_denom
                zus_arr = (zus_arr - zu_min) / zu_denom
            rows.append((n, float(np.mean(ifs_arr)), float(np.std(ifs_arr, ddof=0)), float(np.mean(zus_arr)), float(np.std(zus_arr, ddof=0))))
        except Exception:
            continue
    if not rows:
        raise RuntimeError(f"Failed to compute IF/ZU series under {model_dir}")
    rows.sort(key=lambda x: x[0])
    num_train = np.array([r[0] for r in rows], dtype=np.int64)
    mean_if = np.array([r[1] for r in rows], dtype=np.float64)
    std_if = np.array([r[2] for r in rows], dtype=np.float64)
    mean_zu = np.array([r[3] for r in rows], dtype=np.float64)
    std_zu = np.array([r[4] for r in rows], dtype=np.float64)
    return BoundComponentsSeries(num_train=num_train, mean_if=mean_if, std_if=std_if, mean_zu=mean_zu, std_zu=std_zu)


# -------------------------------
# Plotting
# -------------------------------


def _rankdata_average(values: np.ndarray) -> np.ndarray:
    """Assign average ranks (1..n) to values; ties get the average rank."""
    n = len(values)
    order = np.argsort(values, kind="mergesort")
    ranks = np.empty(n, dtype=float)
    sorted_vals = values[order]
    i = 0
    while i < n:
        j = i
        while j + 1 < n and sorted_vals[j + 1] == sorted_vals[i]:
            j += 1
        avg_rank = (i + 1 + j + 1) / 2.0
        ranks[order[i:j + 1]] = avg_rank
        i = j + 1
    return ranks


def _safe_pearsonr(x: np.ndarray, y: np.ndarray) -> float:
    if x.size < 2 or y.size < 2:
        return float("nan")
    if np.allclose(np.std(x), 0.0) or np.allclose(np.std(y), 0.0):
        return float("nan")
    return float(np.corrcoef(x, y)[0, 1])


def _spearmanr(x: np.ndarray, y: np.ndarray) -> float:
    if x.size < 2 or y.size < 2:
        return float("nan")
    rx = _rankdata_average(x)
    ry = _rankdata_average(y)
    return _safe_pearsonr(rx, ry)


def _kendall_tau_a(x: np.ndarray, y: np.ndarray) -> float:
    n = x.size
    if n < 2:
        return float("nan")
    concordant = 0
    discordant = 0
    for i in range(n - 1):
        dx = x[i + 1:] - x[i]
        dy = y[i + 1:] - y[i]
        s = dx * dy
        concordant += int(np.sum(s > 0))
        discordant += int(np.sum(s < 0))
    denom = n * (n - 1) / 2.0
    if denom == 0:
        return float("nan")
    return float((concordant - discordant) / denom)


def _annotate_and_regress(ax, xs: np.ndarray, ys: np.ndarray, line_color: str = "#333333") -> None:
    if xs.size >= 2:
        slope, intercept = np.polyfit(xs, ys, 1)
        x_min, x_max = float(np.min(xs)), float(np.max(xs))
        if np.isfinite(x_min) and np.isfinite(x_max) and x_min != x_max:
            x_line = np.linspace(x_min, x_max, 100)
        else:
            x_line = np.array([x_min, x_max])
        y_line = slope * x_line + intercept
        ax.plot(x_line, y_line, linestyle="--", color=line_color, linewidth=4.0, label="linreg")
    pearson = _safe_pearsonr(xs, ys)
    spearman = _spearmanr(xs, ys)
    kendall = _kendall_tau_a(xs, ys)
    text = (
        #f"n={xs.size}\n"
        f"Pearson={pearson:.3f}\n"
        f"Spearman={spearman:.3f}\n"
        f"Kendall={kendall:.3f}"
    )
    ax.text(0.02, 0.98, text, transform=ax.transAxes, va="top", ha="left", fontsize=12,
            bbox=dict(facecolor="white", alpha=0.65, edgecolor="none"))


def plot_1x4(
    elbo_upper: UpperSeries,
    elbo_series: GapSeries,
    iwae_upper: UpperSeries,
    iwae_series: GapSeries,
    out_path: Path,
    gap_key: str = "gap_recon_loss",
):
    fig, axes = plt.subplots(1, 4, figsize=(18, 4.2), dpi=140)

    # 1) ELBO: num_train vs avg_combined_upper ± std
    ax = axes[0]
    x_u = elbo_upper.num_train
    y_u = elbo_upper.mean_upper
    yerr_u = elbo_upper.std_upper
    ax.errorbar(x_u, y_u, yerr=yerr_u, marker="o", color="#1f77b4", linestyle="-", capsize=3, label="combined_upper (mean ± std)")
    ax.set_xlabel("num_train_samples")
    ax.set_ylabel("combined_upper")
    ax.set_title("Standard VAE: combined_upper vs num_train_samples")
    ax.grid(True, linestyle=":", alpha=0.4)
    ax.legend(loc="best", fontsize=8)

    # 2) ELBO: |gap| series
    ax = axes[1]
    x = elbo_series.num_train
    y = elbo_series.mean_abs_gap
    yerr = elbo_series.std_abs_gap
    ax.errorbar(x, y, yerr=yerr, marker="o", color="#d62728", linestyle="-", capsize=3, label="|gap_recon| (mean ± std)")
    ax.set_xlabel("num_train_samples")
    ax.set_ylabel(f"|{gap_key}|")
    ax.set_title(f"Standard VAE: train size vs |{gap_key}|")
    ax.grid(True, linestyle=":", alpha=0.4)
    ax.legend(loc="best", fontsize=8)

    # 3) IWAE: num_train vs avg_combined_upper ± std
    ax = axes[2]
    x_u = iwae_upper.num_train
    y_u = iwae_upper.mean_upper
    yerr_u = iwae_upper.std_upper
    ax.errorbar(x_u, y_u, yerr=yerr_u, marker="o", color="#ff7f0e", linestyle="-", capsize=3, label="combined_upper (mean ± std)")
    ax.set_xlabel("num_train_samples")
    ax.set_ylabel("combined_upper")
    ax.set_title("VAE (IWAE): combined_upper vs num_train_samples")
    ax.grid(True, linestyle=":", alpha=0.4)
    ax.legend(loc="best", fontsize=8)

    # 4) IWAE: |gap| series
    ax = axes[3]
    x = iwae_series.num_train
    y = iwae_series.mean_abs_gap
    yerr = iwae_series.std_abs_gap
    ax.errorbar(x, y, yerr=yerr, marker="o", color="#2ca02c", linestyle="-", capsize=3, label="|gap_recon| (mean ± std)")
    ax.set_xlabel("num_train_samples")
    ax.set_ylabel(f"|{gap_key}|")
    ax.set_title(f"VAE (IWAE): train size vs |{gap_key}|")
    ax.grid(True, linestyle=":", alpha=0.4)
    ax.legend(loc="best", fontsize=8)

    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path)
    print(f"Saved figure to {out_path}")


# -------------------------------
# Extra: train30000 scatter (per split)
# -------------------------------


def plot_scatter_train30000(
    elbo_points: List[SplitPoint],
    iwae_points: List[SplitPoint],
    out_path: Path,
    gap_key: str = "gap_recon_loss",
    y_transform: str = "none",
):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.2), dpi=140)

    # Color assignment (by encoder depth)
    depths = sorted({p.enc_depth for p in elbo_points} | {p.enc_depth for p in iwae_points})
    cmap = plt.cm.get_cmap("tab10", max(len(depths), 1))
    depth_to_color = {d: cmap(i % cmap.N) for i, d in enumerate(depths)}

    # Helpers
    def _apply_y_transform(y_values: np.ndarray) -> np.ndarray:
        if y_transform == "sqrt_minmax":
            y_sqrt = np.sqrt(np.clip(y_values, a_min=0.0, a_max=None))
            y_min = float(np.min(y_sqrt)) if y_sqrt.size > 0 else 0.0
            y_max = float(np.max(y_sqrt)) if y_sqrt.size > 0 else 1.0
            denom = (y_max - y_min)
            if denom == 0.0:
                return np.zeros_like(y_sqrt)
            return (y_sqrt - y_min) / denom
        return y_values

    def _scatter(ax, points: List[SplitPoint], title: str, with_legend: bool = False) -> None:
        if not points:
            ax.text(0.5, 0.5, "No points", transform=ax.transAxes, ha="center", va="center")
            ax.set_axis_off()
            return
        for d in depths:
            xs = np.array([p.gap_recon_loss for p in points if p.enc_depth == d], dtype=float)
            xs = np.abs(xs)
            ys = np.array([p.combined_upper for p in points if p.enc_depth == d], dtype=float)
            ys = _apply_y_transform(ys)
            if xs.size == 0:
                continue
            ax.scatter(xs, ys, s=36, alpha=0.85, color=depth_to_color[d], label=f"enc_depth={d}")
        # Regression line and correlation annotations
        all_x = np.array([p.gap_recon_loss for p in points], dtype=float)
        all_x = np.abs(all_x)
        all_y = np.array([p.combined_upper for p in points], dtype=float)
        all_y = _apply_y_transform(all_y)
        _annotate_and_regress(ax, all_x, all_y)
        ax.set_xlabel(f"|{gap_key}| (per split)")
        ax.set_ylabel("combined_upper (per split)" if y_transform == "none" else "normalized sqrt(combined_upper)")
        ax.set_title(title)
        ax.grid(True, linestyle=":", alpha=0.4)
        if with_legend:
            ax.legend(loc="best", fontsize=8, ncols=1)

    _scatter(axes[0], elbo_points, title="train30000: ELBO (per split)", with_legend=True)
    _scatter(axes[1], iwae_points, title="train30000: IWAE (per split)", with_legend=False)

    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path)
    print(f"Saved figure to {out_path}")


# -------------------------------
# Correlation and LaTeX table generation
# -------------------------------


def _apply_y_transform_values(values: np.ndarray, mode: str = "none") -> np.ndarray:
    if mode == "sqrt_minmax":
        vals = np.sqrt(np.clip(values, a_min=0.0, a_max=None))
        if vals.size == 0:
            return vals
        vmin = float(np.min(vals))
        vmax = float(np.max(vals))
        denom = (vmax - vmin)
        if denom == 0.0:
            return np.zeros_like(vals)
        return (vals - vmin) / denom
    return values


def _compute_corr_metrics(xs: np.ndarray, ys: np.ndarray) -> Tuple[int, float, float, float]:
    n = int(xs.size)
    pearson = _safe_pearsonr(xs, ys)
    spearman = _spearmanr(xs, ys)
    kendall = _kendall_tau_a(xs, ys)
    return n, pearson, spearman, kendall


def write_if_zu_correlation_table(
    mnist_root: Path,
    fashion_root: Path,
    out_path: Path,
    gap_key: str = "gap_recon_loss",
    y_transform: str = "none",
    algo_filter: str = "both",
):
    """
    For n=30000 per-split data, compute correlations between |gap| and IF/ZU
    (Pearson/Spearman/Kendall) and output a LaTeX table.

    - Datasets: MNIST, Fashion-MNIST
    - Algorithms: VAE(ELBO), IWAE
    - Components: IF(params_u_upper), ZU(zu_upper)
    """
    # rows: dataset, algo, component(IF|ZU|IF+ZU), pearson, spearman, kendall
    rows: List[Tuple[str, str, str, float, float, float]] = []

    def summarize(dataset_name: str, base: Path, algo: str, component: str) -> None:
        pts = collect_train30000_if_zu_points(base, algo=algo, gap_key=gap_key)
        if not pts:
            rows.append((dataset_name, algo.upper(), component, float("nan"), float("nan"), float("nan")))
            return
        xs = np.abs(np.asarray([p.gap_value for p in pts], dtype=float))
        if component == "IF":
            ys_raw = np.asarray([p.if_params_u_upper for p in pts], dtype=float)
        elif component == "ZU":
            ys_raw = np.asarray([p.zu_upper for p in pts], dtype=float)
        else:
            ys_raw = np.asarray([p.if_params_u_upper + p.zu_upper for p in pts], dtype=float)
        ys = _apply_y_transform_values(ys_raw, mode=y_transform)
        _, pr, sp, kd = _compute_corr_metrics(xs, ys)
        rows.append((dataset_name, algo.upper(), component, pr, sp, kd))

    # MNIST
    if algo_filter in ("vae", "both"):
        summarize("MNIST", mnist_root, algo="vae", component="IF")
        summarize("MNIST", mnist_root, algo="vae", component="ZU")
        summarize("MNIST", mnist_root, algo="vae", component="IF+ZU")
    if algo_filter in ("iwae", "both"):
        summarize("MNIST", mnist_root, algo="iwae", component="IF")
        summarize("MNIST", mnist_root, algo="iwae", component="ZU")
        summarize("MNIST", mnist_root, algo="iwae", component="IF+ZU")
    # Fashion-MNIST
    if algo_filter in ("vae", "both"):
        summarize("F-MNIST", fashion_root, algo="vae", component="IF")
        summarize("F-MNIST", fashion_root, algo="vae", component="ZU")
        summarize("F-MNIST", fashion_root, algo="vae", component="IF+ZU")
    if algo_filter in ("iwae", "both"):
        summarize("F-MNIST", fashion_root, algo="iwae", component="IF")
        summarize("F-MNIST", fashion_root, algo="iwae", component="ZU")
        summarize("F-MNIST", fashion_root, algo="iwae", component="IF+ZU")

    # LaTeX output
    lines: List[str] = []
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\small")
    lines.append("\\begin{tabular}{l l r r r}")
    lines.append("\\hline")
    lines.append("Dataset & Component & Pearson & Spearman & Kendall \\\\")
    lines.append("\\hline")
    for ds, al, comp, pr, sp, kd in rows:
        lines.append(f"{ds} & {comp} & {pr:.3f} & {sp:.3f} & {kd:.3f} \\\\")
    lines.append("\\hline")
    lines.append("\\end{tabular}")
    lines.append(
        f"\\caption{{Correlation between |{gap_key}| and bounded components (incl. IF+ZU) at n=30000. Algo: {algo_filter.upper()}, y-transform: {y_transform}.}}"
    )
    lines.append("\\label{tab:if_zu_corr_train30000}")
    lines.append("\\end{table}")

    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text("\n".join(lines))
    print(f"Saved LaTeX table to {out_path}")


# -------------------------------
# Entry point
# -------------------------------


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Plot MNIST VAE(MLP) ELBO/IWAE results")
    parser.add_argument("--root", type=str, default=None, help="Root of results/experiments (default: project_root/results/experiments)")
    parser.add_argument("--dataset", type=str, default="mnist", help="Dataset subdirectory (e.g., mnist)")
    parser.add_argument("--mode", type=str, choices=["1x4", "scatter", "iwae_2x3", "elbo_2x3", "elbo_1x4_both", "iwae_1x4_both", "corr_table"], default="1x4", help="Output mode")
    parser.add_argument("--gap", type=str, choices=["gap_loss", "gap_recon_loss", "gap_kl_loss"], default="gap_recon_loss", help="Gap metric to use")
    parser.add_argument("--y_transform", type=str, choices=["none", "sqrt_minmax"], default="none", help="y-axis transform for scatter (applied to combined_upper)")
    parser.add_argument("--out", type=str, default=None, help="Output image path (default: mode-specific name)")
    parser.add_argument("--algo", type=str, choices=["vae", "iwae", "both"], default="both", help="Algorithms to include for corr_table")
    args = parser.parse_args()

    project_root = Path(__file__).resolve().parents[1]
    base_root = Path(args.root) if args.root else (project_root / "results" / "experiments")
    mnist_root = base_root / args.dataset

    out_dir = project_root / "results" / "figures"

    if args.mode == "1x4":
    # Series aggregation
        elbo_upper = find_upper_series_for_4layer(mnist_root, algo="vae")
        iwae_upper = find_upper_series_for_4layer(mnist_root, algo="iwae")
        elbo_series = find_abs_gap_series_for_4layer(mnist_root, algo="vae", gap_key=args.gap)
        iwae_series = find_abs_gap_series_for_4layer(mnist_root, algo="iwae", gap_key=args.gap)
        out_path = Path(args.out) if args.out else (out_dir / "mnist_vae_mlp_elbo_iwae_1x4.png")
        plot_1x4(elbo_upper, elbo_series, iwae_upper, iwae_series, out_path, gap_key=args.gap)
    else:
        if args.mode == "scatter":
            # Scatter plot (train30000 per split)
            elbo_points = collect_train30000_split_points(mnist_root, algo="vae", gap_key=args.gap)
            iwae_points = collect_train30000_split_points(mnist_root, algo="iwae", gap_key=args.gap)
            out_path = Path(args.out) if args.out else (out_dir / "mnist_vae_mlp_elbo_iwae_scatter_train30000.png")
            plot_scatter_train30000(elbo_points, iwae_points, out_path, gap_key=args.gap, y_transform=args.y_transform)
        elif args.mode == "iwae_2x3":
            # 2x3: row1=MNIST×IWAE, row2=F-MNIST×IWAE
            fig, axes = plt.subplots(2, 4, figsize=(20, 7.5), dpi=140)
            # Style settings
            row1_color = "#1f77b4"  # unified MNIST color (blue)
            row2_color = "#ff7f0e"  # unified Fashion color (orange)
            title_fs = 14
            label_fs = 14
            tick_fs = 13
            shade_alpha = 0.20
            # Top row: MNIST IWAE
            iwae_upper_m = find_upper_series_for_4layer(mnist_root, algo="iwae", transform="sqrt")
            iwae_series_m = find_abs_gap_series_for_4layer(mnist_root, algo="iwae", gap_key=args.gap)
            # Left: n vs |gap|
            ax = axes[0, 0]
            x = 0.9 * iwae_series_m.num_train
            y = iwae_series_m.mean_abs_gap
            yerr = iwae_series_m.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("MNIST (IWAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Middle: n vs sqrt(combined_upper)
            ax = axes[0, 1]
            x = 0.9 * iwae_upper_m.num_train
            y = iwae_upper_m.mean_upper
            yerr = iwae_upper_m.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("MNIST (IWAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Right: scatter (|gap| vs upper, y is sqrt_minmax)
            iwae_points_m = collect_train30000_split_points(mnist_root, algo="iwae", gap_key=args.gap)
            xs = np.abs(np.array([p.gap_recon_loss for p in iwae_points_m], dtype=float))
            ys = np.array([p.combined_upper for p in iwae_points_m], dtype=float)
            ys_s = np.sqrt(np.clip(ys, 0.0, None))
            if ys_s.size:
                y_min, y_max = float(np.min(ys_s)), float(np.max(ys_s))
                denom = (y_max - y_min) if (y_max - y_min) != 0 else 1.0
                ys_s = (ys_s - y_min) / denom
            ax = axes[0, 2]
            ax.scatter(xs, ys_s, s=70, alpha=0.85, color=row1_color)
            _annotate_and_regress(ax, xs, ys_s)
            ax.set_xlabel("Generalization gap in loss (per split)", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("MNIST (IWAE): correlation", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Adjust x-axis tick labels to avoid overlap
            ax.xaxis.set_major_locator(MaxNLocator(nbins=5, prune='both'))
            ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            for label in ax.get_xticklabels():
                label.set_rotation(30)
                label.set_ha('right')

            # 4th column: mean±std of if_params_u_upper and zu_upper (min-max normalized)
            comps_m = find_if_zu_series_for_4layer(mnist_root, algo="iwae", normalize=False)
            ax = axes[0, 3]
            x = 0.9 * comps_m.num_train
            # Left axis: IF(params_u)
            l1, = ax.plot(x, comps_m.mean_if, marker="o", markersize=10, color="#009E73", linewidth=4, label=r"$I(\phi ; U \mid  X^n)$ (IF approx.)")
            ax.fill_between(x, comps_m.mean_if - comps_m.std_if, comps_m.mean_if + comps_m.std_if, color="#009E73", alpha=shade_alpha)
            ax.set_ylabel(r"$I(\phi ; U \mid  X^n)$ (IF approx.)", fontsize=label_fs, color="#009E73")
            ax.tick_params(axis='y', labelcolor="#009E73", labelsize=tick_fs)
            ax.tick_params(axis='x', labelsize=tick_fs)
            ax.spines['left'].set_color("#009E73")
            # Right axis: ZU
            ax2 = ax.twinx()
            l2, = ax2.plot(x, comps_m.mean_zu, marker="s", markersize=10, color="#CC79A7", linewidth=4, label=r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)")
            ax2.fill_between(x, comps_m.mean_zu - comps_m.std_zu, comps_m.mean_zu + comps_m.std_zu, color="#CC79A7", alpha=shade_alpha)
            ax2.set_ylabel(r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)", fontsize=label_fs, color="#CC79A7")
            ax2.tick_params(axis='y', labelcolor="#CC79A7", labelsize=tick_fs)
            ax2.spines['right'].set_color("#CC79A7")
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_title("MNIST (IWAE): bound components", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.legend(handles=[l1, l2], loc="center right", fontsize=13, framealpha=0.9)

            # Bottom row: Fashion IWAE
            fashion_root = base_root / "fashion_mnist"
            iwae_upper_f = find_upper_series_for_4layer(fashion_root, algo="iwae", transform="sqrt")
            iwae_series_f = find_abs_gap_series_for_4layer(fashion_root, algo="iwae", gap_key=args.gap)
            # Left
            ax = axes[1, 0]
            x = 0.9 * iwae_series_f.num_train
            y = iwae_series_f.mean_abs_gap
            yerr = iwae_series_f.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("F-MNIST (IWAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Middle
            ax = axes[1, 1]
            x = 0.9 * iwae_upper_f.num_train
            y = iwae_upper_f.mean_upper
            yerr = iwae_upper_f.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("F-MNIST (IWAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Right: scatter
            iwae_points_f = collect_train30000_split_points(fashion_root, algo="iwae", gap_key=args.gap)
            xs = np.abs(np.array([p.gap_recon_loss for p in iwae_points_f], dtype=float))
            ys = np.array([p.combined_upper for p in iwae_points_f], dtype=float)
            ys_s = np.sqrt(np.clip(ys, 0.0, None))
            if ys_s.size:
                y_min, y_max = float(np.min(ys_s)), float(np.max(ys_s))
                denom = (y_max - y_min) if (y_max - y_min) != 0 else 1.0
                ys_s = (ys_s - y_min) / denom
            ax = axes[1, 2]
            ax.scatter(xs, ys_s, s=70, alpha=0.85, color=row2_color)
            _annotate_and_regress(ax, xs, ys_s)
            ax.set_xlabel("Generalization gap in loss (per split)", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("F-MNIST (IWAE): correlation", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Adjust x-axis tick labels to avoid overlap
            ax.xaxis.set_major_locator(MaxNLocator(nbins=5, prune='both'))
            ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            for label in ax.get_xticklabels():
                label.set_rotation(30)
                label.set_ha('right')

            # 4th column (IWAE/F-MNIST)
            comps_f_iwae = find_if_zu_series_for_4layer(fashion_root, algo="iwae", normalize=False)
            ax = axes[1, 3]
            x = 0.9 * comps_f_iwae.num_train
            l1, = ax.plot(x, comps_f_iwae.mean_if, marker="o", markersize=10, color="#009E73", linewidth=4, label=r"$I(\phi ; U \mid  X^n)$ (IF approx.)")
            ax.fill_between(x, comps_f_iwae.mean_if - comps_f_iwae.std_if, comps_f_iwae.mean_if + comps_f_iwae.std_if, color="#009E73", alpha=shade_alpha)
            ax.set_ylabel(r"$I(\phi ; U \mid  X^n)$ (IF approx.)", fontsize=label_fs, color="#009E73")
            ax.tick_params(axis='y', labelcolor="#009E73", labelsize=tick_fs)
            ax.tick_params(axis='x', labelsize=tick_fs)
            ax.spines['left'].set_color("#009E73")
            ax2 = ax.twinx()
            l2, = ax2.plot(x, comps_f_iwae.mean_zu, marker="s", markersize=10, color="#CC79A7", linewidth=4, label=r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)")
            ax2.fill_between(x, comps_f_iwae.mean_zu - comps_f_iwae.std_zu, comps_f_iwae.mean_zu + comps_f_iwae.std_zu, color="#CC79A7", alpha=shade_alpha)
            ax2.set_ylabel(r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)", fontsize=label_fs, color="#CC79A7")
            ax2.tick_params(axis='y', labelcolor="#CC79A7", labelsize=tick_fs)
            ax2.spines['right'].set_color("#CC79A7")
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_title("F-MNIST (IWAE): bound components", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            #ax.legend(handles=[l1, l2], loc="upper right", fontsize=13, framealpha=0.9)

            fig.tight_layout()
            out_path = Path(args.out) if args.out else (out_dir / "iwae_mnist_fashion_2x4.png")
            out_path.parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(out_path)
            print(f"Saved figure to {out_path}")
        elif args.mode == "elbo_2x3":
            # 2x3: row1=MNIST×ELBO, row2=F-MNIST×ELBO
            fig, axes = plt.subplots(2, 4, figsize=(20, 7.5), dpi=140)
            # Style settings (same as IWAE figure)
            row1_color = "#1f77b4"  # unified MNIST color
            row2_color = "#ff7f0e"  # unified F-MNIST color
            title_fs = 14
            label_fs = 14
            tick_fs = 13
            shade_alpha = 0.20

            # Top row: MNIST ELBO
            elbo_upper_m = find_upper_series_for_4layer(mnist_root, algo="vae", transform="sqrt")
            elbo_series_m = find_abs_gap_series_for_4layer(mnist_root, algo="vae", gap_key=args.gap)
            # Left: n vs Gen. gap
            ax = axes[0, 0]
            x = 0.9 * elbo_series_m.num_train
            y = elbo_series_m.mean_abs_gap
            yerr = elbo_series_m.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("MNIST (VAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Middle: n vs Our bound (sqrt applied to upper means)
            ax = axes[0, 1]
            x = 0.9 * elbo_upper_m.num_train
            y = elbo_upper_m.mean_upper
            yerr = elbo_upper_m.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("MNIST (VAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Right: scatter (|gap| vs Our bound, normalized)
            elbo_points_m = collect_train30000_split_points(mnist_root, algo="vae", gap_key=args.gap)
            xs = np.abs(np.array([p.gap_recon_loss for p in elbo_points_m], dtype=float))
            ys = np.array([p.combined_upper for p in elbo_points_m], dtype=float)
            ys_s = np.sqrt(np.clip(ys, 0.0, None))
            if ys_s.size:
                y_min, y_max = float(np.min(ys_s)), float(np.max(ys_s))
                denom = (y_max - y_min) if (y_max - y_min) != 0 else 1.0
                ys_s = (ys_s - y_min) / denom
            ax = axes[0, 2]
            ax.scatter(xs, ys_s, s=70, alpha=0.85, color=row1_color)
            _annotate_and_regress(ax, xs, ys_s)
            ax.set_xlabel("Generalization gap in loss (per split)", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("MNIST (VAE): correlation", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Adjust x-axis ticks to avoid overlap (3rd column: MNIST/ELBO scatter)
            ax.xaxis.set_major_locator(MaxNLocator(nbins=5, prune='both'))
            ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            for label in ax.get_xticklabels():
                label.set_rotation(30)
                label.set_ha('right')

            # 4th column (ELBO/MNIST)
            comps_m_elbo = find_if_zu_series_for_4layer(mnist_root, algo="vae", normalize=False)
            ax = axes[0, 3]
            x = 0.9 * comps_m_elbo.num_train
            IF_COLOR = "#009E73"
            ZU_COLOR = "#CC79A7"
            l1, = ax.plot(x, comps_m_elbo.mean_if, marker="o", markersize=10, color=IF_COLOR, linewidth=4, label=r"$I(\phi ; U \mid  X^n)$ (IF approx.)")
            ax.fill_between(x, comps_m_elbo.mean_if - comps_m_elbo.std_if, comps_m_elbo.mean_if + comps_m_elbo.std_if, color=IF_COLOR, alpha=shade_alpha)
            ax.set_ylabel(r"$I(\phi ; U \mid  X^n)$ (IF approx.)", fontsize=label_fs, color=IF_COLOR)
            ax.tick_params(axis='y', labelcolor=IF_COLOR, labelsize=tick_fs)
            ax.tick_params(axis='x', labelsize=tick_fs)
            ax.spines['left'].set_color(IF_COLOR)
            ax2 = ax.twinx()
            l2, = ax2.plot(x, comps_m_elbo.mean_zu, marker="s", markersize=10, color=ZU_COLOR, linewidth=4, label=r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)")
            ax2.fill_between(x, comps_m_elbo.mean_zu - comps_m_elbo.std_zu, comps_m_elbo.mean_zu + comps_m_elbo.std_zu, color=ZU_COLOR, alpha=shade_alpha)
            ax2.set_ylabel(r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)", fontsize=label_fs, color=ZU_COLOR)
            ax2.tick_params(axis='y', labelcolor=ZU_COLOR, labelsize=tick_fs)
            ax2.spines['right'].set_color(ZU_COLOR)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_title("MNIST (VAE): bound components", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.legend(handles=[l1, l2], loc="center right", fontsize=13, framealpha=0.9)

            # Bottom row: F-MNIST ELBO
            fashion_root = base_root / "fashion_mnist"
            elbo_upper_f = find_upper_series_for_4layer(fashion_root, algo="vae", transform="sqrt")
            elbo_series_f = find_abs_gap_series_for_4layer(fashion_root, algo="vae", gap_key=args.gap)
            # Left
            ax = axes[1, 0]
            x = 0.9 * elbo_series_f.num_train
            y = elbo_series_f.mean_abs_gap
            yerr = elbo_series_f.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("F-MNIST (VAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Middle
            ax = axes[1, 1]
            x = 0.9 * elbo_upper_f.num_train
            y = elbo_upper_f.mean_upper
            yerr = elbo_upper_f.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("F-MNIST (VAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Right
            elbo_points_f = collect_train30000_split_points(fashion_root, algo="vae", gap_key=args.gap)
            xs = np.abs(np.array([p.gap_recon_loss for p in elbo_points_f], dtype=float))
            ys = np.array([p.combined_upper for p in elbo_points_f], dtype=float)
            ys_s = np.sqrt(np.clip(ys, 0.0, None))
            if ys_s.size:
                y_min, y_max = float(np.min(ys_s)), float(np.max(ys_s))
                denom = (y_max - y_min) if (y_max - y_min) != 0 else 1.0
                ys_s = (ys_s - y_min) / denom
            ax = axes[1, 2]
            ax.scatter(xs, ys_s, s=70, alpha=0.85, color=row2_color)
            _annotate_and_regress(ax, xs, ys_s)
            ax.set_xlabel("Generalization gap in loss (per split)", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("F-MNIST (VAE): correlation", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Adjust x-axis tick labels to avoid overlap
            ax.xaxis.set_major_locator(MaxNLocator(nbins=5, prune='both'))
            ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            for label in ax.get_xticklabels():
                label.set_rotation(30)
                label.set_ha('right')

            # 4th column (ELBO/F-MNIST)
            comps_f = find_if_zu_series_for_4layer(fashion_root, algo="vae", normalize=False)
            ax = axes[1, 3]
            x = 0.9 * comps_f.num_train
            # Left axis: IF(params_u)
            l1, = ax.plot(x, comps_f.mean_if, marker="o", markersize=10, color=IF_COLOR, linewidth=4, label=r"$I(\phi ; U \mid  X^n)$ (IF approx.)")
            ax.fill_between(x, comps_f.mean_if - comps_f.std_if, comps_f.mean_if + comps_f.std_if, color=IF_COLOR, alpha=shade_alpha)
            ax.set_ylabel(r"$I(\phi ; U \mid  X^n)$ (IF approx.)", fontsize=label_fs, color=IF_COLOR)
            ax.tick_params(axis='y', labelcolor=IF_COLOR, labelsize=tick_fs)
            ax.tick_params(axis='x', labelsize=tick_fs)
            ax.spines['left'].set_color(IF_COLOR)
            # Right axis: ZU
            ax2 = ax.twinx()
            l2, = ax2.plot(x, comps_f.mean_zu, marker="s", markersize=10, color=ZU_COLOR, linewidth=4, label=r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)")
            ax2.fill_between(x, comps_f.mean_zu - comps_f.std_zu, comps_f.mean_zu + comps_f.std_zu, color=ZU_COLOR, alpha=shade_alpha)
            ax2.set_ylabel(r"$I(Z^n ; U \mid  \phi, X^n)$ (upper bound)", fontsize=label_fs, color=ZU_COLOR)
            ax2.tick_params(axis='y', labelcolor=ZU_COLOR, labelsize=tick_fs)
            ax2.spines['right'].set_color(ZU_COLOR)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_title("F-MNIST (VAE): bound components", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            #ax.legend(handles=[l1, l2], loc="center right", fontsize=13, framealpha=0.9)

            fig.tight_layout()
            out_path = Path(args.out) if args.out else (out_dir / "elbo_mnist_fashion_2x4.png")
            out_path.parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(out_path)
            print(f"Saved figure to {out_path}")
        elif args.mode == "elbo_1x4_both":
            # 1x4: train size vs |gap| and train size vs Our bound for MNIST and F-MNIST
            fig, axes = plt.subplots(1, 4, figsize=(20, 4.2), dpi=140)
            row1_color = "#1f77b4"  # MNIST color
            row2_color = "#ff7f0e"  # F-MNIST color
            title_fs = 14
            label_fs = 14
            tick_fs = 13
            shade_alpha = 0.20

            # MNIST
            elbo_upper_m = find_upper_series_for_4layer(mnist_root, algo="vae", transform="sqrt")
            elbo_series_m = find_abs_gap_series_for_4layer(mnist_root, algo="vae", gap_key=args.gap)
            # Left: n vs |gap|
            ax = axes[0]
            x = 0.9 * elbo_series_m.num_train
            y = elbo_series_m.mean_abs_gap
            yerr = elbo_series_m.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("MNIST (VAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Middle-left: n vs Our bound
            ax = axes[1]
            x = 0.9 * elbo_upper_m.num_train
            y = elbo_upper_m.mean_upper
            yerr = elbo_upper_m.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("MNIST (VAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)

            # F-MNIST
            fashion_root = base_root / "fashion_mnist"
            elbo_upper_f = find_upper_series_for_4layer(fashion_root, algo="vae", transform="sqrt")
            elbo_series_f = find_abs_gap_series_for_4layer(fashion_root, algo="vae", gap_key=args.gap)
            # Middle-right: n vs |gap|
            ax = axes[2]
            x = 0.9 * elbo_series_f.num_train
            y = elbo_series_f.mean_abs_gap
            yerr = elbo_series_f.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("F-MNIST (VAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Right: n vs Our bound
            ax = axes[3]
            x = 0.9 * elbo_upper_f.num_train
            y = elbo_upper_f.mean_upper
            yerr = elbo_upper_f.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("F-MNIST (VAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)

            fig.tight_layout()
            out_path = Path(args.out) if args.out else (out_dir / "elbo_mnist_fashion_1x4.png")
            out_path.parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(out_path)
            print(f"Saved figure to {out_path}")
        elif args.mode == "iwae_1x4_both":
            # 1x4: train size vs |gap| and train size vs Our bound for IWAE (MNIST and F-MNIST)
            fig, axes = plt.subplots(1, 4, figsize=(20, 4.2), dpi=140)
            row1_color = "#1f77b4"  # MNIST color
            row2_color = "#ff7f0e"  # F-MNIST color
            title_fs = 14
            label_fs = 14
            tick_fs = 13
            shade_alpha = 0.20

            # MNIST IWAE
            iwae_upper_m = find_upper_series_for_4layer(mnist_root, algo="iwae", transform="sqrt")
            iwae_series_m = find_abs_gap_series_for_4layer(mnist_root, algo="iwae", gap_key=args.gap)
            # Left: n vs |gap|
            ax = axes[0]
            x = 0.9 * iwae_series_m.num_train
            y = iwae_series_m.mean_abs_gap
            yerr = iwae_series_m.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("MNIST (IWAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Middle-left: n vs Our bound
            ax = axes[1]
            x = 0.9 * iwae_upper_m.num_train
            y = iwae_upper_m.mean_upper
            yerr = iwae_upper_m.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row1_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row1_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("MNIST (IWAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)

            # F-MNIST IWAE
            fashion_root = base_root / "fashion_mnist"
            iwae_upper_f = find_upper_series_for_4layer(fashion_root, algo="iwae", transform="sqrt")
            iwae_series_f = find_abs_gap_series_for_4layer(fashion_root, algo="iwae", gap_key=args.gap)
            # Middle-right: n vs |gap|
            ax = axes[2]
            x = 0.9 * iwae_series_f.num_train
            y = iwae_series_f.mean_abs_gap
            yerr = iwae_series_f.std_abs_gap
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
            ax.set_title("F-MNIST (IWAE): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)
            # Right: n vs Our bound
            ax = axes[3]
            x = 0.9 * iwae_upper_f.num_train
            y = iwae_upper_f.mean_upper
            yerr = iwae_upper_f.std_upper
            ax.plot(x, y, marker="o", markersize=10, color=row2_color, linestyle="-", linewidth=4)
            ax.fill_between(x, y - yerr, y + yerr, color=row2_color, alpha=shade_alpha)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
            ax.set_title("F-MNIST (IWAE): train size v.s. Our bound", fontsize=title_fs, fontweight="bold")
            ax.grid(True, linestyle=":", alpha=0.4)
            ax.tick_params(labelsize=tick_fs)

            fig.tight_layout()
            out_path = Path(args.out) if args.out else (out_dir / "iwae_mnist_fashion_1x4.png")
            out_path.parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(out_path)
            print(f"Saved figure to {out_path}")
        elif args.mode == "corr_table":
            # Compute correlations for both MNIST and F-MNIST, and output a LaTeX table
            fashion_root = base_root / "fashion_mnist"
            tables_dir = project_root / "results" / "tables"
            out_path = Path(args.out) if args.out else (tables_dir / "if_zu_corr_train30000.tex")
            write_if_zu_correlation_table(
                mnist_root=mnist_root,
                fashion_root=fashion_root,
                out_path=out_path,
                gap_key=args.gap,
                y_transform=args.y_transform,
                algo_filter=getattr(args, "algo", "both"),
            )


if __name__ == "__main__":
    main()


