import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset

import torchvision
import torchvision.models as models
from torchvision import transforms

import contextlib

import numpy as np

import time
import random

from sklearn.decomposition import PCA
from sklearn.cluster import MiniBatchKMeans

from tqdm import tqdm # Recommended for progress bar

from datasets import load_dataset
from transformers import ViTModel, AutoModelForSequenceClassification, AutoTokenizer

# -----------------------------------------------------------------------------
# Device Setup
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()

# -----------------------------------------------------------------------------
# Define ResNet18 for CIFAR-100 (MyResNet18)
# -----------------------------------------------------------------------------
class MyResNet18(nn.Module):
    def __init__(self, num_classes=100):
        super(MyResNet18, self).__init__()
        self.base_model = torchvision.models.resnet18(weights=None)
        # For CIFAR (32x32 images), adjust first conv layer:
        self.base_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.base_model.maxpool = nn.Identity()
        self.features = nn.Sequential(*list(self.base_model.children())[:-1])
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512, num_classes)
    
    def forward(self, x):
        x = self.features(x)   # shape: (batch, 512, 1, 1)
        x = self.flatten(x)    # shape: (batch, 512)
        x = self.fc(x)
        return x


# -----------------------------------------------------------------------------
# Define a DistilBERT Model for Text Classification (MyTransformer)
# -----------------------------------------------------------------------------
class MyTransformer(nn.Module):
    def __init__(self, preload=True, num_classes=4):
        super(MyTransformer, self).__init__()
        # "distilbert-base-uncased" is a smaller, faster model than DeBERTa v3 large
        if preload:
            self.base_model = AutoModelForSequenceClassification.from_pretrained(
                "distilbert-base-uncased",
                num_labels=num_classes
            )
        else:
            self.base_model = AutoModelForSequenceClassification.from_pretrained(
                None,
                num_labels=num_classes
            )
    def forward(self, input_ids, attention_mask=None):
        """
        Hugging Face models typically accept:
          - input_ids
          - attention_mask
          - token_type_ids (optional, depending on model)
        We'll keep only input_ids/attention_mask for simplicity.
        """
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        # The logits are in outputs.logits
        return outputs.logits
    
# -----------------------------------------------------------------------------
# Define EfficientNetV2-L for ImageNet-100 (MyEfficientNetV2L)
# -----------------------------------------------------------------------------
class MyEfficientNetV2L(nn.Module):
    def __init__(self, preload=True, num_classes=100):
        super(MyEfficientNetV2L, self).__init__()
        # Load the best model for ImageNet in torchvision (one of the top is EfficientNet V2-L)
        # Pretrained weights for standard ImageNet-1k
        if preload:
            self.base_model = models.efficientnet_v2_l(weights=models.EfficientNet_V2_L_Weights.IMAGENET1K_V1)
        else:
            self.base_model = models.efficientnet_v2_l(weights=None)
        # Replace the final classifier layer to output 100 classes instead of 1000
        in_features = self.base_model.classifier[1].in_features
        self.base_model.classifier[1] = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.base_model(x)
    

# -----------------------------------------------------------------------------
# Define ViT for ImageNet-100 (vit)
# -----------------------------------------------------------------------------
class ViT(nn.Module):
    def __init__(self, preload=True, num_classes=100):
        super(ViT, self).__init__()
        if preload:
            self.base = ViTModel.from_pretrained('google/vit-base-patch16-224')
        else:
            self.base = ViTModel.from_pretrained(None)
        self.final = nn.Linear(self.base.config.hidden_size, num_classes)
        self.num_classes = num_classes
        self.relu = nn.ReLU()

    def forward(self, pixel_values):
        outputs = self.base(pixel_values=pixel_values)
        logits = self.final(outputs.last_hidden_state[:,0])

        return logits

# -----------------------------------------------------------------------------
# Define a function to create the model based on the specified architecture
# -----------------------------------------------------------------------------

def create_model(model_name, device, preload=True, num_classes=100):
    if model_name == "resnet18":
        model = MyResNet18(num_classes=num_classes).to(device)

    elif model_name == "resnet18_imagenet":
        model = models.resnet18(weights=None, num_classes=100)
        model.to(device)

    elif model_name == "berta_distill":
        model = MyTransformer(preload, num_classes=num_classes).to(device)

    elif model_name == "efficientnetv2l":
        model = MyEfficientNetV2L(preload, num_classes=num_classes).to(device)

    elif model_name == "vit":
        model = ViT(num_classes=num_classes).to(device)

    else:
        raise ValueError(f"Unsupported model: {model_name}")

    # wrap for multi‐GPU if needed (you had this already)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    return model

# -----------------------------------------------------------------------------
# DataLoaders
# -----------------------------------------------------------------------------
def collate_fn(batch):
    """
    batch is a list of (input_dict, label).
    We'll merge input_dict items into a single dict of tensors,
    and keep labels as a separate tensor.
    """
    input_ids = []
    attention_masks = []
    labels = []
    for (inp, lbl) in batch:
        input_ids.append(inp["input_ids"])
        attention_masks.append(inp["attention_mask"])
        labels.append(lbl)
    # Stack them into tensors
    input_ids = torch.stack(input_ids, dim=0)
    attention_masks = torch.stack(attention_masks, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    return {"input_ids": input_ids, "attention_mask": attention_masks}, labels

# -----------------------------------------------------------------------------
# Training Function (Initial Training on D_train)
# -----------------------------------------------------------------------------
def train(model, train_loader, val_loader, model_name, criterion, optimizer, num_epochs, scheduler=None, trans=False):
    import copy, math, time
    from torch.cuda.amp import autocast, GradScaler

    scaler = GradScaler(enabled=torch.cuda.is_available())
    patience, min_delta = 10, 0.0
    best_val, best_epoch = -math.inf, -1
    best_state = copy.deepcopy(model.state_dict())
    patience_ctr = 0

    total_start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        running_loss, total, correct = 0.0, 0, 0

        for inputs, labels in train_loader:
            if trans:
                inputs = {k: v.to(device) for k, v in inputs.items()}
            else:
                inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad(set_to_none=True)

            with autocast(enabled=torch.cuda.is_available()):
                if trans:
                    out = model(**inputs)
                    outputs = out.logits if hasattr(out, "logits") else out
                else:
                    outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            # optional but safe:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct / total

        # Step schedulers (ReduceLROnPlateau needs val metric; others step now)
        val_accuracy = evaluate(model, val_loader, device, trans=trans)
        if scheduler is not None:
            if hasattr(scheduler, "step") and "ReduceLROnPlateau" in scheduler.__class__.__name__:
                scheduler.step(val_accuracy)
            else:
                scheduler.step()

        # improved = True #(val_accuracy - best_val) > min_delta
        # if improved:
        #     best_val = val_accuracy
        #     best_epoch = epoch
        #     best_state = copy.deepcopy(model.state_dict())
        #     patience_ctr = 0
        #     print(f"New best validation accuracy: {val_accuracy:.4f} (epoch {epoch+1})")
        # else:
        #     patience_ctr += 1

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f} Accuracy: {epoch_acc:.4f} | Val: {val_accuracy:.4f}")

        # if patience_ctr >= patience:
        #     print(f"Early stopping after {patience} epochs without improvement (best @ epoch {best_epoch+1}).")
        #     break

    total_training_time = time.time() - total_start_time

    # Save last and best
    # last_path = f"../models/{model_name}_last.pth"
    best_path = f"../models/{model_name}.pth"  # keep your original name for BEST
    # torch.save(model.state_dict(), last_path)
    # model.load_state_dict(best_state)  # restore best
    torch.save(model.state_dict(), best_path)

    return total_training_time, model

