# -*- coding: utf-8 -*-
"""
目的：
1) 单一训练生态：严格对齐 DC 的训练生态（SGD+Momentum=0.9+WD=5e-4、DC增广策略开关、
   epoch=300/1000、批大小=256、学习率与中途衰减）
2) 生态覆盖：在 1) 的基础上扩展到多架构/多优化器/增广
"""
import os, re, copy, argparse, time, json
import csv
import numpy as np
import os
os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
def _set_visible_device(dev_arg: str):
    dev_arg = str(dev_arg).strip()
    if ',' in dev_arg:
        # 多卡逗号分隔，例如 "0,1,2,3"
        os.environ["CUDA_VISIBLE_DEVICES"] = dev_arg
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = dev_arg
#_set_visible_device(os.environ.get("DD_ECO_DEVICE", "2")) 

import torch
import torch.multiprocessing as mp
torch.backends.cudnn.benchmark = True
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from utils import (
    get_dataset, get_network, get_daparam, TensorDataset,
    epoch, ParamDiffAug, DiffAugment
)

# ----------------- 工具函数：日志 -----------------
def log(msg):
    ts = time.strftime("[%H:%M:%S]", time.localtime())
    print(f"{ts} {msg}", flush=True)

# ----------------- 选择 pt 文件（支持 *_ipc{n}.pt 和 *_{n}ipc.pt）-----------------
def find_pt_for_ipc(root, ipc, dataset_tag=None):
    """
    root：包含若干 .pt 的目录（或把该目录当作“方法”子目录）
    ipc：每类合成图像张数
    dataset_tag：优先匹配包含该数据集关键字（如 CIFAR10/MNIST）的 .pt
    """
    if not os.path.isdir(root):
        return None
    files = [f for f in os.listdir(root) if f.endswith(".pt")]
    def _score_list(fs):
        exact, approx = [], []
        for fn in fs:
            m1 = re.search(rf"_ipc({ipc})\.pt$", fn)      
            m2 = re.search(rf"_({ipc})ipc\.pt$", fn)       
            if m1 or m2:
                exact.append(fn)
                continue
            a1 = re.search(r"_ipc(\d+)\.pt$", fn)
            a2 = re.search(r"_(\d+)ipc\.pt$", fn)
            cid = int(a1.group(1)) if a1 else (int(a2.group(1)) if a2 else None)
            if cid is not None:
                approx.append((abs(cid - ipc), fn))
        return exact, approx
    # 先按数据集名优先
    prefer = [f for f in files if (dataset_tag and dataset_tag in f)]
    others = [f for f in files if f not in prefer]
    for bucket in [prefer, others]:
        exact, approx = _score_list(bucket)
        if exact:
            exact.sort()
            return os.path.join(root, exact[-1])
        if approx:
            approx.sort()
            return os.path.join(root, approx[0][1])
    return None

# ----------------- 载入合成数据-----------------
def load_synth_from_pt(pt_path, exp_idx=-1):
    """
    兼容两类 .pt 结构：
    1) DC 格式（原有）：{'data': [[images, labels], [images, labels], ...], ...}
       - 默认取最后一个 exp（exp_idx=-1）
    2) MTT 格式（新增）：{'images': Tensor(K,C,H,W), 'labels': Tensor(K,), 'dataset':..., ...}
       - 直接读取 'images'/'labels'

    返回: (images.float(), labels.long())
    """
    ckpt = torch.load(pt_path, map_location='cpu')

    # ---- DC----
    if isinstance(ckpt, dict) and 'data' in ckpt:
        data = ckpt.get('data', None)
        if not data:
            raise RuntimeError(f"[ERR] {pt_path} 不含有效 'data'")
        # 支持负索引
        if exp_idx < 0:
            exp_idx = len(data) + exp_idx
        if exp_idx < 0 or exp_idx >= len(data):
            raise IndexError(f"[ERR] exp_idx={exp_idx} 越界（共有 {len(data)} 个 exp）")
        images, labels = data[exp_idx]

    # ---- MTT:----
    elif isinstance(ckpt, dict) and ('images' in ckpt and 'labels' in ckpt):
        images, labels = ckpt['images'], ckpt['labels']

    # ---- 其他可能命名 ----
    elif isinstance(ckpt, dict) and ('img_syn' in ckpt and 'lab_syn' in ckpt):
        images, labels = ckpt['img_syn'], ckpt['lab_syn']

    else:
        raise RuntimeError(f"[ERR] 不支持的 .pt 结构：{pt_path} 的键为 {list(ckpt.keys())[:10]}")

    # 统一到 CPU / dtype，并处理标签形状
    if torch.is_tensor(images):
        images = images.detach().cpu().float()
    else:
        images = torch.tensor(images).float()

    if torch.is_tensor(labels):
        labels = labels.detach().cpu().long().view(-1)
    else:
        labels = torch.tensor(labels).long().view(-1)

    return images, labels

