import os
import numpy as np
from PIL import Image
from torchvision.datasets import MNIST

from utils import log, stratified_pick_per_class
from augment import geom_module, opt_module, copy_module, poison_module


def load_mnist(data_root="./data"):
    os.makedirs(data_root, exist_ok=True)
    train_mnist = MNIST(root=data_root, train=True, download=True)
    test_mnist = MNIST(root=data_root, train=False, download=True)
    X_orig_full = train_mnist.data.numpy().astype(np.uint8)
    y_orig_full = train_mnist.targets.numpy().astype(np.int64)
    X_test_full = test_mnist.data.numpy().astype(np.uint8)
    y_test_full = test_mnist.targets.numpy().astype(np.int64)
    log(f"X_orig_full shape: {X_orig_full.shape}, y_orig_full shape: {y_orig_full.shape}")
    log(f"X_test_full shape: {X_test_full.shape},  y_test_full shape: {y_test_full.shape}")
    return X_orig_full, y_orig_full, X_test_full, y_test_full


def subsample_mnist(
    X_orig_full,
    y_orig_full,
    X_test_full,
    y_test_full,
    per_class_train=10,
    per_class_test=100,
    rng=np.random,
):
    orig_idx = stratified_pick_per_class(y_orig_full, per_class=per_class_train, num_classes=10, rng=rng)
    test_idx = stratified_pick_per_class(y_test_full, per_class=per_class_test, num_classes=10, rng=rng)
    X_orig = X_orig_full[orig_idx]
    y_orig = y_orig_full[orig_idx]
    X_test = X_test_full[test_idx]
    y_test = y_test_full[test_idx]
    log(f"Subsampled train: X_orig {X_orig.shape}, y_orig {y_orig.shape}")
    log(f"Subsampled test:  X_test  {X_test.shape}, y_test  {y_test.shape}")
    return X_orig, y_orig, X_test, y_test


def load_cifar10(data_root="./data"):
    import torchvision.datasets as dsets
    import torchvision.transforms as T
    os.makedirs(os.path.join(data_root, "cifar10"), exist_ok=True)
    tf = T.Compose([T.ToTensor()])
    tr = dsets.CIFAR10(root=os.path.join(data_root, "cifar10"), train=True, download=True, transform=tf)
    te = dsets.CIFAR10(root=os.path.join(data_root, "cifar10"), train=False, download=True, transform=tf)
    X_tr = (np.stack([(np.array(tr[i][0]) * 255).astype(np.uint8) for i in range(len(tr))]))
    y_tr = np.array([tr[i][1] for i in range(len(tr))], dtype=np.int64)
    X_te = (np.stack([(np.array(te[i][0]) * 255).astype(np.uint8) for i in range(len(te))]))
    y_te = np.array([te[i][1] for i in range(len(te))], dtype=np.int64)
    log(f"[CIFAR10] train {X_tr.shape}, test {X_te.shape}")
    return X_tr, y_tr, X_te, y_te


def subsample_cifar10(
    X_orig_full,
    y_orig_full,
    X_test_full,
    y_test_full,
    per_class_train=10,
    per_class_test=100,
    rng=np.random,
):
    orig_idx = stratified_pick_per_class(y_orig_full, per_class=per_class_train, num_classes=10, rng=rng)
    test_idx = stratified_pick_per_class(y_test_full, per_class=per_class_test, num_classes=10, rng=rng)
    X_orig = X_orig_full[orig_idx]
    y_orig = y_orig_full[orig_idx]
    X_test = X_test_full[test_idx]
    y_test = y_test_full[test_idx]
    log(f"[CIFAR10] Subsampled train {X_orig.shape}, test {X_test.shape}")
    return X_orig, y_orig, X_test, y_test