class ConformitySeparationCrossEntropy(nn.Module):
    """
    Cross-entropy + non-negative separation hinge on the gap between
    the true class conformity score and the next strictly larger score.
    Conformity score = -log_softmax(logits) (lower is better).

    Total loss (per-sample):
        L = CE + alpha * relu(margin - (next_larger - true_score))  >= 0
    """
    def __init__(self, alpha: float = 2.0, margin: float = 10.0, reduction: str = "mean"):
        super().__init__()
        self.alpha = alpha
        self.margin = margin
        if reduction not in ("mean", "sum", "none"):
            raise ValueError("reduction must be 'mean', 'sum', or 'none'")
        self.reduction = reduction

    @staticmethod
    def _extract_logits(outputs) -> torch.Tensor:
        if isinstance(outputs, torch.Tensor):
            return outputs
        if hasattr(outputs, "logits"):
            return outputs.logits
        if isinstance(outputs, dict) and "logits" in outputs:
            return outputs["logits"]
        if isinstance(outputs, (tuple, list)) and outputs and isinstance(outputs[0], torch.Tensor):
            return outputs[0]
        raise TypeError("Unsupported model output type; expected Tensor or object with .logits")

    def forward(self, outputs, targets: torch.Tensor) -> torch.Tensor:
        logits = self._extract_logits(outputs)  # (B, C)

        # Term 1: cross-entropy (per-sample)
        ce = F.cross_entropy(logits, targets, reduction="none")

        # Conformity scores: lower is better
        nll = -F.log_softmax(logits, dim=1)  # (B, C)

        # True-label scores
        true_scores = nll.gather(1, targets.view(-1, 1)).squeeze(1)  # (B,)

        # Next strictly larger score than the true score (per sample)
        greater_mask = nll > true_scores.unsqueeze(1)                 # (B, C)
        next_larger = nll.masked_fill(~greater_mask, torch.inf).min(dim=1).values
        has_next = torch.isfinite(next_larger)
        next_larger = torch.where(has_next, next_larger, true_scores)  # if none, gap=0

        # Gap and non-negative separation hinge
        gap = next_larger - true_scores                                # >= 0
        sep_loss = F.relu(self.margin - gap)                           # >= 0

        loss = ce + self.alpha * sep_loss

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        return loss
    
# -----------------------------------------------------------------------------
# Evaluation Function
# -----------------------------------------------------------------------------
def evaluate(model, dataloader, device, trans=False, amp_dtype=None):
    model.eval()
    total = 0
    correct = 0

    use_cuda = (device.type == "cuda")
    # autocast only on CUDA; choose dtype if you want fp16/bf16
    ac = torch.cuda.amp.autocast(dtype=amp_dtype) if use_cuda else contextlib.nullcontext()

    with torch.inference_mode():  # faster than no_grad for inference
        for inputs, labels in dataloader:
            if trans:
                inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()}
            else:
                inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            with ac:
                outputs = model(**inputs) if trans else model(inputs)

            predicted = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# -----------------------------------------------------------------------------
# Nonconformity Score & Conformal Prediction
# -----------------------------------------------------------------------------
def compute_nonconformity_score(probs, true_labels, score_func_name, eps=1e-8):
    """
    Computes nonconformity scores for the true labels based on the selected method.

    Options:
      - "one_minus":         1 - p(true)
      - "pmax_minus":        max_k p(k) - p(true)
      - "neg_log":           -log(p(true) + eps)
      - "margin":            1 - [ p(true) - max_{k != true} p(k) ]
      - "entropy_weighted":  (1 + H_norm) * [1 - p(true)]  where H_norm = H/ln(K)

    Parameters
    ----------
    probs : Tensor, shape (n, K)
        Softmax probabilities over K classes.
    true_labels : LongTensor, shape (n,)
        Ground-truth labels (0 <= y < K).
    score_func_name : str
        Which score to compute.
    eps : float
        Small constant to avoid log(0).

    Returns
    -------
    Tensor, shape (n,)
        Nonconformity scores.
    """
    n, K = probs.size()
    if score_func_name == "one_minus":
        # Least confidence
        return 1.0 - probs.gather(1, true_labels.view(-1, 1)).squeeze()

    elif score_func_name == "pmax_minus":
        # Difference between highest prob and true prob
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        p_max, _ = probs.max(dim=1)
        return p_max - p_true

    elif score_func_name == "neg_log":
        # Negative log-likelihood
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        return -torch.log(p_true + eps)

    elif score_func_name == "margin":
        # Margin-based: 1 - (p_true - runner_up)
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        # Mask out true class to find runner-up
        probs_minus_true = probs.clone()
        probs_minus_true[torch.arange(n), true_labels] = -float('inf')
        p_second, _ = probs_minus_true.max(dim=1)
        return 1.0 - (p_true - p_second)

    elif score_func_name == "entropy_weighted":
        # Entropy-reweighted least confidence
        # 1) Compute normalized Shannon entropy H_norm in [0,1]
        log_probs = torch.log(probs + eps)
        H = -torch.sum(probs * log_probs, dim=1)                # un-normalized entropy
        H_norm = H / torch.log(torch.tensor(K, device=probs.device).float())
        # 2) Base score: 1 - p_true
        p_true = probs.gather(1, true_labels.view(-1, 1)).squeeze()
        base_nc = 1.0 - p_true
        # 3) Weight and return
        return (1.0 + H_norm) * base_nc

    else:
        raise ValueError(f"Unknown nonconformity score function: {score_func_name}")