# ----------------- 与 DC 完全一致的单一生态评估 -----------------
def evaluate_single_ecology(images_train, labels_train, testloader, args):
    """
    关键点：
    - 优化器：SGD(lr=args.lr_net, momentum=0.9, weight_decay=5e-4)
    - epoch：若开启 DC/DSA 增广 => 1000，否则 300
    - 批大小：args.batch_train
    - 中途 lr 衰减：在 epoch//2+1 时 *0.1（与 DC evaluate_synset 一致）
    - 增广：开启则 epoch=1000
    """
    # === 构造网络 ===
    net = get_network(args.model, args.channel, args.num_classes, args.im_size).to(args.device)
    net.train()

    # === 选择增广策略 ===
    dc_aug_param = get_daparam(args.dataset, args.model, args.model, args.ipc)  # DC 里就是这么做的
    use_aug = (dc_aug_param is not None) and (dc_aug_param.get('strategy', 'none') != 'none')
    Epoch = 1000 if use_aug else 300

    # === 优化器设置===
    lr = float(args.lr_net)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss().to(args.device)
    lr_schedule = [Epoch // 2 + 1]

    # === 组装 DataLoader ===
    dst_train = TensorDataset(images_train.to(args.device), labels_train.to(args.device))
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train,
                                              shuffle=True, num_workers=0, pin_memory=False)

    # === 训练 ===
    t0 = time.time()
    for ep in range(Epoch + 1):
        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug=use_aug)
        if ep in lr_schedule:
            lr *= 0.1
            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        if ep % 50 == 0 or ep == Epoch:
            log(f"[S-DC] epoch={ep:04d}/{Epoch} loss={loss_train:.4f} acc_train={acc_train:.4f}")

    # === 测试 ===
    loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False)
    log(f"[S-DC] 完成：train_time={int(time.time()-t0)}s, test_acc={acc_test:.4f}")
    return acc_test

# ----------------- 生态覆盖评估（多架构×多优化器×增广）-----------------
def evaluate_coverage(images_train, labels_train, testloader, args, cov_arch, cov_optim, cov_aug):
    """
    在单一生态基础上扩展：
      - 架构：如 ['ConvNet','ResNet18', ...]
      - 优化器：['sgd','adam']
      - 增广：['none','dc']
    训练循环仍与 DC 一致，只是把优化器/增广替换掉。
    """
    results = {}  
    for arch in cov_arch:
        for optim in cov_optim:
            for aug in cov_aug:
                log(f"[COV] 评估组合: arch={arch} | optim={optim} | aug={aug}")
                net = get_network(arch, args.channel, args.num_classes, args.im_size).to(args.device)
                net.train()

                # 选择增广
                if aug == 'dc':
                    dc_aug_param = get_daparam(args.dataset, arch, arch, args.ipc)
                    use_aug = (dc_aug_param is not None) and (dc_aug_param.get('strategy','none')!='none')
                else:
                    dc_aug_param = {'strategy': 'none'}
                    use_aug = False

                Epoch = 1000 if use_aug else 300
                lr = float(args.lr_net)
                if optim == 'sgd':
                    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
                elif optim == 'adam':
                    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=5e-4)
                else:
                    raise ValueError(f"未知优化器：{optim}")
                criterion = nn.CrossEntropyLoss().to(args.device)
                lr_schedule = [Epoch // 2 + 1]

                dst_train = TensorDataset(images_train.to(args.device), labels_train.to(args.device))
                trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train,
                                                          shuffle=True, num_workers=0, pin_memory=False)

                t0 = time.time()
                for ep in range(Epoch + 1):
                    loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug=use_aug)
                    if ep in lr_schedule:
                        lr *= 0.1
                        if optim == 'sgd':
                            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
                        else:
                            optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=5e-4)
                    if ep % 50 == 0 or ep == Epoch:
                        log(f"[COV][{arch}|{optim}|{aug}] epoch={ep:04d}/{Epoch} loss={loss_train:.4f} acc={acc_train:.4f}")

                loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False)
                log(f"[COV][{arch}|{optim}|{aug}] 完成：time={int(time.time()-t0)}s, test_acc={acc_test:.4f}")
                results[(arch, optim, aug)] = float(acc_test)
    return results

