
# -*- coding: utf-8 -*-
"""
extrapolate_A_predict_k_with_plots_fixed.py

Fixes:
- bound_hcov imports `bound as B` (no evaluate_synset). We use B.epoch("test", ...) to evaluate.
- bound_hcov.get_dataset signature requires (dataset, data_path) and returns 9 outputs; we unpack accordingly.
- Ensure args has fields expected by bound.B.epoch (dsa_param, dc_aug_param, ipc, etc.)

Experiment A (predictive extrapolation) for Scaling/Coverage Laws.

Goal:
- Fit laws ONLY on a seen configuration subset (model families).
- Predict performance (worst-case synthetic risk) for:
    (i) the existing discrete k_list (with ground-truth from training),
    (ii) extra k values (k_pred) that may be non-available / not trained.

Definitions (as requested):
- We DO NOT train on real data baselines. We use a constant baseline Acc(D)=1,
  so the "gap" becomes worst-case synthetic risk:
      Delta_subset(k) = max_{mf in subset} (1 - Acc(mf, synth_k))
- Aggregation mode: worst-case (max) over the subset.

Laws:
1) Scaling Law (fit on seen):
      Delta(k) ≈ b + a / sqrt(k)
2) Coverage Law (fit on seen):
      Delta(k) ≈ b + a * sqrt(Hcov(subset)/k)

Prediction on unseen:
- Scaling: uses the same (a,b) from seen.
- Coverage: uses same (a,b) from seen but swaps Hcov_seen -> Hcov_unseen.

This file reuses strict-ecology infrastructure from bound_hcov.py:
- get_network_ecology
- find_pt_for_ipc, load_synth_from_pt
- estimate_update_signature, estimate_hcov
- B.epoch for training/eval
"""

import os
import json
import math
import csv
import argparse
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# Reuse code from bound_hcov.py (same directory or PYTHONPATH)
import bound_hcov as BH


# ----------------------------
# Fixed seen/unseen split (from user)
# ----------------------------
SEEN_MFS = [
    "ConvNet",
    "ConvNetD1", "ConvNetD2", "ConvNetD4",
    "ConvNetW32", "ConvNetW64", "ConvNetW256",
    "ConvNetAR", "ConvNetAL", "ConvNetASwish",
    "ConvNetLN", "ConvNetGN",
]
UNSEEN_MFS_RAW = [
    "ConvNetD3",
    "ConvNetW128",
    # user listed ConvNetAL twice; we intentionally do NOT include it since it's in seen
    "ConvNetAL", "ConvNetAL",
]

def dedupe_keep_order(xs: List[str]) -> List[str]:
    s = set()
    out = []
    for x in xs:
        if x not in s:
            out.append(x)
            s.add(x)
    return out

SEEN_MFS = dedupe_keep_order(SEEN_MFS)
UNSEEN_MFS = [x for x in dedupe_keep_order(UNSEEN_MFS_RAW) if x not in set(SEEN_MFS)]


# ----------------------------
# Helpers: delta definition
# ----------------------------
def worst_delta_from_acc(synth_acc: Dict[str, Dict[int, float]], mfs: List[str], k: int) -> float:
    """Delta_subset(k) = max_{mf in subset} (1 - acc_synth(mf,k))."""
    vals = []
    for mf in mfs:
        if mf not in synth_acc or k not in synth_acc[mf]:
            raise KeyError(f"Missing synth_acc for mf={mf}, k={k}")
        acc = float(synth_acc[mf][k])
        vals.append(1.0 - acc)
    return max(vals)


# ----------------------------
# Fitting / prediction
# ----------------------------
@dataclass
class FitResult:
    a: float
    b: float
    rmse: float
    r2: float

def fit_line(x: np.ndarray, y: np.ndarray) -> FitResult:
    """Fit y = b + a*x."""
    assert x.ndim == 1 and y.ndim == 1 and x.shape[0] == y.shape[0]
    X = np.stack([x, np.ones_like(x)], axis=1)
    coef, *_ = np.linalg.lstsq(X, y, rcond=None)
    a, b = float(coef[0]), float(coef[1])
    y_hat = a * x + b
    rmse = float(np.sqrt(np.mean((y_hat - y) ** 2)))
    denom = float(np.sum((y - y.mean()) ** 2))
    r2 = float(1.0 - np.sum((y_hat - y) ** 2) / denom) if denom > 1e-12 else 0.0
    return FitResult(a=a, b=b, rmse=rmse, r2=r2)

def fit_scaling_law(k_list: List[int], delta_list: List[float]) -> FitResult:
    x = np.array([1.0 / math.sqrt(k) for k in k_list], dtype=np.float64)
    y = np.array(delta_list, dtype=np.float64)
    return fit_line(x, y)