def conformal_prediction_quantile_and_returnall(
    model, calib_dataset, alpha, score_func_name="one_minus",
    device="cuda", batch_size=256, num_workers=4, trans=False
):
    """
    Returns:
      q_hat : float
        The conformal threshold (quantile of nonconformity scores).
      scores_arr : ndarray, shape (n_calib,)
        All calibration nonconformity scores in dataset order.
    """
    model.eval()

    loader = DataLoader(
        calib_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    all_scores = []

    with torch.no_grad():
        for inputs, labels in loader:
            if trans:
                inputs = {k: v.to(device) for k, v in inputs.items()}
            else:
                inputs = inputs.to(device)

            labels = labels.to(device)

            if trans:
                outputs = model(**inputs)
            else:
                outputs = model(inputs)

            probs   = F.softmax(outputs, dim=1)

            # compute_nonconformity_score should return a 1-D tensor of length=batch_size
            batch_scores = compute_nonconformity_score(probs, labels, score_func_name)
            all_scores.append(batch_scores.cpu())

    # concatenate into one big vector
    scores_tensor = torch.cat(all_scores, dim=0)    # shape: (n_calib,)
    
    # compute the (1-alpha)-quantile with “higher” interpolation
    q_hat = torch.quantile(scores_tensor, 1.0 - alpha, interpolation="higher").item()

    # move to numpy for return
    scores_arr = scores_tensor.numpy()
    return q_hat, scores_arr


# -----------------------------------------------------------------------------
# Parameter Regularization
# -----------------------------------------------------------------------------
def parameter_diff_norm(model, original_state_dict, norm: str = 'l2'):
    """
    Compute the norm of (θ - θ₀) across all parameters.
    
    Args:
        model: your PyTorch model
        original_state_dict: a state_dict mapping parameter names to their original tensors
        norm: one of
            - 'l1'  : sum of absolute differences (‖θ - θ₀‖₁)
            - 'l2'  : Euclidean norm (‖θ - θ₀‖₂)
            - 'l2_sq': squared Euclidean norm (‖θ - θ₀‖₂²)
    Returns:
        a scalar tensor on the same device as the model parameters
    """
    device = next(model.parameters()).device

    if norm not in ('l1', 'l2', 'l2_sq'):
        raise ValueError(f"Unsupported norm '{norm}'; choose 'l1', 'l2' or 'l2_sq'")

    # accumulate either abs-sum or squared-sum
    acc = torch.tensor(0.0, device=device)
    for name, p in model.named_parameters():
        p0 = original_state_dict[name].to(device)
        diff = p - p0
        if norm == 'l1':
            acc = acc + diff.abs().sum()
        else:
            # for both l2_sq and l2 we first sum squares
            acc = acc + (diff * diff).sum()

    if norm == 'l2':
        return torch.sqrt(acc)
    return acc

# --------------------------------------------------------------------------------
# Reproducing the Rethinking MU using CP metrics code
# --------------------------------------------------------------------------------
def cpu(
    model,
    original_state_dict,
    forget_loader,
    retained_loader,
    calib_dataset,
    alpha,
    beta,                 # unused (kept for signature compatibility)
    phi,                  # unused here (kept)
    lambda_reg,           # <-- λ for FORGET loss weight
    n_epochs,
    optimizer,
    delta=0.0001,
    norm_name='l2_sq',    # unused
    score_func_name="one_minus",
    trans=False,
):
    """
    Rethinking unlearning (single backward/step per epoch):
      - CP threshold q from calib_dataset each epoch (alpha, score_func_name).
      - Forget:   loss_f = mean( max( S(x,y) - q, -delta ) ) over D_f.
      - Retained: loss_r = cross_entropy over D_r.
      - Total loss = loss_r + λ * loss_f.  No parameter regularization.
      - Same inputs/outputs as original; some args remain unused.
    """
    device = next(model.parameters()).device
    model.train()

    total_time_q = 0.0
    t_start = time.time()

    for epoch in range(1, n_epochs + 1):
        model.zero_grad(set_to_none=True)
        print(f"Unlearning Epoch [{epoch}/{n_epochs}]")

        # ── 1) Compute CP threshold q (in eval mode to avoid BN stat drift) ──
        t_q = time.time()
        prev_training = model.training
        model.eval()
        with torch.no_grad():
            q_hat, _ = conformal_prediction_quantile_and_returnall(
                model, calib_dataset, alpha, score_func_name, trans=trans
            )
        model.train(prev_training)
        total_time_q += (time.time() - t_q)
        q = float(q_hat)
        print(f"    CP threshold q = {q:.6f} (score='{score_func_name}'), λ(FORGET) = {lambda_reg}")

        # ── 2) Accumulate retained CE over D_r ──
        total_r = len(retained_loader.dataset) if retained_loader is not None else 0
        loss_r_total = None
        if total_r > 0:
            for x_r, y_r in retained_loader:
                if trans:
                    x_r = {k: v.to(device) for k, v in x_r.items()}
                else:
                    x_r = x_r.to(device)
                y_r = y_r.to(device)

                logits_r = model(**x_r) if trans else model(x_r)
                ce_r = F.cross_entropy(logits_r, y_r, reduction='mean')
                term = ce_r * (y_r.size(0) / total_r)
                loss_r_total = term if loss_r_total is None else (loss_r_total + term)
        else:
            loss_r_total = torch.tensor(0.0, device=device)

        # ── 3) Accumulate forget max-loss over D_f (skip entirely if λ==0) ──
        total_f = len(forget_loader.dataset) if forget_loader is not None else 0
        loss_f_total = torch.tensor(0.0, device=device)
        if total_f > 0 and float(lambda_reg) != 0.0:
            loss_f_total = None
            for x_f, y_f in forget_loader:
                if trans:
                    x_f = {k: v.to(device) for k, v in x_f.items()}
                else:
                    x_f = x_f.to(device)
                y_f = y_f.to(device)

                logits_f = model(**x_f) if trans else model(x_f)
                probs    = F.softmax(logits_f, dim=1)

                # Must return a (batch,) tensor; e.g., for "one_minus": 1 - probs.gather(1, y_f[:,None])
                scores_f = compute_nonconformity_score(probs, y_f, score_func_name)

                per_sample = torch.clamp(scores_f - q, min=-float(delta))  # max(S - q, -δ)
                batch_loss = per_sample.mean()
                term = batch_loss * (y_f.size(0) / total_f)
                loss_f_total = term if loss_f_total is None else (loss_f_total + term)

        # ── 4) Single backward/step on total loss ──
        total_loss = loss_r_total + lambda_reg * loss_f_total
        total_loss.backward()
        optimizer.step()

    total_time = time.time() - t_start
    q_percent = (total_time_q / total_time * 100.0) if total_time > 0 else 0.0
    print(f"Total unlearning time ({n_epochs} epochs): {total_time:.2f}s\n")
    return model, total_time, q_percent


# -----------------------------------------------------------------------------
def cqmu(
    model,
    original_state_dict,
    forget_loader,
    retained_loader,
    calib_dataset,
    alpha,
    beta,
    phi,
    lambda_reg,
    n_epochs,
    optimizer,
    delta=0.0001,
    norm_name = 'l2_sq',
    score_func_name="one_minus",
    trans=False,
    file_name=None,
):
    device = next(model.parameters()).device
    gamma = phi
    model.train()
    total_time_q = 0.0
    t_start = time.time()

    for epoch in range(1, n_epochs + 1):
        # MODIFICATION: Zero gradients only once at the start of the epoch
        optimizer.zero_grad()
        print(f"Unlearning Epoch [{epoch}/{n_epochs}]")
        t_start_q = time.time()
        # ── 1) Calibration part remains the same ──
        with torch.no_grad():
            q_hat_r, all_scores = conformal_prediction_quantile_and_returnall(
                model, calib_dataset, alpha, score_func_name, trans=trans
            )
        q_hat_f = np.quantile(all_scores, 1 - beta, interpolation='higher')
        idx_r = np.argmin(np.abs(all_scores - q_hat_r))
        idx_f = np.argmin(np.abs(all_scores - q_hat_f))

        X_r, Y_r = calib_dataset[idx_r]
        X_f, Y_f = calib_dataset[idx_f]
        if trans:
            X_q = {key: torch.stack([X_r[key], X_f[key]], dim=0).to(device) for key in X_r.keys()}
        else:
            X_q = torch.stack([X_r, X_f], dim=0).to(device)
        Y_q = torch.tensor([Y_r, Y_f], device=device)

        print(f"    q_hat_r = {q_hat_r:.4f}, using calibration sample idx={idx_r} for L'")
        print(f"    q_hat_f = {q_hat_f:.4f}, using calibration sample idx={idx_f} for L''")

        if trans:
            out_q = model(**X_q)
        else:
            out_q = model(X_q)
        
        probs_q = F.softmax(out_q, dim=1)
        corrs = probs_q.gather(1, Y_q.view(-1,1)).squeeze(1)
        corr_q_r, corr_q_f = corrs[0], corrs[1]

        total_time_q += time.time() - t_start_q
        # ──────────────────────────────────────────────────────────
        # MODIFICATION: Process retained set with per-batch backward pass
        total_r = len(retained_loader.dataset)
        for x_r, y_r in retained_loader:
            if trans:
                x_r = {k: v.to(device) for k, v in x_r.items()}
            else:
                x_r = x_r.to(device)
            y_r = y_r.to(device)

            if trans:
                logits_r = model(**x_r)
            else:
                logits_r = model(x_r)
            
            probs_r  = F.softmax(logits_r, dim=1)
            corr_r   = probs_r.gather(1, y_r.view(-1,1)).squeeze(1)
            u_r      = (corr_q_r + delta) - corr_r
            batch_eps_r = torch.sigmoid(gamma * u_r).mean()
            
            # This is the loss component for this batch
            loss_r_component = batch_eps_r * (y_r.size(0) / total_r)

            # Backpropagate, but keep the graph for corr_q_r
            loss_r_component.backward(retain_graph=True)

        # ──────────────────────────────────────────────────────────
        # MODIFICATION: Process forget set with per-batch backward pass
        total_f = len(forget_loader.dataset)
        num_forget_batches = len(forget_loader)
        for i, (x_f, y_f) in enumerate(forget_loader):
            if trans:
                x_f = {k: v.to(device) for k, v in x_f.items()}
            else:
                x_f = x_f.to(device)
            y_f = y_f.to(device)

            if trans:
                logits_f = model(**x_f)
            else:
                logits_f = model(x_f)
            
            probs_f  = F.softmax(logits_f, dim=1)
            corr_f   = probs_f.gather(1, y_f.view(-1,1)).squeeze(1)
            u_f      = corr_f - (corr_q_f - delta)
            batch_eps_f = torch.sigmoid(gamma * u_f).mean()
            
            loss_f_component = batch_eps_f * (y_f.size(0) / total_f)

            # For the very last batch of the forget set, we don't need to retain the graph
            is_last_batch = (i == num_forget_batches - 1)
            loss_f_component.backward(retain_graph=not is_last_batch)
            
        # ──────────────────────────────────────────────────────────
        # MODIFICATION: Handle the regularization term separately
        reg_term = lambda_reg * parameter_diff_norm(model, original_state_dict, norm=norm_name)
        # Its gradient can be added to the accumulated gradients
        reg_term.backward()

        # MODIFICATION: Step the optimizer AFTER all gradients are accumulated
        optimizer.step()
        
    total_time = time.time() - t_start
    q_percent = total_time_q/total_time*100
    print(f"Total unlearning time ({n_epochs} epochs): {total_time:.2f}s\n")
    best_path = f"./checkpoints/unlearning_{model.__class__.__name__}_updated.pth"
    torch.save(model.state_dict(), file_name if file_name is not None else best_path)
    return model, total_time, q_percent


# -----------------------------------------------------------------------------
# Helper: Get Conformal Prediction Set for a Batch
# -----------------------------------------------------------------------------
def get_cp_prediction_set(model, inputs, q_hat, score_func_name="one_minus", eps=1e-8, trans=False):
    """
    For each sample in inputs, returns the set of class indices c for which the candidate
    nonconformity score is <= q_hat.

    score_func_name options:
      - "one_minus"
      - "pmax_minus"
      - "neg_log"
      - "margin"
      - "entropy_weighted"
    """
    model.eval()
    prediction_sets = []
    with torch.no_grad():
        if trans:
            # inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs)                # (B, K)
        else:
            outputs = model(inputs)                   # (B, K)

        probs   = F.softmax(outputs, dim=1)       # (B, K)
        B, K    = probs.shape

        # Precompute per-sample entropy for entropy_weighted
        if score_func_name == "entropy_weighted":
            log_probs = torch.log(probs + eps)
            H = -torch.sum(probs * log_probs, dim=1)                   # (B,)
            H_norm = H / torch.log(torch.tensor(K, device=probs.device).float())

        for i in range(B):
            p = probs[i]                     # (K,)

            if score_func_name == "one_minus":
                candidate_scores = 1.0 - p

            elif score_func_name == "pmax_minus":
                p_max = p.max()
                candidate_scores = p_max - p

            elif score_func_name == "neg_log":
                candidate_scores = -torch.log(p + eps)

            elif score_func_name == "margin":
                # for each class c: 1 - (p[c] - max_{k!=c} p[k])
                candidate_scores = torch.empty_like(p)
                for c in range(K):
                    p_c = p[c]
                    tmp = p.clone()
                    tmp[c] = -float('inf')
                    p_second = tmp.max()
                    candidate_scores[c] = 1.0 - (p_c - p_second)

            elif score_func_name == "entropy_weighted":
                # (1 + H_norm[i]) * [1 - p[c]]
                candidate_scores = (1.0 + H_norm[i]) * (1.0 - p)

            else:
                raise ValueError("Unknown nonconformity score function: " + score_func_name)

            # select all classes whose score ≤ q_hat
            mask = candidate_scores <= q_hat
            selected = mask.nonzero(as_tuple=True)[0].cpu().tolist()
            prediction_sets.append(selected)

    return prediction_sets


# -----------------------------------------------------------------------------
# Metric: CCUCR
# -----------------------------------------------------------------------------
def compute_ccucr(model, forget_loader, retained_loader, q_hat, score_func_name, c=100, trans=False):
    """
    We first compute for each dataset D (D_f or D_r):
        CCUCR_D(c) = average over X in D of I{(Y not in CPset(X)) OR (Y in CPset(X) and |CPset(X)| >= C_total - c)}

    Then:
        CCUCR(c) = max{ (CCUCR_f - CCUCR_r) / (1 - CCUCR_r), 0 }.
    """
    def compute_ccucr_D(loader, method):
        total_samples = 0
        indicator_sum = 0

        # ---- added for average set size ----
        total_set_size = 0
        total_set_count = 0
        # ------------------------------------

        model.eval()
        with torch.no_grad():
            for inputs, labels in loader:
                if trans:
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                else:
                    inputs = inputs.to(device)

                labels = labels.to(device)

                pred_sets = get_cp_prediction_set(model, inputs, q_hat, score_func_name, trans=trans)
                for i, cp_set in enumerate(pred_sets):
                    # accumulate for average set size (over the dataset)
                    total_set_size += len(cp_set)
                    total_set_count += 1

                    true_label = labels[i].item()
                    # If the true label is excluded OR included but the set is near-trivial:
                    if method == 'efc':
                        if len(cp_set) <= (c):
                            total_samples += 1
                            if true_label not in cp_set:
                                indicator_sum += 1
                    elif method == 'coverage':
                        if len(cp_set) <= (c):
                            total_samples += 1
                            if true_label in cp_set:
                                indicator_sum += 1

        # ---- added print ----
        if total_set_count > 0:
            avg_set_size = total_set_size / total_set_count
            print(f"Average prediction set size {method}: {avg_set_size:.4f}")
        else:
            print(f"Average prediction set size {method}: n/a (no samples).")
        # ---------------------

        return (indicator_sum / total_samples) if total_samples > 0 else 0.0, avg_set_size if total_set_count > 0 else 0.0

    f_d, avg_f = compute_ccucr_D(forget_loader, 'efc')
    r_d, avg_r = compute_ccucr_D(retained_loader, 'coverage')
    list_average_set_sizes = [avg_f, avg_r]
    return r_d, f_d, list_average_set_sizes

# -----------------------------------------------------------------------------
# Helper: CR metric
# -----------------------------------------------------------------------------
def compute_cr(model, forget_loader, retained_loader, q_hat, score_func_name, trans=False):
    """
    For each dataset D in {forget, retained}, compute:
        CR_D = (sum_{(x,y) in D} 1{ y in C(x) }) / (sum_{(x,y) in D} |C(x)|)
    Returns:
        r_cr, f_cr  (retained CR, forgotten CR)
    """
    def _as_set(cp_set):
        # Robustly convert a CP prediction "set" to a Python set of ints
        if isinstance(cp_set, torch.Tensor):
            return set(map(int, cp_set.detach().view(-1).tolist()))
        if isinstance(cp_set, (list, tuple, set)):
            return set(int(x) for x in cp_set)
        # Fallback: assume it's a single label
        return {int(cp_set)}

    def compute_cr_D(loader):
        num = 0  # sum of indicators 1{ y in C(x) }
        den = 0  # sum of |C(x)|
        model.eval()
        with torch.no_grad():
            for inputs, labels in loader:
                if trans:
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                else:
                    inputs = inputs.to(device)
                labels = labels.to(device)

                pred_sets = get_cp_prediction_set(model, inputs, q_hat, score_func_name, trans=trans)
                for i, cp_set in enumerate(pred_sets):
                    cp = _as_set(cp_set)
                    den += len(cp)
                    num += int(int(labels[i].item()) in cp)
        return (num / den) if den > 0 else 0.0

    f_cr = compute_cr_D(forget_loader)
    r_cr = compute_cr_D(retained_loader)
    return r_cr, f_cr

# -----------------------------------------------------------------------------
# Helper: Harmonic Mean
# -----------------------------------------------------------------------------

def harmonic_mean(numbers):
    """
    Compute the harmonic mean of a list of positive numbers.
    
    Parameters:
        numbers (list of float): The input values (must be non-zero).
        
    Returns:
        float: The harmonic mean.
        
    Raises:
        ValueError: If the list is empty or contains zero.
    """
    n = len(numbers)
    if n == 0:
        raise ValueError("At least one number is required to compute the harmonic mean.")
    # Sum the reciprocals
    reciprocal_sum = sum(1.0 / x for x in numbers if x != 0)
    if any(x == 0 for x in numbers):
        print("Warning: One or more numbers are zero, which will lead to division by zero in harmonic mean calculation.")
        return 0
    result = n / reciprocal_sum
    return result

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
import torch.nn.functional as F
import numpy as np
import torch

def compute_membership_inference_attack(
    model: torch.nn.Module,
    member_loader: torch.utils.data.DataLoader,
    nonmember_loader: torch.utils.data.DataLoader,
    device: torch.device,
    n_splits: int = 10,
    random_state: int = 0,
    trans: bool = False,
) -> (np.ndarray, float):
    """
    Runs a strengthened membership-inference attack using a RandomForest on 12 features:
      1) per-sample loss
      2) entropy of the softmax
      3) margin (top1–top2)
      4–11) top-8 softmax probabilities
      12) l2-norm of the logits
    """
    model.eval()

    def gather_features(loader):
        feats = []
        with torch.no_grad():
            for x, y in loader:
                if trans:
                    x = {k: v.to(device) for k, v in x.items()}
                else:
                    x = x.to(device)
                
                y = y.to(device)

                if trans:
                    logits = model(**x)                            # (batch, num_classes)
                else:
                    logits = model(x)   
                                                  # (batch, num_classes)
                probs  = F.softmax(logits, dim=1)                 # (batch, num_classes)
                
                # --- Start of Changes ---
                
                # 1. Get the number of classes dynamically
                num_classes = probs.shape[1]

                # 2. Calculate k for margin (top-2 style feature)
                # Ensure k is at least 2 for the margin calculation to be possible
                k_for_margin = max(2, int(0.02 * num_classes))

                # 3. Calculate k for top features (top-8 style feature)
                k_for_top_features = int(0.08 * num_classes) + 1

                # --- End of Changes ---

                # 1) loss
                loss     = F.cross_entropy(logits, y, reduction='none')  # (batch,)

                # 2) entropy
                entropy  = -(probs * torch.log(probs + 1e-12)).sum(dim=1)  # (batch,)

                # 3) margin = top1 - top2 (using the new dynamic k)
                top_margin_probs, _ = probs.topk(k_for_margin, dim=1)      # (batch, k_for_margin)
                margin = top_margin_probs[:, 0] - top_margin_probs[:, 1]  # (batch,)

                # 4–X) top-k probabilities (using the new dynamic k)
                top_k_probs, _  = probs.topk(k_for_top_features, dim=1)   # (batch, k_for_top_features)

                # Y) l2-norm of logits
                logit_norm = logits.norm(dim=1)                    # (batch,)

                # stack into (batch, 4 + k_for_top_features)
                batch_feats = torch.cat([
                    loss.unsqueeze(1),
                    entropy.unsqueeze(1),
                    margin.unsqueeze(1),
                    top_k_probs,  # Use the new dynamic tensor here
                    logit_norm.unsqueeze(1)
                ], dim=1)

                feats.append(batch_feats.cpu().numpy())
        
        # The number of columns in the output is now dynamic
        return np.vstack(feats)

    # build feature matrix X and labels y
    X_mem    = gather_features(member_loader)
    X_non    = gather_features(nonmember_loader)
    X        = np.vstack([X_mem, X_non])
    y        = np.concatenate([
                  np.ones(len(X_mem),  dtype=int),
                  np.zeros(len(X_non), dtype=int)
               ])

    # strong ensemble attacker
    attack = RandomForestClassifier(
        n_estimators=200,
        max_depth=None,
        random_state=random_state,
        n_jobs=-1
    )

    cv     = StratifiedShuffleSplit(
                 n_splits=n_splits,
                 random_state=random_state
             )
    scores = cross_val_score(
                 attack, X, y,
                 cv=cv,
                 scoring="accuracy",
                 n_jobs=-1
             )

    # baseline for adversarial advantage
    total = X_mem.shape[0] + X_non.shape[0]
    baseline = max(len(X_mem)/total, len(X_non)/total)
    # return per-split accuracies and mean advantage over baseline
    return scores, float(np.mean(scores) - baseline)



def set_seed(seed):
    """
    Set the random seed for reproducibility.
    """
    RNG = torch.Generator().manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return RNG 

# -----------------------------------------------------------------------------
# Helper: Build Base models
# -----------------------------------------------------------------------------
def build_base_model(model_type, model_filename, train_loader, val_loader, device, num_classes, trans=False):
    """
    Build the base model and train it on the training dataset.
    Returns the trained model and elapsed time.
    """
    if model_type == "resnet18":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.1, 50
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                               end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()

        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f}s")

    elif model_type == "resnet18_imagenet":
        model = create_model(model_type, device, num_classes=num_classes)
        lr, epochs = 0.1, 80
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, 
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename, 
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f} seconds")

    elif model_type == "efficientnetv2l":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.01, 80
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f}s")
    
    elif model_type == "vit":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.01, 15
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f} seconds")

    elif model_type == "berta_distill":
        model = create_model(model_type, device, preload=True, num_classes=num_classes)
        lr, epochs = 0.01, 15
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0,
                                                end_factor=0.001, total_iters=epochs)
        criterion = nn.CrossEntropyLoss()
        print("Starting initial training on D_train (all classes)...")
        training_time, model = train(model, train_loader, val_loader, model_filename,
                              criterion, optimizer, epochs, scheduler, trans=trans)
        print(f"Initial training completed in {training_time:.2f} seconds")

    else:
        raise ValueError(f"Unsupported model_type: {model_type}")

    return model, training_time