# ----------------- 画图：k-曲线 -----------------


def _linfit_numpy(x, y):
    """ 最小二乘线性拟合 y = a*x + b，返回 (a, b, R^2)；x,y 为 1D numpy 数组 """
    x = np.asarray(x).reshape(-1)
    y = np.asarray(y).reshape(-1)
    m = np.isfinite(x) & np.isfinite(y)
    x, y = x[m], y[m]
    if len(x) < 2:
        return 0.0, float(np.mean(y) if len(y) else 0.0), 0.0
    X = np.vstack([x, np.ones_like(x)]).T
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    a, b = beta[0], beta[1]
    y_hat = a * x + b
    ss_res = np.sum((y - y_hat) ** 2)
    ss_tot = np.sum((y - np.mean(y)) ** 2) + 1e-12
    r2 = 1.0 - ss_res / ss_tot
    return float(a), float(b), float(r2)

def plot_single_inv_sqrtk_with_fit(k_list, y_list, title, out_png,
                                   ylabel="Test Acc",
                                   scatter_label="实验散点",
                                   line_label="线性拟合（y=ax+b）",
                                   dpi=200):
    """
    将单一生态的曲线按 1/sqrt(k) 绘制，并叠加线性拟合直线。
      - k_list: list[int] 或 ndarray
      - y_list: 与 k 对应的指标（例如 Test Acc 或 Δrisk）
    """
    ks = np.asarray(k_list, dtype=float)
    ys = np.asarray(y_list, dtype=float)
    xs = 1.0 / np.sqrt(np.maximum(ks, 1e-12))

    # 线性拟合
    a, b, r2 = _linfit_numpy(xs, ys)
    y_fit = a * xs + b

    # 绘图
    os.makedirs(os.path.dirname(out_png) if os.path.dirname(out_png) else ".", exist_ok=True)
    plt.figure(figsize=(6.0, 4.2))
    plt.scatter(xs, ys, label=scatter_label)
    plt.plot(xs, y_fit, linestyle='--', label=f"{line_label}\na={a:.4f}, b={b:.4f}, R²={r2:.4f}")
    plt.xlabel('1/√k')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True, ls='--', alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=dpi)
    plt.close()
    log(f"[FIG] 保存曲线到 {out_png}")
    log(f"[FIT] y = a*x + b   a={a:.6f}, b={b:.6f}, R^2={r2:.6f}")

def plot_k_curve(k_list, acc_dict, title, out_png):
    """
    acc_dict: {legend -> [acc_k1, acc_k2, ...]}
    """
    plt.figure(figsize=(6.0,4.2))
    for name, accs in acc_dict.items():
        plt.plot(k_list, accs, marker='o', label=name)
    plt.xscale('log')
    plt.xlabel('k (每类样本数 / IPC)')
    plt.ylabel('Test Acc')
    plt.title(title)
    plt.grid(True, ls='--', alpha=0.5)
    plt.legend()
    os.makedirs(os.path.dirname(out_png), exist_ok=True)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    log(f"[FIG] 保存曲线到 {out_png}")

