import json
import re
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def find_target_json_files(root: Path) -> List[Path]:
    """
    Find all MI JSONs under root. Prefer new combined files 'mi_both_*encoder.json',
    fall back to legacy 'mi_params_u_encoder.json'.
    Excludes any path containing '/train10000_'.
    """
    patterns = ["mi_both_*encoder.json", "mi_params_u_encoder.json"]
    found: List[Path] = []
    for pat in patterns:
        for p in root.rglob(pat):
            if "train10000_" in str(p):
                continue
            found.append(p)
    # De-duplicate while preserving order (prefer 'mi_both' if duplicates)
    seen = set()
    unique: List[Path] = []
    for p in sorted(found):
        if p not in seen:
            unique.append(p)
            seen.add(p)
    return unique


def parse_train_size_from_path(path: Path) -> int:
    """Extract train size from directory name like 'train30000_beta0.1_lr0.0005'."""
    m = re.search(r"train(\d+)_", path.parent.name)
    if not m:
        raise ValueError(f"train size not found in directory name: {path.parent}")
    return int(m.group(1))


def read_summary_from_json(json_path: Path) -> Tuple[int, float, float]:
    """
    Read (num_train_samples, mean_upper_bound, std_upper_bound) from a MI JSON.
    Preference order: combined_summary > if_upper_bounds_summary > compute from per-split.
    """
    with json_path.open("r") as f:
        data = json.load(f)

    # Prefer combined summary when available (sum of two MI terms)
    if "combined_summary" in data:
        summary = data["combined_summary"]
        mean_ub = float(summary.get("mean_upper_bound"))
        std_ub = float(summary.get("std_upper_bound"))
        # num_train_samples: take from one of component arrays
        per_split_any = data.get("combined_per_split") or data.get("if_upper_bounds_per_split") or data.get("mi_zu_upper_per_split")
        if per_split_any:
            num_samples = int(per_split_any[0].get("num_train_samples", parse_train_size_from_path(json_path)))
        else:
            num_samples = parse_train_size_from_path(json_path)
        return num_samples, mean_ub, std_ub

    if "if_upper_bounds_summary" in data:
        summary = data["if_upper_bounds_summary"]
        mean_ub = float(summary.get("mean_upper_bound"))
        std_ub = float(summary.get("std_upper_bound"))
        num_samples = int(data["if_upper_bounds_per_split"][0]["num_train_samples"]) if data.get("if_upper_bounds_per_split") else parse_train_size_from_path(json_path)
        return num_samples, mean_ub, std_ub

    # Fallback: compute from per split entries
    per_split = data.get("if_upper_bounds_per_split", [])
    if not per_split:
        raise ValueError(f"No per-split data in {json_path}")
    num_samples = int(per_split[0]["num_train_samples"])  # assume constant
    vals = [float(item["upper_bound"]) for item in per_split]
    return num_samples, float(np.mean(vals)), float(np.std(vals, ddof=0))


def aggregate(mnist_root: Path) -> pd.DataFrame:
    files = find_target_json_files(mnist_root)
    rows: List[Dict[str, float]] = []
    for p in files:
        try:
            n, m, s = read_summary_from_json(p)
            rows.append({"num_train_samples": n, "mean_upper_bound": m, "std_upper_bound": s, "path": str(p)})
        except Exception as e:
            print(f"[WARN] skip {p}: {e}")
    if not rows:
        raise RuntimeError("No valid JSON files found to aggregate.")
    df = pd.DataFrame(rows)
    # Group by num_train_samples in case multiple seeds/runs with same train size exist
    agg = (
        df.groupby("num_train_samples", as_index=False)
        .agg(
            mean_upper_bound=("mean_upper_bound", "mean"),
            # Use the average of per-run (per-seed) stds for the error band
            std_upper_bound=("std_upper_bound", "mean"),
            n=("mean_upper_bound", "count"),
        )
        .sort_values("num_train_samples")
    )
    # If only one run for a train size, std becomes NaN; set to 0
    agg["std_upper_bound"] = agg["std_upper_bound"].fillna(0.0)
    return agg


