#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Theorem 4.1 geometry checker with:
  - DC: Phi = flattened CE-gradient vector (per subgroup, per class)
  - DM: Phi = subgroup mean embedding mu_g = mean(embed(x)) (per subgroup, per class)

It compares MaxRes_COBRA vs MaxRes_Vanilla (squared L2), and also prints
the theorem condition <Δ^C_{a+|y}, s_y> (<=0 means Thm condition holds).
"""

import argparse
import copy
import os
import random
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn

# Your project utils (must exist)
from utils import (
    get_dataset,
    get_network,
    TensorDataset,
    epoch,
    ParamDiffAug,
    DiffAugment,
)

# ----------------------------
# Reproducibility helpers
# ----------------------------
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def save_random_state() -> Dict[str, Any]:
    state = {
        "torch": torch.get_rng_state(),
        "np": np.random.get_state(),
        "random": random.getstate(),
    }
    if torch.cuda.is_available():
        state["cuda"] = torch.cuda.get_rng_state_all()
    return state

def load_random_state(state: Dict[str, Any]) -> None:
    torch.set_rng_state(state["torch"])
    np.random.set_state(state["np"])
    random.setstate(state["random"])
    if torch.cuda.is_available() and "cuda" in state:
        torch.cuda.set_rng_state_all(state["cuda"])

# ----------------------------
# IO helpers
# ----------------------------
def safe_torch_load(path: str, device: str):
    try:
        return torch.load(path, map_location=device, weights_only=False)
    except TypeError:
        return torch.load(path, map_location=device)

def extract_syn_data(checkpoint: dict):
    """
    Handles:
      checkpoint['data'] = (image_syn, label_syn)
      checkpoint['data'] = [(image_syn, label_syn), ...]
    """
    data = checkpoint.get("data", None)
    if data is None:
        raise KeyError("Checkpoint has no key 'data'.")

    if isinstance(data, (list, tuple)) and len(data) > 0 and isinstance(data[0], (list, tuple)):
        image_syn, label_syn = data[0]
    else:
        image_syn, label_syn = data
    return image_syn, label_syn

# ----------------------------
# Training helpers
# ----------------------------
def make_loader_from_syn(image_syn: torch.Tensor, label_syn: torch.Tensor, batch_size: int = 256):
    dst_syn = TensorDataset(image_syn, label_syn)
    return torch.utils.data.DataLoader(dst_syn, batch_size=batch_size, shuffle=True, num_workers=0)

def make_loader_from_real(images: torch.Tensor, labels: torch.Tensor, batch_size: int = 256):
    dst = TensorDataset(images, labels)
    return torch.utils.data.DataLoader(dst, batch_size=batch_size, shuffle=True, num_workers=0)

def train_on_loader(
    net: nn.Module,
    loader: torch.utils.data.DataLoader,
    args,
    epochs: int,
    lr: float = 0.01,
    weight_decay: float = 0.0,
) -> None:
    criterion = nn.CrossEntropyLoss().to(args.device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay)
    net.train()
    for _ in range(epochs):
        epoch("train", loader, net, optimizer, criterion, args, aug=True)

# ----------------------------
# Theorem 4.1 verification
# ----------------------------
@dataclass
class ClassCheckResult:
    cls: int
    max_res_van: float
    max_res_cobra: float
    condition_dot: float
    a_plus: int
    n_groups: int

def _get_embed_fn(net: nn.Module):
    # Handle DataParallel
    if hasattr(net, "module"):
        if not hasattr(net.module, "embed"):
            raise AttributeError("net.module has no attribute 'embed' (needed for DM phi).")
        return net.module.embed
    else:
        if not hasattr(net, "embed"):
            raise AttributeError("net has no attribute 'embed' (needed for DM phi).")
        return net.embed

def _as_tensor_2d(x: torch.Tensor) -> torch.Tensor:
    # Ensure [B, D]
    if isinstance(x, (tuple, list)):
        x = x[0]
    return x.view(x.size(0), -1)

@torch.enable_grad()
def compute_subgroup_phis_grad(
    net: nn.Module,
    imgs_c: torch.Tensor,
    colors_c: torch.Tensor,
    cls: int,
    criterion: nn.Module,
    args,
    max_samples_per_group: int = 256,
) -> Tuple[Dict[int, torch.Tensor], Dict[int, float]]:
    """
    DC-style phi: flattened parameter gradients of CE loss on subgroup samples.
    """
    net.eval()
    unique_groups = torch.unique(colors_c)

    subgroup_phis: Dict[int, torch.Tensor] = {}
    subgroup_pi: Dict[int, float] = {}

    total_c = int(imgs_c.shape[0])

    for grp in unique_groups:
        g = int(grp.item())
        mask = (colors_c == grp)
        count = int(mask.sum().item())
        if count == 0:
            continue

        subgroup_pi[g] = count / total_c

        sample_size = min(count, max_samples_per_group)
        perm = torch.randperm(count, device=args.device)[:sample_size]
        img_batch = imgs_c[mask][perm]
        lab_batch = torch.full((sample_size,), cls, device=args.device, dtype=torch.long)

        net.zero_grad(set_to_none=True)
        out = net(img_batch)
        loss = criterion(out, lab_batch)

        grads = torch.autograd.grad(
            loss, list(net.parameters()),
            retain_graph=False, create_graph=False, allow_unused=True
        )
        phi_vec = torch.cat([g_.detach().reshape(-1) for g_ in grads if g_ is not None])
        subgroup_phis[g] = phi_vec

    return subgroup_phis, subgroup_pi

@torch.no_grad()
def compute_subgroup_phis_embed_mean(
    net: nn.Module,
    imgs_c: torch.Tensor,
    colors_c: torch.Tensor,
    args,
    max_samples_per_group: int = 256,
) -> Tuple[Dict[int, torch.Tensor], Dict[int, float]]:
    """
    DM-style phi: subgroup mean embedding mu_g = mean(embed(x)) over subgroup samples.
    This mirrors the structure in your DM snippet (group_means in embedding space).
    """
    net.eval()
    embed = _get_embed_fn(net)

    unique_groups = torch.unique(colors_c)

    subgroup_phis: Dict[int, torch.Tensor] = {}
    subgroup_pi: Dict[int, float] = {}

    total_c = int(imgs_c.shape[0])

    for grp in unique_groups:
        g = int(grp.item())
        mask = (colors_c == grp)
        count = int(mask.sum().item())
        if count == 0:
            continue

        subgroup_pi[g] = count / total_c

        sample_size = min(count, max_samples_per_group)
        perm = torch.randperm(count, device=args.device)[:sample_size]
        img_batch = imgs_c[mask][perm]

        # Optional: mimic DM-style DSA usage (like your snippet)
        if getattr(args, "dsa", False):
            seed = int(time.time() * 1000) % 100000
            img_batch = DiffAugment(img_batch, args.dsa_strategy, seed=seed, param=args.dsa_param)

        feat = embed(img_batch)         # [B, D] (or [B, ...])
        feat = _as_tensor_2d(feat)      # [B, D]
        mu_g = feat.mean(dim=0)         # [D]
        subgroup_phis[g] = mu_g.detach()

    return subgroup_phis, subgroup_pi

def check_theorem_4_1(
    net_phi: nn.Module,
    images_all: torch.Tensor,
    labels_all: torch.Tensor,
    color_all: torch.Tensor,
    args,
    max_samples_per_group: int = 256,
    print_per_class: bool = True,
) -> Tuple[Dict[str, int], List[ClassCheckResult]]:
    """
    Empirical check under squared L2 distance:
      - Vanilla target: m_van = sum_a pi_{a|y} phi_{a|y}
      - COBRA target:   m*    = mean_a phi_{a|y}  (uniform barycenter for squared L2)
      - Compare max_a ||phi_{a|y} - m||^2
      - Condition: <Δ^C_{a+|y}, s_y> <= 0 where:
            a+ = argmax_a ||phi_a - m*||^2,
            Δ^C_{a|y} = phi_a - m*,
            s_y = m_van - m*
    """
    print(f"\n[Analysis] Verifying Theorem 4.1 (squared L2) | objective={args.objective}")
    net_phi.eval()
    criterion = nn.CrossEntropyLoss().to(args.device)

    results = {"better": 0, "worse": 0, "same": 0, "cond_holds": 0, "cond_violates": 0}
    logs: List[ClassCheckResult] = []

    num_classes = int(args.num_classes)

    for c in range(num_classes):
        idx = (labels_all == c)
        imgs_c = images_all[idx]
        colors_c = color_all[idx]
        unique_groups = torch.unique(colors_c)

        if unique_groups.numel() < 2:
            continue

        # Compute subgroup phis depending on objective
        if args.objective.upper() == "DM":
            subgroup_phis, subgroup_pi = compute_subgroup_phis_embed_mean(
                net=net_phi,
                imgs_c=imgs_c,
                colors_c=colors_c,
                args=args,
                max_samples_per_group=max_samples_per_group,
            )
        else:  # DC (default)
            subgroup_phis, subgroup_pi = compute_subgroup_phis_grad(
                net=net_phi,
                imgs_c=imgs_c,
                colors_c=colors_c,
                cls=c,
                criterion=criterion,
                args=args,
                max_samples_per_group=max_samples_per_group,
            )

        if len(subgroup_phis) < 2:
            continue

        # Vanilla: weighted mean by pi_{a|y}
        m_van = None
        for g, phi in subgroup_phis.items():
            w = float(subgroup_pi[g])
            m_van = phi * w if m_van is None else (m_van + phi * w)

        # COBRA (uniform barycenter for squared L2): mean over groups
        m_cobra = torch.stack(list(subgroup_phis.values()), dim=0).mean(dim=0)

        # Worst-case residuals (squared L2)
        max_res_van = -1.0
        max_res_cobra = -1.0
        a_plus: Optional[int] = None

        for g, phi in subgroup_phis.items():
            d_van = torch.norm(phi - m_van).pow(2).item()
            d_cob = torch.norm(phi - m_cobra).pow(2).item()

            if d_van > max_res_van:
                max_res_van = d_van
            if d_cob > max_res_cobra:
                max_res_cobra = d_cob
                a_plus = g

        assert a_plus is not None

        # Theorem condition
        delta_c = subgroup_phis[a_plus] - m_cobra
        s_y = m_van - m_cobra
        condition_dot = torch.dot(delta_c, s_y).item()

        if condition_dot <= 0:
            results["cond_holds"] += 1
            cond_str = "HOLDS"
        else:
            results["cond_violates"] += 1
            cond_str = "VIOLATES"

        # Compare residuals
        eps = 1e-12
        if max_res_cobra < max_res_van - eps:
            results["better"] += 1
            verdict = "COBRA WINS"
        elif max_res_cobra > max_res_van + eps:
            results["worse"] += 1
            verdict = "VANILLA WINS"
        else:
            results["same"] += 1
            verdict = "SAME"

        logs.append(
            ClassCheckResult(
                cls=c,
                max_res_van=max_res_van,
                max_res_cobra=max_res_cobra,
                condition_dot=condition_dot,
                a_plus=int(a_plus),
                n_groups=int(unique_groups.numel()),
            )
        )

        if print_per_class:
            print(
                f"  Class {c:02d} | groups={unique_groups.numel()} | "
                f"MaxRes_Van={max_res_van:.6f} | MaxRes_COBRA={max_res_cobra:.6f} | "
                f"<Δc,s_y>={condition_dot:.6e} ({cond_str}) -> {verdict}"
            )

    print("\n--- Summary ---")
    if len(logs) == 0:
        print("No multi-group classes found.")
        return results, logs

    print(f"Classes checked: {len(logs)}")
    print(f"COBRA better:   {results['better']}")
    print(f"COBRA worse:    {results['worse']}")
    print(f"Same:          {results['same']}")
    print(f"Condition holds (Thm applies):        {results['cond_holds']}")
    print(f"Condition violates (Thm doesn't apply): {results['cond_violates']}")
    return results, logs

# ----------------------------
# Main
# ----------------------------
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--data_path", type=str, default="data")
    parser.add_argument("--model", type=str, default="ConvNet")
    parser.add_argument("--seed", type=int, default=42)

    # Which objectives/checkpoints to run
    # parser.add_argument("--methods", nargs="+", default=["DC", "DM"], help="Which objectives to evaluate: DC DM")
    parser.add_argument("--methods", nargs="+", default=["DM"], help="Which objectives to evaluate: DC DM")
    # parser.add_argument("--datasets", nargs="+", default=["BFFHQ", "CIFAR10_S_90"])
    parser.add_argument("--datasets", nargs="+", default=["BFFHQ"])
    # parser.add_argument("--ipcs", nargs="+", type=int, default=[10, 50])
    parser.add_argument("--ipcs", nargs="+", type=int, default=[ 50])
    # parser.add_argument("--fair_crts", nargs="+", default=["-nofair", "-fairdd", "-COBRA"])
    parser.add_argument("--fair_crts", nargs="+", default=[ "-fairdd", "-COBRA"])

    # Checkpoint path pattern
    parser.add_argument("--ckpt_dir", type=str, default="./result")
    parser.add_argument(
        "--ckpt_pattern",
        type=str,
        default="res_{method}_{dataset}_{model}_{ipc}ipc{fair}.pt",
        help="Filename pattern inside ckpt_dir. {fair} should include the leading '-' (e.g., -nofair).",
    )

    # Synthetic training settings
    parser.add_argument("--syn_train_epochs", type=int, default=1000)
    parser.add_argument("--train_lr", type=float, default=0.01)

    # Optional: train a teacher ERM on real data and use it for Phi
    parser.add_argument("--phi_model", type=str, default="student", choices=["student", "teacher"])
    parser.add_argument("--teacher_train_epochs", type=int, default=0)

    # Phi estimation settings
    parser.add_argument("--max_samples_per_group", type=int, default=256)

    # DiffAug settings (used in your repo; and DM-phi uses it if args.dsa=True)
    parser.add_argument("--dsa_strategy", type=str, default="color_crop_cutout_flip_scale_rotate")
    parser.add_argument("--dsa", action="store_true", default=True)

    args = parser.parse_args()
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.dsa_param = ParamDiffAug()

    seed_everything(args.seed)
    base_state = save_random_state()

    for dataset_name in args.datasets:
        # reset RNG per dataset (optional, keeps comparisons fair)
        load_random_state(base_state)

        args.dataset = dataset_name
        channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(
            args.dataset, args.data_path
        )

        # Organize REAL data (expects dst_train[i] = (img, label, group))
        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        color_all  = [dst_train[i][2] for i in range(len(dst_train))]

        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
        color_all  = torch.tensor(color_all, dtype=torch.long, device=args.device)

        args.num_classes = int(torch.unique(labels_all).numel())
        args.num_groups  = int(torch.unique(color_all).numel())

        print("\n==============================")
        print(f"Dataset: {args.dataset} | classes={args.num_classes} | groups={args.num_groups} | im_size={im_size}")
        print("==============================")

        # Optional teacher (trained once per dataset)
        teacher_net = None
        if args.teacher_train_epochs > 0:
            print(f"\n[Teacher] Training ERM teacher on REAL data for {args.teacher_train_epochs} epochs...")
            teacher_net = get_network(args.model, channel, args.num_classes, im_size).to(args.device)
            real_loader = make_loader_from_real(images_all, labels_all, batch_size=256)
            train_on_loader(teacher_net, real_loader, args, epochs=args.teacher_train_epochs, lr=args.train_lr)
            print("[Teacher] Done.\n")

        for method in args.methods:
            method = method.upper()
            if method not in {"DC", "DM"}:
                raise ValueError(f"Unknown method '{method}'. Use DC or DM.")

            args.objective = method  # controls how Phi is computed inside check_theorem_4_1

            for ipc in args.ipcs:
                args.ipc = int(ipc)

                for fair_crt in args.fair_crts:
                    ckpt_name = args.ckpt_pattern.format(
                        method=method, dataset=args.dataset, model=args.model, ipc=args.ipc, fair=fair_crt
                    )
                    ckpt_path = os.path.join(args.ckpt_dir, ckpt_name)

                    print(f"\n--- Running: method={method} dataset={args.dataset} ipc={args.ipc} fair_crt={fair_crt} ---")
                    print(f"Loading checkpoint: {ckpt_path}")

                    checkpoint = safe_torch_load(ckpt_path, args.device)
                    image_syn, label_syn = extract_syn_data(checkpoint)
                    image_syn = image_syn.to(args.device)
                    label_syn = label_syn.to(args.device)

                    # Safety: handle im_size mismatch between dataset config and checkpoint
                    loaded_im_size = (int(image_syn.shape[2]), int(image_syn.shape[3]))
                    if tuple(loaded_im_size) != tuple(im_size):
                        print("\n[WARNING] Image size mismatch!")
                        print(f"  Dataset expects im_size={im_size}, checkpoint has {loaded_im_size}")
                        print("  -> Overriding im_size for network construction.")
                        im_size = loaded_im_size

                    # Train student on synthetic
                    student_net = get_network(args.model, channel, args.num_classes, im_size).to(args.device)
                    syn_loader = make_loader_from_syn(
                        copy.deepcopy(image_syn.detach()),
                        copy.deepcopy(label_syn.detach()),
                        batch_size=256
                    )

                    print(f"[Student] Training on SYN data for {args.syn_train_epochs} epochs (lr={args.train_lr})...")
                    train_on_loader(student_net, syn_loader, args, epochs=args.syn_train_epochs, lr=args.train_lr)
                    print("[Student] Done.")

                    # Choose Phi model (student or teacher)
                    if args.phi_model == "teacher":
                        if teacher_net is None:
                            raise ValueError("phi_model=teacher but teacher_train_epochs==0 (no teacher trained).")
                        net_phi = teacher_net
                        print("[Phi] Using TEACHER model for Phi statistics.")
                    else:
                        net_phi = student_net
                        print("[Phi] Using STUDENT model for Phi statistics.")

                    # Run theorem check
                    check_theorem_4_1(
                        net_phi=net_phi,
                        images_all=images_all,
                        labels_all=labels_all,
                        color_all=color_all,
                        args=args,
                        max_samples_per_group=args.max_samples_per_group,
                        print_per_class=True,
                    )

if __name__ == "__main__":
    main()



# #!/usr/bin/env python3
# # -*- coding: utf-8 -*-

# import argparse
# import copy
# import random
# from dataclasses import dataclass
# from typing import Dict, List, Tuple, Any

# import numpy as np
# import torch
# import torch.nn as nn

# # Your project utils (must exist in your repo)
# from utils import (
#     get_dataset,
#     get_network,
#     TensorDataset,
#     epoch,
#     ParamDiffAug,
# )

# # ----------------------------
# # Reproducibility helpers
# # ----------------------------
# def seed_everything(seed: int = 42) -> None:
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.cuda.manual_seed_all(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

# def save_random_state() -> Dict[str, Any]:
#     state = {
#         "torch": torch.get_rng_state(),
#         "np": np.random.get_state(),
#         "random": random.getstate(),
#     }
#     if torch.cuda.is_available():
#         state["cuda"] = torch.cuda.get_rng_state_all()
#     return state

# def load_random_state(state: Dict[str, Any]) -> None:
#     torch.set_rng_state(state["torch"])
#     np.random.set_state(state["np"])
#     random.setstate(state["random"])
#     if torch.cuda.is_available() and "cuda" in state:
#         torch.cuda.set_rng_state_all(state["cuda"])

# # ----------------------------
# # Training helpers
# # ----------------------------
# def train_on_loader(
#     net: nn.Module,
#     loader: torch.utils.data.DataLoader,
#     args,
#     epochs: int,
#     lr: float = 0.01,
#     weight_decay: float = 0.0,
# ) -> None:
#     criterion = nn.CrossEntropyLoss().to(args.device)
#     optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay)
#     net.train()
#     for _ in range(epochs):
#         epoch("train", loader, net, optimizer, criterion, args, aug=True)

# def make_loader_from_syn(image_syn: torch.Tensor, label_syn: torch.Tensor, batch_size: int = 256):
#     dst_syn = TensorDataset(image_syn, label_syn)
#     return torch.utils.data.DataLoader(dst_syn, batch_size=batch_size, shuffle=True, num_workers=0)

# def make_loader_from_real(images: torch.Tensor, labels: torch.Tensor, batch_size: int = 256):
#     dst = TensorDataset(images, labels)
#     return torch.utils.data.DataLoader(dst, batch_size=batch_size, shuffle=True, num_workers=0)

# # ----------------------------
# # Theorem 4.1 verification
# # ----------------------------
# @dataclass
# class ClassCheckResult:
#     cls: int
#     max_res_van: float
#     max_res_cobra: float
#     condition_dot: float  # <Δ^C_{a+|y}, s_y> should be <= 0 for Theorem 4.1 to apply
#     a_plus: int
#     n_groups: int

# def _flatten_grads(grads: List[torch.Tensor]) -> torch.Tensor:
#     return torch.cat([g.detach().reshape(-1) for g in grads if g is not None])

# @torch.enable_grad()
# def compute_subgroup_phis(
#     net: nn.Module,
#     imgs_c: torch.Tensor,
#     colors_c: torch.Tensor,
#     cls: int,
#     criterion: nn.Module,
#     device: str,
#     max_samples_per_group: int = 256,
# ) -> Tuple[Dict[int, torch.Tensor], Dict[int, float]]:
#     """
#     Returns:
#       subgroup_phis[g] = phi vector for subgroup g (gradient statistic)
#       subgroup_pi[g]   = pi_{a|y} for subgroup g
#     """
#     net.eval()  # keep BN/Dropout fixed; still compute grads
#     unique_groups = torch.unique(colors_c)
#     subgroup_phis: Dict[int, torch.Tensor] = {}
#     subgroup_pi: Dict[int, float] = {}