def _linfit_numpy(x, y):
    x = np.asarray(x).reshape(-1)
    y = np.asarray(y).reshape(-1)
    m = np.isfinite(x) & np.isfinite(y)
    x, y = x[m], y[m]
    if len(x) < 2:
        return 0.0, float(np.mean(y) if len(y) else 0.0), 0.0
    X = np.vstack([x, np.ones_like(x)]).T
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    a, b = float(beta[0]), float(beta[1])
    y_hat = a * x + b
    ss_res = float(np.sum((y - y_hat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2) + 1e-12)
    r2 = 1.0 - ss_res / ss_tot
    return a, b, r2

def train_real_baseline_for_ecology(arch, optim, aug, args, testloader):
    """
    对“单个生态(arch|optim|aug)”在真实训练集上训练一次，返回 test_acc。
    训练生态与 evaluate_coverage 中保持一致（epoch/调度/增广/优化器）。
    """
    # 取增广
    if aug == 'dc':
        dc_aug_param = get_daparam(args.dataset, arch, arch, args.ipc)
        use_aug = (dc_aug_param is not None) and (dc_aug_param.get('strategy','none')!='none')
    else:
        dc_aug_param = {'strategy': 'none'}
        use_aug = False
    Epoch = 1000 if use_aug else 300

    # 模型 & 优化器
    net = get_network(arch, args.channel, args.num_classes, args.im_size).to(args.device)
    lr = float(args.lr_net)
    if optim == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    elif optim == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=5e-4)
    else:
        raise ValueError(f"未知优化器：{optim}")
    criterion = nn.CrossEntropyLoss().to(args.device)
    lr_schedule = [Epoch // 2 + 1]
    _, _, _, _, _, _, dst_train, _, _ = get_dataset(args.dataset, data_path='./data')
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train,
                                              shuffle=True, num_workers=0, pin_memory=False)
    t0 = time.time()
    for ep in range(Epoch + 1):
        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug=use_aug)
        if ep in lr_schedule:
            lr *= 0.1
            if optim == 'sgd':
                optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
            else:
                optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=5e-4)
        if ep % 10 == 0 or ep == Epoch:
            log(f"[REAL][{arch}|{optim}|{aug}] ep={ep:04d}/{Epoch} loss={loss_train:.4f} acc={acc_train:.4f}")
    _, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug=False)
    log(f"[REAL][{arch}|{optim}|{aug}] test_acc={acc_test:.4f}, time={int(time.time()-t0)}s")
    return float(acc_test)

def ecology_distance(name_a, name_b, w=(1., 0.5, 0.25)):
    arch_a, opt_a, aug_a = name_a.split('|')
    arch_b, opt_b, aug_b = name_b.split('|')
    return (w[0]*(arch_a!=arch_b) + w[1]*(opt_a!=opt_b) + w[2]*(aug_a!=aug_b))

def packing_number(legends, rho=1.0, w=(1., 0.5, 0.25)):
    S = []
    for name in legends:  # 贪心构造一个 ρ-分离集
        if all(ecology_distance(name, s, w) >= rho for s in S):
            S.append(name)
    return len(S)

def build_coverage_points(cov_curves, real_baselines, k_list,
                          subset_mode="prefix", subset_trials=8, rng_seed=0):
    """
    将多生态曲线聚合为覆盖律散点。
    输入：
      - cov_curves: {legend -> [acc_syn(k1), acc_syn(k2), ...]}
      - real_baselines: {legend -> acc_real}（真实集基线 | 同生态）
      - k_list: 与每条曲线对齐的 k 序列（长度 L）
      - subset_mode: 'prefix' | 'random'
      - subset_trials: random 模式下每个 m 的采样次数（求均值降方差）
    输出：
      - X, Y 数组：X = sqrt(log m)/sqrt(k)，Y = 平均 Δrisk (acc_real - acc_syn)
    说明：
      - 先把 legend 排序，保证 prefix 模式可复现。
      - 对每个覆盖大小 m=1..N，每个 k_j，聚合 m 条生态的 Δrisk 平均，形成一个点。
    """
    rng = np.random.default_rng(rng_seed)
    legends = sorted(cov_curves.keys())
    N = len(legends)
    L = min(len(v) for v in cov_curves.values())  # 对齐到最短曲线长度
    ks = np.asarray(k_list[:L], dtype=float)
    Xs, Ys = [], []
    for m in range(1, N+1):
        H = np.log(m + 1e-12)  
        if subset_mode == "prefix":
            chosen = legends[:m]
            #M = packing_number(chosen, rho=1.0)  
            #H = np.log(M + 1e-12)
            # 计算 Δrisk(k) 的均值
            delta_mat = []
            for name in chosen:
                acc_real = real_baselines[name]
                acc_syns = np.asarray(cov_curves[name][:L], dtype=float)
                delta_mat.append(acc_real - acc_syns)
            delta_mat = np.stack(delta_mat, axis=0)  # [m, L]
            delta_mean = delta_mat.mean(0)           # [L]
            Xs.extend( (np.sqrt(H) / np.sqrt(ks)).tolist() )
            Ys.extend( delta_mean.tolist() )
        else:  # random 子集均值
            delta_accum = np.zeros(L, dtype=float)
            for _ in range(subset_trials):
                chosen = list(rng.choice(legends, size=m, replace=False))
                #M = packing_number(chosen, rho=1.0)  # 或者据经验设 rho
                #H = np.log(M + 1e-12)
                cur = []
                for name in chosen:
                    acc_real = real_baselines[name]
                    acc_syns = np.asarray(cov_curves[name][:L], dtype=float)
                    cur.append(acc_real - acc_syns)
                cur = np.stack(cur, axis=0).mean(0)
                delta_accum += cur
            delta_mean = delta_accum / subset_trials
            Xs.extend( (np.sqrt(H) / np.sqrt(ks)).tolist() )
            Ys.extend( delta_mean.tolist() )
    return np.asarray(Xs), np.asarray(Ys)

