# -*- coding: utf-8 -*-
"""
bound_hcov.py (FINAL, STRICT ECOLOGY VERSION)

- Training ecology == Coverage ecology
- Ecology space strictly restricted to existing model families
- No illegal combinations, no kwargs abuse
- Output strictly matches bound.py: coverage_points.csv
"""

import os
import argparse
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# ---------------- CUDA before torch ----------------
def _set_visible_device(dev):
    if dev is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(dev)

_p = argparse.ArgumentParser(add_help=False)
_p.add_argument("--device", type=str, default="0")
_tmp, _ = _p.parse_known_args()
_set_visible_device(_tmp.device)
# --------------------------------------------------

import bound as B
from utils import get_dataset, TensorDataset, get_daparam, ParamDiffAug, get_network

log = B.log
find_pt_for_ipc = B.find_pt_for_ipc
load_synth_from_pt = B.load_synth_from_pt
save_csv = B.save_csv

# ==================================================
# Supported ecology space (STRICT)
# ==================================================

SUPPORTED_MODEL_FAMILIES = {
    # base
    "ConvNet",

    # depth
    "ConvNetD1", "ConvNetD2", "ConvNetD3", "ConvNetD4",

    # width
    "ConvNetW32", "ConvNetW64", "ConvNetW128", "ConvNetW256",

    # activation
    "ConvNetAS", "ConvNetAR", "ConvNetAL", "ConvNetASwish",

    # normalization
    "ConvNetNN", "ConvNetBN", "ConvNetLN", "ConvNetIN", "ConvNetGN",

    # special combo
    "ConvNetASwishBN",
}


def get_network_ecology(model_family, channel, num_classes, im_size):
    """
    Strict ecology-aware network constructor.
    """
    if model_family not in SUPPORTED_MODEL_FAMILIES:
        raise ValueError(
            f"[ECOLOGY ERROR] Unsupported model family: {model_family}\n"
            f"Supported families:\n{sorted(SUPPORTED_MODEL_FAMILIES)}"
        )
    return get_network(model_family, channel, num_classes, im_size)


# ==================================================
# Update-metric utilities (for Hcov)
# ==================================================

def _flatten_grads(net):
    gs = [p.grad.view(-1) for p in net.parameters() if p.grad is not None]
    if not gs:
        return torch.zeros(1, device=next(net.parameters()).device)
    return torch.cat(gs)


def _one_step_update_(net, grads, lr, optim_name: str, eps: float = 1e-8):
    """
    In-place apply one step update to net parameters using grads.
    - SGD:   p <- p - lr * g
    - Adam:  p <- p - lr * g / (sqrt(g^2) + eps)   (1st step, m=0,v=0)
    """
    with torch.no_grad():
        for p, g in zip([p for p in net.parameters() if p.requires_grad], grads):
            if g is None:
                continue
            if optim_name == "sgd":
                p.add_(g, alpha=-lr)
            elif optim_name == "adam":
                denom = g.pow(2).sqrt().add_(eps)
                p.addcdiv_(g, denom, value=-lr)
            else:
                raise ValueError(f"Unknown optim for Hcov: {optim_name}")

def estimate_update_signature(
    model_family, images, labels, args, seeds, num_batches
):
    device = args.device
    criterion = nn.CrossEntropyLoss().to(device)

    # 固定 probe batch 大小（保证维度一致）
    bs = min(args.hcov_batch, images.shape[0])  # 新增 args.hcov_batch
    labels = labels.clone()

    if labels.dtype != torch.long:
        labels = labels.long()

    lab_min, lab_max = labels.min().item(), labels.max().item()
    if lab_min < 0 or lab_max >= args.num_classes:
        if set(torch.unique(labels).tolist()) == {-1, 1}:
            labels = (labels > 0).long()
        else:
            labels = labels - lab_min

    assert labels.min() >= 0 and labels.max() < args.num_classes

    ds = TensorDataset(images.to(device), labels.to(device))
    loader = DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)

    batches = []
    for x, y in loader:
        batches.append((x, y))
        if len(batches) >= num_batches:
            break
    if len(batches) == 0:
        raise RuntimeError("Not enough samples for a full probe batch (drop_last=True).")

    u_samples = []

    for s in seeds:
        torch.manual_seed(s)
        np.random.seed(s)

        net = get_network_ecology(model_family, args.channel, args.num_classes, args.im_size).to(device)
        net.train()

        for x, y in batches:
            # 1) logits before
            logits0 = net(x).detach()

            # 2) compute grads on current params
            net.zero_grad(set_to_none=True)
            logits = net(x)
            loss = criterion(logits, y)
            loss.backward()

            grads = [p.grad for p in net.parameters() if p.requires_grad]

            # 3) do one-step update (in-place)
            _one_step_update_(net, grads, lr=args.hcov_lr, optim_name=args.hcov_optim)

            # 4) logits after
            logits1 = net(x).detach()

            # 5) signature = vec(delta logits), normalize
            df = (logits1 - logits0).reshape(-1)
            df = df / (df.norm() + 1e-8)
            u_samples.append(df.cpu().numpy())

    u_samples = np.stack(u_samples, axis=0)
    return u_samples.mean(axis=0), u_samples