#     total_c = int(imgs_c.shape[0])
#     for grp in unique_groups:
#         grp_int = int(grp.item())
#         mask = (colors_c == grp)
#         count = int(mask.sum().item())
#         if count == 0:
#             continue

#         subgroup_pi[grp_int] = count / total_c

#         # sample for gradient estimation
#         sample_size = min(count, max_samples_per_group)
#         perm = torch.randperm(count, device=device)[:sample_size]
#         img_batch = imgs_c[mask][perm]
#         lab_batch = torch.full((sample_size,), cls, device=device, dtype=torch.long)

#         net.zero_grad(set_to_none=True)
#         out = net(img_batch)
#         loss = criterion(out, lab_batch)

#         grads = torch.autograd.grad(loss, list(net.parameters()), retain_graph=False, create_graph=False, allow_unused=True)
#         phi_vec = _flatten_grads(grads)

#         subgroup_phis[grp_int] = phi_vec

#     return subgroup_phis, subgroup_pi

# def check_theorem_4_1(
#     net_phi: nn.Module,
#     images_all: torch.Tensor,
#     labels_all: torch.Tensor,
#     color_all: torch.Tensor,
#     args,
#     max_samples_per_group: int = 256,
#     print_per_class: bool = True,
# ) -> Tuple[Dict[str, int], List[ClassCheckResult]]:
#     """
#     Empirical check of Theorem 4.1 for squared L2 distance case.