def fit_coverage_law(k_list: List[int], delta_list: List[float], H_seen: float) -> FitResult:
    x = np.array([math.sqrt(H_seen) / math.sqrt(k) for k in k_list], dtype=np.float64)
    y = np.array(delta_list, dtype=np.float64)
    return fit_line(x, y)

def predict_delta_scaling(fit: FitResult, k: int) -> float:
    return float(fit.b + fit.a / math.sqrt(k))

def predict_delta_coverage(fit: FitResult, H: float, k: int) -> float:
    return float(fit.b + fit.a * (math.sqrt(H) / math.sqrt(k)))


# ----------------------------
# Train/eval on synth for one mf,k
# ----------------------------
def sanitize_labels(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
    labels = labels.clone()
    if labels.dtype != torch.long:
        labels = labels.long()
    lab_min, lab_max = int(labels.min().item()), int(labels.max().item())
    if lab_min < 0 or lab_max >= num_classes:
        if set(torch.unique(labels).tolist()) == {-1, 1}:
            labels = (labels > 0).long()
        else:
            labels = labels - lab_min
    assert int(labels.min().item()) >= 0 and int(labels.max().item()) < num_classes, \
        f"Illegal labels after fix: min={labels.min()}, max={labels.max()}"
    return labels

def train_eval_on_synth(
    mf: str,
    images: torch.Tensor,
    labels: torch.Tensor,
    testloader: DataLoader,
    channel: int,
    num_classes: int,
    im_size: Tuple[int, int],
    args
) -> float:
    """
    Train on synthetic set and evaluate with bound.B.epoch("test", ...).
    NOTE: bound.py does not provide evaluate_synset in your codebase.
    """
    device = args.device
    net = BH.get_network_ecology(mf, channel, num_classes, im_size).to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr_net, momentum=0.9)
    criterion = nn.CrossEntropyLoss().to(device)

    ds = TensorDataset(images.to(device), labels.to(device))
    loader = DataLoader(ds, batch_size=args.batch_train, shuffle=True)

    for ep in range(args.epochs + 1):
        BH.B.epoch("train", loader, net, optimizer, criterion, args, aug=False)

    loss_test, acc = BH.B.epoch("test", testloader, net, optimizer, criterion, args, aug=False)
    return float(acc)


# ----------------------------
# Hcov estimation per subset
# ----------------------------
def estimate_hcov_for_subset(mfs: List[str], ref_images: torch.Tensor, ref_labels: torch.Tensor, args) -> float:
    u_means = {}
    u_samples = {}
    for mf in mfs:
        um, us = BH.estimate_update_signature(mf, ref_images, ref_labels, args, args.hcov_seeds, args.hcov_num_batches)
        u_means[mf] = um
        u_samples[mf] = us
    H = BH.estimate_hcov(u_means, u_samples)
    return float(H)


# ----------------------------
# Plotting (unseen actual vs predicted curve + scatter)
# ----------------------------
def plot_unseen_curve_and_scatter(
    out_dir: str,
    valid_k: List[int],
    delta_unseen: List[float],
    pred_scal: List[float],
    pred_cov: List[float],
):
    # ---- curve
    ks = np.array(valid_k, dtype=np.float64)

    fig = plt.figure()
    plt.plot(ks, delta_unseen, marker="o", linestyle="None", label="Actual (unseen)")
    plt.plot(ks, pred_scal, linestyle="-", label="Pred Scaling")
    plt.plot(ks, pred_cov, linestyle="-", label="Pred Coverage")
    plt.xscale("log")
    plt.xlabel("k (log scale)")
    plt.ylabel("Worst-case synthetic risk Δ(k)")
    plt.title("Unseen: Actual vs Predicted Δ(k)")
    plt.legend()
    plt.tight_layout()
    fig.savefig(os.path.join(out_dir, "unseen_actual_vs_pred_curve.png"), dpi=200)
    plt.close(fig)

    # ---- scatter
    fig = plt.figure()
    # scaling points
    plt.scatter(pred_scal, delta_unseen, label="Scaling", marker="o")
    # coverage points
    plt.scatter(pred_cov, delta_unseen, label="Coverage", marker="x")
    # y=x reference
    allx = np.array(pred_scal + pred_cov, dtype=np.float64)
    ally = np.array(delta_unseen + delta_unseen, dtype=np.float64)
    lo = float(min(allx.min(), ally.min()))
    hi = float(max(allx.max(), ally.max()))
    plt.plot([lo, hi], [lo, hi], linestyle="--", label="y=x")
    plt.xlabel("Predicted Δ(k)")
    plt.ylabel("Actual Δ(k)")
    plt.title("Unseen: Predicted vs Actual (scatter)")
    plt.legend()
    plt.tight_layout()
    fig.savefig(os.path.join(out_dir, "unseen_pred_vs_actual_scatter.png"), dpi=200)
    plt.close(fig)


