import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np


# ----------------------------------
# Basic utilities
# ----------------------------------


TRAIN_DIR_RE = re.compile(r"train(\d+)_")


def parse_train_size_from_path(path: Path) -> int:
    m = TRAIN_DIR_RE.search(path.name)
    if not m:
        m = TRAIN_DIR_RE.search(path.parent.name)
    if not m:
        raise ValueError(f"train size not found in path: {path}")
    return int(m.group(1))


def read_json(path: Path) -> Dict:
    with path.open("r") as f:
        return json.load(f)


def _read_combined_upper_per_split(mi_path: Path) -> List[float]:
    data = read_json(mi_path)
    cps = data.get("combined_per_split")
    if isinstance(cps, list) and cps:
        vals = [x.get("combined_upper") for x in cps]
        vals = [float(v) for v in vals if v is not None]
        if vals:
            return vals
    iups = data.get("if_upper_bounds_per_split")
    if isinstance(iups, list) and iups:
        vals = [x.get("upper_bound") for x in iups]
        vals = [float(v) for v in vals if v is not None]
        if vals:
            return vals
    s = None
    if data.get("combined_summary") and data["combined_summary"].get("mean_upper_bound") is not None:
        s = float(data["combined_summary"]["mean_upper_bound"])
    elif data.get("if_upper_bounds_summary") and data["if_upper_bounds_summary"].get("mean_upper_bound") is not None:
        s = float(data["if_upper_bounds_summary"]["mean_upper_bound"])
    if s is not None:
        return [s]
    return []


def _delta_method_std_after_sqrt(mean_x: float, std_x: float) -> float:
    if mean_x <= 0.0:
        return 0.0
    denom = 2.0 * max(np.sqrt(mean_x), 1e-12)
    return float(std_x / denom)


def _detect_hvae_dir(base: Path, algo: str = "iwae") -> Path:
    """
    Automatically detect the HVAE experiment directory.
    Example: iwae_hmlp_latent32_16_8_4_hidden512_256_128_64
    """
    candidates: List[Path] = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if not d.name.startswith(f"{algo}_"):
            continue
        # Prefer names containing a hierarchical latent spec with 4 levels
        m = re.search(r"latent([0-9_]+)", d.name)
        if m:
            latent_parts = m.group(1).split("_")
            if len(latent_parts) == 4:
                candidates.append(d)
    if not candidates:
        # Fallback: search depth-first for mi_hierarchical_ef_l*.json and return its parent experiment dir
        for mi_file in base.rglob("mi_hierarchical_ef_l1.json"):
            exp_dir = mi_file.parent.parent  # one above train*/ (model directory)
            if exp_dir.name.startswith(f"{algo}_"):
                return exp_dir
        raise RuntimeError(f"HVAE dir not found under {base} for algo={algo}")
    # Use the first candidate
    return candidates[0]


# ----------------------------------
# Data classes
# ----------------------------------


@dataclass
class LayerUpperSeries:
    layer: int
    num_train: np.ndarray
    mean_upper: np.ndarray
    std_upper: np.ndarray


# ----------------------------------
# Aggregation
# ----------------------------------


def find_hvae_layer_upper_series(
    base: Path,
    layer: int,
    algo: str = "iwae",
    transform: str = "sqrt",
) -> LayerUpperSeries:
    """
    For each HVAE layer l, aggregate the mean/std of combined_upper across train* sizes (over splits).
    If transform="sqrt", apply sqrt(max(upper, 0)) to per-split values before computing mean/std.
    If only summaries exist, approximate std via the delta method.
    """
    model_dir = _detect_hvae_dir(base, algo=algo)

    rows: List[Tuple[int, float, float]] = []
    for train_dir in model_dir.iterdir():
        if not train_dir.is_dir() or not train_dir.name.startswith("train"):
            continue
        # Filter by algo name (e.g., train directories containing "iwae")
        is_algo = (algo == "vae" and "elbo" in train_dir.name) or (algo == "iwae" and "iwae" in train_dir.name)
        if not is_algo:
            continue
        mi_path = train_dir / f"mi_hierarchical_ef_l{layer}.json"
        if not mi_path.exists():
            # Skip sizes that were not executed
            continue
        try:
            vals = _read_combined_upper_per_split(mi_path)
            if vals:
                arr = np.asarray(vals, dtype=float)
                if transform == "sqrt":
                    arr = np.sqrt(np.clip(arr, a_min=0.0, a_max=None))
                mean_upper = float(np.mean(arr))
                std_upper = float(np.std(arr, ddof=0))
            else:
                data = read_json(mi_path)
                if data.get("combined_summary") and data["combined_summary"].get("mean_upper_bound") is not None:
                    mu = float(data["combined_summary"].get("mean_upper_bound"))
                    su = float(data["combined_summary"].get("std_upper_bound", 0.0))
                elif data.get("if_upper_bounds_summary") and data["if_upper_bounds_summary"].get("mean_upper_bound") is not None:
                    mu = float(data["if_upper_bounds_summary"].get("mean_upper_bound"))
                    su = float(data["if_upper_bounds_summary"].get("std_upper_bound", 0.0))
                else:
                    continue
                if transform == "sqrt":
                    mean_upper = float(np.sqrt(max(mu, 0.0)))
                    std_upper = _delta_method_std_after_sqrt(mu, su)
                else:
                    mean_upper, std_upper = mu, su
            n = parse_train_size_from_path(train_dir)
            rows.append((n, mean_upper, std_upper))
        except Exception:
            continue

    if not rows:
        raise RuntimeError(f"No HVAE MI files found for layer={layer} under {model_dir}")

    rows.sort(key=lambda x: x[0])
    num_train = np.array([r[0] for r in rows], dtype=np.int64)
    mean_upper = np.array([r[1] for r in rows], dtype=np.float64)
    std_upper = np.array([r[2] for r in rows], dtype=np.float64)
    return LayerUpperSeries(layer=layer, num_train=num_train, mean_upper=mean_upper, std_upper=std_upper)


# ----------------------------------
# Gap series (for HVAE)
# ----------------------------------


@dataclass
class GapSeries:
    num_train: np.ndarray
    mean_abs_gap: np.ndarray
    std_abs_gap: np.ndarray