def aggregate_gap_recon(mnist_root: Path) -> pd.DataFrame:
    """Aggregate avg_gap_recon_loss and std_gap_recon_loss per train size."""
    files = find_target_json_files(mnist_root)
    rows: List[Dict[str, float]] = []
    for p in files:
        exp_dir = p.parent
        eval_path = exp_dir / "evaluation_aggregated.json"
        try:
            with eval_path.open("r") as f:
                data = json.load(f)
            avg = float(data["average_metrics"]["avg_gap_recon_loss"])
            std = float(data["average_metrics"]["std_gap_recon_loss"])
            n_train = parse_train_size_from_path(exp_dir)
            rows.append({
                "num_train_samples": n_train,
                "avg_gap_recon_loss": avg,
                "std_gap_recon_loss": std,
                "path": str(eval_path),
            })
        except Exception as e:
            print(f"[WARN] skip {eval_path}: {e}")
    if not rows:
        return pd.DataFrame(columns=["num_train_samples", "avg_gap_recon_loss", "std_gap_recon_loss"])  # empty
    df = pd.DataFrame(rows)
    agg = (
        df.groupby("num_train_samples", as_index=False)
        .agg(
            avg_gap_recon_loss=("avg_gap_recon_loss", "mean"),
            std_gap_recon_loss=("std_gap_recon_loss", "mean"),
            n=("avg_gap_recon_loss", "count"),
        )
        .sort_values("num_train_samples")
    )
    agg["std_gap_recon_loss"] = agg["std_gap_recon_loss"].fillna(0.0)
    return agg


def plot(df_upper: pd.DataFrame, df_gap: Optional[pd.DataFrame], out_path: Path) -> None:
    x_u = df_upper["num_train_samples"].to_numpy()
    y_u = df_upper["mean_upper_bound"].to_numpy()
    yerr_u = df_upper["std_upper_bound"].to_numpy()

    fig, ax1 = plt.subplots(figsize=(7, 4.5), dpi=160)
    # Left axis: upper bound
    ax1.plot(x_u, y_u, marker="o", color="#1f77b4", label="upper_bound (mean)")
    ax1.fill_between(x_u, y_u - yerr_u, y_u + yerr_u, color="#1f77b4", alpha=0.2, label="upper ±1 std")
    ax1.set_xlabel("num_train_samples")
    ax1.set_ylabel("upper_bound")
    ax1.set_title("MNIST: upper_bound vs num_train_samples")
    ax1.grid(True, linestyle=":", alpha=0.5)

    handles, labels = ax1.get_legend_handles_labels()

    # Right axis: avg_gap_recon_loss ± std
    if df_gap is not None and not df_gap.empty:
        x_g = df_gap["num_train_samples"].to_numpy()
        y_g = df_gap["avg_gap_recon_loss"].to_numpy()
        yerr_g = df_gap["std_gap_recon_loss"].to_numpy()
        ax2 = ax1.twinx()
        ax2.plot(x_g, y_g, marker="^", color="#ff7f0e", label="gap_recon_loss (avg)")
        ax2.fill_between(x_g, y_g - yerr_g, y_g + yerr_g, color="#ff7f0e", alpha=0.2, label="gap ±1 std")
        ax2.set_ylabel("avg_gap_recon_loss")
        h2, l2 = ax2.get_legend_handles_labels()
        handles += h2
        labels += l2

    ax1.legend(handles, labels, loc="best")
    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path)
    print(f"Saved figure to {out_path}")


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Plot upper bound vs num_train_samples")
    parser.add_argument("--root", type=str, default=None, help="Root directory under which to search experiments (default: results/experiments)")
    parser.add_argument("--dataset", type=str, default=None, help="Optional dataset subdir (e.g., mnist, fashion_mnist)")
    parser.add_argument("--out", type=str, default=None, help="Output figure path (default under results/figures)")
    args = parser.parse_args()

    project_root = Path(__file__).resolve().parents[1]
    base_root = Path(args.root) if args.root else (project_root / "results" / "experiments")
    search_root = base_root / args.dataset if args.dataset else base_root
    df_upper = aggregate(search_root)
    # Gap recon currently only supported for MNIST layout; guard with try
    try:
        df_gap = aggregate_gap_recon(search_root)
    except Exception:
        df_gap = None
    out_dir = project_root / "results" / "figures"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = Path(args.out) if args.out else (out_dir / "upper_bound_vs_num_train_samples.png")
    plot(df_upper, df_gap, out_path)


if __name__ == "__main__":
    main()