#     For each class y:
#       - compute subgroup phis: Φ_{T_{a|y}}
#       - vanilla target m_van = sum_a pi_{a|y} Φ_{a|y}
#       - cobra barycenter (uniform weights) m* = mean_a Φ_{a|y} (for squared L2)
#       - compute max residuals: max_a ||Φ_{a|y} - m||^2 for vanilla and cobra
#       - compute Theorem 4.1 condition: <Δ^C_{a+|y}, s_y> <= 0
#             where a+ is subgroup with worst COBRA residual,
#                   Δ^C_{a|y} = Φ_{a|y} - m*,
#                   s_y = m_van - m*
#     """
#     print("\n[Analysis] Verifying Theorem 4.1 (squared L2 case)")
#     net_phi.eval()
#     criterion = nn.CrossEntropyLoss().to(args.device)

#     results = {"better": 0, "worse": 0, "same": 0, "cond_holds": 0, "cond_violates": 0}
#     logs: List[ClassCheckResult] = []

#     num_classes = int(args.num_classes)
#     for c in range(num_classes):
#         idx = (labels_all == c)
#         imgs_c = images_all[idx]
#         colors_c = color_all[idx]

#         unique_groups = torch.unique(colors_c)
#         if unique_groups.numel() < 2:
#             continue