# -----------------------------------------------------------------------------
# Helper: Choose Dataset and Completely Load It
# -----------------------------------------------------------------------------
# This class and helper functions are specific to the 'news' dataset.
# They are correctly defined here to be available in the main function's scope.
class HuggingFaceSubset(Dataset):
    """
    Simple wrapper so we can use the same .[i] approach that returns (input, label).
    But in this code, "input" is a dict {input_ids, attention_mask}, "label" is an integer.
    """
    def __init__(self, hf_dataset, indices):
        self.hf_dataset = hf_dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        sample = self.hf_dataset[actual_idx]
        # 'sample' is a dict, e.g. { 'input_ids':..., 'attention_mask':..., 'labels':... }
        inputs = {
            "input_ids": sample["input_ids"],
            "attention_mask": sample["attention_mask"]
        }
        label = sample["labels"].item()  # a single integer
        return inputs, label

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=False)

def tokenize_fn(examples):
    # We'll tokenize the combined text. This returns a dict: { 'input_ids': [...], 'attention_mask': [...], ... }
    return tokenizer(examples["combined_text"], truncation=True, padding="max_length", max_length=128)

def combine_text(ex):
    # For AG News, the relevant column is 'text'. 'content' is empty.
    title = ex["text"] if ex["text"] else ""
    content = ""
    combined = title + " " + content
    return {"combined_text": combined.strip()}


