import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset

import torchvision
import torchvision.models as models
from torchvision import transforms

import numpy as np

import time
import random

from sklearn.decomposition import PCA
from sklearn.cluster import MiniBatchKMeans

from tqdm import tqdm # Recommended for progress bar

from datasets import load_dataset
from transformers import ViTModel, AutoModelForSequenceClassification, AutoTokenizer


# -----------------------------------------------------------------------------
# Device Setup
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()

# -----------------------------------------------------------------------------
# Define ResNet18 for CIFAR-100 (MyResNet18)
# -----------------------------------------------------------------------------
class MyResNet18(nn.Module):
    def __init__(self, num_classes=100):
        super(MyResNet18, self).__init__()
        self.base_model = torchvision.models.resnet18(weights=None)
        # For CIFAR (32x32 images), adjust first conv layer:
        self.base_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.base_model.maxpool = nn.Identity()
        self.features = nn.Sequential(*list(self.base_model.children())[:-1])
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512, num_classes)
    
    def forward(self, x):
        x = self.features(x)   # shape: (batch, 512, 1, 1)
        x = self.flatten(x)    # shape: (batch, 512)
        x = self.fc(x)
        return x


# -----------------------------------------------------------------------------
# Define a DistilBERT Model for Text Classification (MyTransformer)
# -----------------------------------------------------------------------------
class MyTransformer(nn.Module):
    def __init__(self, preload=True, num_classes=4):
        super(MyTransformer, self).__init__()
        # "distilbert-base-uncased" is a smaller, faster model than DeBERTa v3 large
        if preload:
            self.base_model = AutoModelForSequenceClassification.from_pretrained(
                "distilbert-base-uncased",
                num_labels=num_classes
            )
        else:
            self.base_model = AutoModelForSequenceClassification.from_pretrained(
                None,
                num_labels=num_classes
            )
    def forward(self, input_ids, attention_mask=None):
        """
        Hugging Face models typically accept:
          - input_ids
          - attention_mask
          - token_type_ids (optional, depending on model)
        We'll keep only input_ids/attention_mask for simplicity.
        """
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        # The logits are in outputs.logits
        return outputs.logits
    
# -----------------------------------------------------------------------------
# Define EfficientNetV2-L for ImageNet-100 (MyEfficientNetV2L)
# -----------------------------------------------------------------------------
class MyEfficientNetV2L(nn.Module):
    def __init__(self, preload=True, num_classes=100):
        super(MyEfficientNetV2L, self).__init__()
        # Load the best model for ImageNet in torchvision (one of the top is EfficientNet V2-L)
        # Pretrained weights for standard ImageNet-1k
        if preload:
            self.base_model = models.efficientnet_v2_l(weights=models.EfficientNet_V2_L_Weights.IMAGENET1K_V1)
        else:
            self.base_model = models.efficientnet_v2_l(weights=None)
        # Replace the final classifier layer to output 100 classes instead of 1000
        in_features = self.base_model.classifier[1].in_features
        self.base_model.classifier[1] = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.base_model(x)
    

# -----------------------------------------------------------------------------
# Define ViT for ImageNet-100 (vit)
# -----------------------------------------------------------------------------
class ViT(nn.Module):
    def __init__(self, preload=True, num_classes=100):
        super(ViT, self).__init__()
        if preload:
            self.base = ViTModel.from_pretrained('google/vit-base-patch16-224')
        else:
            self.base = ViTModel.from_pretrained(None)
        self.final = nn.Linear(self.base.config.hidden_size, num_classes)
        self.num_classes = num_classes
        self.relu = nn.ReLU()

    def forward(self, pixel_values):
        outputs = self.base(pixel_values=pixel_values)
        logits = self.final(outputs.last_hidden_state[:,0])

        return logits

# -----------------------------------------------------------------------------
# Define a function to create the model based on the specified architecture
# -----------------------------------------------------------------------------

def create_model(model_name, device, preload=True, num_classes=100):
    if model_name == "resnet18":
        model = MyResNet18(num_classes=num_classes).to(device)

    elif model_name == "resnet18_imagenet":
        model = models.resnet18(weights=None, num_classes=100)
        model.to(device)

    elif model_name == "berta_distill":
        model = MyTransformer(preload, num_classes=num_classes).to(device)

    elif model_name == "efficientnetv2l":
        model = MyEfficientNetV2L(preload, num_classes=num_classes).to(device)

    elif model_name == "vit":
        model = ViT(num_classes=num_classes).to(device)

    else:
        raise ValueError(f"Unsupported model: {model_name}")

    # wrap for multi‐GPU if needed (you had this already)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    return model

# -----------------------------------------------------------------------------
# DataLoaders
# -----------------------------------------------------------------------------
def collate_fn(batch):
    """
    batch is a list of (input_dict, label).
    We'll merge input_dict items into a single dict of tensors,
    and keep labels as a separate tensor.
    """
    input_ids = []
    attention_masks = []
    labels = []
    for (inp, lbl) in batch:
        input_ids.append(inp["input_ids"])
        attention_masks.append(inp["attention_mask"])
        labels.append(lbl)
    # Stack them into tensors
    input_ids = torch.stack(input_ids, dim=0)
    attention_masks = torch.stack(attention_masks, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    return {"input_ids": input_ids, "attention_mask": attention_masks}, labels

# -----------------------------------------------------------------------------
# Training Function (Initial Training on D_train)
# -----------------------------------------------------------------------------
def train(model, train_loader, val_loader, model_name, criterion, optimizer, num_epochs, scheduler=None, trans=False):
    import copy, math, time
    from torch.cuda.amp import autocast, GradScaler

    scaler = GradScaler(enabled=torch.cuda.is_available())
    patience, min_delta = 10, 0.0
    best_val, best_epoch = -math.inf, -1
    best_state = copy.deepcopy(model.state_dict())
    patience_ctr = 0

    total_start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for inputs, labels in train_loader:
            if trans:
                inputs = {k: v.to(device) for k, v in inputs.items()}
            else:
                inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad(set_to_none=True)

            with autocast(enabled=torch.cuda.is_available()):
                if trans:
                    out = model(**inputs)
                    outputs = out.logits if hasattr(out, "logits") else out
                else:
                    outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            # optional but safe:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct / total

        # Step schedulers (ReduceLROnPlateau needs val metric; others step now)
        val_accuracy = evaluate(model, val_loader, device, trans=trans)
        if scheduler is not None:
            if hasattr(scheduler, "step") and "ReduceLROnPlateau" in scheduler.__class__.__name__:
                scheduler.step(val_accuracy)
            else:
                scheduler.step()

        improved = True #(val_accuracy - best_val) > min_delta
        if improved:
            best_val = val_accuracy
            best_epoch = epoch
            best_state = copy.deepcopy(model.state_dict())
            patience_ctr = 0
            print(f"New best validation accuracy: {val_accuracy:.4f} (epoch {epoch+1})")
        else:
            patience_ctr += 1

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f} Accuracy: {epoch_acc:.4f} | Val: {val_accuracy:.4f}")

        if patience_ctr >= patience:
            print(f"Early stopping after {patience} epochs without improvement (best @ epoch {best_epoch+1}).")
            break

    total_training_time = time.time() - total_start_time

    # Save last and best
    # last_path = f"../models/{model_name}_last.pth"
    best_path = f"../models/{model_name}.pth"  # keep your original name for BEST
    # torch.save(model.state_dict(), last_path)
    model.load_state_dict(best_state)  # restore best
    torch.save(model.state_dict(), best_path)

    return total_training_time, model