def plot_coverage_law(X, Y, title, out_png, ylabel="Δrisk (acc_real - acc_syn)", dpi=220):
    """ 覆盖律总图：散点 + 全局线性拟合（验证 Δ ~ sqrt(H)/sqrt(k)） """
    a, b, r2 = _linfit_numpy(X, Y)
    y_fit = a * X + b
    os.makedirs(os.path.dirname(out_png) if os.path.dirname(out_png) else ".", exist_ok=True)
    plt.figure(figsize=(6.4, 4.6))
    plt.scatter(X, Y, s=18, label="实验散点")
    plt.plot(X, y_fit, 'r--', label=f"线性拟合: a={a:.4f}, b={b:.4f}, R²={r2:.4f}")
    plt.xlabel("√H / √k   （H≈log(覆盖生态数)）")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True, ls='--', alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=dpi)
    plt.close()
    log(f"[FIG] 覆盖律图保存：{out_png}")
    log(f"[FIT] y = a*x + b   a={a:.6f}, b={b:.6f}, R^2={r2:.6f}")

def _ensure_dir(d: str):
    os.makedirs(d, exist_ok=True)

def fit_linear(x: np.ndarray, y: np.ndarray):
    """y ~ a*x + b，返回 a,b, R^2"""
    a, b = np.polyfit(x, y, deg=1)
    y_hat = a * x + b
    ss_res = float(np.sum((y - y_hat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2)) if len(y) > 1 else 0.0
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 1.0
    return a, b, r2

def save_csv(path: str, header: list[str], rows: list[tuple]):
    _ensure_dir(os.path.dirname(path))
    with open(path, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(header)
        for r in rows:
            writer.writerow(r)

def plot_single_inv_sqrtk_with_fit_err(
    k_list: list[int],
    acc_mean: list[float],
    acc_std: list[float],
    *,
    title: str,
    out_png: str,
    ylabel: str = "Test Acc",
    scatter_label: str = "mean ± std",
    line_label: str = "Linear fit (y = a·x + b)"
):
    """单一生态：画 1/sqrt(k)–Acc 的误差棒 + 线性拟合 """
    _ensure_dir(os.path.dirname(out_png))
    k_arr = np.array(k_list, dtype=float)
    x = 1.0 / np.sqrt(k_arr)
    y = np.array(acc_mean, dtype=float)
    yerr = np.array(acc_std, dtype=float)

    a, b, r2 = fit_linear(x, y)
    x_line = np.linspace(x.min(), x.max(), 200)
    y_line = a * x_line + b

    plt.figure()
    plt.errorbar(x, y, yerr=yerr, fmt='o', label=scatter_label)
    plt.plot(x_line, y_line, label=f"{line_label}; a={a:.4f}, b={b:.4f}, R²={r2:.4f}")
    plt.xlabel("1 / √k")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()

    rows = []
    for ki, xi, yi, si in zip(k_list, x, y, yerr):
        rows.append((ki, xi, yi, si))
    save_csv(
        os.path.join(os.path.dirname(out_png), "single_ecology_points.csv"),
        header=["k", "inv_sqrt_k", "acc_mean", "acc_std"],
        rows=rows
    )
    save_csv(
        os.path.join(os.path.dirname(out_png), "single_ecology_fit.csv"),
        header=["a", "b", "r2", "n_points"],
        rows=[(a, b, r2, len(k_list))]
    )

def plot_coverage_law_with_dump(
    X: np.ndarray, Y: np.ndarray, *, title: str, out_png: str, ylabel: str = "Δrisk (acc_real - acc_syn)"
):
    """覆盖律：画 Δrisk–√H/√k """
    _ensure_dir(os.path.dirname(out_png))
    X = np.array(X, dtype=float)
    Y = np.array(Y, dtype=float)

    a, b, r2 = fit_linear(X, Y)
    x_line = np.linspace(X.min(), X.max(), 200)
    y_line = a * x_line + b

    plt.figure()
    plt.plot(X, Y, 'o', label="Ecology subsets")
    plt.plot(x_line, y_line, label=f"Linear fit; a={a:.4f}, b={b:.4f}, R²={r2:.4f}")
    plt.xlabel("√H / √k")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()

    save_csv(
        os.path.join(os.path.dirname(out_png), "coverage_points.csv"),
        header=["sqrtH_over_sqrtk", "delta_risk"],
        rows=list(zip(X.tolist(), Y.tolist()))
    )
    save_csv(
        os.path.join(os.path.dirname(out_png), "coverage_fit.csv"),
        header=["a", "b", "r2", "n_points"],
        rows=[(a, b, r2, len(X))]
    )

# ----------------- 主流程 -----------------
def main():
    import torch, torch.multiprocessing as mp
    try:
        mp.set_start_method("spawn", force=True)   
    except RuntimeError:
        pass

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        torch.cuda.init()  
        device = torch.device("cuda:0")  
        print("[CUDA] init ok on", torch.cuda.get_device_name(device))
    else:
        device = torch.device("cpu")
        print("[WARN] CUDA 不可用，先在 CPU 上调试。")
    p = argparse.ArgumentParser()
    # 数据/模型/路径
    p.add_argument('--dataset', type=str, default='CIFAR10')
    p.add_argument('--model', type=str, default='ConvNet')        # 单一生态默认架构
    p.add_argument('--synth_root', type=str, required=True)       # 含 *.pt 的目录
    p.add_argument('--k_list', type=str, required=True)           # 例如 "1,2,4,6,8,12,18,28,51,100,200"
    p.add_argument('--k_is_ipc', action='store_true', default=False,
                   help='若传入的 k 就是 IPC（每类张数），打开该开关')
    p.add_argument('--out_dir', type=str, default='figs')
    p.add_argument('--device', type=str, default='0')
    p.add_argument('--repeats_single', type=int, default=5, help='单一生态重复评测次数求 mean/std')

    # 生态里的关键超参
    p.add_argument('--lr_net', type=float, default=0.01)
    p.add_argument('--batch_train', type=int, default=256)
    p.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    p.add_argument('--method', type=str, default='DC', help='DC/DSA')

    # 生态覆盖开关与集合
    p.add_argument('--coverage', action='store_true', default=False)
    p.add_argument('--cov_arch', type=str, default='ConvNet,ResNet18')
    p.add_argument('--cov_optim', type=str, default='sgd,adam')
    p.add_argument('--cov_aug', type=str, default='none,dc')
    p.add_argument('--cov_subset_mode', type=str, default='prefix', choices=['prefix','random'])
    p.add_argument('--cov_subset_trials', type=int, default=8)

    
    args = p.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.method == 'DSA' else False
    args.dc_aug_param = None
    subset_mode = args.cov_subset_mode
    subset_trials = args.cov_subset_trials  

    # 载入真实数据集
    log(f"[LOAD] dataset={args.dataset}")
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = \
        get_dataset(args.dataset, data_path='./data')
    args.channel, args.im_size, args.num_classes = channel, im_size, num_classes
    args.dataset = args.dataset  
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.backends.cudnn.benchmark = True

    # 解析 k 列表
    k_list = [int(x) for x in args.k_list.split(',')]
    dataset_tag = args.dataset.upper()

    # 结果累积
    single_curve = []       # 单一生态曲线
    cov_curves = {}         # 生态覆盖曲线：legend->list

    for k in k_list:
        ipc = k if args.k_is_ipc else max(1, k // num_classes)
        log(f"[K] 处理 k={k} (解释为 IPC={ipc})")

        # 寻找 pt 文件
        pt_path = find_pt_for_ipc(args.synth_root, ipc, dataset_tag=dataset_tag)
        if pt_path is None:
            log(f"[WARN] 未找到与 IPC≈{ipc} 对应的 .pt，跳过该点")
            continue
        log(f"[LOAD] 载入合成数据: {pt_path}")
        images_syn, labels_syn = load_synth_from_pt(pt_path, exp_idx=-1)  

        # === 单一训练生态 ===
        args.ipc = ipc
        acc_runs = []
        for r in range(getattr(args, "repeats_single", 5)):
            acc_r = evaluate_single_ecology(images_syn, labels_syn, testloader, args)
            acc_runs.append(acc_r)

        # 聚合
        acc_mean = float(np.mean(acc_runs))
        acc_std  = float(np.std(acc_runs, ddof=1)) if len(acc_runs) > 1 else 0.0
        single_curve.append((acc_mean, acc_std))

        save_csv(
            os.path.join(args.out_dir, "single_ecology_runs.csv"),
            header=["k", "ipc", "run_idx", "acc"],
            rows=[(k, ipc, i, v) for i, v in enumerate(acc_runs)]
        )

        # === 生态覆盖 ===
        if args.coverage:
            arch = [x for x in args.cov_arch.split(',') if x]
            optim = [x for x in args.cov_optim.split(',') if x]
            aug = [x for x in args.cov_aug.split(',') if x]
            cov_res = evaluate_coverage(images_syn, labels_syn, testloader, args, arch, optim, aug)
            # 追加到曲线缓存
            for key, acc in cov_res.items():
                legend = f"{key[0]}|{key[1]}|{key[2]}"
                cov_curves.setdefault(legend, []).append(acc)

    # ---- 画图 ----
    os.makedirs(args.out_dir, exist_ok=True)
    if single_curve:
        used_k = k_list[:len(single_curve)]
        acc_mean = [m for (m, s) in single_curve]
        acc_std  = [s for (m, s) in single_curve]

        plot_single_inv_sqrtk_with_fit_err(
            used_k,
            acc_mean,
            acc_std,
            title=f"{args.dataset}: Single Training Ecosystem (DC) — Acc vs 1/√k",
            out_png=os.path.join(args.out_dir, f"{args.dataset}_single_inv_sqrtk.png"),
            ylabel="Test Acc",
            scatter_label="mean ± std",
            line_label="Linear fit (y = a·x + b)",
        )
    if args.coverage and cov_curves:
        L = min(len(v) for v in cov_curves.values())
        used_k = k_list[:L]
        trimmed = {name: vals[:L] for name, vals in cov_curves.items()}

        # === 1) 为每个生态训练一次“真实集基线” ===
        if args.dataset=='CIFAR10':
            real_baselines = {
                "ConvNet|adam|none": 0.7696,
                "ConvNet|adam|dc": 0.7687,
                "ConvNet|sgd|none": 0.8145,
                "ConvNet|sgd|dc": 0.8143,
                "LeNet|adam|none": 0.6424,
                "LeNet|adam|dc": 0.6355,
                "LeNet|sgd|none": 0.6312,
                "LeNet|sgd|dc": 0.6295,
                "ResNet18|adam|none": 0.7810,
                "ResNet18|adam|dc": 0.7809,
                "ResNet18|sgd|none": 0.8198,
                "ResNet18|sgd|dc": 0.8256,
                "AlexNet|adam|dc": 0.5623,
                "AlexNet|adam|none": 0.7321,
                "AlexNet|sgd|dc": 0.7989,
                "AlexNet|sgd|none": 0.7985,
                "MLP|adam|dc": 0.4901,
                "MLP|adam|none": 0.4923,
                "MLP|sgd|dc": 0.5103,
                "MLP|sgd|none": 0.5143,
                "VGG11|adam|dc": 0.7523,
                "VGG11|adam|none": 0.7488,
                "VGG11|sgd|dc": 0.8099,
                "VGG11|sgd|none": 0.8168
            }
        elif args.dataset=='MNIST':
            real_baselines = {
                "ConvNet|adam|none": 0.9923,
                "ConvNet|adam|dc": 0.9916,
                "ConvNet|sgd|none": 0.9958,
                "ConvNet|sgd|dc": 0.9958,
                "LeNet|adam|none": 0.9897,
                "LeNet|adam|dc": 0.9912,
                "LeNet|sgd|none": 0.9920,
                "LeNet|sgd|dc": 0.9919,
                "ResNet18|adam|none": 0.9914,
                "ResNet18|adam|dc": 0.9892,
                "ResNet18|sgd|none": 0.9959,
                "ResNet18|sgd|dc": 0.9959,
                "AlexNet|adam|dc": 0.9917,
                "AlexNet|adam|none": 0.9901,
                "AlexNet|sgd|dc": 0.9912,
                "AlexNet|sgd|none": 0.9923,
                "MLP|adam|dc": 0.9751,
                "MLP|adam|none": 0.9784,
                "MLP|sgd|dc": 0.9833,
                "MLP|sgd|none": 0.9831
            }
        elif args.dataset=='CIFAR100':
            real_baselines = {
                "AlexNet|sgd|none": 0.4211,
                "ConvNet|adam|none": 0.4425,
                "ConvNet|sgd|none": 0.5235,
                "LeNet|adam|none": 0.3040,
                "LeNet|sgd|none": 0.2510,
                "ResNet18|adam|none": 0.4221,
                "ResNet18|sgd|none": 0.5026
            }
        else:
            real_baselines = {}
        
        baseline_path = args.dataset+"_real_baselines.json"
        baseline_path = os.path.join(args.out_dir, baseline_path)

        # 加载已有的
        #if os.path.exists(baseline_path):
        #    with open(baseline_path, "r") as f:
        #        real_baselines = json.load(f)

        # 训练缺失的
        for name in sorted(trimmed.keys()):
            if name not in real_baselines:
                arch, optim, aug = name.split('|')
                acc_real = train_real_baseline_for_ecology(arch, optim, aug, args, testloader)
                real_baselines[name] = acc_real

        # 保存更新后的
        with open(baseline_path, "w") as f:
            json.dump(real_baselines, f, indent=2)
        
        #log("[COV] 训练真实基线（每生态一次）...")
        #for name in sorted(trimmed.keys()):
        #    arch, optim, aug = name.split('|')
        #    acc_real = train_real_baseline_for_ecology(arch, optim, aug, args, testloader)
        #    real_baselines[name] = acc_real

        # === 2) 构造 (X,Y) 点：X = sqrt(log m)/sqrt(k), Y = 平均 Δrisk ===
        #subset_mode = "prefix"     
        #subset_trials = args.cov_subset_trials if hasattr(args, 'cov_subset_trials') else 8
        X, Y = build_coverage_points(trimmed, real_baselines, used_k,
                                    subset_mode=subset_mode, subset_trials=subset_trials, rng_seed=0)

        # === 3) 覆盖律总图：Δrisk vs √H/√k（散点+线性拟合） ===
        plot_coverage_law_with_dump(
            X, Y,
            title=f"{args.dataset}: Coverage Law Δ ≈ C·√H/√k",
            out_png=os.path.join(args.out_dir, f"{args.dataset}_coverage_law.png"),
            ylabel="Δrisk (acc_real - acc_syn)"
        )

        cov_rows = []
        for name, vals in trimmed.items():
            arch, optim, aug = name.split('|')
            acc_real = real_baselines[name]
            for ki, acc_syn in zip(used_k, vals):
                cov_rows.append((name, arch, optim, aug, ki, acc_syn, acc_real, acc_real - acc_syn))
        save_csv(
            os.path.join(args.out_dir, "coverage_ecology_detail.csv"),
            header=["legend", "arch", "optim", "aug", "k", "acc_syn", "acc_real", "delta_risk"],
            rows=cov_rows
        )
    log("全部完成。")

if __name__ == "__main__":
    main()