#         subgroup_phis, subgroup_pi = compute_subgroup_phis(
#             net=net_phi,
#             imgs_c=imgs_c,
#             colors_c=colors_c,
#             cls=c,
#             criterion=criterion,
#             device=args.device,
#             max_samples_per_group=max_samples_per_group,
#         )

#         if len(subgroup_phis) < 2:
#             continue

#         # targets
#         # vanilla: sum pi * phi
#         m_van = None
#         for g, phi in subgroup_phis.items():
#             w = float(subgroup_pi[g])
#             m_van = phi * w if m_van is None else (m_van + phi * w)

#         # cobra barycenter under squared L2 with uniform weights -> mean(phi)
#         m_cobra = torch.stack(list(subgroup_phis.values()), dim=0).mean(dim=0)

#         # max residuals (squared L2)
#         max_res_van = -1.0
#         max_res_cobra = -1.0
#         a_plus = None

#         for g, phi in subgroup_phis.items():
#             d_van = torch.norm(phi - m_van).pow(2).item()
#             d_cob = torch.norm(phi - m_cobra).pow(2).item()

#             if d_van > max_res_van:
#                 max_res_van = d_van
#             if d_cob > max_res_cobra:
#                 max_res_cobra = d_cob
#                 a_plus = g

#         # theorem condition (for Theorem 4.1)
#         # <Δ^C_{a+|y}, s_y> <= 0
#         delta_c = subgroup_phis[a_plus] - m_cobra
#         s_y = m_van - m_cobra
#         condition_dot = torch.dot(delta_c, s_y).item()