def estimate_hcov(u_means, u_samples, r_c_unused=None):
    """
    Adaptive Hcov:
    - r is estimated from *within-model* update variability (noise radius)
    - cover is computed with a deterministic greedy set cover:
        pick the center that covers the largest number of uncovered points.
    """

    keys = list(u_means.keys())
    M = len(keys)
    if M <= 1:
        return 0.0  # log(1)=0

    # -----------------------
    # 1) Adaptive radius r
    # -----------------------
    # For each model family k:
    #   distances = ||u_sample - u_mean||
    # Choose r_k as a high-quantile of distances so that most update directions
    # fall inside the ball.
    # Quantile is chosen adaptively based on number of samples:
    #   q = 1 - 1/n  (roughly "allow 1 outlier")
    r_list = []
    for k in keys:
        mu = u_means[k]
        S = u_samples[k]  # shape: [n_samples, dim]
        n = S.shape[0]
        if n <= 1:
            continue
        d = np.linalg.norm(S - mu[None, :], axis=1)

        # adaptive quantile: keep all but ~1 sample inside
        q = 1.0 - 1.0 / float(n)
        q = float(np.clip(q, 0.80, 0.99))  # prevent extreme cases when n small
        r_k = float(np.quantile(d, q))
        r_list.append(r_k)

    # fallback if something weird
    if len(r_list) == 0:
        r = 1e-8
    else:
        # robust aggregation across model families
        r = float(np.median(r_list))
        r = max(r, 1e-8)

    # -----------------------
    # 2) Distances between model means
    # -----------------------
    U = np.stack([u_means[k] for k in keys], axis=0)  # [M, dim]
    D = np.linalg.norm(U[:, None, :] - U[None, :, :], axis=-1)  # [M, M]

    # Precompute coverage sets: cover[i] = indices within radius r of i
    cover = []
    for i in range(M):
        cover_i = set(np.where(D[i] <= r)[0].tolist())
        cover.append(cover_i)

    # -----------------------
    # 3) Deterministic greedy set cover
    # -----------------------
    uncovered = set(range(M))
    centers = []

    while uncovered:
        # pick the point that covers the most uncovered points
        best_i = None
        best_gain = -1

        for i in uncovered:
            gain = len(cover[i] & uncovered)
            if gain > best_gain:
                best_gain = gain
                best_i = i

        centers.append(best_i)
        uncovered -= cover[best_i]

    # Use log(#centers) as H
    return math.log(len(centers) + 1e-12)