class ConformitySeparationCrossEntropy(nn.Module):
    """
    Cross-entropy + non-negative separation hinge on the gap between
    the true class conformity score and the next strictly larger score.
    Conformity score = -log_softmax(logits) (lower is better).

    Total loss (per-sample):
        L = CE + alpha * relu(margin - (next_larger - true_score))  >= 0
    """
    def __init__(self, alpha: float = 2.0, margin: float = 10.0, reduction: str = "mean"):
        super().__init__()
        self.alpha = alpha
        self.margin = margin
        if reduction not in ("mean", "sum", "none"):
            raise ValueError("reduction must be 'mean', 'sum', or 'none'")
        self.reduction = reduction

    @staticmethod
    def _extract_logits(outputs) -> torch.Tensor:
        if isinstance(outputs, torch.Tensor):
            return outputs
        if hasattr(outputs, "logits"):
            return outputs.logits
        if isinstance(outputs, dict) and "logits" in outputs:
            return outputs["logits"]
        if isinstance(outputs, (tuple, list)) and outputs and isinstance(outputs[0], torch.Tensor):
            return outputs[0]
        raise TypeError("Unsupported model output type; expected Tensor or object with .logits")

    def forward(self, outputs, targets: torch.Tensor) -> torch.Tensor:
        logits = self._extract_logits(outputs)  # (B, C)

        # Term 1: cross-entropy (per-sample)
        ce = F.cross_entropy(logits, targets, reduction="none")

        # Conformity scores: lower is better
        nll = -F.log_softmax(logits, dim=1)  # (B, C)

        # True-label scores
        true_scores = nll.gather(1, targets.view(-1, 1)).squeeze(1)  # (B,)

        # Next strictly larger score than the true score (per sample)
        greater_mask = nll > true_scores.unsqueeze(1)                 # (B, C)
        next_larger = nll.masked_fill(~greater_mask, torch.inf).min(dim=1).values
        has_next = torch.isfinite(next_larger)
        next_larger = torch.where(has_next, next_larger, true_scores)  # if none, gap=0

        # Gap and non-negative separation hinge
        gap = next_larger - true_scores                                # >= 0
        sep_loss = F.relu(self.margin - gap)                           # >= 0

        loss = ce + self.alpha * sep_loss

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        return loss
    
# -----------------------------------------------------------------------------
# Evaluation Function
# -----------------------------------------------------------------------------
def evaluate(model, dataloader, device, trans=False):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            if trans:
                inputs = {k: v.to(device) for k, v in inputs.items()}
            else:
                inputs = inputs.to(device)

            labels = labels.to(device)

            if trans:
                outputs = model(**inputs)
            else:
                outputs = model(inputs)
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# -----------------------------------------------------------------------------
# Nonconformity Score & Conformal Prediction
# -----------------------------------------------------------------------------
def compute_nonconformity_score(probs, true_labels, score_func_name, eps=1e-8):
    """
    Computes nonconformity scores for the true labels based on the selected method.

    Options:
      - "one_minus":         1 - p(true)
      - "pmax_minus":        max_k p(k) - p(true)
      - "neg_log":           -log(p(true) + eps)
      - "margin":            1 - [ p(true) - max_{k != true} p(k) ]
      - "entropy_weighted":  (1 + H_norm) * [1 - p(true)]  where H_norm = H/ln(K)

    Parameters
    ----------
    probs : Tensor, shape (n, K)
        Softmax probabilities over K classes.
    true_labels : LongTensor, shape (n,)
        Ground-truth labels (0 <= y < K).
    score_func_name : str
        Which score to compute.
    eps : float
        Small constant to avoid log(0).

    Returns
    -------
    Tensor, shape (n,)
        Nonconformity scores.
    """
    n, K = probs.size()
    if score_func_name == "one_minus":
        # Least confidence
        return 1.0 - probs.gather(1, true_labels.view(-1, 1)).squeeze()

    elif score_func_name == "pmax_minus":
        # Difference between highest prob and true prob
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        p_max, _ = probs.max(dim=1)
        return p_max - p_true

    elif score_func_name == "neg_log":
        # Negative log-likelihood
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        return -torch.log(p_true + eps)

    elif score_func_name == "margin":
        # Margin-based: 1 - (p_true - runner_up)
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        # Mask out true class to find runner-up
        probs_minus_true = probs.clone()
        probs_minus_true[torch.arange(n), true_labels] = -float('inf')
        p_second, _ = probs_minus_true.max(dim=1)
        return 1.0 - (p_true - p_second)

    elif score_func_name == "entropy_weighted":
        # Entropy-reweighted least confidence
        # 1) Compute normalized Shannon entropy H_norm in [0,1]
        log_probs = torch.log(probs + eps)
        H = -torch.sum(probs * log_probs, dim=1)                # un-normalized entropy
        H_norm = H / torch.log(torch.tensor(K, device=probs.device).float())
        # 2) Base score: 1 - p_true
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        base_nc = 1.0 - p_true
        # 3) Weight and return
        return (1.0 + H_norm) * base_nc

    else:
        raise ValueError(f"Unknown nonconformity score function: {score_func_name}")