#         if condition_dot <= 0:
#             results["cond_holds"] += 1
#         else:
#             results["cond_violates"] += 1

#         # compare
#         if max_res_cobra < max_res_van - 1e-12:
#             results["better"] += 1
#             verdict = "COBRA WINS"
#         elif max_res_cobra > max_res_van + 1e-12:
#             results["worse"] += 1
#             verdict = "VANILLA WINS"
#         else:
#             results["same"] += 1
#             verdict = "SAME"

#         logs.append(
#             ClassCheckResult(
#                 cls=c,
#                 max_res_van=max_res_van,
#                 max_res_cobra=max_res_cobra,
#                 condition_dot=condition_dot,
#                 a_plus=int(a_plus),
#                 n_groups=int(unique_groups.numel()),
#             )
#         )

#         if print_per_class:
#             cond_str = "HOLDS" if condition_dot <= 0 else "VIOLATES"
#             print(
#                 f"  Class {c:02d} | groups={unique_groups.numel()} | "
#                 f"MaxRes_Van={max_res_van:.6f} | MaxRes_COBRA={max_res_cobra:.6f} | "
#                 f"<Δc,s_y>={condition_dot:.6e} ({cond_str}) -> {verdict}"
#             )

#     print("\n--- Summary ---")
#     if len(logs) == 0:
#         print("No multi-group classes found.")
#         return results, logs