@dataclass
class BoundComponentsLayerSeries:
    layer: int
    num_train: np.ndarray
    mean_if: np.ndarray
    std_if: np.ndarray
    mean_zu: np.ndarray
    std_zu: np.ndarray


def _read_aggregated_results(path: Path) -> Dict:
    return read_json(path)


def _compute_abs_gap_stats_from_aggregated(agg: Dict, gap_key: str = "gap_recon_loss") -> Tuple[float, float]:
    individuals = agg.get("individual_results") or agg.get("individual", [])
    if not individuals:
        avg_key = f"avg_{gap_key}"
        std_key = f"std_{gap_key}"
        avg = float(agg["average_metrics"][avg_key])
        std = float(agg["average_metrics"].get(std_key, 0.0))
        return abs(avg), std
    vals = [abs(float(item.get(gap_key))) for item in individuals if item.get(gap_key) is not None]
    if not vals:
        avg_key = f"avg_{gap_key}"
        std_key = f"std_{gap_key}"
        avg = float(agg["average_metrics"][avg_key])
        std = float(agg["average_metrics"].get(std_key, 0.0))
        return abs(avg), std
    return float(np.mean(vals)), float(np.std(vals, ddof=0))


def find_hvae_gap_series(
    base: Path,
    algo: str = "iwae",
    gap_key: str = "gap_loss",
) -> GapSeries:
    model_dir = _detect_hvae_dir(base, algo=algo)
    rows: List[Tuple[int, float, float]] = []
    for train_dir in model_dir.iterdir():
        if not train_dir.is_dir() or not train_dir.name.startswith("train"):
            continue
        is_algo = (algo == "vae" and "elbo" in train_dir.name) or (algo == "iwae" and "iwae" in train_dir.name)
        if not is_algo:
            continue
        agg_path = train_dir / "aggregated_results.json"
        if not agg_path.exists():
            continue
        try:
            agg = _read_aggregated_results(agg_path)
            mean_abs, std_abs = _compute_abs_gap_stats_from_aggregated(agg, gap_key=gap_key)
            n = parse_train_size_from_path(train_dir)
            rows.append((n, float(mean_abs), float(std_abs)))
        except Exception:
            continue
    if not rows:
        raise RuntimeError(f"No aggregated_results.json found for HVAE under {model_dir}")
    rows.sort(key=lambda x: x[0])
    num_train = np.array([r[0] for r in rows], dtype=np.int64)
    mean_abs = np.array([r[1] for r in rows], dtype=np.float64)
    std_abs = np.array([r[2] for r in rows], dtype=np.float64)
    return GapSeries(num_train=num_train, mean_abs_gap=mean_abs, std_abs_gap=std_abs)


# ----------------------------------
# Plotting
# ----------------------------------


def _rankdata_average(values: np.ndarray) -> np.ndarray:
    n = len(values)
    order = np.argsort(values, kind="mergesort")
    ranks = np.empty(n, dtype=float)
    sorted_vals = values[order]
    i = 0
    while i < n:
        j = i
        while j + 1 < n and sorted_vals[j + 1] == sorted_vals[i]:
            j += 1
        avg_rank = (i + 1 + j + 1) / 2.0
        ranks[order[i:j + 1]] = avg_rank
        i = j + 1
    return ranks


def _safe_pearsonr(x: np.ndarray, y: np.ndarray) -> float:
    if x.size < 2 or y.size < 2:
        return float("nan")
    if np.allclose(np.std(x), 0.0) or np.allclose(np.std(y), 0.0):
        return float("nan")
    return float(np.corrcoef(x, y)[0, 1])


def _spearmanr(x: np.ndarray, y: np.ndarray) -> float:
    if x.size < 2 or y.size < 2:
        return float("nan")
    rx = _rankdata_average(x)
    ry = _rankdata_average(y)
    return _safe_pearsonr(rx, ry)


def _kendall_tau_a(x: np.ndarray, y: np.ndarray) -> float:
    n = x.size
    if n < 2:
        return float("nan")
    concordant = 0
    discordant = 0
    for i in range(n - 1):
        dx = x[i + 1:] - x[i]
        dy = y[i + 1:] - y[i]
        s = dx * dy
        concordant += int(np.sum(s > 0))
        discordant += int(np.sum(s < 0))
    denom = n * (n - 1) / 2.0
    if denom == 0:
        return float("nan")
    return float((concordant - discordant) / denom)


def _apply_y_transform_values(values: np.ndarray, mode: str = "none") -> np.ndarray:
    """
    y-transform consistent with the VAE correlation tables.
    - none: no transform
    - sqrt_minmax: take sqrt(max(y,0)) and then min-max normalize
    """
    if mode == "sqrt_minmax":
        vals = np.sqrt(np.clip(values, a_min=0.0, a_max=None))
        if vals.size == 0:
            return vals
        vmin = float(np.min(vals))
        vmax = float(np.max(vals))
        denom = (vmax - vmin)
        if denom == 0.0:
            return np.zeros_like(vals)
        return (vals - vmin) / denom
    return values


def _read_gap_per_split_from_aggregated_global(agg: Dict, gap_key: str = "gap_loss") -> List[float]:
    """
    Return per-split gap values from aggregated_results.json.
    If individuals are missing, return a single value [avg_{gap_key}].
    """
    individuals = agg.get("individual_results") or []
    vals = [float(item[gap_key]) for item in individuals if gap_key in item]
    if vals:
        return vals
    avg = agg.get("average_metrics", {}).get(f"avg_{gap_key}")
    return [float(avg)] if avg is not None else []


def _read_if_zu_per_split_hvae(mi_path: Path) -> Tuple[List[float], List[float]]:
    data = read_json(mi_path)
    ifs: List[float] = []
    zus: List[float] = []
    cps = data.get("combined_per_split")
    if isinstance(cps, list) and cps:
        for x in cps:
            if x.get("if_params_u_upper") is not None:
                ifs.append(float(x["if_params_u_upper"]))
            if x.get("z1tol_upper") is not None:
                zus.append(float(x["z1tol_upper"]))
    # Fallbacks
    if not ifs and data.get("if_upper_bounds_per_split"):
        ifs = [float(x.get("upper_bound")) for x in data["if_upper_bounds_per_split"] if x.get("upper_bound") is not None]
    if not zus and data.get("mi_z_upper_per_split"):
        zus = [float(x.get("mi_upper_bound")) for x in data["mi_z_upper_per_split"] if x.get("mi_upper_bound") is not None]
    return ifs, zus