def conformal_prediction_quantile_and_returnall(
    model, calib_dataset, alpha, score_func_name="one_minus",
    device="cuda", batch_size=512, num_workers=4, trans=False
):
    """
    Returns:
      q_hat : float
        The conformal threshold (quantile of nonconformity scores).
      scores_arr : ndarray, shape (n_calib,)
        All calibration nonconformity scores in dataset order.
    """
    model.eval()

    loader = DataLoader(
        calib_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    all_scores = []

    with torch.no_grad():
        for inputs, labels in loader:
            if trans:
                inputs = {k: v.to(device) for k, v in inputs.items()}
            else:
                inputs = inputs.to(device)

            labels = labels.to(device)

            if trans:
                outputs = model(**inputs)
            else:
                outputs = model(inputs)

            probs   = F.softmax(outputs, dim=1)

            # compute_nonconformity_score should return a 1-D tensor of length=batch_size
            batch_scores = compute_nonconformity_score(probs, labels, score_func_name)
            all_scores.append(batch_scores.cpu())

    # concatenate into one big vector
    scores_tensor = torch.cat(all_scores, dim=0)    # shape: (n_calib,)
    
    # compute the (1-alpha)-quantile with “higher” interpolation
    q_hat = torch.quantile(scores_tensor, 1.0 - alpha, interpolation="higher").item()

    # move to numpy for return
    scores_arr = scores_tensor.numpy()
    return q_hat, scores_arr


# -----------------------------------------------------------------------------
# Helper: Get Conformal Prediction Set for a Batch
# -----------------------------------------------------------------------------
def get_cp_prediction_set(model, inputs, q_hat, score_func_name="one_minus", eps=1e-8, trans=False):
    """
    For each sample in inputs, returns the set of class indices c for which the candidate
    nonconformity score is <= q_hat.

    score_func_name options:
      - "one_minus"
      - "pmax_minus"
      - "neg_log"
      - "margin"
      - "entropy_weighted"
    """
    model.eval()
    prediction_sets = []
    with torch.no_grad():
        if trans:
            # inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs)                # (B, K)
        else:
            outputs = model(inputs)                   # (B, K)

        probs   = F.softmax(outputs, dim=1)       # (B, K)
        B, K    = probs.shape

        # Precompute per-sample entropy for entropy_weighted
        if score_func_name == "entropy_weighted":
            log_probs = torch.log(probs + eps)
            H = -torch.sum(probs * log_probs, dim=1)                   # (B,)
            H_norm = H / torch.log(torch.tensor(K, device=probs.device).float())

        for i in range(B):
            p = probs[i]                     # (K,)

            if score_func_name == "one_minus":
                candidate_scores = 1.0 - p

            elif score_func_name == "pmax_minus":
                p_max = p.max()
                candidate_scores = p_max - p

            elif score_func_name == "neg_log":
                candidate_scores = -torch.log(p + eps)

            elif score_func_name == "margin":
                # for each class c: 1 - (p[c] - max_{k!=c} p[k])
                candidate_scores = torch.empty_like(p)
                for c in range(K):
                    p_c = p[c]
                    tmp = p.clone()
                    tmp[c] = -float('inf')
                    p_second = tmp.max()
                    candidate_scores[c] = 1.0 - (p_c - p_second)

            elif score_func_name == "entropy_weighted":
                # (1 + H_norm[i]) * [1 - p[c]]
                candidate_scores = (1.0 + H_norm[i]) * (1.0 - p)

            else:
                raise ValueError("Unknown nonconformity score function: " + score_func_name)

            # select all classes whose score ≤ q_hat
            mask = candidate_scores <= q_hat
            selected = mask.nonzero(as_tuple=True)[0].cpu().tolist()
            prediction_sets.append(selected)

    return prediction_sets


# -----------------------------------------------------------------------------
# Metric: CCUCR
# -----------------------------------------------------------------------------
def compute_ccucr(model, forget_loader, retained_loader, q_hat, score_func_name, c=100, trans=False):
    """
    We first compute for each dataset D (D_f or D_r):
        CCUCR_D(c) = average over X in D of I{(Y not in CPset(X)) OR (Y in CPset(X) and |CPset(X)| >= C_total - c)}

    Then:
        CCUCR(c) = max{ (CCUCR_f - CCUCR_r) / (1 - CCUCR_r), 0 }.
    """
    def compute_ccucr_D(loader, method):
        total_samples = 0
        indicator_sum = 0

        # ---- added for average set size ----
        total_set_size = 0
        total_set_count = 0
        # ------------------------------------

        model.eval()
        with torch.no_grad():
            for inputs, labels in loader:
                if trans:
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                else:
                    inputs = inputs.to(device)

                labels = labels.to(device)

                pred_sets = get_cp_prediction_set(model, inputs, q_hat, score_func_name, trans=trans)
                for i, cp_set in enumerate(pred_sets):
                    # accumulate for average set size (over the dataset)
                    total_set_size += len(cp_set)
                    total_set_count += 1

                    true_label = labels[i].item()
                    # If the true label is excluded OR included but the set is near-trivial:
                    if method == 'efc':
                        if len(cp_set) <= (c):
                            total_samples += 1
                            if true_label not in cp_set:
                                indicator_sum += 1
                    elif method == 'coverage':
                        if len(cp_set) <= (c):
                            total_samples += 1
                            if true_label in cp_set:
                                indicator_sum += 1

        # ---- added print ----
        if total_set_count > 0:
            avg_set_size = total_set_size / total_set_count
            print(f"Average prediction set size {method}: {avg_set_size:.4f}")
        else:
            print(f"Average prediction set size {method}: n/a (no samples).")
        # ---------------------

        return (indicator_sum / total_samples) if total_samples > 0 else 0.0

    f_d = compute_ccucr_D(forget_loader, 'efc')
    r_d = compute_ccucr_D(retained_loader, 'coverage')
    return r_d, f_d

# -----------------------------------------------------------------------------
# Helper: CR metric
# -----------------------------------------------------------------------------
def compute_cr(model, forget_loader, retained_loader, q_hat, score_func_name, trans=False):
    """
    For each dataset D in {forget, retained}, compute:
        CR_D = (sum_{(x,y) in D} 1{ y in C(x) }) / (sum_{(x,y) in D} |C(x)|)
    Returns:
        r_cr, f_cr  (retained CR, forgotten CR)
    """
    def _as_set(cp_set):
        # Robustly convert a CP prediction "set" to a Python set of ints
        if isinstance(cp_set, torch.Tensor):
            return set(map(int, cp_set.detach().view(-1).tolist()))
        if isinstance(cp_set, (list, tuple, set)):
            return set(int(x) for x in cp_set)
        # Fallback: assume it's a single label
        return {int(cp_set)}

    def compute_cr_D(loader):
        num = 0  # sum of indicators 1{ y in C(x) }
        den = 0  # sum of |C(x)|
        model.eval()
        with torch.no_grad():
            for inputs, labels in loader:
                if trans:
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                else:
                    inputs = inputs.to(device)
                labels = labels.to(device)

                pred_sets = get_cp_prediction_set(model, inputs, q_hat, score_func_name, trans=trans)
                for i, cp_set in enumerate(pred_sets):
                    cp = _as_set(cp_set)
                    den += len(cp)
                    num += int(int(labels[i].item()) in cp)
        return (num / den) if den > 0 else 0.0

    f_cr = compute_cr_D(forget_loader)
    r_cr = compute_cr_D(retained_loader)
    return r_cr, f_cr

# -----------------------------------------------------------------------------
# Helper: Harmonic Mean
# -----------------------------------------------------------------------------

def harmonic_mean(numbers):
    """
    Compute the harmonic mean of a list of positive numbers.
    
    Parameters:
        numbers (list of float): The input values (must be non-zero).
        
    Returns:
        float: The harmonic mean.
        
    Raises:
        ValueError: If the list is empty or contains zero.
    """
    n = len(numbers)
    if n == 0:
        raise ValueError("At least one number is required to compute the harmonic mean.")
    # Sum the reciprocals
    reciprocal_sum = sum(1.0 / x for x in numbers if x != 0)
    if any(x == 0 for x in numbers):
        print("Warning: One or more numbers are zero, which will lead to division by zero in harmonic mean calculation.")
        return 0
    result = n / reciprocal_sum
    return result

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
import torch.nn.functional as F
import numpy as np
import torch

def compute_membership_inference_attack(
    model: torch.nn.Module,
    member_loader: torch.utils.data.DataLoader,
    nonmember_loader: torch.utils.data.DataLoader,
    device: torch.device,
    n_splits: int = 10,
    random_state: int = 0,
    trans: bool = False,
) -> (np.ndarray, float):
    """
    Runs a strengthened membership-inference attack using a RandomForest on 12 features:
      1) per-sample loss
      2) entropy of the softmax
      3) margin (top1–top2)
      4–11) top-8 softmax probabilities
      12) l2-norm of the logits
    """
    model.eval()

    def gather_features(loader):
        feats = []
        with torch.no_grad():
            for x, y in loader:
                if trans:
                    x = {k: v.to(device) for k, v in x.items()}
                else:
                    x = x.to(device)
                
                y = y.to(device)

                if trans:
                    logits = model(**x)                            # (batch, num_classes)
                else:
                    logits = model(x)   
                                                  # (batch, num_classes)
                probs  = F.softmax(logits, dim=1)                 # (batch, num_classes)
                
                # --- Start of Changes ---
                
                # 1. Get the number of classes dynamically
                num_classes = probs.shape[1]

                # 2. Calculate k for margin (top-2 style feature)
                # Ensure k is at least 2 for the margin calculation to be possible
                k_for_margin = max(2, int(0.02 * num_classes))

                # 3. Calculate k for top features (top-8 style feature)
                k_for_top_features = int(0.08 * num_classes) + 1

                # --- End of Changes ---

                # 1) loss
                loss     = F.cross_entropy(logits, y, reduction='none')  # (batch,)

                # 2) entropy
                entropy  = -(probs * torch.log(probs + 1e-12)).sum(dim=1)  # (batch,)

                # 3) margin = top1 - top2 (using the new dynamic k)
                top_margin_probs, _ = probs.topk(k_for_margin, dim=1)      # (batch, k_for_margin)
                margin = top_margin_probs[:, 0] - top_margin_probs[:, 1]  # (batch,)

                # 4–X) top-k probabilities (using the new dynamic k)
                top_k_probs, _  = probs.topk(k_for_top_features, dim=1)   # (batch, k_for_top_features)

                # Y) l2-norm of logits
                logit_norm = logits.norm(dim=1)                    # (batch,)

                # stack into (batch, 4 + k_for_top_features)
                batch_feats = torch.cat([
                    loss.unsqueeze(1),
                    entropy.unsqueeze(1),
                    margin.unsqueeze(1),
                    top_k_probs,  # Use the new dynamic tensor here
                    logit_norm.unsqueeze(1)
                ], dim=1)

                feats.append(batch_feats.cpu().numpy())
        
        # The number of columns in the output is now dynamic
        return np.vstack(feats)

    # build feature matrix X and labels y
    X_mem    = gather_features(member_loader)
    X_non    = gather_features(nonmember_loader)
    X        = np.vstack([X_mem, X_non])
    y        = np.concatenate([
                  np.ones(len(X_mem),  dtype=int),
                  np.zeros(len(X_non), dtype=int)
               ])

    # strong ensemble attacker
    attack = RandomForestClassifier(
        n_estimators=200,
        max_depth=None,
        random_state=random_state,
        n_jobs=-1
    )

    cv     = StratifiedShuffleSplit(
                 n_splits=n_splits,
                 random_state=random_state
             )
    scores = cross_val_score(
                 attack, X, y,
                 cv=cv,
                 scoring="accuracy",
                 n_jobs=-1
             )

    # baseline for adversarial advantage
    total = X_mem.shape[0] + X_non.shape[0]
    baseline = max(len(X_mem)/total, len(X_non)/total)
    # return per-split accuracies and mean advantage over baseline
    return scores, float(np.mean(scores) - baseline)



def set_seed(seed):
    """
    Set the random seed for reproducibility.
    """
    RNG = torch.Generator().manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return RNG 

# -----------------------------------------------------------------------------
# Helper: Build Base models
# -----------------------------------------------------------------------------
def build_base_model(model_type, model_filename, train_loader, val_loader, device, num_classes, trans=False):
    """
    Build the base model and train it on the training dataset.
    Returns the trained model and elapsed time.
    """
    if model_type == "resnet18":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.1, 80
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                               end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()

        # # Better schedule: short warmup + cosine
        # from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
        # warmup_epochs = max(1, epochs // 10)
        # cosine_epochs = epochs - warmup_epochs
        # scheduler = SequentialLR(
        #     optimizer,
        #     schedulers=[
        #         LinearLR(optimizer, start_factor=0.1, total_iters=warmup_epochs),
        #         CosineAnnealingLR(optimizer, T_max=cosine_epochs, eta_min=lr * 1e-2),
        #     ],
        #     milestones=[warmup_epochs],
        # )

        # criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # small, safe gain; no data change
        # criterion = ConformitySeparationCrossEntropy()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f}s")

    elif model_type == "resnet18_imagenet":
        model = create_model(model_type, device, num_classes=num_classes)
        lr, epochs = 0.1, 80
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, 
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename, 
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f} seconds")

    elif model_type == "efficientnetv2l":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.01, 80
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f}s")
    
    elif model_type == "vit":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.01, 15
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f} seconds")

    elif model_type == "berta_distill":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.01, 15
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f} seconds")

    else:
        raise ValueError(f"Unsupported model_type: {model_type}")

    return model, training_time