#     print(f"Classes checked: {len(logs)}")
#     print(f"COBRA better:   {results['better']}")
#     print(f"COBRA worse:    {results['worse']}")
#     print(f"Same:          {results['same']}")
#     print(f"Condition holds (Thm applies):   {results['cond_holds']}")
#     print(f"Condition violates (Thm doesn't apply): {results['cond_violates']}")
#     return results, logs

# # ----------------------------
# # Main experiment
# # ----------------------------
# def safe_torch_load(path: str, device: str):
#     try:
#         return torch.load(path, map_location=device, weights_only=False)
#     except TypeError:
#         # older torch versions may not support weights_only
#         return torch.load(path, map_location=device)

# def extract_syn_data(checkpoint: dict):
#     """
#     Handles both:
#       checkpoint['data'] = (image_syn, label_syn)
#       checkpoint['data'] = [(image_syn, label_syn), ...]
#     """
#     data = checkpoint.get("data", None)
#     if data is None:
#         raise KeyError("Checkpoint has no key 'data'.")

#     if isinstance(data, (list, tuple)) and len(data) > 0 and isinstance(data[0], (list, tuple)):
#         # e.g. [(image_syn, label_syn), ...]
#         image_syn, label_syn = data[0]
#     else:
#         # e.g. (image_syn, label_syn)
#         image_syn, label_syn = data

#     return image_syn, label_syn

# def main():
#     parser = argparse.ArgumentParser()

#     # core args
#     parser.add_argument("--data_path", type=str, default="data")
#     parser.add_argument("--model", type=str, default="ConvNet")
#     parser.add_argument("--dsa_strategy", type=str, default="color_crop_cutout_flip_scale_rotate")
#     parser.add_argument("--seed", type=int, default=42)

#     # what to run
#     # parser.add_argument("--datasets", nargs="+", default=["BFFHQ", "CIFAR10_S_90"])
#     parser.add_argument("--datasets", nargs="+", default=["CIFAR10_S_90"])
#     # parser.add_argument("--ipcs", nargs="+", type=int, default=[10, 50])
#     parser.add_argument("--ipcs", nargs="+", type=int, default=[10])
#     # parser.add_argument("--fair_crts", nargs="+", default=["-nofair", "-fairdd", "-COBRA"])
#     parser.add_argument("--fair_crts", nargs="+", default=[ "-fairdd", "-COBRA"])

#     # paths
#     parser.add_argument(
#         "--ckpt_prefix",
#         type=str,
#         help="Prefix before dataset name in checkpoint path.",
#     )

#     # training lengths
#     parser.add_argument("--syn_train_epochs", type=int, default=1000)
#     parser.add_argument("--teacher_train_epochs", type=int, default=0,
#                         help="If >0, trains a teacher ERM model on REAL data and uses it for Phi stats.")
#     parser.add_argument("--train_lr", type=float, default=0.01)

#     # theorem check
#     parser.add_argument("--phi_model", type=str, default="student", choices=["student", "teacher"],
#                         help="Which model to use to compute subgroup Phi: student (trained on synthetic) or teacher (trained on real).")
#     parser.add_argument("--max_samples_per_group", type=int, default=256)

#     args = parser.parse_args()
#     args.device = "cuda" if torch.cuda.is_available() else "cpu"
#     args.dsa_param = ParamDiffAug()
#     args.dsa = True

#     seed_everything(args.seed)
#     base_random_state = save_random_state()

#     for dataset_name in args.datasets:
#         # reset RNG for fairness across datasets/conditions if desired
#         load_random_state(base_random_state)

#         args.dataset = dataset_name

#         channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(
#             args.dataset, args.data_path
#         )
#         args.num_classes = num_classes

#         # organize real data (expects dst_train[i] = (img, label, group))
#         images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
#         labels_all = [dst_train[i][1] for i in range(len(dst_train))]
#         color_all = [dst_train[i][2] for i in range(len(dst_train))]

#         images_all = torch.cat(images_all, dim=0).to(args.device)
#         labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
#         color_all = torch.tensor(color_all, dtype=torch.long, device=args.device)

#         args.num_classes = int(torch.unique(labels_all).numel())
#         args.num_groups = int(torch.unique(color_all).numel())

#         print(f"\n==============================")
#         print(f"Dataset: {args.dataset} | classes={args.num_classes} | groups={args.num_groups} | im_size={im_size}")
#         print(f"==============================")

#         # optional: train teacher on real data once per dataset
#         teacher_net = None
#         if args.teacher_train_epochs > 0:
#             print(f"\n[Teacher] Training ERM teacher on REAL data for {args.teacher_train_epochs} epochs...")
#             teacher_net = get_network(args.model, channel, args.num_classes, im_size).to(args.device)
#             real_loader = make_loader_from_real(images_all, labels_all, batch_size=256)
#             train_on_loader(teacher_net, real_loader, args, epochs=args.teacher_train_epochs, lr=args.train_lr)
#             print("[Teacher] Done.\n")

#         for ipc in args.ipcs:
#             args.ipc = int(ipc)

#             for fair_crt in args.fair_crts:
#                 # checkpoint path
#                 save_path = (
#                     f"{args.ckpt_prefix}{args.dataset}_ConvNet_{args.ipc}ipc{fair_crt}.pt"
#                 )