def write_hvae_if_zu_correlation_table(
    mnist_root: Path,
    fashion_root: Path,
    out_path: Path,
    algo: str = "iwae",
    gap_key: str = "gap_loss",
    y_transform: str = "none",
    include_sum: bool = False,
):
    """
    For n=30000 per-split data, compute correlations between |gap| and IF/ZU (and optionally IF+ZU)
    using Pearson/Spearman/Kendall, and output a LaTeX table.

    - Datasets: MNIST, Fashion-MNIST
    - algo: "iwae" or "vae" (default: iwae)
    - gap_key: key in aggregated_results.json (default: gap_loss)
    - y_transform: {"none", "sqrt_minmax"}
    - include_sum: if True, also include IF+ZU
    """

    def summarize_one_dataset(dataset_label: str, base: Path) -> List[Tuple[str, str, float, float, float]]:
        rows_local: List[Tuple[str, str, float, float, float]] = []
        try:
            model_dir = _detect_hvae_dir(base, algo=algo)
        except Exception:
            # No data
            for comp in (["IF", "ZU", "IF+ZU"] if include_sum else ["IF", "ZU"]):
                rows_local.append((dataset_label, comp, float("nan"), float("nan"), float("nan")))
            return rows_local

        # Prefer train=30000; otherwise choose the maximum n
        train_dirs = [
            d for d in model_dir.iterdir()
            if d.is_dir() and d.name.startswith("train") and ((algo == "vae" and "elbo" in d.name) or (algo == "iwae" and "iwae" in d.name))
        ]
        if not train_dirs:
            for comp in (["IF", "ZU", "IF+ZU"] if include_sum else ["IF", "ZU"]):
                rows_local.append((dataset_label, comp, float("nan"), float("nan"), float("nan")))
            return rows_local

        chosen_dir = None
        train_dirs_sorted = sorted(train_dirs, key=lambda p: parse_train_size_from_path(p))
        for td in train_dirs_sorted:
            if parse_train_size_from_path(td) == 30000:
                chosen_dir = td
                break
        if chosen_dir is None:
            chosen_dir = train_dirs_sorted[-1]

        agg_path = chosen_dir / "aggregated_results.json"
        if not agg_path.exists():
            for comp in (["IF", "ZU", "IF+ZU"] if include_sum else ["IF", "ZU"]):
                rows_local.append((dataset_label, comp, float("nan"), float("nan"), float("nan")))
            return rows_local

        try:
            agg = _read_aggregated_results(agg_path)
            gaps = _read_gap_per_split_from_aggregated_global(agg, gap_key=gap_key)
        except Exception:
            gaps = []

        # Read per-split values per layer
        uppers_per_layer: List[List[float]] = []
        ifs_per_layer: List[List[float]] = []
        zus_per_layer: List[List[float]] = []
        for li in (1, 2, 3, 4):
            mi_path = chosen_dir / f"mi_hierarchical_ef_l{li}.json"
            if not mi_path.exists():
                continue
            try:
                uppers = _read_combined_upper_per_split(mi_path)
                ifs_l, zus_l = _read_if_zu_per_split_hvae(mi_path)
                if uppers:
                    uppers_per_layer.append([float(v) for v in uppers])
                    ifs_per_layer.append([float(v) for v in ifs_l])
                    zus_per_layer.append([float(v) for v in zus_l])
            except Exception:
                continue

        if not gaps or not uppers_per_layer or not ifs_per_layer or not zus_per_layer:
            for comp in (["IF", "ZU", "IF+ZU"] if include_sum else ["IF", "ZU"]):
                rows_local.append((dataset_label, comp, float("nan"), float("nan"), float("nan")))
            return rows_local

        # For each split, choose the layer with the smallest upper bound and take IF/ZU from that layer
        n = min([len(gaps)] + [len(u) for u in uppers_per_layer])
        xs_all: List[float] = []
        chosen_if: List[float] = []
        chosen_zu: List[float] = []
        for i in range(n):
            # Choose the minimum among available layers only
            layer_vals = [u[i] for u in uppers_per_layer if i < len(u)]
            if not layer_vals:
                continue
            li_min = int(np.argmin(layer_vals))  # Index is 0-based after compaction
            # Retrieve with the same index assuming compaction
            if i < len(ifs_per_layer[li_min]) and i < len(zus_per_layer[li_min]):
                xs_all.append(abs(float(gaps[i])))
                chosen_if.append(float(ifs_per_layer[li_min][i]))
                chosen_zu.append(float(zus_per_layer[li_min][i]))

        # Compute correlations
        def _corr(xs: List[float], ys: List[float]) -> Tuple[float, float, float]:
            x = np.asarray(xs, dtype=float)
            y_raw = np.asarray(ys, dtype=float)
            y = _apply_y_transform_values(y_raw, mode=y_transform)
            return _safe_pearsonr(x, y), _spearmanr(x, y), _kendall_tau_a(x, y)

        pr_if, sp_if, kd_if = _corr(xs_all, chosen_if)
        rows_local.append((dataset_label, "IF", pr_if, sp_if, kd_if))

        pr_zu, sp_zu, kd_zu = _corr(xs_all, chosen_zu)
        rows_local.append((dataset_label, "ZU", pr_zu, sp_zu, kd_zu))

        if include_sum:
            chosen_sum = (np.asarray(chosen_if, dtype=float) + np.asarray(chosen_zu, dtype=float)).tolist()
            pr_s, sp_s, kd_s = _corr(xs_all, chosen_sum)
            rows_local.append((dataset_label, "IF+ZU", pr_s, sp_s, kd_s))

        return rows_local

    # Build rows
    rows: List[Tuple[str, str, float, float, float]] = []
    rows.extend(summarize_one_dataset("MNIST", mnist_root))
    rows.extend(summarize_one_dataset("F-MNIST", fashion_root))

    # LaTeX output
    lines: List[str] = []
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\small")
    lines.append("\\begin{tabular}{l l r r r}")
    lines.append("\\hline")
    lines.append("Dataset & Component & Pearson & Spearman & Kendall \\\\")
    lines.append("\\hline")
    for ds, comp, pr, sp, kd in rows:
        if np.isnan(pr) or np.isnan(sp) or np.isnan(kd):
            lines.append(f"{ds} & {comp} & NaN & NaN & NaN \\\\")
        else:
            lines.append(f"{ds} & {comp} & {pr:.3f} & {sp:.3f} & {kd:.3f} \\\\")
    lines.append("\\hline")
    lines.append("\\end{tabular}")
    lines.append(
        f"\\caption{{Correlation between |{gap_key}| and bounded components at n=30000. Algo: H-{algo.upper()}, y-transform: {y_transform}.}}"
    )
    lines.append("\\label{tab:hvae_if_zu_corr_train30000}")
    lines.append("\\end{table}")

    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text("\n".join(lines))
    print(f"Saved LaTeX table to {out_path}")