def create_feature_clusters(D_train, D_test, n_clusters=100, batch_size=64, device="cuda", trans=False):
    """
    Generates feature-based cluster assignments for a dataset using embeddings from a
    pre-trained ResNet18 model and K-Means clustering.

    Args:
        D_train (Dataset): The training dataset.
        D_test (Dataset): The test dataset.
        n_clusters (int): The number of clusters to form (e.g., 100 to mimic classes).
        batch_size (int): Batch size for processing.
        device (str): 'cuda' or 'cpu'.

    Returns:
        tuple[np.ndarray, np.ndarray]: A tuple containing:
            - train_cluster_ids (np.ndarray): Cluster assignments for the training set.
            - test_cluster_ids (np.ndarray): Cluster assignments for the test set.
    """
    # 1. Load a pre-trained model to use as a feature extractor
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    # Remove the final classification layer
    model.fc = torch.nn.Identity()
    model.to(device)
    model.eval()

    def get_embeddings(dataset):
        """Helper function to extract feature embeddings from a dataset."""
        embeddings = []
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
        with torch.no_grad():
            for images, _ in tqdm(data_loader, desc="Extracting embeddings"):
                if trans:
                    images = {k: v.to(device) for k, v in images.items()}
                    emb = model(**images)
                else:
                    images = images.to(device)
                    emb = model(images)
                embeddings.append(emb.cpu().numpy())
        return np.concatenate(embeddings, axis=0)

    # 2. Extract embeddings for both training and test sets
    train_embeddings = get_embeddings(D_train)
    test_embeddings = get_embeddings(D_test)

    # 3. Use K-Means to cluster the training embeddings
    # MiniBatchKMeans is faster and more memory-efficient for large datasets
    print(f"Clustering {len(train_embeddings)} training embeddings into {n_clusters} clusters...")
    kmeans = MiniBatchKMeans(n_clusters=n_clusters, n_init='auto', random_state=42, batch_size=256)
    kmeans.fit(train_embeddings)

    # 4. Assign cluster IDs to train and test sets
    train_cluster_ids = kmeans.predict(train_embeddings)
    test_cluster_ids = kmeans.predict(test_embeddings)

    print("Clustering complete.")
    return train_cluster_ids, test_cluster_ids