#                 print(f"\n--- Running: dataset={args.dataset} ipc={args.ipc} fair_crt={fair_crt} ---")
#                 print(f"Loading checkpoint: {save_path}")

#                 checkpoint = safe_torch_load(save_path, args.device)
#                 image_syn, label_syn = extract_syn_data(checkpoint)

#                 image_syn = image_syn.to(args.device)
#                 label_syn = label_syn.to(args.device)

#                 # safety check: im_size mismatch
#                 loaded_im_size = (int(image_syn.shape[2]), int(image_syn.shape[3]))
#                 if loaded_im_size != tuple(im_size):
#                     print("\n[WARNING] Dataset/ckpt image size mismatch detected!")
#                     print(f"  Config expects im_size={im_size}, ckpt has {loaded_im_size}")
#                     print("  -> Overriding im_size for network construction to prevent crash.")
#                     im_size = loaded_im_size

#                 # train student on synthetic data
#                 student_net = get_network(args.model, channel, args.num_classes, im_size).to(args.device)
#                 syn_loader = make_loader_from_syn(copy.deepcopy(image_syn.detach()),
#                                                   copy.deepcopy(label_syn.detach()),
#                                                   batch_size=256)

#                 print(f"[Student] Training on SYN data for {args.syn_train_epochs} epochs (lr={args.train_lr})...")
#                 train_on_loader(student_net, syn_loader, args, epochs=args.syn_train_epochs, lr=args.train_lr)
#                 print("[Student] Done.\n")

#                 # choose model for Phi computation
#                 if args.phi_model == "teacher":
#                     if teacher_net is None:
#                         raise ValueError("phi_model=teacher but teacher_train_epochs==0 (no teacher trained).")
#                     net_phi = teacher_net
#                     print("[Phi] Using TEACHER model for subgroup Phi statistics.")
#                 else:
#                     net_phi = student_net
#                     print("[Phi] Using STUDENT model for subgroup Phi statistics.")

#                 # theorem check
#                 _ = check_theorem_4_1(
#                     net_phi=net_phi,
#                     images_all=images_all,
#                     labels_all=labels_all,
#                     color_all=color_all,
#                     args=args,
#                     max_samples_per_group=args.max_samples_per_group,
#                     print_per_class=True,
#                 )

# if __name__ == "__main__":
#     main()



# # import argparse
# # import random
# # import numpy as np
# # import torch
# # import torch.nn as nn
# # import copy
# # from utils import get_dataset, get_network, TensorDataset, epoch
# # from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug

# # # --- 1. The Verification Function (Theorem 4.1) ---
# # def check_theorem_geometry(net, images_all, labels_all, color_all, criterion, args):
# #     """
# #     Calculates Real Gradients for all subgroups and compares 
# #     the Max Residual of the Vanilla Target vs COBRA Target.
# #     """
# #     print(f"\n[Analysis] Verifying Theorem 4.1 on Converged Model...")
# #     net.eval()
    
# #     # Track success across classes
# #     results = {'better': 0, 'worse': 0, 'same': 0}
# #     residuals_log = []

# #     for c in range(args.num_classes):
# #         # A. Get Real Data for Class C
# #         indices = (labels_all == c)
# #         imgs_c = images_all[indices]
# #         colors_c = color_all[indices]
        
# #         unique_groups = torch.unique(colors_c)
# #         if len(unique_groups) < 2: continue # Skip single-group classes

# #         # B. Compute Subgroup Statistics (Gradients)
# #         subgroup_phis = {}
# #         subgroup_pi = {}
# #         total_c = len(imgs_c)
        
# #         for grp in unique_groups:
# #             mask = (colors_c == grp)
# #             count = mask.sum().item()
# #             subgroup_pi[grp.item()] = count / total_c
            
# #             # Sample batch for gradient estimation
# #             sample_size = min(count, 256) 
# #             perm = torch.randperm(count)[:sample_size]
# #             img_batch = imgs_c[mask][perm]
# #             lab_batch = torch.ones(sample_size, device=args.device, dtype=torch.long) * c
            
# #             # Forward/Backward to get Phi (Gradient Vector)
# #             net.zero_grad()
# #             output = net(img_batch)
# #             loss = criterion(output, lab_batch)
# #             grads = torch.autograd.grad(loss, list(net.parameters()))
            
# #             # Flatten to single vector
# #             phi_vec = torch.cat([g.detach().view(-1) for g in grads])
# #             subgroup_phis[grp.item()] = phi_vec

# #         # C. Construct Targets
# #         # Vanilla (Weighted): sum( pi * phi )
# #         m_van = sum(subgroup_phis[g] * subgroup_pi[g] for g in subgroup_phis)
        
# #         # COBRA (Barycenter): mean( phi )
# #         m_cobra = sum(subgroup_phis[g] for g in subgroup_phis) / len(subgroup_phis)

# #         # D. Calculate Worst-Case Residuals
# #         max_res_van = 0.0
# #         max_res_cobra = 0.0
        
# #         for g in subgroup_phis:
# #             phi = subgroup_phis[g]
# #             # Squared L2 Distance
# #             d_van = torch.norm(phi - m_van).item() ** 2
# #             d_cobra = torch.norm(phi - m_cobra).item() ** 2
            
# #             if d_van > max_res_van: max_res_van = d_van
# #             if d_cobra > max_res_cobra: max_res_cobra = d_cobra