def write_hvae_if_zu_correlation_table_per_layer(
    mnist_root: Path,
    fashion_root: Path,
    out_path: Path,
    algo: str = "iwae",
    gap_key: str = "gap_loss",
    y_transform: str = "none",
    include_sum: bool = True,
):
    """
    For n=30000 per-split data, compute correlations between |gap| and IF/ZU (and IF+ZU)
    per layer (l=1..4) using Pearson/Spearman/Kendall, and output a LaTeX table where each
    correlation type has 4 columns (l=1..4).

    - Datasets: MNIST, Fashion-MNIST
    - algo: "iwae" or "vae" (default: iwae)
    - gap_key: key in aggregated_results.json (default: gap_loss)
    - y_transform: {"none", "sqrt_minmax"}
    - include_sum: if True, also include IF+ZU
    """

    def summarize_one_dataset(dataset_label: str, base: Path) -> Dict[str, Tuple[List[float], List[float], List[float]]]:
        """Return: {component: (pearson_by_l[4], spearman_by_l[4], kendall_by_l[4])}"""
        result: Dict[str, Tuple[List[float], List[float], List[float]]] = {}
        try:
            model_dir = _detect_hvae_dir(base, algo=algo)
        except Exception:
            # No data
            nan4 = [float("nan")] * 4
            result["IF"] = (nan4[:], nan4[:], nan4[:])
            result["ZU"] = (nan4[:], nan4[:], nan4[:])
            if include_sum:
                result["IF+ZU"] = (nan4[:], nan4[:], nan4[:])
            return result

        # Prefer train=30000; otherwise choose the maximum n
        train_dirs = [
            d for d in model_dir.iterdir()
            if d.is_dir() and d.name.startswith("train") and ((algo == "vae" and "elbo" in d.name) or (algo == "iwae" and "iwae" in d.name))
        ]
        if not train_dirs:
            nan4 = [float("nan")] * 4
            result["IF"] = (nan4[:], nan4[:], nan4[:])
            result["ZU"] = (nan4[:], nan4[:], nan4[:])
            if include_sum:
                result["IF+ZU"] = (nan4[:], nan4[:], nan4[:])
            return result

        chosen_dir: Optional[Path] = None
        train_dirs_sorted = sorted(train_dirs, key=lambda p: parse_train_size_from_path(p))
        for td in train_dirs_sorted:
            if parse_train_size_from_path(td) == 30000:
                chosen_dir = td
                break
        if chosen_dir is None:
            chosen_dir = train_dirs_sorted[-1]

        agg_path = chosen_dir / "aggregated_results.json"
        if not agg_path.exists():
            nan4 = [float("nan")] * 4
            result["IF"] = (nan4[:], nan4[:], nan4[:])
            result["ZU"] = (nan4[:], nan4[:], nan4[:])
            if include_sum:
                result["IF+ZU"] = (nan4[:], nan4[:], nan4[:])
            return result

        try:
            agg = _read_aggregated_results(agg_path)
            gaps = _read_gap_per_split_from_aggregated_global(agg, gap_key=gap_key)
        except Exception:
            gaps = []

        # Per-split values per layer
        ifs_per_layer: List[List[float]] = [[] for _ in range(4)]
        zus_per_layer: List[List[float]] = [[] for _ in range(4)]
        for li in (1, 2, 3, 4):
            mi_path = chosen_dir / f"mi_hierarchical_ef_l{li}.json"
            if not mi_path.exists():
                continue
            try:
                ifs_l, zus_l = _read_if_zu_per_split_hvae(mi_path)
                ifs_per_layer[li - 1] = [float(v) for v in ifs_l]
                zus_per_layer[li - 1] = [float(v) for v in zus_l]
            except Exception:
                continue

        def _corrs_by_layer(ys_per_layer: List[List[float]]) -> Tuple[List[float], List[float], List[float]]:
            pr_list: List[float] = []
            sp_list: List[float] = []
            kd_list: List[float] = []
            for li_idx in range(4):
                ys_raw = np.asarray(ys_per_layer[li_idx], dtype=float)
                n_li = int(min(len(gaps), ys_raw.size))
                if n_li < 2:
                    pr_list.append(float("nan"))
                    sp_list.append(float("nan"))
                    kd_list.append(float("nan"))
                    continue
                xs = np.abs(np.asarray(gaps[:n_li], dtype=float))
                ys = _apply_y_transform_values(ys_raw[:n_li], mode=y_transform)
                pr_list.append(_safe_pearsonr(xs, ys))
                sp_list.append(_spearmanr(xs, ys))
                kd_list.append(_kendall_tau_a(xs, ys))
            return pr_list, sp_list, kd_list

        # IF / ZU
        pr_if, sp_if, kd_if = _corrs_by_layer(ifs_per_layer)
        pr_zu, sp_zu, kd_zu = _corrs_by_layer(zus_per_layer)
        result["IF"] = (pr_if, sp_if, kd_if)
        result["ZU"] = (pr_zu, sp_zu, kd_zu)

        # IF+ZU (per layer)
        if include_sum:
            sum_per_layer: List[List[float]] = []
            for li_idx in range(4):
                a = np.asarray(ifs_per_layer[li_idx], dtype=float)
                b = np.asarray(zus_per_layer[li_idx], dtype=float)
                n_li = int(min(a.size, b.size))
                if n_li == 0:
                    sum_per_layer.append([])
                else:
                    sum_per_layer.append((a[:n_li] + b[:n_li]).tolist())
            pr_s, sp_s, kd_s = _corrs_by_layer(sum_per_layer)
            result["IF+ZU"] = (pr_s, sp_s, kd_s)

        return result

    # Collect row data
    mnist_rows = summarize_one_dataset("MNIST", mnist_root)
    fm_rows = summarize_one_dataset("F-MNIST", fashion_root)

    # LaTeX output (3 correlations × 4 columns each)
    lines: List[str] = []
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\small")
    lines.append("\\begin{tabular}{l l r r r r r r r r r r r r}")
    lines.append("\\hline")
    lines.append("Dataset & Component & \\multicolumn{4}{c}{Pearson} & \\multicolumn{4}{c}{Spearman} & \\multicolumn{4}{c}{Kendall} \\\\")
    lines.append(" &  & l=1 & l=2 & l=3 & l=4 & l=1 & l=2 & l=3 & l=4 & l=1 & l=2 & l=3 & l=4 \\\\")
    lines.append("\\hline")

    def _fmt(vals: List[float]) -> List[str]:
        return [("NaN" if (v is None or (isinstance(v, float) and np.isnan(v))) else f"{v:.3f}") for v in vals]

    def _append_rows(ds_label: str, rows_map: Dict[str, Tuple[List[float], List[float], List[float]]]):
        order = ["IF", "ZU"] + (["IF+ZU"] if include_sum else [])
        for comp in order:
            if comp not in rows_map:
                pr, sp, kd = [float("nan")]*4, [float("nan")]*4, [float("nan")]*4
            else:
                pr, sp, kd = rows_map[comp]
            cells = [ds_label, comp] + _fmt(pr) + _fmt(sp) + _fmt(kd)
            lines.append(" & ".join(cells) + " \\\\")

    _append_rows("MNIST", mnist_rows)
    _append_rows("F-MNIST", fm_rows)

    lines.append("\\hline")
    lines.append("\\end{tabular}")
    algo_label = f"H-{algo.upper()}"
    lines.append(
        f"\\caption{{Correlation between |{gap_key}| and bounded components at n=30000 by focus layer (l=1..4). Algo: {algo_label}, y-transform: {y_transform}.}}"
    )
    lines.append("\\label{tab:hvae_if_zu_corr_train30000_by_layer}")
    lines.append("\\end{table}")

    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text("\n".join(lines))
    print(f"Saved LaTeX table to {out_path}")