def pca_projector(D_train, test_dataset, batch_size, n_components=100):
    # 2) Compute 100‐dim PCA on the TRAIN set’s raw pixels
    feat_loader_train = DataLoader(D_train, batch_size=batch_size, shuffle=False, num_workers=2)
    feats_train = []
    for imgs, _ in feat_loader_train:
        # flatten to (B, 3072)
        feats_train.append(imgs.view(imgs.size(0), -1).numpy())
    feats_train = np.concatenate(feats_train, axis=0)  # shape = (50000, 3072)

    pca = PCA(n_components=n_components)
    pca_feats_train = pca.fit_transform(feats_train)   # (50000, 100)

    # 3) Project the TEST set into the same PCA space
    feat_loader_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    feats_test = []
    for imgs, _ in feat_loader_test:
        feats_test.append(imgs.view(imgs.size(0), -1).numpy())
    feats_test = np.concatenate(feats_test, axis=0)   # (10000, 3072)

    pca_feats_test = pca.transform(feats_test)         # (10000, 100)

    # 4) Assign each image to its “dominant” PC axis
    train_pc_cluster = np.argmax(np.abs(pca_feats_train), axis=1)  # values in [0..99]
    test_pc_cluster  = np.argmax(np.abs(pca_feats_test),  axis=1)
    return train_pc_cluster, test_pc_cluster


def to_list(idx):
    # robustly convert indices (list/tuple/range/np/torch) to a Python list of ints
    if isinstance(idx, range):
        return list(idx)
    if torch.is_tensor(idx):
        return idx.to(dtype=torch.long).tolist()
    return list(idx)

def union_subsets(a: Subset, b: Subset):
    # If both subsets wrap the SAME base dataset, we can deduplicate indices
    if a.dataset is b.dataset:
        ua = set(map(int, to_list(a.indices)))
        ub = set(map(int, to_list(b.indices)))
        union_idx = sorted(ua | ub)
        return Subset(a.dataset, union_idx)     # still a Subset
    # Otherwise, just concatenate them (can’t dedupe safely across different datasets)
    return ConcatDataset([a, b])