# #         # E. Record Result
# #         if max_res_cobra < max_res_van:
# #             results['better'] += 1
# #         elif max_res_cobra > max_res_van:
# #             results['worse'] += 1
# #         else:
# #             results['same'] += 1
            
# #         residuals_log.append((c, max_res_van, max_res_cobra))
        
# #         print(f"  Class {c}: MaxRes_Van={max_res_van:.4f} | MaxRes_COBRA={max_res_cobra:.4f} "
# #               f"-> {'COBRA WINS' if max_res_cobra < max_res_van else 'VANILLA WINS'}")

# #     print("\n--- Summary ---")
# #     if len(residuals_log) > 0:
# #         print(f"COBRA reduced worst-case residual in {results['better']}/{len(residuals_log)} classes.")
# #     else:
# #         print("No multi-group classes found to verify theorem.")
# #     return results

# # # --- 2. Main Execution ---
# # def main():
# #     parser = argparse.ArgumentParser()
# #     # Default is CIFAR, but you are loading BFFHQ. Change default or pass arg.
# #     parser.add_argument('--dataset', type=str, default='CIFAR10_S_90', help='dataset') 
# #     parser.add_argument('--model', type=str, default='ConvNet', help='model')
# #     parser.add_argument('--ipc', type=int, default=10, help='image(s) per class')
# #     parser.add_argument('--data_path', type=str, default='data', help='path to dataset')
# #     # parser.add_argument('--load_path', type=str, required=True, help='path to .pt file') 
# #     parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
# #     # Other args...

    
    
# #     for datasetsss in ['BFFHQ', 'CIFAR10_S_90']:
# #     # for datasetsss in ['CIFAR10_S_90']:



# #         args = parser.parse_args()
# #         args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# #         args.dsa_param = ParamDiffAug()
# #         args.dsa = True
# #         args.dataset = datasetsss
# #         # eval_it_pool = [args.Iteration]
# #         channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
# #         args.num_classes = num_classes

# #         load_random_state(random_state)
# #         # Organize Real Data
# #         images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
# #         labels_all = [dst_train[i][1] for i in range(len(dst_train))]
# #         color_all = [dst_train[i][2] for i in range(len(dst_train))] 
        
# #         images_all = torch.cat(images_all, dim=0).to(args.device)
# #         labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
# #         color_all = torch.tensor(color_all, dtype=torch.long, device=args.device)
# #         args.num_classes = len(torch.unique(labels_all))
# #         args.num_groups = len(torch.unique(color_all))


# #         for ipc in [10,50]:
# #             args.ipc = ipc
# #             for fair_crt in ['-nofair','-fairdd','-COBRA']:

# #                 checkpoint = torch.load(save_path, map_location=args.device, weights_only=False)





                
                


        
# #                 try:
# #                     image_syn, label_syn = checkpoint['data'][0] 
# #                 except:
# #                     image_syn, label_syn = checkpoint['data']    

# #                 image_syn = image_syn.to(args.device)
# #                 label_syn = label_syn.to(args.device)
                
# #                 # --- 2.1 SAFETY CHECK: Match Network to Loaded Data ---
# #                 loaded_im_size = (image_syn.shape[2], image_syn.shape[3])
# #                 if loaded_im_size != im_size:
# #                     print(f"\n[WARNING] Dataset mismatch detected!")
# #                     print(f"  Config '{args.dataset}' expects size {im_size}")
# #                     print(f"  Loaded .pt file contains size {loaded_im_size}")
# #                     print(f"  -> Overriding im_size to {loaded_im_size} to prevent crash.")
# #                     print(f"  -> PLEASE CHECK YOUR --dataset ARGUMENT! Comparison might be invalid if real data is different.\n")
# #                     im_size = loaded_im_size

# #                 # 3. Reconstruct "Converged" Model
# #                 net = get_network(args.model, channel, num_classes, im_size).to(args.device)
# #                 criterion = nn.CrossEntropyLoss().to(args.device)
# #                 net.train()
# #                 optimizer_net = torch.optim.SGD(net.parameters(), lr=0.01)  
# #                 optimizer_net.zero_grad()
                
# #                 image_syn_train = copy.deepcopy(image_syn.detach())
# #                 label_syn_train = copy.deepcopy(label_syn.detach())
# #                 dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
# #                 trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=256, shuffle=True, num_workers=0)
                
# #                 for il in range(1000):
# #                     epoch('train', trainloader, net, optimizer_net, criterion, args, aug=True)

# #                 print(('dataset:', args.dataset, 'ipc:', args.ipc, 'fair_crt:', fair_crt))
# #                 check_theorem_geometry(net, images_all, labels_all, color_all, criterion, args)


# # if __name__ == '__main__':
# #     def save_random_state():
# #         return {
# #             'torch': torch.get_rng_state(),
# #             'np': np.random.get_state(),
# #             'random': random.getstate(),
# #             'cuda': torch.cuda.get_rng_state_all()
# #         }
# #     def load_random_state(state):
# #         torch.set_rng_state(state['torch'])
# #         np.random.set_state(state['np'])
# #         random.setstate(state['random'])
# #         torch.cuda.set_rng_state_all(state['cuda'])

# #     seed=42
# #     random.seed(seed)
# #     np.random.seed(seed)
# #     torch.manual_seed(seed)
# #     torch.cuda.manual_seed(seed)
# #     torch.cuda.manual_seed_all(seed)
# #     torch.backends.cudnn.deterministic = True
# #     torch.backends.cudnn.benchmark = False

# #     # 保存当前的随机状态
# #     random_state = save_random_state()

# #     main()