def find_hvae_if_zu_series(
    base: Path,
    layer: int,
    algo: str = "iwae",
    normalize: bool = False,
) -> BoundComponentsLayerSeries:
    model_dir = _detect_hvae_dir(base, algo=algo)
    all_ifs: List[float] = []
    all_zus: List[float] = []
    per_train_paths: List[Tuple[int, Path]] = []
    for train_dir in model_dir.iterdir():
        if not train_dir.is_dir() or not train_dir.name.startswith("train"):
            continue
        is_algo = (algo == "vae" and "elbo" in train_dir.name) or (algo == "iwae" and "iwae" in train_dir.name)
        if not is_algo:
            continue
        mi_path = train_dir / f"mi_hierarchical_ef_l{layer}.json"
        if not mi_path.exists():
            continue
        try:
            ifs, zus = _read_if_zu_per_split_hvae(mi_path)
            if ifs:
                all_ifs.extend(ifs)
            if zus:
                all_zus.extend(zus)
            n = parse_train_size_from_path(train_dir)
            per_train_paths.append((n, mi_path))
        except Exception:
            continue
    if not per_train_paths or not all_ifs or not all_zus:
        raise RuntimeError(f"No IF/ZU values found for HVAE layer={layer} under {model_dir}")

    if_min, if_max = float(np.min(all_ifs)), float(np.max(all_ifs))
    zu_min, zu_max = float(np.min(all_zus)), float(np.max(all_zus))
    if_denom = if_max - if_min if if_max > if_min else 1.0
    zu_denom = zu_max - zu_min if zu_max > zu_min else 1.0

    rows: List[Tuple[int, float, float, float, float]] = []
    for n, mi_path in per_train_paths:
        try:
            ifs, zus = _read_if_zu_per_split_hvae(mi_path)
            ifs_arr = np.asarray(ifs, dtype=float)
            zus_arr = np.asarray(zus, dtype=float)
            if normalize:
                ifs_arr = (ifs_arr - if_min) / if_denom
                zus_arr = (zus_arr - zu_min) / zu_denom
            rows.append((n, float(np.mean(ifs_arr)), float(np.std(ifs_arr, ddof=0)), float(np.mean(zus_arr)), float(np.std(zus_arr, ddof=0))))
        except Exception:
            continue
    if not rows:
        raise RuntimeError(f"Failed to compute IF/ZU series for HVAE layer={layer} under {model_dir}")
    rows.sort(key=lambda x: x[0])
    num_train = np.array([r[0] for r in rows], dtype=np.int64)
    mean_if = np.array([r[1] for r in rows], dtype=np.float64)
    std_if = np.array([r[2] for r in rows], dtype=np.float64)
    mean_zu = np.array([r[3] for r in rows], dtype=np.float64)
    std_zu = np.array([r[4] for r in rows], dtype=np.float64)
    return BoundComponentsLayerSeries(layer=layer, num_train=num_train, mean_if=mean_if, std_if=std_if, mean_zu=mean_zu, std_zu=std_zu)


def plot_upper_1x4(
    series_list: List[LayerUpperSeries],
    out_path: Path,
):
    fig, axes = plt.subplots(1, 4, figsize=(18, 4.2), dpi=140)

    for i, series in enumerate(sorted(series_list, key=lambda s: s.layer)):
        ax = axes[i]
        x = series.num_train
        y = series.mean_upper
        yerr = series.std_upper
        ax.plot(x, y, marker="o", linestyle="-", linewidth=2.5)
        ax.fill_between(x, y - yerr, y + yerr, alpha=0.2)
        ax.set_xlabel("num_train_samples")
        ax.set_ylabel("Our hierarchical bound (approx.)")
        ax.set_title(f"HVAE (IWAE): layer {series.layer}")
        ax.grid(True, linestyle=":", alpha=0.4)

    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path)
    print(f"Saved figure to {out_path}")