# ==================================================
# Main
# ==================================================

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, default="MNIST")
    p.add_argument("--synth_root", type=str, required=True)
    p.add_argument("--k_list", type=str, required=True)
    p.add_argument("--k_is_ipc", action="store_true")
    p.add_argument("--out_dir", type=str, default="hcov")
    p.add_argument("--device", type=str, default="0")

    # training
    p.add_argument("--lr_net", type=float, default=0.01)
    p.add_argument("--batch_train", type=int, default=256)

    # ecology
    p.add_argument("--coverage", action="store_true")
    p.add_argument(
        "--cov_model_family",
        type=str,
        required=True,
        help="Comma-separated model families, e.g. ConvNet,ConvNetBN,ConvNetASwish",
    )

    # hcov
    p.add_argument("--hcov_seeds", type=str, default="0")
    p.add_argument("--hcov_num_batches", type=int, default=2)

    p.add_argument("--plot", action="store_true")
    p.add_argument("--hcov_batch", type=int, default=256)
    p.add_argument("--hcov_lr", type=float, default=0.01)
    p.add_argument("--hcov_optim", type=str, default="sgd", choices=["sgd", "adam"])


    args = p.parse_args()
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.dsa_param = ParamDiffAug()
    args.dc_aug_param = None

    # dataset
    channel, im_size, num_classes, _, _, _, _, _, testloader = get_dataset(
        args.dataset, "./data"
    )
    print("args.dataset =", args.dataset)
    print("num_classes =", num_classes)
    args.channel, args.im_size, args.num_classes = channel, im_size, num_classes

    k_list = [int(x) for x in args.k_list.split(",")]
    model_families = [x.strip() for x in args.cov_model_family.split(",")]

    # ---------------- Training + evaluation ----------------
    synth_acc = {mf: {} for mf in model_families} 
    valid_k = [] 
    for k in k_list:
        ipc = k if args.k_is_ipc else max(1, k // num_classes)
        args.ipc = ipc
        log(f"[K] k={k} (IPC={ipc})")

        pt = find_pt_for_ipc(args.synth_root, ipc, args.dataset.upper())
        print(pt)
        if pt is None:
            continue
        valid_k.append(k)
        images, labels = load_synth_from_pt(pt)

        # ===== 🔒 强制 label 合法化（关键）=====
        if labels.dtype != torch.long:
            labels = labels.long()

        # 若 label 是 {-1,1} 或 {1,2} 等，统一映射到 {0,...,C-1}
        lab_min, lab_max = labels.min().item(), labels.max().item()

        if lab_min < 0 or lab_max >= num_classes:
            # 常见情况：{-1,1}
            if set(torch.unique(labels).tolist()) == {-1, 1}:
                labels = (labels > 0).long()
            # 常见情况：{1,2,...,C}
            else:
                labels = labels - lab_min

        # 最终兜底断言
        assert labels.min() >= 0 and labels.max() < num_classes, \
            f"Illegal labels after fix: min={labels.min()}, max={labels.max()}"

        for mf in model_families:
            log(f"[ECO] train {mf}")
            net = get_network_ecology(
                mf, channel, num_classes, im_size
            ).to(args.device)

            optimizer = torch.optim.SGD(
                net.parameters(), lr=args.lr_net, momentum=0.9
            )
            criterion = nn.CrossEntropyLoss().to(args.device)

            ds = TensorDataset(images.to(args.device), labels.to(args.device))
            loader = DataLoader(ds, batch_size=args.batch_train, shuffle=True)

            Epoch = 300
            for ep in range(Epoch + 1):
                B.epoch("train", loader, net, optimizer, criterion, args, aug=False)

            loss_test, acc = B.epoch(
                "test", testloader, net, optimizer, criterion, args, aug=False
            )
            log(f"[ECO][{mf}] test_acc = {acc:.4f}")
            synth_acc[mf][k]=float(acc)

    # ---------------- Hcov ----------------
    ref_k = max(k_list)
    args.ipc = ref_k
    pt = find_pt_for_ipc(args.synth_root, ref_k, args.dataset.upper())
    print(pt)
    ref_images, ref_labels = load_synth_from_pt(pt)

    # ===== 写死 ConvNet 的真实集 baseline（按数据集）=====
    if args.dataset == "MNIST":
        convnet_real_acc = 0.9958
    elif args.dataset == "CIFAR10":
        convnet_real_acc = 0.8145
    elif args.dataset == "CIFAR100":
        convnet_real_acc = 0.5235
    else:
        raise RuntimeError(f"No ConvNet real baseline defined for dataset {args.dataset}")


    seeds = tuple(int(x) for x in args.hcov_seeds.split(","))
    u_means, u_samples = {}, {}

    for mf in model_families:
        um, us = estimate_update_signature(
            mf, ref_images, ref_labels, args, seeds, args.hcov_num_batches
        )
        u_means[mf] = um
        u_samples[mf] = us

    X, Y = [], []

    model_families_sorted = list(model_families)  # prefix 顺序可控
    M = len(model_families_sorted)

    for m in range(1, M + 1):
        # ---- 1) 构造生态子集 A_m ----
        subset = model_families_sorted[:m]

        # ---- 2) 用子集算 H(m) ----
        u_means_m = {k: u_means[k] for k in subset}
        u_samples_m = {k: u_samples[k] for k in subset}
        H_m = estimate_hcov(u_means_m, u_samples_m)


        # ---- 3) 对每个 k，算 Δ(m,k) 并记录点 ----
        for k in valid_k:
            # 这里先按你现有逻辑，用平均 acc
            errs = []
            for mf in subset:
                if k not in synth_acc[mf]:
                    continue
                errs.append(1.0 - synth_acc[mf][k])

            if len(errs) == 0:
                continue

            y_mk = float(max(errs))  # worst-case
            X.append(math.sqrt(H_m) / math.sqrt(k))
            Y.append(y_mk)


    os.makedirs(args.out_dir, exist_ok=True)
    save_csv(
        os.path.join(args.out_dir, "coverage_points.csv"),
        ["sqrtH_over_sqrtk", "delta_risk"],
        list(zip(X, Y)),
    )

    if args.plot:
        from bound import plot_coverage_law_with_dump
        plot_coverage_law_with_dump(
            np.array(X),
            np.array(Y),
            title=f"{args.dataset}: Coverage Law (Hcov)",
            out_png=os.path.join(args.out_dir, "coverage_law.png"),
        )


if __name__ == "__main__":
    main()