# -----------------------------------------------------------------------------
# Helper: Choose Dataset and Completely Load It
# -----------------------------------------------------------------------------
# This class and helper functions are specific to the 'news' dataset.
# They are correctly defined here to be available in the main function's scope.
class HuggingFaceSubset(Dataset):
    """
    Simple wrapper so we can use the same .[i] approach that returns (input, label).
    But in this code, "input" is a dict {input_ids, attention_mask}, "label" is an integer.
    """
    def __init__(self, hf_dataset, indices):
        self.hf_dataset = hf_dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        sample = self.hf_dataset[actual_idx]
        # 'sample' is a dict, e.g. { 'input_ids':..., 'attention_mask':..., 'labels':... }
        inputs = {
            "input_ids": sample["input_ids"],
            "attention_mask": sample["attention_mask"]
        }
        label = sample["labels"].item()  # a single integer
        return inputs, label

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=False)

def tokenize_fn(examples):
    # We'll tokenize the combined text. This returns a dict: { 'input_ids': [...], 'attention_mask': [...], ... }
    return tokenizer(examples["combined_text"], truncation=True, padding="max_length", max_length=128)

def combine_text(ex):
    # For AG News, the relevant column is 'text'. 'content' is empty.
    title = ex["text"] if ex["text"] else ""
    content = ""
    combined = title + " " + content
    return {"combined_text": combined.strip()}