def plot_combo_gap_and_upper(
    ds_root: Path,
    gap_series: GapSeries,
    series_list: List[LayerUpperSeries],
    out_path: Path,
    algo: str,
    gap_key: str,
    transform: str,
):
    fig, axes = plt.subplots(1, 4, figsize=(22, 4.5), dpi=140)
    # Style (consistent with plot_mnist_vae_results.py)
    title_fs = 14
    label_fs = 14
    tick_fs = 13
    shade_alpha = 0.20

    # Left: gap (mean ± std)
    ax = axes[0]
    x = 0.9 * gap_series.num_train
    y = gap_series.mean_abs_gap
    yerr = gap_series.std_abs_gap
    color_gap = "#d62728"
    ax.plot(x, y, marker="o", linestyle="-", color=color_gap, linewidth=4)
    ax.fill_between(x, y - yerr, y + yerr, color=color_gap, alpha=shade_alpha)
    # Labels for titles
    dataset_tag = ds_root.name
    dataset_label = "MNIST" if dataset_tag == "mnist" else ("F-MNIST" if dataset_tag in ("fashion_mnist", "fashion-mnist") else dataset_tag)
    algo_label = "H-IWAE" if algo == "iwae" else "H-VAE"
    ax.set_xlabel("num_train_samples", fontsize=label_fs)
    ax.set_ylabel("Generalization gap in loss", fontsize=label_fs)
    ax.set_title(f"{dataset_label} ({algo_label}): train size v.s. Gen. gap", fontsize=title_fs, fontweight="bold")
    ax.grid(True, linestyle=":", alpha=0.4)
    ax.tick_params(labelsize=tick_fs)

    # Middle: overlay 4 layers (mean ± std)
    ax = axes[1]
    layers_sorted = sorted(series_list, key=lambda s: s.layer)
    if not layers_sorted:
        ax.text(0.5, 0.5, "No upper series", ha="center", va="center")
    else:
        # Colormap
        cmap = plt.cm.get_cmap("tab10", max(len(layers_sorted), 1))
        for i, series in enumerate(layers_sorted):
            x = 0.9 * series.num_train
            y = series.mean_upper
            yerr = series.std_upper
            col = cmap(i % cmap.N)
            ax.plot(x, y, marker="o", linestyle="-", color=col, linewidth=4, label=f"layer {series.layer}")
            ax.fill_between(x, y - yerr, y + yerr, color=col, alpha=shade_alpha)
    ax.set_xlabel("num_train_samples", fontsize=label_fs)
    ax.set_ylabel("Our bound (approx.)", fontsize=label_fs)
    ax.set_title(f"{dataset_label} ({algo_label}): train size v.s. Our bound (approx.)", fontsize=title_fs, fontweight="bold")
    ax.grid(True, linestyle=":", alpha=0.4)
    ax.legend(loc="best", fontsize=13)
    ax.tick_params(labelsize=tick_fs)

    # Right: scatter of |gap| vs Our bound (use layer-wise minimum upper per split)
    def _read_gap_per_split_from_aggregated(agg: Dict, gap_key: str = "gap_loss") -> List[float]:
        individuals = agg.get("individual_results") or []
        vals = [float(item[gap_key]) for item in individuals if gap_key in item]
        if vals:
            return vals
        avg = agg.get("average_metrics", {}).get(f"avg_{gap_key}")
        return [float(avg)] if avg is not None else []

    # Target train directory (prefer the largest n)
    model_dir = _detect_hvae_dir(ds_root, algo=algo)
    train_dirs = [d for d in model_dir.iterdir() if d.is_dir() and d.name.startswith("train") and ((algo == "vae" and "elbo" in d.name) or (algo == "iwae" and "iwae" in d.name))]
    if train_dirs:
        train_dirs.sort(key=lambda p: parse_train_size_from_path(p), reverse=True)
        target_train_dir = train_dirs[0]
        agg_path = target_train_dir / "aggregated_results.json"
        if agg_path.exists():
            try:
                agg = _read_aggregated_results(agg_path)
                gaps = _read_gap_per_split_from_aggregated(agg, gap_key=gap_key)
                # Read per-split uppers and components per layer; take the minimum upper per split
                uppers_per_layer: List[List[float]] = []
                ifs_per_layer: List[List[float]] = []
                zus_per_layer: List[List[float]] = []
                for li in (1, 2, 3, 4):
                    mi_path = target_train_dir / f"mi_hierarchical_ef_l{li}.json"
                    if not mi_path.exists():
                        continue
                    uppers = _read_combined_upper_per_split(mi_path)
                    ifs_l, zus_l = _read_if_zu_per_split_hvae(mi_path)
                    if uppers:
                        uppers_per_layer.append([float(v) for v in uppers])
                        ifs_per_layer.append([float(v) for v in ifs_l])
                        zus_per_layer.append([float(v) for v in zus_l])
                if uppers_per_layer and ifs_per_layer and zus_per_layer:
                    n = min([len(gaps)] + [len(u) for u in uppers_per_layer])
                    xs = np.abs(np.asarray(gaps[:n], dtype=float))
                    mins = []
                    chosen_layer_idx: List[int] = []  # 0-based layer index chosen per split
                    chosen_if: List[float] = []
                    chosen_zu: List[float] = []
                    for i in range(n):
                        # Choose the layer with the minimum upper among layers
                        layer_vals = [u[i] for u in uppers_per_layer]
                        li_min = int(np.argmin(layer_vals))
                        mins.append(layer_vals[li_min])
                        chosen_layer_idx.append(li_min)
                        # Take components from the same chosen layer
                        if li_min < len(ifs_per_layer) and i < len(ifs_per_layer[li_min]):
                            chosen_if.append(float(ifs_per_layer[li_min][i]))
                        if li_min < len(zus_per_layer) and i < len(zus_per_layer[li_min]):
                            chosen_zu.append(float(zus_per_layer[li_min][i]))
                    ys = np.asarray(mins, dtype=float)
                    if transform == "sqrt":
                        ys = np.sqrt(np.clip(ys, a_min=0.0, a_max=None))
                    ax_sc = axes[2]
                    color_pts = "#1f77b4"
                    ax_sc.scatter(xs, ys, s=70, alpha=0.85, color=color_pts)
                    # Regression line over all points
                    if xs.size >= 2:
                        x_min = float(np.min(xs))
                        x_max = float(np.max(xs))
                        if np.isfinite(x_min) and np.isfinite(x_max) and x_min != x_max:
                            slope, intercept = np.polyfit(xs, ys, 1)
                            x_line = np.linspace(x_min, x_max, 100)
                            y_line = slope * x_line + intercept
                            ax_sc.plot(x_line, y_line, linestyle="--", linewidth=4.0, color="#333333")
                    # Correlation metrics annotation
                    pearson = _safe_pearsonr(xs, ys)
                    spearman = _spearmanr(xs, ys)
                    kendall = _kendall_tau_a(xs, ys)
                    text = (
                        #f"n={xs.size}\n"
                        f"Pearson={pearson:.3f}\n"
                        f"Spearman={spearman:.3f}\n"
                        f"Kendall={kendall:.3f}"
                    )
                    ax_sc.text(0.02, 0.98, text, transform=ax_sc.transAxes, va="top", ha="left", fontsize=12,
                               bbox=dict(facecolor="white", alpha=0.65, edgecolor="none"))
                    ax_sc.set_xlabel("Generalization gap in loss (per split)", fontsize=label_fs)
                    ax_sc.set_ylabel("min. of Our bound (approx.)", fontsize=label_fs)
                    ax_sc.set_title(f"{dataset_label} ({algo_label}): correlation", fontsize=title_fs, fontweight="bold")
                    ax_sc.grid(True, linestyle=":", alpha=0.4)
                    ax_sc.tick_params(labelsize=tick_fs)
                    # Further right: plot IF/ZU per layer normalized to [0,1] with mean±std line plots
                    try:
                        ax_lc = axes[3]
                        # Normalize with global min-max (comparable across layers)
                        all_ifs_vals = np.asarray([v for lst in ifs_per_layer for v in lst], dtype=float)
                        all_zus_vals = np.asarray([v for lst in zus_per_layer for v in lst], dtype=float)
                        if_min = float(np.min(all_ifs_vals)) if all_ifs_vals.size else 0.0
                        if_max = float(np.max(all_ifs_vals)) if all_ifs_vals.size else 1.0
                        zu_min = float(np.min(all_zus_vals)) if all_zus_vals.size else 0.0
                        zu_max = float(np.max(all_zus_vals)) if all_zus_vals.size else 1.0
                        if_denom = (if_max - if_min) if if_max > if_min else 1.0
                        zu_denom = (zu_max - zu_min) if zu_max > zu_min else 1.0

                        IF_COLOR = "#009E73"
                        ZU_COLOR = "#CC79A7"
                        layers = np.array([1, 2, 3, 4], dtype=float)
                        if_means: List[float] = []
                        if_stds: List[float] = []
                        zu_means: List[float] = []
                        zu_stds: List[float] = []
                        for li_idx in range(4):
                            if li_idx < len(ifs_per_layer):
                                ifs_arr = np.asarray(ifs_per_layer[li_idx], dtype=float)
                                ifs_norm = (ifs_arr - if_min) / if_denom if ifs_arr.size else ifs_arr
                            else:
                                ifs_norm = np.array([], dtype=float)
                            if li_idx < len(zus_per_layer):
                                zus_arr = np.asarray(zus_per_layer[li_idx], dtype=float)
                                zus_norm = (zus_arr - zu_min) / zu_denom if zus_arr.size else zus_arr
                            else:
                                zus_norm = np.array([], dtype=float)
                            if_means.append(float(np.mean(ifs_norm)) if ifs_norm.size else np.nan)
                            if_stds.append(float(np.std(ifs_norm, ddof=0)) if ifs_norm.size else 0.0)
                            zu_means.append(float(np.mean(zus_norm)) if zus_norm.size else np.nan)
                            zu_stds.append(float(np.std(zus_norm, ddof=0)) if zus_norm.size else 0.0)

                        # Line (IF)
                        ax_lc.plot(layers, if_means, marker="o", color=IF_COLOR, linewidth=4, label=r"$I(\phi; U \mid X^n)$ (IF approx.)")
                        ax_lc.fill_between(layers, np.array(if_means) - np.array(if_stds), np.array(if_means) + np.array(if_stds), color=IF_COLOR, alpha=shade_alpha)
                        # Line (ZU)
                        ax_lc.plot(layers, zu_means, marker="s", color=ZU_COLOR, linewidth=4, linestyle='-', label=r"$I(Z^n; U \mid \phi, X^n)$ (upper bound)")
                        ax_lc.fill_between(layers, np.array(zu_means) - np.array(zu_stds), np.array(zu_means) + np.array(zu_stds), color=ZU_COLOR, alpha=0.12)
                        ax_lc.set_xlabel("layer l", fontsize=label_fs)
                        ax_lc.set_ylabel("normalized components", fontsize=label_fs)
                        ax_lc.set_title("Focused layer vs components", fontsize=title_fs, fontweight="bold")
                        ax_lc.set_xticks([1, 2, 3, 4])
                        # Set y-axis limits with dynamic margin
                        if_means_np = np.asarray(if_means, dtype=float)
                        if_stds_np = np.asarray(if_stds, dtype=float)
                        zu_means_np = np.asarray(zu_means, dtype=float)
                        zu_stds_np = np.asarray(zu_stds, dtype=float)
                        lows = np.concatenate([
                            if_means_np - if_stds_np,
                            zu_means_np - zu_stds_np,
                        ])
                        highs = np.concatenate([
                            if_means_np + if_stds_np,
                            zu_means_np + zu_stds_np,
                        ])
                        if lows.size and highs.size:
                            # Lower bound fixed at -0.05, upper bound fixed at 1.0
                            ax_lc.set_ylim(-0.05, 1.0)
                        else:
                            ax_lc.set_ylim(-0.05, 1.0)
                        ax_lc.grid(True, linestyle=":", alpha=0.4)
                        ax_lc.legend(loc="best", fontsize=13)
                        ax_lc.tick_params(labelsize=tick_fs)
                    except Exception:
                        pass
            except Exception:
                pass

    # Component panels are not produced in this mode (use components_1x4 instead)

    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path)
    print(f"Saved figure to {out_path}")