def build_providers(
    X_owner,
    y_owner,
    *,
    rng=np.random,
    poison_prob=1.0,
    booster_owner_count=None,
    copier_owner_count=None,
    poisoner_owner_count=None,
    booster_source: str = "none",
    booster_samples_root: str = None,
    booster_variant_owner: str = "orig",
    booster_variant_anchor: str = "aug",
    booster_samples_root_owner: str = None,
    booster_samples_root_anchor: str = None,
):
    N = len(X_owner)
    half = N // 2
    if booster_owner_count is None:
        booster_owner_count = half
    if copier_owner_count is None:
        copier_owner_count = half
    if poisoner_owner_count is None:
        poisoner_owner_count = half
    booster_owner_count = int(np.clip(booster_owner_count, 0, N))
    copier_owner_count = int(np.clip(copier_owner_count, 0, N))
    poisoner_owner_count = int(np.clip(poisoner_owner_count, 0, N))

    X_anchor_geom, y_anchor_geom = geom_module(X_owner, y_owner, per_sample=1)
    X_anchor, y_anchor = opt_module(X_anchor_geom, y_anchor_geom, per_sample=1)
    parents_anchor = {
        "node": np.array(["owner"] * N, dtype=object),
        "index": np.arange(N, dtype=np.int64),
    }

    def _default_samples_root_for(source_name: str, variant: str = None) -> str:
        this_dir = os.path.dirname(__file__)
        vsub = (variant or "").strip().lower()
        if source_name == "gan":
            base = os.path.join(this_dir, "generative", "gan", "samples")
            ds_first = os.path.join(base, "mnist")
            return os.path.join(ds_first, vsub) if vsub else ds_first
        if source_name == "fm":
            base = os.path.join(this_dir, "generative", "fm", "samples")
            ds_first = os.path.join(base, "mnist")
            return os.path.join(ds_first, vsub) if vsub else ds_first
        if source_name == "ddpm":
            base = os.path.join(this_dir, "generative", "diffusion", "samples")
            ds_first = os.path.join(base, "mnist")
            return os.path.join(ds_first, vsub, "ddpm") if vsub else os.path.join(ds_first, "ddpm")
        if source_name == "ddim":
            base = os.path.join(this_dir, "generative", "diffusion", "samples")
            ds_first = os.path.join(base, "mnist")
            return os.path.join(ds_first, vsub, "ddim") if vsub else os.path.join(ds_first, "ddim")
        raise ValueError(source_name)

    def _read_png_as_uint8_gray(path, out_size=(28, 28)):
        img = Image.open(path)
        img = img.convert("L")
        if out_size is not None and (img.size[0] != out_size[1] or img.size[1] != out_size[0]):
            img = img.resize((out_size[1], out_size[0]), Image.BILINEAR)
        arr = np.array(img, dtype=np.uint8)
        return arr

    def _collect_class_image_paths(root_dir):
        class_to_paths = {c: [] for c in range(10)}
        if not os.path.isdir(root_dir):
            return class_to_paths
        found = False
        for c in range(10):
            cdir = os.path.join(root_dir, "classwise", f"class_{c:02d}")
            if os.path.isdir(cdir):
                found = True
                for fn in sorted(os.listdir(cdir)):
                    if fn.lower().endswith(".png"):
                        class_to_paths[c].append(os.path.join(cdir, fn))
        if found:
            return class_to_paths
        has_class_dirs = False
        for c in range(10):
            cdir = os.path.join(root_dir, f"class_{c:02d}")
            if os.path.isdir(cdir):
                has_class_dirs = True
                for fn in sorted(os.listdir(cdir)):
                    if fn.lower().endswith(".png"):
                        class_to_paths[c].append(os.path.join(cdir, fn))
        if has_class_dirs:
            return class_to_paths
        import re
        patt = re.compile(r"_c(\d{2})_")
        for fn in sorted(os.listdir(root_dir)):
            if not fn.lower().endswith(".png"):
                continue
            m = patt.search(fn)
            if not m:
                continue
            c = int(m.group(1))
            if 0 <= c <= 9:
                class_to_paths[c].append(os.path.join(root_dir, fn))
        return class_to_paths

    booster_source = (booster_source or "").lower()
    if booster_source in {"gan", "ddpm", "ddim", "fm"}:
        owner_cnt = booster_owner_count
        anchor_cnt = N - owner_cnt

        variant_owner = (booster_variant_owner or "orig").lower()
        variant_anchor = (booster_variant_anchor or "aug").lower()
        root_owner = booster_samples_root_owner if booster_samples_root_owner is not None else \
            (_default_samples_root_for(booster_source, variant_owner) if booster_samples_root is None else booster_samples_root)
        root_anchor = booster_samples_root_anchor if booster_samples_root_anchor is not None else \
            (_default_samples_root_for(booster_source, variant_anchor) if booster_samples_root is None else booster_samples_root)

        def _with_fallback(base_root, source, variant):
            if any(os.path.isdir(os.path.join(base_root, d)) for d in ["classwise"]):
                return base_root
            if source in {"ddpm", "ddim"}:
                base = os.path.join(os.path.dirname(os.path.dirname(base_root)), variant or "")
                if source == "ddpm":
                    return os.path.join(base, "ddpm")
                return os.path.join(base, "ddim")
            return os.path.join(os.path.dirname(base_root), variant or "")

        root_owner = _with_fallback(root_owner, booster_source, variant_owner)
        root_anchor = _with_fallback(root_anchor, booster_source, variant_anchor)

        cls_paths_owner = _collect_class_image_paths(root_owner)
        cls_paths_anchor = _collect_class_image_paths(root_anchor)

        idx_owner_for_booster = rng.choice(N, size=owner_cnt, replace=False) if owner_cnt > 0 else np.array([], dtype=np.int64)
        remaining_idx = np.setdiff1d(np.arange(N, dtype=np.int64), idx_owner_for_booster, assume_unique=False)
        if anchor_cnt > 0:
            if remaining_idx.size >= anchor_cnt:
                idx_anchor_for_booster = rng.choice(remaining_idx, size=anchor_cnt, replace=False)
            else:
                idx_anchor_for_booster = rng.choice(N, size=anchor_cnt, replace=False)
        else:
            idx_anchor_for_booster = np.array([], dtype=np.int64)

        def synth_from_paths(indices, paths_map):
            X_out = np.empty((indices.size, X_owner.shape[1], X_owner.shape[2]), dtype=X_owner.dtype)
            y_out = np.empty(indices.size, dtype=y_owner.dtype)
            cursors = {c: 0 for c in range(10)}
            for k, i in enumerate(indices):
                c = int(y_owner[i])
                paths = paths_map.get(c, [])
                if len(paths) == 0:
                    X_out[k] = X_owner[i]
                    y_out[k] = y_owner[i]
                    continue
                j = cursors[c] % len(paths)
                cursors[c] += 1
                X_out[k] = _read_png_as_uint8_gray(paths[j], out_size=(28, 28))
                y_out[k] = c
            return X_out, y_out

        X_boost_from_owner, y_boost_from_owner = synth_from_paths(idx_owner_for_booster, cls_paths_owner) if idx_owner_for_booster.size > 0 else (X_owner[:0], y_owner[:0])
        X_boost_from_anchor, y_boost_from_anchor = synth_from_paths(idx_anchor_for_booster, cls_paths_anchor) if idx_anchor_for_booster.size > 0 else (X_owner[:0], y_owner[:0])

        parents_boost_from_owner = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
        parents_boost_from_anchor = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}

        X_booster = np.concatenate([X_boost_from_owner, X_boost_from_anchor], axis=0) if (X_boost_from_owner.shape[0] + X_boost_from_anchor.shape[0]) > 0 else X_owner[:0]
        y_booster = np.concatenate([y_boost_from_owner, y_boost_from_anchor], axis=0) if (y_boost_from_owner.shape[0] + y_boost_from_anchor.shape[0]) > 0 else y_owner[:0]
    else:
        X_boost_from_owner, y_boost_from_owner = X_owner[:0], y_owner[:0]
        X_boost_from_anchor, y_boost_from_anchor = X_owner[:0], y_owner[:0]
        
        parents_boost_from_owner = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
        parents_boost_from_anchor = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
        
        X_booster = X_owner[:0]
        y_booster = y_owner[:0]

    owner_cnt = copier_owner_count
    anchor_cnt = N - owner_cnt
    
    idx_owner_for_copier = rng.choice(N, size=owner_cnt, replace=False) if owner_cnt > 0 else np.array([], dtype=np.int64)
    idx_anchor_for_copier = rng.choice(N, size=anchor_cnt, replace=False) if anchor_cnt > 0 else np.array([], dtype=np.int64)
    
    X_copier_from_owner, y_copier_from_owner = copy_module(
        X_owner[idx_owner_for_copier] if idx_owner_for_copier.size > 0 else X_owner[:0],
        y_owner[idx_owner_for_copier] if idx_owner_for_copier.size > 0 else y_owner[:0],
        per_sample=1,
    )
    
    X_copier_from_anchor, y_copier_from_anchor = copy_module(
        X_anchor[idx_anchor_for_copier] if idx_anchor_for_copier.size > 0 else X_anchor[:0],
        y_anchor[idx_anchor_for_copier] if idx_anchor_for_copier.size > 0 else y_anchor[:0],
        per_sample=1,
    )
    
    parents_copier_from_owner = {
        "node": np.array(["owner"] * (idx_owner_for_copier.size), dtype=object),
        "index": idx_owner_for_copier.astype(np.int64),
    } if idx_owner_for_copier.size > 0 else {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
    
    parents_copier_from_anchor = {
        "node": np.array(["anchor"] * (idx_anchor_for_copier.size), dtype=object),
        "index": idx_anchor_for_copier.astype(np.int64),
    } if idx_anchor_for_copier.size > 0 else {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
    
    X_copier = np.concatenate([X_copier_from_owner, X_copier_from_anchor], axis=0) if (X_copier_from_owner.shape[0] + X_copier_from_anchor.shape[0]) > 0 else X_owner[:0]
    y_copier = np.concatenate([y_copier_from_owner, y_copier_from_anchor], axis=0) if (y_copier_from_owner.shape[0] + y_copier_from_anchor.shape[0]) > 0 else y_owner[:0]

    owner_cnt = poisoner_owner_count
    anchor_cnt = N - owner_cnt
    
    idx_owner_for_poisoner = rng.choice(N, size=owner_cnt, replace=False) if owner_cnt > 0 else np.array([], dtype=np.int64)
    idx_anchor_for_poisoner = rng.choice(N, size=anchor_cnt, replace=False) if anchor_cnt > 0 else np.array([], dtype=np.int64)
    
    X_poisoner_from_owner, y_poisoner_from_owner, mask_poisoner_from_owner, parent_idx_rel_owner = poison_module(
        X_owner[idx_owner_for_poisoner] if idx_owner_for_poisoner.size > 0 else X_owner[:0],
        y_owner[idx_owner_for_poisoner] if idx_owner_for_poisoner.size > 0 else y_owner[:0],
        per_sample=1,
        poisoning_prob=poison_prob,
        rng=rng,
        aug_type="geom+opt",
    )
    
    X_poisoner_from_anchor, y_poisoner_from_anchor, mask_poisoner_from_anchor, parent_idx_rel_anchor = poison_module(
        X_anchor[idx_anchor_for_poisoner] if idx_anchor_for_poisoner.size > 0 else X_anchor[:0],
        y_anchor[idx_anchor_for_poisoner] if idx_anchor_for_poisoner.size > 0 else y_anchor[:0],
        per_sample=1,
        poisoning_prob=poison_prob,
        rng=rng,
        aug_type="geom+opt",
    )
    
    parents_poisoner_from_owner = {
        "node": np.array(["owner"] * parent_idx_rel_owner.size, dtype=object),
        "index": (idx_owner_for_poisoner[parent_idx_rel_owner] if parent_idx_rel_owner.size > 0 else np.array([], dtype=np.int64)).astype(np.int64) if idx_owner_for_poisoner.size > 0 else np.array([], dtype=np.int64),
    }
    
    parents_poisoner_from_anchor = {
        "node": np.array(["anchor"] * parent_idx_rel_anchor.size, dtype=object),
        "index": (idx_anchor_for_poisoner[parent_idx_rel_anchor] if parent_idx_rel_anchor.size > 0 else np.array([], dtype=np.int64)).astype(np.int64) if idx_anchor_for_poisoner.size > 0 else np.array([], dtype=np.int64),
    }
    
    X_poisoner = np.concatenate([X_poisoner_from_owner, X_poisoner_from_anchor], axis=0) if (X_poisoner_from_owner.shape[0] + X_poisoner_from_anchor.shape[0]) > 0 else X_owner[:0]
    y_poisoner = np.concatenate([y_poisoner_from_owner, y_poisoner_from_anchor], axis=0) if (y_poisoner_from_owner.shape[0] + y_poisoner_from_anchor.shape[0]) > 0 else y_owner[:0]

    nodes = {
        "owner": {"X": X_owner, "y": y_owner},
        "anchor": {"X": X_anchor, "y": y_anchor, "parents": parents_anchor},
        "booster_from_owner": {"X": X_boost_from_owner, "y": y_boost_from_owner, "parents": parents_boost_from_owner},
        "booster_from_anchor": {"X": X_boost_from_anchor, "y": y_boost_from_anchor, "parents": parents_boost_from_anchor},
        "copier_from_owner": {"X": X_copier_from_owner, "y": y_copier_from_owner, "parents": parents_copier_from_owner},
        "copier_from_anchor": {"X": X_copier_from_anchor, "y": y_copier_from_anchor, "parents": parents_copier_from_anchor},
        "poisoner_from_owner": {"X": X_poisoner_from_owner, "y": y_poisoner_from_owner, "parents": parents_poisoner_from_owner},
        "poisoner_from_anchor": {"X": X_poisoner_from_anchor, "y": y_poisoner_from_anchor, "parents": parents_poisoner_from_anchor},
    }

    providers = {
        "owner": {"X": X_owner, "y": y_owner, "role": "original"},
        "anchor": {"X": X_anchor, "y": y_anchor, "role": "reliable"},
        "booster": {"X": X_booster, "y": y_booster, "role": "reliable"},
        "copier": {"X": X_copier, "y": y_copier, "role": "harmful"},
        "poisoner": {"X": X_poisoner, "y": y_poisoner, "role": "harmful"},
    }

    return nodes, providers


def build_providers_cifar(
    X_owner,
    y_owner,
    *,
    rng=np.random,
    poison_prob=1.0,
    booster_owner_count=None,
    copier_owner_count=None,
    poisoner_owner_count=None,
    booster_source: str = "none",
    booster_samples_root: str = None,
    booster_variant_owner: str = "orig",
    booster_variant_anchor: str = "aug",
    booster_samples_root_owner: str = None,
    booster_samples_root_anchor: str = None,
):
    N = len(X_owner)
    half = N // 2
    if booster_owner_count is None:
        booster_owner_count = half
    if copier_owner_count is None:
        copier_owner_count = half
    if poisoner_owner_count is None:
        poisoner_owner_count = half
    booster_owner_count = int(np.clip(booster_owner_count, 0, N))
    copier_owner_count = int(np.clip(copier_owner_count, 0, N))
    poisoner_owner_count = int(np.clip(poisoner_owner_count, 0, N))

    def _to_hwc(arr):
        return np.transpose(arr, (0, 2, 3, 1)) if (arr.ndim == 4 and arr.shape[1] == 3) else arr
    def _to_chw(arr):
        return np.transpose(arr, (0, 3, 1, 2)) if (arr.ndim == 4 and arr.shape[-1] == 3) else arr
    X_owner_hwc = _to_hwc(X_owner)
    X_anchor_geom_hwc, y_anchor_geom = geom_module(
        X_owner_hwc, y_owner, per_sample=1,
        angle_range=(0.0, 0.0),
        translate_frac=0.10,
        scale_range=(0.95, 1.05),
        fill=(123, 116, 103),
    )
    X_anchor_hwc, y_anchor = opt_module(
        X_anchor_geom_hwc, y_anchor_geom, per_sample=1,
        brightness_range=(0.8, 1.2),
        contrast_range=(0.8, 1.2),
        blur_kernel=1,
        blur_sigma=(1.0, 1.0),
    )
    X_anchor = _to_chw(X_anchor_hwc).astype(np.uint8)
    parents_anchor = {"node": np.array(["owner"] * N, dtype=object), "index": np.arange(N, dtype=np.int64)}

    def _default_samples_root_for(source_name: str, variant: str = None) -> str:
        this_dir = os.path.dirname(__file__)
        vsub = (variant or "").strip().lower()
        if source_name == "gan":
            base = os.path.join(this_dir, "generative", "gan", "samples", "cifar10")
            return os.path.join(base, vsub) if vsub else base
        elif source_name == "fm":
            base = os.path.join(this_dir, "generative", "fm", "samples", "cifar10")
            return os.path.join(base, vsub) if vsub else base
        elif source_name == "ddpm":
            base = os.path.join(this_dir, "generative", "diffusion", "samples", "cifar10")
            if vsub:
                return os.path.join(base, vsub, "ddpm")
            else:
                return os.path.join(base, "ddpm")
        elif source_name == "ddim":
            base = os.path.join(this_dir, "generative", "diffusion", "samples", "cifar10")
            if vsub:
                return os.path.join(base, vsub, "ddim")
            else:
                return os.path.join(base, "ddim")
        else:
            raise ValueError(source_name)

    def _collect_class_image_paths_rgb(root_dir):
        class_to_paths = {c: [] for c in range(10)}
        if not os.path.isdir(root_dir):
            return class_to_paths
        found = False
        for c in range(10):
            cdir = os.path.join(root_dir, "classwise", f"class_{c:02d}")
            if os.path.isdir(cdir):
                found = True
                for fn in sorted(os.listdir(cdir)):
                    if fn.lower().endswith(".png"):
                        class_to_paths[c].append(os.path.join(cdir, fn))
        if found:
            return class_to_paths
        for c in range(10):
            cdir = os.path.join(root_dir, f"class_{c:02d}")
            if os.path.isdir(cdir):
                found = True
                for fn in sorted(os.listdir(cdir)):
                    if fn.lower().endswith(".png"):
                        class_to_paths[c].append(os.path.join(cdir, fn))
        if found:
            return class_to_paths
        import re
        patt = re.compile(r"class[_-]?(\d{1,2})")
        for fn in sorted(os.listdir(root_dir)):
            if not fn.lower().endswith(".png"):
                continue
            m = patt.search(fn)
            if not m:
                continue
            c = int(m.group(1))
            if 0 <= c <= 9:
                class_to_paths[c].append(os.path.join(root_dir, fn))
        return class_to_paths

    booster_source_l = (booster_source or "").lower()
    if booster_source_l in {"gan", "ddpm", "ddim", "fm"}:
        owner_cnt = booster_owner_count
        anchor_cnt = N - owner_cnt
        variant_owner = (booster_variant_owner or "orig").lower()
        variant_anchor = (booster_variant_anchor or "aug").lower()
        root_owner = booster_samples_root_owner if booster_samples_root_owner is not None else \
            (_default_samples_root_for(booster_source_l, variant_owner) if booster_samples_root is None else booster_samples_root)
        root_anchor = booster_samples_root_anchor if booster_samples_root_anchor is not None else \
            (_default_samples_root_for(booster_source_l, variant_anchor) if booster_samples_root is None else booster_samples_root)
        cls_paths_owner = _collect_class_image_paths_rgb(root_owner)
        cls_paths_anchor = _collect_class_image_paths_rgb(root_anchor)

        idx_owner_for_booster = rng.choice(N, size=owner_cnt, replace=False) if owner_cnt > 0 else np.array([], dtype=np.int64)
        remaining_idx = np.setdiff1d(np.arange(N, dtype=np.int64), idx_owner_for_booster, assume_unique=False)
        if anchor_cnt > 0:
            if remaining_idx.size >= anchor_cnt:
                idx_anchor_for_booster = rng.choice(remaining_idx, size=anchor_cnt, replace=False)
            else:
                idx_anchor_for_booster = rng.choice(N, size=anchor_cnt, replace=False)
        else:
            idx_anchor_for_booster = np.array([], dtype=np.int64)

        def synth_from_paths(indices, paths_map):
            Xs = []
            ys = []
            cursors = {c: 0 for c in range(10)}
            target_h, target_w = int(X_owner.shape[2]), int(X_owner.shape[3])
            for i in indices:
                c = int(y_owner[i])
                paths = paths_map.get(c, [])
                if len(paths) == 0:
                    Xs.append(X_owner[i])
                    ys.append(y_owner[i])
                else:
                    j = cursors[c] % len(paths)
                    cursors[c] += 1
                    img = Image.open(paths[j]).convert("RGB")
                    if img.size != (target_w, target_h):
                        img = img.resize((target_w, target_h), Image.BILINEAR)
                    arr = np.array(img, dtype=np.uint8)
                    Xs.append(arr)
                    ys.append(c)
            if Xs:
                Xs = np.stack(Xs, axis=0).astype(np.uint8)
                Xs = _to_chw(Xs)
                ys = np.array(ys, dtype=np.int64)
            else:
                Xs = X_owner[:0]
                ys = y_owner[:0]
            return Xs, ys

        X_boost_from_owner, y_boost_from_owner = synth_from_paths(idx_owner_for_booster, cls_paths_owner) if idx_owner_for_booster.size > 0 else (X_owner[:0], y_owner[:0])
        X_boost_from_anchor, y_boost_from_anchor = synth_from_paths(idx_anchor_for_booster, cls_paths_anchor) if idx_anchor_for_booster.size > 0 else (X_owner[:0], y_owner[:0])
        parents_boost_from_owner = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
        parents_boost_from_anchor = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
        X_booster = np.concatenate([X_boost_from_owner, X_boost_from_anchor], axis=0) if (X_boost_from_owner.shape[0] + X_boost_from_anchor.shape[0]) > 0 else X_owner[:0]
        y_booster = np.concatenate([y_boost_from_owner, y_boost_from_anchor], axis=0) if (y_boost_from_owner.shape[0] + y_boost_from_anchor.shape[0]) > 0 else y_owner[:0]
    elif booster_source_l == "none":
        X_boost_from_owner, y_boost_from_owner = X_owner[:0], y_owner[:0]
        X_boost_from_anchor, y_boost_from_anchor = X_owner[:0], y_owner[:0]
        parents_boost_from_owner = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
        parents_boost_from_anchor = {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
        X_booster = X_owner[:0]
        y_booster = y_owner[:0]
    else:
        raise ValueError("booster_source must be one of {'gan','ddpm','ddim','fm','none'}")

    owner_cnt = copier_owner_count
    anchor_cnt = N - owner_cnt
    idx_owner_for_copier = rng.choice(N, size=owner_cnt, replace=False) if owner_cnt > 0 else np.array([], dtype=np.int64)
    idx_anchor_for_copier = rng.choice(N, size=anchor_cnt, replace=False) if anchor_cnt > 0 else np.array([], dtype=np.int64)
    
    _xo_sel = X_owner[idx_owner_for_copier] if idx_owner_for_copier.size > 0 else X_owner[:0]
    _yo_sel = y_owner[idx_owner_for_copier] if idx_owner_for_copier.size > 0 else y_owner[:0]
    
    X_copier_from_owner_hwc, y_copier_from_owner = copy_module(_to_hwc(_xo_sel), _yo_sel, per_sample=1)
    X_copier_from_owner = _to_chw(X_copier_from_owner_hwc).astype(np.uint8)
    
    _xa_sel = X_anchor[idx_anchor_for_copier] if idx_anchor_for_copier.size > 0 else X_anchor[:0]
    _ya_sel = y_anchor[idx_anchor_for_copier] if idx_anchor_for_copier.size > 0 else y_anchor[:0]
    
    X_copier_from_anchor_hwc, y_copier_from_anchor = copy_module(_to_hwc(_xa_sel), _ya_sel, per_sample=1)
    X_copier_from_anchor = _to_chw(X_copier_from_anchor_hwc).astype(np.uint8)
    
    parents_copier_from_owner = {"node": np.array(["owner"] * (idx_owner_for_copier.size), dtype=object), "index": idx_owner_for_copier.astype(np.int64)} if idx_owner_for_copier.size > 0 else {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
    parents_copier_from_anchor = {"node": np.array(["anchor"] * (idx_anchor_for_copier.size), dtype=object), "index": idx_anchor_for_copier.astype(np.int64)} if idx_anchor_for_copier.size > 0 else {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
    
    X_copier = np.concatenate([X_copier_from_owner, X_copier_from_anchor], axis=0) if (X_copier_from_owner.shape[0] + X_copier_from_anchor.shape[0]) > 0 else X_owner[:0]
    y_copier = np.concatenate([y_copier_from_owner, y_copier_from_anchor], axis=0) if (y_copier_from_owner.shape[0] + y_copier_from_anchor.shape[0]) > 0 else y_owner[:0]

    owner_cnt = poisoner_owner_count
    anchor_cnt = N - owner_cnt
    idx_owner_for_poisoner = rng.choice(N, size=owner_cnt, replace=False) if owner_cnt > 0 else np.array([], dtype=np.int64)
    idx_anchor_for_poisoner = rng.choice(N, size=anchor_cnt, replace=False) if anchor_cnt > 0 else np.array([], dtype=np.int64)
    
    _xo_p = X_owner[idx_owner_for_poisoner] if idx_owner_for_poisoner.size > 0 else X_owner[:0]
    _yo_p = y_owner[idx_owner_for_poisoner] if idx_owner_for_poisoner.size > 0 else y_owner[:0]
    _xa_p = X_anchor[idx_anchor_for_poisoner] if idx_anchor_for_poisoner.size > 0 else X_anchor[:0]
    _ya_p = y_anchor[idx_anchor_for_poisoner] if idx_anchor_for_poisoner.size > 0 else y_anchor[:0]
    
    X_poisoner_from_owner_hwc, y_poisoner_from_owner, mask_poisoner_from_owner, parent_idx_rel_owner = poison_module(
        _to_hwc(_xo_p), _yo_p, per_sample=1, poisoning_prob=poison_prob, rng=rng, aug_type="geom+opt",
    )
    
    X_poisoner_from_anchor_hwc, y_poisoner_from_anchor, mask_poisoner_from_anchor, parent_idx_rel_anchor = poison_module(
        _to_hwc(_xa_p), _ya_p, per_sample=1, poisoning_prob=poison_prob, rng=rng, aug_type="geom+opt",
    )
    
    X_poisoner_from_owner = _to_chw(X_poisoner_from_owner_hwc).astype(np.uint8)
    X_poisoner_from_anchor = _to_chw(X_poisoner_from_anchor_hwc).astype(np.uint8)
    
    parents_poisoner_from_owner = {"node": np.array(["owner"] * parent_idx_rel_owner.size, dtype=object), "index": (idx_owner_for_poisoner[parent_idx_rel_owner] if parent_idx_rel_owner.size > 0 else np.array([], dtype=np.int64)).astype(np.int64)} if idx_owner_for_poisoner.size > 0 else {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
    parents_poisoner_from_anchor = {"node": np.array(["anchor"] * parent_idx_rel_anchor.size, dtype=object), "index": (idx_anchor_for_poisoner[parent_idx_rel_anchor] if parent_idx_rel_anchor.size > 0 else np.array([], dtype=np.int64)).astype(np.int64)} if idx_anchor_for_poisoner.size > 0 else {"node": np.array([], dtype=object), "index": np.array([], dtype=np.int64)}
    X_poisoner = np.concatenate([X_poisoner_from_owner, X_poisoner_from_anchor], axis=0) if (X_poisoner_from_owner.shape[0] + X_poisoner_from_anchor.shape[0]) > 0 else X_owner[:0]
    y_poisoner = np.concatenate([y_poisoner_from_owner, y_poisoner_from_anchor], axis=0) if (y_poisoner_from_owner.shape[0] + y_poisoner_from_anchor.shape[0]) > 0 else y_owner[:0]

    nodes = {
        "owner": {"X": X_owner, "y": y_owner},
        "anchor": {"X": X_anchor, "y": y_anchor, "parents": parents_anchor},
        "booster_from_owner": {"X": X_boost_from_owner, "y": y_boost_from_owner, "parents": parents_boost_from_owner},
        "booster_from_anchor": {"X": X_boost_from_anchor, "y": y_boost_from_anchor, "parents": parents_boost_from_anchor},
        "copier_from_owner": {"X": X_copier_from_owner, "y": y_copier_from_owner, "parents": parents_copier_from_owner},
        "copier_from_anchor": {"X": X_copier_from_anchor, "y": y_copier_from_anchor, "parents": parents_copier_from_anchor},
        "poisoner_from_owner": {"X": X_poisoner_from_owner, "y": y_poisoner_from_owner, "parents": parents_poisoner_from_owner},
        "poisoner_from_anchor": {"X": X_poisoner_from_anchor, "y": y_poisoner_from_anchor, "parents": parents_poisoner_from_anchor},
    }
    
    providers = {
        "owner": {"X": X_owner, "y": y_owner, "role": "original"},
        "anchor": {"X": X_anchor, "y": y_anchor, "role": "reliable"},
        "booster": {"X": X_booster, "y": y_booster, "role": "reliable"},
        "copier": {"X": X_copier, "y": y_copier, "role": "harmful"},
        "poisoner": {"X": X_poisoner, "y": y_poisoner, "role": "harmful"},
    }
    return nodes, providers


