# --- XDER helpers ---
def _xder_norm_stats(ds_name):
    # Per-dataset (mean, std) used to de/normalize buffer items.
    # Tweak if your own codebase uses different stats.
    if ds_name.lower() in ('cifar10', 'cifar-10', 'cifar100', 'cifar-100'):
        mean = (0.4914, 0.4822, 0.4465)
        std  = (0.2023, 0.1994, 0.2010)
    elif ds_name.lower() in ('tinyimagenet', 'tiny-imagenet', 'tiny_imagenet'):
        mean = (0.480, 0.448, 0.398)
        std  = (0.277, 0.269, 0.282)
    else:  # CUB / ImageNet-like
        mean = (0.485, 0.456, 0.406)
        std  = (0.229, 0.224, 0.225)
    return mean, std

def _make_xder_collate(mean, std):
    mean_t = torch.tensor(mean).view(1, -1, 1, 1)
    std_t  = torch.tensor(std).view(1, -1, 1, 1)
    def collate_fn(batch):
        # Works for items shaped like (x, y) or (x, y, *rest); ignores extra fields.
        xs, ys = [], []
        for item in batch:
            xs.append(item[0])
            ys.append(item[1])
        inputs  = torch.stack(xs, dim=0)                # already augmented+normalized by the dataset transform
        labels  = torch.as_tensor(ys)
        # "not_aug_inputs" ≈ de-normalized tensor; this is what XDer stores in the buffer.
        # (We don't undo random crops/augs — not needed. Buffer will re-normalize via transform.)
        not_aug = inputs * std_t + mean_t
        not_aug = torch.clamp(not_aug, 0.0, 1.0)
        return inputs, labels, not_aug
    return collate_fn

class _XderDatasetAdapter:
    """
    Minimal adapter providing the bits your XDer expects from its 'dataset' arg:
      - SIZE
      - get_denormalization_transform() -> object with .mean and .std
      - train_loader (used only by XDer.end_task)
    """
    def __init__(self, exp_dataset, batch_size, mean, std, size_hw, collate_fn):
        self.SIZE = (3, size_hw, size_hw) if isinstance(size_hw, int) else size_hw
        self._denorm = type("Denorm", (), {"mean": mean, "std": std})()
        self.train_loader = DataLoader(
            exp_dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False, collate_fn=collate_fn
        )
    def get_denormalization_transform(self):
        return self._denorm