def load_dataset_and_transform(dataset_name, forgot_set, rng, mode, batch_size=256, n_work=2):
    num_classes = None  # Will be set by each dataset's block

    if dataset_name == "cifar100":
        # CIFAR-100 dataset
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761))
        ])
        D_train = torchvision.datasets.CIFAR100(
            root='../../../data/cifar100',
            train=True, download=True,
            transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root='../../../data/cifar100',
            train=False, download=True,
            transform=transform
        )
        num_classes = 100

    # -----------------------------------------------------------------------------
    elif dataset_name == "imagenet":
        # ImageNet dataset
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ])
        train_dir = "../../../data/imagenet-100/train.X"
        val_dir   = "../../../data/imagenet-100/val.X"

        D_train = torchvision.datasets.ImageFolder(
            root=train_dir, transform=transform
        )
        test_dataset = torchvision.datasets.ImageFolder(
            root=val_dir, transform=transform
        )
        num_classes = 100

    # -----------------------------------------------------------------------------
    elif dataset_name == "news":
        # Load and preprocess the AG News text dataset
        raw_dataset = load_dataset("ag_news")
        processed_dataset = raw_dataset.map(combine_text, num_proc=n_work)

        train_data = processed_dataset["train"]
        test_data = processed_dataset["test"]
        num_classes = 4  # AG News has 4 classes

        # Apply tokenizer and format for PyTorch
        print("Tokenizing dataset...")
        train_data = train_data.map(tokenize_fn, batched=True, num_proc=n_work)
        test_data = test_data.map(tokenize_fn, batched=True, num_proc=n_work)

        train_data = train_data.rename_column("label", "labels")
        test_data = test_data.rename_column("label", "labels")

        train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        print("Tokenization complete.")

        # Wrap the processed data into a standard PyTorch Dataset format
        # This makes it compatible with the generic splitting logic below.
        all_train_indices = list(range(len(train_data)))
        all_test_indices = list(range(len(test_data)))

        D_train = HuggingFaceSubset(train_data, all_train_indices)
        test_dataset = HuggingFaceSubset(test_data, all_test_indices)
        # All dataset-specific splitting logic has been removed from this block.
    
    elif dataset_name == "20_newsgroups":
        # Load and preprocess the 20 Newsgroups text dataset
        # This dataset is well-suited for the berta_distill model and has 20 classes.
        raw_dataset = load_dataset("SetFit/20_newsgroups")
        processed_dataset = raw_dataset.map(combine_text, num_proc=n_work)

        train_data = processed_dataset["train"]
        test_data = processed_dataset["test"]
        num_classes = 20  # 20 Newsgroups has 20 classes

        # Apply tokenizer and format for PyTorch
        print("Tokenizing dataset...")
        train_data = train_data.map(tokenize_fn, batched=True, num_proc=n_work)
        test_data = test_data.map(tokenize_fn, batched=True, num_proc=n_work)

        train_data = train_data.rename_column("label", "labels")
        test_data = test_data.rename_column("label", "labels")

        train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        print("Tokenization complete.")

        # Wrap the processed data into a standard PyTorch Dataset format
        all_train_indices = list(range(len(train_data)))
        all_test_indices = list(range(len(test_data)))

        D_train = HuggingFaceSubset(train_data, all_train_indices)
        test_dataset = HuggingFaceSubset(test_data, all_test_indices)

    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    # -----------------------------------------------------------------------------
    # Generic Data Splitting Logic (applies to all datasets)
    # -----------------------------------------------------------------------------
    if mode == "embedding_cluster": # Changed mode name for clarity        
        # This single call replaces the entire pca_projector function
        train_groups, test_groups = create_feature_clusters(
            D_train, test_dataset, n_clusters=num_classes, batch_size=batch_size
        )

        perm_clusters = torch.randperm(num_classes, generator=rng).tolist()
        # forgot_set is assumed to be defined elsewhere (e.g., 10)
        forgotten_clusters = perm_clusters[:forgot_set]
        forgotten_clusters = set(forgotten_clusters)
        print(f"Forgotten feature clusters selected: {forgotten_clusters}")
        
        # The rest of your logic remains the same, just with the new group variables.
        # I've used 'train_groups' and 'test_groups' for clarity.

        # --- Test split: 80% calib, 20% forget/retain by group ---
        perm_test = torch.randperm(len(test_dataset), generator=rng).tolist()
        n_test = len(perm_test)
        num_calib_test = int(0.8 * n_test)
        indices_calib_test = perm_test[:num_calib_test]
        remaining_test = perm_test[num_calib_test:]

        # Use the pre-computed group assignments for fast splitting
        indices_f_test = [i for i in remaining_test if test_groups[i] in forgotten_clusters]
        indices_r_test = [i for i in remaining_test if test_groups[i] not in forgotten_clusters]

        D_calib = Subset(test_dataset, indices_calib_test)
        D_f     = Subset(test_dataset, indices_f_test)
        D_r     = Subset(test_dataset, indices_r_test)

        # --- Calibration subsets by group from D_calib ---
        calibration_forget_indices = [i for i in indices_calib_test if test_groups[i] in forgotten_clusters]
        calibration_retain_indices = [i for i in indices_calib_test if test_groups[i] not in forgotten_clusters]
        D_calib_forget = Subset(test_dataset, calibration_forget_indices)
        D_calib_retain = Subset(test_dataset, calibration_retain_indices)

        # --- Train split: 90% train_main, 10% val ---
        perm_train = torch.randperm(len(D_train), generator=rng).tolist()
        n_train = len(perm_train)
        num_train_main = int(0.9 * n_train)
        train_main_idx = perm_train[:num_train_main]
        train_val_idx  = perm_train[num_train_main:]

        # Within the 10% val: 80% calib_val, rest forget_val/retain_val by group
        num_calib_val = int(0.5 * len(train_val_idx))
        calib_val_idx = train_val_idx[:num_calib_val]
        rem_val_idx   = train_val_idx[num_calib_val:]

        forget_val_idx   = [i for i in rem_val_idx if train_groups[i] in forgotten_clusters]
        retained_val_idx = [i for i in rem_val_idx if train_groups[i] not in forgotten_clusters]

        D_train_main   = Subset(D_train, train_main_idx)
        D_calib_val    = Subset(D_train, calib_val_idx)
        D_f_val        = Subset(D_train, forget_val_idx)
        D_r_val        = Subset(D_train, retained_val_idx)

        # --- Training subsets S vs R ---
        # This logic correctly uses the original index ('orig_idx') to look up the
        # group assignment from the full 'train_groups' array.
        S_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_groups[orig_idx] in forgotten_clusters]
        R_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_groups[orig_idx] not in forgotten_clusters]
        
        S_subset = Subset(D_train_main, S_indices)
        R_subset = Subset(D_train_main, R_indices)
        copied_list = R_indices[:]
        random.shuffle(copied_list)
        R_small = Subset(D_train_main, copied_list[:len(S_indices)])
        
    elif mode == "pca":
        # if dataset_name == "news":
        #     raise ValueError("PCA mode is not compatible with the 'news' text dataset.")

        # PCA-based Grouping
        train_pc_cluster, test_pc_cluster = pca_projector(D_train, test_dataset, batch_size)
        print("PCA projection applied.")
        perm_clusters = torch.randperm(num_classes, generator=rng).tolist()
        forgotten_clusters = perm_clusters[:forgot_set]
        forgotten_clusters = set(forgotten_clusters)
        print(f"Forgotten PC clusters selected: {forgotten_clusters}")
        # --- Test split: 80% calib, 20% forget/retain by label ---
        perm_test = torch.randperm(len(test_dataset), generator=rng).tolist()
        n_test = len(perm_test)
        num_calib_test = int(0.8 * n_test)
        indices_calib_test = perm_test[:num_calib_test]
        remaining_test = perm_test[num_calib_test:]

        indices_f_test = [i for i in remaining_test if test_pc_cluster[i] in forgotten_clusters]
        indices_r_test = [i for i in remaining_test if test_pc_cluster[i] not in forgotten_clusters]

        D_calib = Subset(test_dataset, indices_calib_test)
        D_f     = Subset(test_dataset, indices_f_test)
        D_r     = Subset(test_dataset, indices_r_test)

        # --- Calibration subsets by label from D_calib ---
        calibration_forget_indices = [i for i in indices_calib_test if test_pc_cluster[i] in forgotten_clusters]
        calibration_retain_indices = [i for i in indices_calib_test if test_pc_cluster[i] not in forgotten_clusters]
        D_calib_forget = Subset(test_dataset, calibration_forget_indices)
        D_calib_retain = Subset(test_dataset, calibration_retain_indices)

        # --- Train split: 90% train_main, 10% val ---
        perm_train = torch.randperm(len(D_train), generator=rng).tolist()
        n_train = len(perm_train)
        num_train_main = int(0.9 * n_train)
        train_main_idx = perm_train[:num_train_main]
        train_val_idx  = perm_train[num_train_main:]

        # Within the 10% val: 80% calib_val, rest forget_val/retain_val by label
        num_calib_val = int(0.5 * len(train_val_idx))
        calib_val_idx = train_val_idx[:num_calib_val]
        rem_val_idx   = train_val_idx[num_calib_val:]

        forget_val_idx   = [i for i in rem_val_idx if train_pc_cluster[i] in forgotten_clusters]
        retained_val_idx = [i for i in rem_val_idx if train_pc_cluster[i] not in forgotten_clusters]

        D_train_main   = Subset(D_train, train_main_idx)
        D_calib_val    = Subset(D_train, calib_val_idx)
        D_f_val        = Subset(D_train, forget_val_idx)
        D_r_val        = Subset(D_train, retained_val_idx)

        # --- Training subsets S vs R (unchanged) ---
        S_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_pc_cluster[orig_idx] in forgotten_clusters]
        R_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_pc_cluster[orig_idx] not in forgotten_clusters]
        S_subset = Subset(D_train_main, S_indices)
        R_subset = Subset(D_train_main, R_indices)
        copied_list = R_indices[:]
        random.shuffle(copied_list)
        R_small = Subset(D_train_main, copied_list[:int(len(S_indices))])

    elif mode == "label":
        # --- Label-based partitioning ---
        if num_classes is None:
            raise ValueError("num_classes was not set for the chosen dataset.")
        
        if dataset_name != "20_newsgroups" and dataset_name != "news":
            all_test_labels = test_dataset.targets
            print(len(all_test_labels))
            all_train_labels = D_train.targets 
            print(len(all_train_labels))
        # Select forgotten classes via torch.randperm
        perm_classes = torch.randperm(num_classes, generator=rng).tolist()
        forgotten_classes = perm_classes[:forgot_set]
        forgotten_classes = set(forgotten_classes)  # Convert to set for faster lookups
        print(f"Forgotten classes selected: {forgotten_classes}")
        
        # --- Test split: 80% calib, 20% forget/retain by label ---
        perm_test = torch.randperm(len(test_dataset), generator=rng).tolist()
        n_test = len(perm_test)
        num_calib_test = int(0.8 * n_test)
        indices_calib_test = perm_test[:num_calib_test]
        remaining_test = perm_test[num_calib_test:]

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            indices_f_test = [i for i in remaining_test if all_test_labels[i] in forgotten_classes]
            indices_r_test = [i for i in remaining_test if all_test_labels[i] not in forgotten_classes]
        else:
            indices_f_test = [i for i in remaining_test if test_dataset[i][1] in forgotten_classes]
            indices_r_test = [i for i in remaining_test if test_dataset[i][1] not in forgotten_classes]

        D_calib = Subset(test_dataset, indices_calib_test)
        D_f     = Subset(test_dataset, indices_f_test)
        D_r     = Subset(test_dataset, indices_r_test)

        # --- Calibration subsets by label from D_calib ---
        if dataset_name != "20_newsgroups" and dataset_name != "news":
            calibration_forget_indices = [i for i in indices_calib_test if all_test_labels[i] in forgotten_classes]
            calibration_retain_indices = [i for i in indices_calib_test if all_test_labels[i] not in forgotten_classes]
        else:
            calibration_forget_indices = [idx for idx in indices_calib_test if test_dataset[idx][1] in forgotten_classes]
            calibration_retain_indices = [idx for idx in indices_calib_test if test_dataset[idx][1] not in forgotten_classes]
        D_calib_forget = Subset(test_dataset, calibration_forget_indices)
        D_calib_retain = Subset(test_dataset, calibration_retain_indices)


        # --- Train split: 90% train_main, 10% val ---
        perm_train = torch.randperm(len(D_train), generator=rng).tolist()
        n_train = len(perm_train)
        num_train_main = int(0.9 * n_train)
        train_main_idx = perm_train[:num_train_main]
        train_val_idx  = perm_train[num_train_main:]

        # Within the 10% val: 80% calib_val, rest forget_val/retain_val by label
        num_calib_val = int(0.5 * len(train_val_idx))
        calib_val_idx = train_val_idx[:num_calib_val]
        rem_val_idx   = train_val_idx[num_calib_val:]

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            forget_val_idx   = [i for i in rem_val_idx if all_train_labels[i] in forgotten_classes]
            retained_val_idx = [i for i in rem_val_idx if all_train_labels[i] not in forgotten_classes]
        else:
            forget_val_idx   = [i for i in rem_val_idx if D_train[i][1] in forgotten_classes]
            retained_val_idx = [i for i in rem_val_idx if D_train[i][1] not in forgotten_classes]

        D_train_main   = Subset(D_train, train_main_idx)
        D_calib_val    = Subset(D_train, calib_val_idx)
        D_f_val        = Subset(D_train, forget_val_idx)
        D_r_val        = Subset(D_train, retained_val_idx)

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            subset_labels = [all_train_labels[i] for i in train_main_idx]
            S_indices = [i for i, label in enumerate(subset_labels) if label in forgotten_classes]
            R_indices = [i for i, label in enumerate(subset_labels) if label not in forgotten_classes]
        # --- Training subsets S vs R (unchanged) ---
        else:
            S_indices = [i for i, (_, label) in enumerate(D_train_main) if label in forgotten_classes]
            R_indices = [i for i, (_, label) in enumerate(D_train_main) if label not in forgotten_classes]
        S_subset = Subset(D_train_main, S_indices)
        R_subset = Subset(D_train_main, R_indices)
        copied_list = R_indices[:]
        random.shuffle(copied_list)
        R_small = Subset(D_train_main, copied_list[:int(len(S_indices))])
    
    # Get the indices from each subset
    s_indices = S_subset.indices
    r_indices = R_subset.indices

    # Concatenate the indices
    all_indices = s_indices + r_indices

    # Create a new subset with the combined indices
    # This will be equivalent to D_train_main if S and R indices cover the whole dataset
    D_train_main = Subset(D_train_main, all_indices)

    # --- DataLoaders ---
    train_loader              = DataLoader(D_train_main,    batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_loader        = DataLoader(D_calib,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    forget_loader             = DataLoader(D_f,             batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    retained_loader           = DataLoader(D_r,             batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_val_loader    = DataLoader(D_calib_val,     batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    forget_val_loader         = DataLoader(D_f_val,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    retained_val_loader       = DataLoader(D_r_val,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    S_loader                  = DataLoader(S_subset,        batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    R_loader                  = DataLoader(R_subset,        batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    R_small_loader            = DataLoader(R_small,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_forget_loader = DataLoader(D_calib_forget,  batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_retain_loader = DataLoader(D_calib_retain,  batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    print("Data loaders created successfully.")

    return {
        "train_loader": train_loader,
        "calibration_loader": calibration_loader,
        "forget_loader": forget_loader,
        "retained_loader": retained_loader,
        "calibration_val_loader": calibration_val_loader,
        "forget_val_loader": forget_val_loader,
        "retained_val_loader": retained_val_loader,
        "S_loader": S_loader,
        "R_loader": R_loader,
        "calibration_forget_loader": calibration_forget_loader,
        "calibration_retain_loader": calibration_retain_loader,
        "D_calib_val": D_calib_val,
        "D_f_val": D_f_val,
        "D_r_val": D_r_val,
        "D_calib": D_calib,
        "D_f": D_f,
        "D_r": D_r,
        "test_dataset": test_dataset,
        "R_small_loader": R_small_loader,
    }, forgotten_classes


def create_feature_clusters(D_train, D_test, n_clusters=100, batch_size=64, device="cuda", trans=False):
    """
    Generates feature-based cluster assignments for a dataset using embeddings from a
    pre-trained ResNet18 model and K-Means clustering.

    Args:
        D_train (Dataset): The training dataset.
        D_test (Dataset): The test dataset.
        n_clusters (int): The number of clusters to form (e.g., 100 to mimic classes).
        batch_size (int): Batch size for processing.
        device (str): 'cuda' or 'cpu'.

    Returns:
        tuple[np.ndarray, np.ndarray]: A tuple containing:
            - train_cluster_ids (np.ndarray): Cluster assignments for the training set.
            - test_cluster_ids (np.ndarray): Cluster assignments for the test set.
    """
    # 1. Load a pre-trained model to use as a feature extractor
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    # Remove the final classification layer
    model.fc = torch.nn.Identity()
    model.to(device)
    model.eval()

    def get_embeddings(dataset):
        """Helper function to extract feature embeddings from a dataset."""
        embeddings = []
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        with torch.no_grad():
            for images, _ in tqdm(data_loader, desc="Extracting embeddings"):
                if trans:
                    images = {k: v.to(device) for k, v in images.items()}
                    emb = model(**images)
                else:
                    images = images.to(device)
                    emb = model(images)
                embeddings.append(emb.cpu().numpy())
        return np.concatenate(embeddings, axis=0)

    # 2. Extract embeddings for both training and test sets
    train_embeddings = get_embeddings(D_train)
    test_embeddings = get_embeddings(D_test)

    # 3. Use K-Means to cluster the training embeddings
    # MiniBatchKMeans is faster and more memory-efficient for large datasets
    print(f"Clustering {len(train_embeddings)} training embeddings into {n_clusters} clusters...")
    kmeans = MiniBatchKMeans(n_clusters=n_clusters, n_init='auto', random_state=42, batch_size=256)
    kmeans.fit(train_embeddings)

    # 4. Assign cluster IDs to train and test sets
    train_cluster_ids = kmeans.predict(train_embeddings)
    test_cluster_ids = kmeans.predict(test_embeddings)

    print("Clustering complete.")
    return train_cluster_ids, test_cluster_ids

def pca_projector(D_train, test_dataset, batch_size, n_components=100):
    # 2) Compute 100‐dim PCA on the TRAIN set’s raw pixels
    feat_loader_train = DataLoader(D_train, batch_size=batch_size, shuffle=False, num_workers=2)
    feats_train = []
    for imgs, _ in feat_loader_train:
        # flatten to (B, 3072)
        feats_train.append(imgs.view(imgs.size(0), -1).numpy())
    feats_train = np.concatenate(feats_train, axis=0)  # shape = (50000, 3072)

    pca = PCA(n_components=n_components)
    pca_feats_train = pca.fit_transform(feats_train)   # (50000, 100)

    # 3) Project the TEST set into the same PCA space
    feat_loader_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    feats_test = []
    for imgs, _ in feat_loader_test:
        feats_test.append(imgs.view(imgs.size(0), -1).numpy())
    feats_test = np.concatenate(feats_test, axis=0)   # (10000, 3072)

    pca_feats_test = pca.transform(feats_test)         # (10000, 100)

    # 4) Assign each image to its “dominant” PC axis
    train_pc_cluster = np.argmax(np.abs(pca_feats_train), axis=1)  # values in [0..99]
    test_pc_cluster  = np.argmax(np.abs(pca_feats_test),  axis=1)
    return train_pc_cluster, test_pc_cluster




def get_input_dims(dataset: Dataset):
    """
    Inspects a dataset to dynamically determine the dimensions of the input data.
    This version correctly handles nested data structures, such as a tuple
    containing a dictionary of tensors.

    Args:
        dataset (torch.utils.data.Dataset): The dataset (or subset) to inspect.

    Returns:
        torch.Size: The dimensions of a single input data sample.
    """
    if len(dataset) == 0:
        raise ValueError("Cannot determine dimensions from an empty dataset.")

    sample = dataset[0]

    print(f"Sample shape: {sample}")

    # --- NEW LOGIC: Step 1 ---
    # First, isolate the primary data payload. It might be the sample itself
    # or the first element of a tuple (sample, label).
    if isinstance(sample, (tuple, list)):
        data_payload = sample[0]
    else:
        data_payload = sample

    # --- NEW LOGIC: Step 2 ---
    # Now, analyze the payload to find the tensor.
    data_tensor = None
    if isinstance(data_payload, torch.Tensor):
        # The payload itself is the tensor.
        data_tensor = data_payload
    elif isinstance(data_payload, dict):
        # The payload is a dictionary, find the tensor within it.
        priority_keys = ['pixel_values', 'input_ids', 'image', 'features']
        
        for key in priority_keys:
            if key in data_payload and isinstance(data_payload[key], torch.Tensor):
                data_tensor = data_payload[key]
                break
        
        # If no priority key is found, fall back to the first tensor value.
        if data_tensor is None:
            for value in data_payload.values():
                if isinstance(value, torch.Tensor):
                    data_tensor = value
                    break
    
    # After analyzing the payload, if we still haven't found a tensor, raise an error.
    if data_tensor is None:
        raise TypeError(f"Could not find a torch.Tensor in the resolved data payload. Payload type: {type(data_payload)}")

    # 3. Return the shape of the found data tensor
    print(f"Data tensor shape: {data_tensor.shape}, dtype: {data_tensor.dtype}")
    return data_tensor.shape