# ----------------------------
# Main
# ----------------------------
def parse_int_list(s: str) -> List[int]:
    return [int(x.strip()) for x in s.split(",") if x.strip()]

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, default="CIFAR10")
    p.add_argument("--data_path", type=str, default="./data", help="Dataset root (passed to utils.get_dataset)")
    p.add_argument("--synth_root", type=str, required=True, help="Root folder that contains distilled .pt files")
    p.add_argument("--out_dir", type=str, default="out_extrapolateA_pred")
    p.add_argument("--device", type=str, default="0")
    p.add_argument("--lr_net", type=float, default=0.01)
    p.add_argument("--batch_train", type=int, default=256)
    p.add_argument("--epochs", type=int, default=300)
    p.add_argument("--k_is_ipc", action="store_true", help="Interpret k_list as IPC directly (else k -> IPC via k//num_classes)")
    p.add_argument("--k_list", type=str, default="1,2,4,8,18,28,100,200")
    p.add_argument("--k_pred", type=str, default="6,12,51", help="Extra k values to predict (no training needed)")

    # Hcov args (match bound_hcov)
    p.add_argument("--hcov_batch", type=int, default=256)
    p.add_argument("--hcov_lr", type=float, default=0.01)
    p.add_argument("--hcov_optim", type=str, default="sgd", choices=["sgd", "adam"])
    p.add_argument("--hcov_num_batches", type=int, default=2)
    p.add_argument("--hcov_seeds", type=str, default="0,1,2,3", help="Comma-separated seeds for signature estimation")

    args = p.parse_args()

    # device / seeds
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.hcov_seeds = tuple(parse_int_list(args.hcov_seeds))

    # args required by bound.B.epoch
    args.dsa_param = BH.ParamDiffAug()
    args.dc_aug_param = None

    os.makedirs(args.out_dir, exist_ok=True)

    # dataset (utils.get_dataset returns 9 values in your codebase)
    channel, im_size, num_classes, _, _, _, _, _, testloader = BH.get_dataset(args.dataset, args.data_path)
    args.channel, args.im_size, args.num_classes = channel, im_size, num_classes

    # Validate model families
    all_mfs = SEEN_MFS + UNSEEN_MFS
    for mf in all_mfs:
        _ = BH.get_network_ecology(mf, channel, num_classes, im_size)

    print("[Split] seen =", SEEN_MFS)
    print("[Split] unseen =", UNSEEN_MFS)

    k_list = parse_int_list(args.k_list)
    k_pred = parse_int_list(args.k_pred)

    # Train/eval synth for discrete k_list only
    synth_acc: Dict[str, Dict[int, float]] = {mf: {} for mf in all_mfs}
    valid_k: List[int] = []
    for k in k_list:
        ipc = k if args.k_is_ipc else max(1, k // num_classes)
        args.ipc = ipc

        pt = BH.find_pt_for_ipc(args.synth_root, ipc, args.dataset.upper())
        if pt is None:
            print(f"[WARN] missing pt for k={k} (ipc={ipc}), skip")
            continue
        valid_k.append(k)
        images, labels = BH.load_synth_from_pt(pt)
        labels = sanitize_labels(labels, num_classes)

        for mf in all_mfs:
            print(f"[TrainSynth] k={k} mf={mf}")
            acc = train_eval_on_synth(mf, images, labels, testloader, channel, num_classes, im_size, args)
            synth_acc[mf][k] = acc
            print(f"[Acc] k={k} mf={mf} acc={acc:.4f}")

    if len(valid_k) < 3:
        raise RuntimeError(f"Need at least 3 valid k values to fit a line, got {valid_k}")

    # Reference probe for Hcov: use largest available k from valid_k
    ref_k = max(valid_k)
    ref_ipc = ref_k if args.k_is_ipc else max(1, ref_k // num_classes)
    ref_pt = BH.find_pt_for_ipc(args.synth_root, ref_ipc, args.dataset.upper())
    if ref_pt is None:
        raise RuntimeError("Cannot find reference pt for Hcov estimation.")
    ref_images, ref_labels = BH.load_synth_from_pt(ref_pt)
    ref_labels = sanitize_labels(ref_labels, num_classes)

    H_seen = estimate_hcov_for_subset(SEEN_MFS, ref_images, ref_labels, args)
    H_unseen = estimate_hcov_for_subset(UNSEEN_MFS, ref_images, ref_labels, args)
    print(f"[Hcov] H_seen={H_seen:.4f}, H_unseen={H_unseen:.4f}")

    # Build delta curves for seen/unseen over valid_k
    delta_seen = [worst_delta_from_acc(synth_acc, SEEN_MFS, k) for k in valid_k]
    delta_unseen = [worst_delta_from_acc(synth_acc, UNSEEN_MFS, k) for k in valid_k]

    # Fit laws on seen ONLY
    fit_scal = fit_scaling_law(valid_k, delta_seen)
    fit_cov = fit_coverage_law(valid_k, delta_seen, H_seen)

    print(f"[Fit][Scaling] a={fit_scal.a:.6f} b={fit_scal.b:.6f} rmse={fit_scal.rmse:.6f} r2={fit_scal.r2:.4f}")
    print(f"[Fit][Coverage] a={fit_cov.a:.6f} b={fit_cov.b:.6f} rmse={fit_cov.rmse:.6f} r2={fit_cov.r2:.4f}")

    # Predictions on unseen for valid_k (for plots)
    pred_unseen_scal = [predict_delta_scaling(fit_scal, k) for k in valid_k]
    pred_unseen_cov = [predict_delta_coverage(fit_cov, H_unseen, k) for k in valid_k]

    # Prepare report rows
    rows = []

    def add_row(subset_name: str, k: int, delta_actual: Optional[float],
                delta_pred_scal: float, delta_pred_cov: float, H: float):
        acc_actual = (1.0 - delta_actual) if delta_actual is not None else None
        rows.append({
            "subset": subset_name,
            "k": int(k),
            "H_subset": float(H),
            "delta_actual": None if delta_actual is None else float(delta_actual),
            "acc_actual": None if acc_actual is None else float(acc_actual),
            "delta_pred_scaling": float(delta_pred_scal),
            "acc_pred_scaling": float(1.0 - delta_pred_scal),
            "delta_pred_coverage": float(delta_pred_cov),
            "acc_pred_coverage": float(1.0 - delta_pred_cov),
        })

    # For k in valid_k: include actual + predicted
    for k, dS, dU in zip(valid_k, delta_seen, delta_unseen):
        add_row("seen", k, dS,
                predict_delta_scaling(fit_scal, k),
                predict_delta_coverage(fit_cov, H_seen, k),
                H_seen)
        add_row("unseen", k, dU,
                predict_delta_scaling(fit_scal, k),
                predict_delta_coverage(fit_cov, H_unseen, k),
                H_unseen)

    # For k_pred: predictions only (no actual)
    for k in k_pred:
        add_row("seen_pred_only", k, None,
                predict_delta_scaling(fit_scal, k),
                predict_delta_coverage(fit_cov, H_seen, k),
                H_seen)
        add_row("unseen_pred_only", k, None,
                predict_delta_scaling(fit_scal, k),
                predict_delta_coverage(fit_cov, H_unseen, k),
                H_unseen)

    # Save CSV + meta
    csv_path = os.path.join(args.out_dir, "extrapolation_pred_report.csv")
    with open(csv_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        w.writeheader()
        w.writerows(rows)

    meta = {
        "seen_mfs": SEEN_MFS,
        "unseen_mfs": UNSEEN_MFS,
        "valid_k": valid_k,
        "k_pred": k_pred,
        "H_seen": H_seen,
        "H_unseen": H_unseen,
        "fit_scaling": fit_scal.__dict__,
        "fit_coverage": fit_cov.__dict__,
        "delta_seen": delta_seen,
        "delta_unseen": delta_unseen,
        "note": "Delta = worst-case synthetic risk = max_mf(1-acc_synth). Fit on seen only.",
    }
    with open(os.path.join(args.out_dir, "extrapolation_meta.json"), "w") as f:
        json.dump(meta, f, indent=2)

    # Plots (unseen only)
    plot_unseen_curve_and_scatter(
        out_dir=args.out_dir,
        valid_k=valid_k,
        delta_unseen=delta_unseen,
        pred_scal=pred_unseen_scal,
        pred_cov=pred_unseen_cov,
    )

    print(f"[Done] wrote: {csv_path}")
    print(f"[Done] wrote: {os.path.join(args.out_dir, 'extrapolation_meta.json')}")
    print(f"[Done] wrote: {os.path.join(args.out_dir, 'unseen_actual_vs_pred_curve.png')}")
    print(f"[Done] wrote: {os.path.join(args.out_dir, 'unseen_pred_vs_actual_scatter.png')}")

if __name__ == "__main__":
    main()