# ----------------------------------
# Entry point
# ----------------------------------


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Plot MNIST Hierarchical VAE (HVAE) results")
    parser.add_argument("--root", type=str, default=None, help="Root of results/experiments (default: project_root/results/experiments)")
    parser.add_argument("--dataset", type=str, default="mnist", help="Dataset subdirectory (e.g., mnist)")
    parser.add_argument("--algo", type=str, choices=["iwae", "vae"], default="iwae", help="HVAE estimation method (default: iwae)")
    parser.add_argument("--mode", type=str, choices=["upper_1x4", "combo_gap_upper", "components_1x4", "components_corr_table"], default="upper_1x4", help="Output mode")
    parser.add_argument("--transform", type=str, choices=["none", "sqrt"], default="sqrt", help="Transform applied to upper values")
    parser.add_argument("--gap", type=str, choices=["gap_loss", "gap_recon_loss", "gap_kl_loss"], default="gap_loss", help="Gap metric (for components_corr_table)")
    parser.add_argument("--y_transform", type=str, choices=["none", "sqrt_minmax"], default="none", help="y transform for IF/ZU (for components_corr_table)")
    parser.add_argument("--out", type=str, default=None, help="Output image path (default: auto)")
    args = parser.parse_args()

    project_root = Path(__file__).resolve().parents[1]
    base_root = Path(args.root) if args.root else (project_root / "results" / "experiments")
    ds_root = base_root / args.dataset
    out_dir = project_root / "results" / "figures"

    if args.mode in ("upper_1x4", "combo_gap_upper"):
        series_list: List[LayerUpperSeries] = []
        for layer in (1, 2, 3, 4):
            try:
                s = find_hvae_layer_upper_series(ds_root, layer=layer, algo=args.algo, transform=args.transform)
                series_list.append(s)
            except Exception:
                continue
        if not series_list:
            raise SystemExit("No HVAE series found for any layer")
        dataset_tag = ds_root.name
        base_name = f"{dataset_tag}_hvae_{args.algo}"
        if args.mode == "upper_1x4":
            out_path = Path(args.out) if args.out else (out_dir / f"{base_name}_upper_1x4.png")
            plot_upper_1x4(series_list, out_path)
        else:
            # combo: left=gap, center=4 layers overlay, right=scatter (by layer)
            gap_series = find_hvae_gap_series(ds_root, algo=args.algo, gap_key="gap_loss")
            out_path = Path(args.out) if args.out else (out_dir / f"{base_name}_combo_gap_upper.png")
            plot_combo_gap_and_upper(ds_root, gap_series, series_list, out_path, algo=args.algo, gap_key="gap_loss", transform=args.transform)
    elif args.mode == "components_1x4":
        # For each layer, plot IF/ZU in a 1x4 layout
        fig, axes = plt.subplots(1, 4, figsize=(20, 4.5), dpi=140)
        # Style (consistent with plot_mnist_vae_results.py)
        title_fs = 14
        label_fs = 14
        tick_fs = 13
        shade_alpha = 0.20
        for idx, li in enumerate((1, 2, 3, 4)):
            ax = axes[idx]
            try:
                comps = find_hvae_if_zu_series(ds_root, layer=li, algo=args.algo, normalize=False)
            except Exception:
                ax.text(0.5, 0.5, f"No data (layer {li})", ha="center", va="center")
                ax.set_axis_off()
                continue
            x = 0.9 * comps.num_train
            IF_COLOR = "#009E73"
            ZU_COLOR = "#CC79A7"
            l1, = ax.plot(x, comps.mean_if, marker="o", markersize=8, color=IF_COLOR, linewidth=4, label="IF")
            ax.fill_between(x, comps.mean_if - comps.std_if, comps.mean_if + comps.std_if, color=IF_COLOR, alpha=shade_alpha)
            ax.set_ylabel(r"$I(\phi; U | X^n)$ (IF approx.)", fontsize=label_fs, color=IF_COLOR)
            ax.tick_params(axis='y', labelcolor=IF_COLOR, labelsize=tick_fs)
            ax.spines['left'].set_color(IF_COLOR)
            ax2 = ax.twinx()
            l2, = ax2.plot(x, comps.mean_zu, marker="s", markersize=7, color=ZU_COLOR, linewidth=4, linestyle='-', label="ZU")
            ax2.fill_between(x, comps.mean_zu - comps.std_zu, comps.mean_zu + comps.std_zu, color=ZU_COLOR, alpha=0.12)
            ax2.set_ylabel(r"$I(Z^n; U | \phi, X^n)$ (upper)", fontsize=label_fs, color=ZU_COLOR)
            ax2.tick_params(axis='y', labelcolor=ZU_COLOR, labelsize=tick_fs)
            ax2.spines['right'].set_color(ZU_COLOR)
            ax.set_xlabel("num_train_samples", fontsize=label_fs)
            ax.set_title(f"Layer {li}", fontsize=title_fs, fontweight="bold")
            ax.tick_params(axis='x', labelsize=tick_fs)
            ax.grid(True, linestyle=":", alpha=0.4)
        dataset_tag = ds_root.name
        base_name = f"{dataset_tag}_hvae_{args.algo}"
        out_path = Path(args.out) if args.out else (out_dir / f"{base_name}_components_1x4.png")
        out_path.parent.mkdir(parents=True, exist_ok=True)
        fig.tight_layout()
        fig.savefig(out_path)
        print(f"Saved figure to {out_path}")
    elif args.mode == "components_corr_table":
        # For both MNIST and F-MNIST, compute correlations between |gap| and IF/ZU
        # for the layer achieving the minimum upper bound at each split, and save as TeX
        fashion_root = base_root / "fashion_mnist"
        tables_dir = project_root / "results" / "tables"
        out_path = Path(args.out) if args.out else (tables_dir / "hvae_if_zu_corr_train30000.tex")
        write_hvae_if_zu_correlation_table(
            mnist_root=ds_root,
            fashion_root=fashion_root,
            out_path=out_path,
            algo=args.algo,
            gap_key=args.gap,
            y_transform=args.y_transform,
            include_sum=True,
        )


if __name__ == "__main__":
    main()