# Old version of load_data with dataset-specific splitting logic
def OLD_load_data(dataset_name, forgot_set, rng, mode, vary=True, batch_size=256, n_work=2):
    num_classes = None  # Will be set by each dataset's block

    if dataset_name == "cifar100":
        # CIFAR-100 dataset
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761))
        ])
        D_train = torchvision.datasets.CIFAR100(
            root='../../data/cifar100',
            train=True, download=True,
            transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root='../../data/cifar100',
            train=False, download=True,
            transform=transform
        )
        num_classes = 100

    # -----------------------------------------------------------------------------
    elif dataset_name == "imagenet":
        # ImageNet dataset
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]),
        ])
        train_dir = "../../data/imagenet-100/train.X"
        val_dir   = "../../data/imagenet-100/val.X"

        D_train = torchvision.datasets.ImageFolder(
            root=train_dir, transform=transform
        )
        test_dataset = torchvision.datasets.ImageFolder(
            root=val_dir, transform=transform
        )
        num_classes = 100

    # -----------------------------------------------------------------------------
    elif dataset_name == "news":
        # Load and preprocess the AG News text dataset
        raw_dataset = load_dataset("ag_news")
        processed_dataset = raw_dataset.map(combine_text, num_proc=n_work)

        train_data = processed_dataset["train"]
        test_data = processed_dataset["test"]
        num_classes = 4  # AG News has 4 classes

        # Apply tokenizer and format for PyTorch
        print("Tokenizing dataset...")
        train_data = train_data.map(tokenize_fn, batched=True, num_proc=n_work)
        test_data = test_data.map(tokenize_fn, batched=True, num_proc=n_work)

        train_data = train_data.rename_column("label", "labels")
        test_data = test_data.rename_column("label", "labels")

        train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        print("Tokenization complete.")

        # Wrap the processed data into a standard PyTorch Dataset format
        # This makes it compatible with the generic splitting logic below.
        all_train_indices = list(range(len(train_data)))
        all_test_indices = list(range(len(test_data)))

        D_train = HuggingFaceSubset(train_data, all_train_indices)
        test_dataset = HuggingFaceSubset(test_data, all_test_indices)
        # All dataset-specific splitting logic has been removed from this block.
    
    elif dataset_name == "20_newsgroups":
        # Load and preprocess the 20 Newsgroups text dataset
        # This dataset is well-suited for the berta_distill model and has 20 classes.
        raw_dataset = load_dataset("SetFit/20_newsgroups")
        processed_dataset = raw_dataset.map(combine_text, num_proc=n_work)

        train_data = processed_dataset["train"]
        test_data = processed_dataset["test"]
        num_classes = 20  # 20 Newsgroups has 20 classes

        # Apply tokenizer and format for PyTorch
        print("Tokenizing dataset...")
        train_data = train_data.map(tokenize_fn, batched=True, num_proc=n_work)
        test_data = test_data.map(tokenize_fn, batched=True, num_proc=n_work)

        train_data = train_data.rename_column("label", "labels")
        test_data = test_data.rename_column("label", "labels")

        train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        test_data.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
        print("Tokenization complete.")

        # Wrap the processed data into a standard PyTorch Dataset format
        all_train_indices = list(range(len(train_data)))
        all_test_indices = list(range(len(test_data)))

        D_train = HuggingFaceSubset(train_data, all_train_indices)
        test_dataset = HuggingFaceSubset(test_data, all_test_indices)

    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    # -----------------------------------------------------------------------------
    # Generic Data Splitting Logic (applies to all datasets)
    # -----------------------------------------------------------------------------
    if "cluster" in mode:  # Changed mode name for clarity
        # This single call replaces the entire pca_projector function
        train_groups, test_groups = create_feature_clusters(
            D_train, test_dataset, n_clusters=num_classes, batch_size=batch_size
        )

        # if vary:
        perm_clusters = torch.randperm(num_classes, generator=rng).tolist()
        # forgot_set is assumed to be defined elsewhere (e.g., 10)
        forgotten_clusters = perm_clusters[:forgot_set]
        forgotten_clusters = set(forgotten_clusters)
        # else:
        #     forgotten_clusters = set(forgot_set)
        print(f"Forgotten feature clusters selected: {forgotten_clusters}")

        # The rest of your logic remains the same, just with the new group variables.
        # I've used 'train_groups' and 'test_groups' for clarity.

        # --- Test split: 80% calib, 20% forget/retain by label ---
        perm_test = torch.randperm(len(test_dataset), generator=rng).tolist()
        n_test = len(perm_test)
        num_calib_test = int(0.8 * n_test)
        indices_calib_test = perm_test[:num_calib_test]
        remaining_test = perm_test[num_calib_test:]

        # Use the pre-computed group assignments for fast splitting
        indices_f_test = [i for i in remaining_test if test_groups[i] in forgotten_clusters]
        indices_r_test = [i for i in remaining_test if test_groups[i] not in forgotten_clusters]

        D_calib = Subset(test_dataset, indices_calib_test)
        D_f = Subset(test_dataset, indices_f_test)
        D_r = Subset(test_dataset, indices_r_test)

        # --- Calibration subsets by group from D_calib ---
        calibration_forget_indices = [i for i in indices_calib_test if test_groups[i] in forgotten_clusters]
        calibration_retain_indices = [i for i in indices_calib_test if test_groups[i] not in forgotten_clusters]
        D_calib_forget = Subset(test_dataset, calibration_forget_indices)
        D_calib_retain = Subset(test_dataset, calibration_retain_indices)

        # --- Train split: 100% train_main ---
        perm_train = torch.randperm(len(D_train), generator=rng).tolist()
        n_train = len(perm_train)
        num_train_main = int(0.9 * n_train)
        train_main_idx = perm_train[:num_train_main]
        train_val_idx = perm_train[num_train_main:]

        # Within the 10% val: 80% calib_val, rest forget_val/retain_val by group
        num_calib_val = int(0.5 * len(train_val_idx))
        calib_val_idx = train_val_idx[:num_calib_val]
        rem_val_idx = train_val_idx[num_calib_val:]

        forget_val_idx = [i for i in rem_val_idx if train_groups[i] in forgotten_clusters]
        retained_val_idx = [i for i in rem_val_idx if train_groups[i] not in forgotten_clusters]

        D_train_main = Subset(D_train, train_main_idx)
        D_calib_val = Subset(D_train, calib_val_idx)
        D_f_val = Subset(D_train, forget_val_idx)
        D_r_val = Subset(D_train, retained_val_idx)

        # --- Training subsets S vs R ---
        # This logic correctly uses the original index ('orig_idx') to look up the
        # group assignment from the full 'train_groups' array.
        S_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_groups[orig_idx] in forgotten_clusters]
        R_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_groups[orig_idx] not in forgotten_clusters]

        S_subset = Subset(D_train_main, S_indices)
        R_subset = Subset(D_train_main, R_indices)

    elif "pca" in mode:
        # if dataset_name == "news":
        #     raise ValueError("PCA mode is not compatible with the 'news' text dataset.")

        # PCA-based Grouping
        train_pc_cluster, test_pc_cluster = pca_projector(D_train, test_dataset, batch_size)
        print("PCA projection applied.")
        # if vary:
        perm_clusters = torch.randperm(num_classes, generator=rng).tolist()
        forgotten_clusters = perm_clusters[:forgot_set]
        forgotten_clusters = set(forgotten_clusters)
        # else:
        #     forgotten_clusters = set(forgot_set)
        print(f"Forgotten PC clusters selected: {forgotten_clusters}")

        # --- Test split: 80% calib, 20% forget/retain by label ---
        perm_test = torch.randperm(len(test_dataset), generator=rng).tolist()
        n_test = len(perm_test)
        num_calib_test = int(0.8 * n_test)
        indices_calib_test = perm_test[:num_calib_test]
        remaining_test = perm_test[num_calib_test:]

        indices_f_test = [i for i in remaining_test if test_pc_cluster[i] in forgotten_clusters]
        indices_r_test = [i for i in remaining_test if test_pc_cluster[i] not in forgotten_clusters]

        D_calib = Subset(test_dataset, indices_calib_test)
        D_f = Subset(test_dataset, indices_f_test)
        D_r = Subset(test_dataset, indices_r_test)

        # --- Calibration subsets by label from D_calib ---
        calibration_forget_indices = [i for i in indices_calib_test if test_pc_cluster[i] in forgotten_clusters]
        calibration_retain_indices = [i for i in indices_calib_test if test_pc_cluster[i] not in forgotten_clusters]
        D_calib_forget = Subset(test_dataset, calibration_forget_indices)
        D_calib_retain = Subset(test_dataset, calibration_retain_indices)

        # --- Train split: 90% train_main, 10% val ---
        perm_train = torch.randperm(len(D_train), generator=rng).tolist()
        n_train = len(perm_train)
        num_train_main = int(0.9 * n_train)
        train_main_idx = perm_train[:num_train_main]
        train_val_idx = perm_train[num_train_main:]

        # Within the 10% val: 80% calib_val, rest forget_val/retain_val by label
        num_calib_val = int(0.5 * len(train_val_idx))
        calib_val_idx = train_val_idx[:num_calib_val]
        rem_val_idx = train_val_idx[num_calib_val:]

        forget_val_idx = [i for i in rem_val_idx if train_pc_cluster[i] in forgotten_clusters]
        retained_val_idx = [i for i in rem_val_idx if train_pc_cluster[i] not in forgotten_clusters]

        D_train_main = Subset(D_train, train_main_idx)
        D_calib_val = Subset(D_train, calib_val_idx)
        D_f_val = Subset(D_train, forget_val_idx)
        D_r_val = Subset(D_train, retained_val_idx)

        # --- Training subsets S vs R (unchanged) ---
        S_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_pc_cluster[orig_idx] in forgotten_clusters]
        R_indices = [i for i, orig_idx in enumerate(train_main_idx) if train_pc_cluster[orig_idx] not in forgotten_clusters]

        S_subset = Subset(D_train_main, S_indices)
        R_subset = Subset(D_train_main, R_indices)

    elif "label" in mode:
        # --- Label-based partitioning ---
        if num_classes is None:
            raise ValueError("num_classes was not set for the chosen dataset.")

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            all_test_labels = test_dataset.targets
            print(len(all_test_labels))
            all_train_labels = D_train.targets
            print(len(all_train_labels))

        # Select forgotten classes via torch.randperm
        # if vary:
        perm_classes = torch.randperm(num_classes, generator=rng).tolist()
        forgotten_classes = perm_classes[:forgot_set]
        forgotten_classes = set(forgotten_classes)  # Convert to set for faster lookups
        # else:
        #     forgotten_classes = set(forgot_set)
        print(f"Forgotten classes selected: {forgotten_classes}")

        # --- Test split: 80% calib, 20% forget/retain by label ---
        perm_test = torch.randperm(len(test_dataset), generator=rng).tolist()
        n_test = len(perm_test)
        num_calib_test = int(0.8 * n_test)
        indices_calib_test = perm_test[:num_calib_test]
        remaining_test = perm_test[num_calib_test:]

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            indices_f_test = [i for i in remaining_test if all_test_labels[i] in forgotten_classes]
            indices_r_test = [i for i in remaining_test if all_test_labels[i] not in forgotten_classes]
        else:
            indices_f_test = [i for i in remaining_test if test_dataset[i][1] in forgotten_classes]
            indices_r_test = [i for i in remaining_test if test_dataset[i][1] not in forgotten_classes]

        D_calib = Subset(test_dataset, indices_calib_test)
        D_f = Subset(test_dataset, indices_f_test)
        D_r = Subset(test_dataset, indices_r_test)

        # --- Calibration subsets by label from D_calib ---
        if dataset_name != "20_newsgroups" and dataset_name != "news":
            calibration_forget_indices = [i for i in indices_calib_test if all_test_labels[i] in forgotten_classes]
            calibration_retain_indices = [i for i in indices_calib_test if all_test_labels[i] not in forgotten_classes]
        else:
            calibration_forget_indices = [idx for idx in indices_calib_test if test_dataset[idx][1] in forgotten_classes]
            calibration_retain_indices = [idx for idx in indices_calib_test if test_dataset[idx][1] not in forgotten_classes]
        D_calib_forget = Subset(test_dataset, calibration_forget_indices)
        D_calib_retain = Subset(test_dataset, calibration_retain_indices)

        # --- Train split: 90% train_main, 10% val ---
        perm_train = torch.randperm(len(D_train), generator=rng).tolist()
        n_train = len(perm_train)
        num_train_main = int(0.9 * n_train)
        train_main_idx = perm_train[:num_train_main]
        train_val_idx = perm_train[num_train_main:]

        # Within the 10% val: 80% calib_val, rest forget_val/retain_val by label
        num_calib_val = int(0.5 * len(train_val_idx))
        calib_val_idx = train_val_idx[:num_calib_val]
        rem_val_idx = train_val_idx[num_calib_val:]

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            forget_val_idx = [i for i in rem_val_idx if all_train_labels[i] in forgotten_classes]
            retained_val_idx = [i for i in rem_val_idx if all_train_labels[i] not in forgotten_classes]
        else:
            forget_val_idx = [i for i in rem_val_idx if D_train[i][1] in forgotten_classes]
            retained_val_idx = [i for i in rem_val_idx if D_train[i][1] not in forgotten_classes]

        D_train_main = Subset(D_train, train_main_idx)
        D_calib_val = Subset(D_train, calib_val_idx)
        D_f_val = Subset(D_train, forget_val_idx)
        D_r_val = Subset(D_train, retained_val_idx)

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            subset_labels = [all_train_labels[i] for i in train_main_idx]
            S_indices = [i for i, label in enumerate(subset_labels) if label in forgotten_classes]
            R_indices = [i for i, label in enumerate(subset_labels) if label not in forgotten_classes]
        # --- Training subsets S vs R (unchanged) ---
        else:
            S_indices = [i for i, (_, label) in enumerate(D_train_main) if label in forgotten_classes]
            R_indices = [i for i, (_, label) in enumerate(D_train_main) if label not in forgotten_classes]

        S_subset = Subset(D_train_main, S_indices)
        R_subset = Subset(D_train_main, R_indices)

    elif "random" in mode:
        # --- Random point-based partitioning (not label-based) ---
        if num_classes is None:
            raise ValueError("num_classes was not set for the chosen dataset.")

        if dataset_name != "20_newsgroups" and dataset_name != "news":
            all_test_labels = test_dataset.targets
            print(len(all_test_labels))
            all_train_labels = D_train.targets
            print(len(all_train_labels))

        print(f"Using random point selection with forget_set size: {forgot_set}")

        # --- Test split: 80% calib, 20% forget/retain randomly ---
        perm_test = torch.randperm(len(test_dataset), generator=rng).tolist()
        n_test = len(perm_test)
        num_calib_test = int(0.6 * n_test)
        indices_calib_test = perm_test[:num_calib_test]
        remaining_test = perm_test[num_calib_test:]

        # Randomly split remaining test into forget and retain (not by labels)
        perm_remaining_test = torch.randperm(len(remaining_test), generator=rng).tolist()
        num_forget_test = min(forgot_set, len(remaining_test) // 2)  # Ensure balanced split
        indices_f_test = [remaining_test[i] for i in perm_remaining_test[:num_forget_test]]
        indices_r_test = [remaining_test[i] for i in perm_remaining_test[num_forget_test:]]

        D_calib = Subset(test_dataset, indices_calib_test)
        D_f = Subset(test_dataset, indices_f_test)
        D_r = Subset(test_dataset, indices_r_test)

        # --- Calibration subsets randomly from D_calib (not by label) ---
        perm_calib = torch.randperm(len(indices_calib_test), generator=rng).tolist()
        num_calib_forget = min(forgot_set, len(indices_calib_test) // 2)  # Balanced split
        calibration_forget_indices = [indices_calib_test[i] for i in perm_calib[:num_calib_forget]]
        calibration_retain_indices = [indices_calib_test[i] for i in perm_calib[num_calib_forget:]]
        D_calib_forget = Subset(test_dataset, calibration_forget_indices)
        D_calib_retain = Subset(test_dataset, calibration_retain_indices)

        # --- Train split: 90% train_main, 10% val ---
        perm_train = torch.randperm(len(D_train), generator=rng).tolist()
        n_train = len(perm_train)
        num_train_main = int(0.9 * n_train)
        train_main_idx = perm_train[:num_train_main]
        train_val_idx = perm_train[num_train_main:]

        # Within the 10% val: 80% calib_val, rest forget_val/retain_val randomly
        num_calib_val = int(0.5 * len(train_val_idx))
        calib_val_idx = train_val_idx[:num_calib_val]
        rem_val_idx = train_val_idx[num_calib_val:]

        # Randomly split remaining validation into forget and retain (not by labels)
        perm_rem_val = torch.randperm(len(rem_val_idx), generator=rng).tolist()
        num_forget_val = min(forgot_set, len(rem_val_idx) // 2)  # Balanced split
        forget_val_idx = [rem_val_idx[i] for i in perm_rem_val[:num_forget_val]]
        retained_val_idx = [rem_val_idx[i] for i in perm_rem_val[num_forget_val:]]

        D_train_main = Subset(D_train, train_main_idx)
        D_calib_val = Subset(D_train, calib_val_idx)
        D_f_val = Subset(D_train, forget_val_idx)
        D_r_val = Subset(D_train, retained_val_idx)

        # --- Training subsets S vs R: Random selection from D_train_main ---
        # S_subset will have exactly 'forgot_set' random points
        perm_train_main = torch.randperm(len(D_train_main), generator=rng).tolist()
        S_indices = perm_train_main[:forgot_set]  # First forgot_set points for S
        R_indices = perm_train_main[forgot_set:]  # Rest for R

        S_subset = Subset(D_train_main, S_indices)
        R_subset = Subset(D_train_main, R_indices)

        D_f = S_subset  # S is the forget subset
        D_r = R_subset  # R is the retained subset

    

    copied_list = R_indices[:]
    random.shuffle(copied_list)
    R_small = Subset(D_train_main, copied_list[:int(len(R_indices)/20)])
    copied_list_s = S_indices[:]
    random.shuffle(copied_list_s)
    S_small = Subset(D_train_main, copied_list_s[:int(len(S_indices)/20)])

    S_subset_un = Subset(D_train_main, copied_list_s[int(len(S_indices)/20):int(len(S_indices)/8)])
    R_subset_un = Subset(D_train_main, copied_list[int(len(R_indices)/20):int(len(R_indices)/8)])
    # Get the indices from each subset
    s_indices = S_subset.indices
    r_indices = R_subset.indices

    s_indices_small = S_small.indices
    r_indices_small = R_small.indices

    # Build relative indices 0..len(D_calib_val)-1 for the second-level Subset
    rel_calib_val_indices = list(range(len(D_calib_val)))
    random.shuffle(rel_calib_val_indices)

    k = max(1, len(rel_calib_val_indices) // 10)  # pick your fraction safely
    D_calib_instance_small = Subset(D_calib_val, rel_calib_val_indices[:k])
    
    # # Concatenate the indices
    all_indices = s_indices + r_indices

    all_indices_small = s_indices_small + r_indices_small

    # Create a new subset with the combined indices
    # This will be equivalent to D_train_main if S and R indices cover the whole dataset
    D_train_main = Subset(D_train_main, all_indices)

    D_train_main_calibration = Subset(D_train_main, all_indices_small)

    # Get one merged subset now which is instance_calibration = D_calib_instance_small union D_train_main_calibration
    total_calibration_subset = union_subsets(D_calib_instance_small, D_train_main_calibration)

    
    if "instance" in mode:
        D_calib_val = D_train_main_calibration #total_calibration_subset
        D_f         = S_subset_un
        D_r         = R_subset_un

    # R_subset = Subset(D_train_main, r_indices)
    # --- DataLoaders ---
    train_loader              = DataLoader(D_train_main,    batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_val_loader    = DataLoader(D_calib_val,     batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    forget_loader             = DataLoader(D_f,             batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    retained_loader           = DataLoader(D_r,             batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_loader        = DataLoader(D_calib,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    forget_val_loader         = DataLoader(D_f_val,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    retained_val_loader       = DataLoader(D_r_val,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    S_loader                  = DataLoader(S_subset,        batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    R_loader                  = DataLoader(R_subset,        batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    R_small_loader            = DataLoader(R_small,         batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_forget_loader = DataLoader(D_calib_forget,  batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    calibration_retain_loader = DataLoader(D_calib_retain,  batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    test_loader               = DataLoader(test_dataset,    batch_size=batch_size, shuffle=True,  num_workers=n_work, generator=rng)
    print("Data loaders created successfully.")

    return {
        "train_loader": train_loader,
        "test_loader": test_loader,
        "cal_test_loader": calibration_loader,
        "Df_loader": forget_loader,
        "Dr_loader": retained_loader,
        "cal_unlearn_loader": calibration_val_loader,
        "Vf_loader": forget_val_loader,
        "Vr_loader": retained_val_loader,
        "Tf_loader": S_loader,
        "Tr_loader": R_loader,
        "cal_test_forget_loader": calibration_forget_loader,
        "cal_test_retain_loader": calibration_retain_loader,
        "cal_unlearn_subset": D_calib_val,
        "Vf_subset": D_f_val,
        "Vr_subset": D_r_val,
        "cal_test_subset": D_calib,
        "Df_subset": D_f,
        "Dr_subset": D_r,
        "test_subset": test_dataset,
        "Tr_small_loader": R_small_loader,
    }
