import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from numpy.lib.format import open_memmap
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from generate_mask import save_gradient_ratio
import os
import argparse
from models import *
from models.resnet_orig import ResNet18_orig
from models.vgg import VGG
from models.vgg_svd import vgg11_bn
import pandas as pd
import random
import time
import copy
import numpy as np
from torch.utils.data import Dataset
from torchvision.io import read_image
from sgld_optim import *
import re

from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


def discretize(x):
    return torch.round(x * 255) / 255


def FGSM_perturb(x, y, model=None, bound=None, criterion=None):
    device = model.parameters().__next__().device
    model.zero_grad()
    x_adv = x.detach().clone().requires_grad_(True).to(device)

    pred = model(x_adv)
    loss = criterion(pred, y)
    loss.backward()

    grad_sign = x_adv.grad.data.detach().sign()
    x_adv = x_adv + grad_sign * bound
    x_adv = discretize(torch.clamp(x_adv, 0.0, 1.0))

    return x_adv.detach()


def plot_confusion_matrix(true_labels, pred_labels, class_names, title, ax=None):
    cm = confusion_matrix(true_labels, pred_labels)
    cm_normalized = cm.astype("int")  # no normalization in your example
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    sns.heatmap(cm_normalized, annot=True, fmt="d", cmap="Greens", xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel("Predicted Labels")
    ax.set_ylabel("True Labels")
    ax.set_title(title)
    return ax


def forget_loss_max_assign(new_logits, old_probs, forgetting_class):
    eps = 1e-9
    batch_size, num_classes = new_logits.shape

    # Clone old_probs to construct q_star
    q_star = old_probs.clone()

    # Extract and zero out the forget class
    forget_mass = q_star[:, forgetting_class].clone()  # [B]
    q_star[:, forgetting_class] = 0.0

    # Find the max class for each sample excluding the forget class
    masked_probs = q_star.clone()
    masked_probs[:, forgetting_class] = -1  # ensure it's not selected
    max_indices = torch.argmax(masked_probs, dim=1)  # [B]

    # Add the forget class probability to the max class
    q_star[torch.arange(batch_size), max_indices] += forget_mass.squeeze()

    # Compute softmax on new logits
    new_probs = F.softmax(new_logits, dim=1)

    # Compute KL-like cross-entropy loss: -sum(q_star * log(p))
    loss = - (q_star * torch.log(new_probs + eps)).sum(dim=1).mean()
    return loss


def forget_loss_drop_class_dim(new_logits, old_probs, forgetting_class):
    eps = 1e-9
    batch_size, num_class = new_logits.shape
    n_class = old_probs.shape[1]
    #keep_indices = [k for k in range(n_class) if k != forgetting_class]

    q_star = old_probs.clone() 
    q_star[:, forgetting_class] = 0.0 # [:, keep_indices]

    # Denominator = (1 - old_probs(c|x)) for each sample
    # denom = 1.0 - old_probs[:, forgetting_class]  
    # denom = denom.clamp_min(eps).unsqueeze(1)     
    denom = (1.0 - old_probs[:, forgetting_class].sum(dim=1)).clamp_min(eps).unsqueeze(1)


    q_star = q_star / denom  

    new_probs = F.softmax(new_logits, dim=1)  
    # predicted_q = torch.argmax(q_star, dim=1) 
    # predicted_p = torch.argmax(new_probs, dim=1) 
    # same_class_ratio = (predicted_q == predicted_p).float().mean().item()
    #print("Q STAR, NEW Probs", same_class_ratio)
    # Cross-entropy: -sum(q^*(y) * log p'_W(y)) for the remaining classes
    loss = - (q_star * torch.log(new_probs + eps)).sum(dim=1).mean()
    return loss


def expand_model(model):
    last_fc_name = None
    last_fc_layer = None

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            last_fc_name = name
            last_fc_layer = module

    if last_fc_name is None:
        raise ValueError("No Linear layer found in the model.")

    num_classes = last_fc_layer.out_features

    bias = last_fc_layer.bias is not None

    new_last_fc_layer = nn.Linear(
        in_features=last_fc_layer.in_features,
        out_features=num_classes + 1,
        bias=bias,
        device=last_fc_layer.weight.device,
        dtype=last_fc_layer.weight.dtype,
    )

    with torch.no_grad():
        new_last_fc_layer.weight[:-1] = last_fc_layer.weight
        if bias:
            new_last_fc_layer.bias[:-1] = last_fc_layer.bias

    parts = last_fc_name.split(".")
    current_module = model
    for part in parts[:-1]:
        current_module = getattr(current_module, part)
    setattr(current_module, parts[-1], new_last_fc_layer)


def get_projection_matrix(device, Mr, Mf):
    update_dict = OrderedDict()
    for act in Mr.keys():
        mr = Mr[act] 
        mf = Mf[act] 
        I = torch.eye(mf.shape[0]).to(device)
        update_dict[act] =  I  - (mf - torch.mm(mf,mr) )
    return update_dict


def forget_loss_max(logits, forget_class):
    eps = 1e-9  # small constant for numerical stability

    probs = F.softmax(logits, dim=1)
    target_probs = probs.clone().detach()
    forget_mass = target_probs[:, forget_class]
    # print("forget_mass", forget_mass.shape)
    # Zero out forget class
    target_probs[:, forget_class] = 0.0
    max_indices = torch.argmax(target_probs, dim=1)
    # print("max_indices", max_indices.shape)
    target_probs.scatter_add_(1, max_indices.unsqueeze(1), 
                                forget_mass.squeeze(1).unsqueeze(1))

    # Compute KL-style loss
    loss = - (target_probs * torch.log(probs + eps)).sum(dim=1).mean()
    return loss

def forget_loss_fn_similarity_based(logits, forget_class, similarity_matrix, beta=1.0):
    """
    Similarity-based reweighting function using the new equation:
    q*(y|x) = p̃(y|x) * exp(β * s_y) / Σ_{j≠y_f} p̃(j|x) * exp(β * s_j)
    q*(y_f|x) = 0
    
    Args:
        logits: [batch_size, num_classes] - model logits
        forget_class: int - class to forget (y_f)
        similarity_matrix: [num_classes, num_classes] - similarity scores between classes
        beta: float - temperature parameter for similarity scaling
    
    Returns:
        loss: scalar tensor - KL divergence loss
    """
    eps = 1e-9
    batch_size, num_classes = logits.shape
    
    # Safety checks and handle different input types
    if isinstance(forget_class, (list, tuple)):
        if len(forget_class) != 1:
            raise ValueError(f"forget_class should be a single class, got list/tuple with {len(forget_class)} elements: {forget_class}")
        forget_class = forget_class[0]
    elif isinstance(forget_class, torch.Tensor):
        if forget_class.numel() != 1:
            raise ValueError(f"forget_class should be a single class, got tensor with {forget_class.numel()} elements: {forget_class}")
        forget_class = forget_class.item()
    
    # Convert to int to ensure it's a scalar
    try:
        forget_class = int(forget_class)
    except (ValueError, TypeError) as e:
        raise ValueError(f"forget_class must be convertible to int, got {type(forget_class)}: {forget_class}") from e
    
    if forget_class < 0 or forget_class >= num_classes:
        raise ValueError(f"forget_class {forget_class} is out of range [0, {num_classes-1}]")
    
    if similarity_matrix.shape != (num_classes, num_classes):
        raise ValueError(f"similarity_matrix shape {similarity_matrix.shape} doesn't match expected ({num_classes}, {num_classes})")
    
    # Ensure tensors are on the same device
    similarity_matrix = similarity_matrix.to(logits.device)
    
    # Debug: Print similarity matrix information
    # print(f"\n=== Similarity Matrix Debug ===")
    # print(f"Forget class: {forget_class}")
    # print(f"Similarity matrix shape: {similarity_matrix.shape}")
    # print(f"Similarity matrix device: {similarity_matrix.device}")
    # print(f"Similarity matrix dtype: {similarity_matrix.dtype}")
    # print(f"Similarity matrix range: [{similarity_matrix.min().item():.4f}, {similarity_matrix.max().item():.4f}]")
    # print(f"Similarity matrix for forget class {forget_class}:")
    # print(f"  Row {forget_class}: {similarity_matrix[forget_class, :].cpu().numpy()}")
    # print(f"  Column {forget_class}: {similarity_matrix[:, forget_class].cpu().numpy()}")
    # print(f"Full similarity matrix:")
    # print(similarity_matrix.cpu().numpy())
    # print("=" * 40)
    
    with torch.no_grad():
        # Step 1: Compute original probabilities p(y|x)
        old_probs = F.softmax(logits, dim=1)  # [batch_size, num_classes]
        
        # Step 2: Apply first reweighting: p̃(y|x) = p(y|x) / (1 - p(y_f|x)) for y ≠ y_f
        # p̃(y_f|x) = 0
        p_tilde = old_probs.clone()
        p_tilde[:, forget_class] = 0.0  # Set forget class probability to 0
        
        # Normalize remaining probabilities
        denom = (1.0 - old_probs[:, forget_class]).clamp_min(eps).unsqueeze(1)
        p_tilde = p_tilde / denom
        
        # Step 3: Apply similarity-based reweighting
        # Get similarity scores for each class (excluding forget class)
        similarity_scores = similarity_matrix[forget_class, :].clone()  # [num_classes]
        similarity_scores[forget_class] = 0 # Set forget class similarity to 0
        # Compute exp(β * s_y) for all classes
        exp_similarity = torch.exp(similarity_scores * beta)
        # import pdb; pdb.set_trace()
        exp_similarity[forget_class] = 0  # Ensure forget class is 0
        
        # Compute q*(y|x) = p̃(y|x) * exp(β * s_y) / Σ_{j≠y_f} p̃(j|x) * exp(β * s_j)
        numerator = p_tilde * exp_similarity.unsqueeze(0)  # [batch_size, num_classes]
        
        # Denominator: sum over all classes except forget class
        denominator = numerator.sum(dim=1, keepdim=True).clamp_min(eps)  # [batch_size, 1]
        
        q_star = numerator / denominator
        q_star[:, forget_class] = 0.0  # Ensure forget class is exactly 0
    
    # Step 4: Compute KL divergence loss
    new_probs = F.softmax(logits, dim=1)
    # import pdb; pdb.set_trace()
    loss = - (q_star * torch.log(new_probs + eps)).sum(dim=1).mean()
    return loss

def compute_similarity_matrix_from_source_model(source_model, num_classes, method='euclidean', inv_temperature=100.0, pca_components=None, trainloader=None, forget_class=None, forget_temp=None, remain_temp=None):
    """
    Compute similarity matrix from source model's classifier weights.
    
    Args:
        source_model: The trained source model
        num_classes: Number of classes
        method: 'euclidean' or 'cosine' - method to compute similarities
        inv_temperature: Inverse temperature for softmax scaling (higher = more peaked)
        pca_components: Number of PCA components (None to disable PCA)
        trainloader: DataLoader for collecting embeddings (required if pca_components is specified)
        forget_class: int, class to forget (if None, uses single temperature for all classes)
        forget_temp: float, temperature for forget class (lower = more peaked)
        remain_temp: float, temperature for remaining classes (higher = less peaked)
    
    Returns:
        similarity_matrix: [num_classes, num_classes] tensor with similarity scores
    """
    print("Computing similarity matrix from source model classifier weights...")
    
    # Extract classifier weights from source model
    classifier_weights = extract_classifier_weights_from_model(source_model)
    
    # Apply PCA if requested
    if pca_components is not None and pca_components < classifier_weights.shape[1]:
        if trainloader is None:
            raise ValueError("trainloader is required when pca_components is specified")
        
        print(f"Applying PCA with {pca_components} components...")
        # Collect embeddings for PCA fitting
        all_embeddings = collect_embeddings_for_pca(source_model, trainloader)
        classifier_weights = apply_pca_to_weights_with_embeddings(classifier_weights, all_embeddings, pca_components)
    
    # Compute similarity matrix based on method
    if method == 'euclidean':
        # Compute Euclidean distances and convert to similarities
        dist_matrix = euclidean_distance_matrix(classifier_weights)
        similarity_matrix = distances_to_similarities(dist_matrix.cpu().numpy(), inv_temperature, forget_class, forget_temp, remain_temp)
        similarity_matrix = torch.tensor(similarity_matrix, dtype=torch.float32, device=classifier_weights.device)
    elif method == 'cosine':
        # Compute cosine similarities
        similarity_matrix = cosine_similarity_matrix(classifier_weights)
        # Apply softmax with temperature scaling
        similarity_matrix = similarity_matrix / inv_temperature
        similarity_matrix = torch.softmax(similarity_matrix, dim=1)
    else:
        raise ValueError(f"Unknown method: {method}. Use 'euclidean' or 'cosine'")
    # Zero out diagonal (class similarity with itself)
    similarity_matrix.fill_diagonal_(0)
    
    # Check for NaN or inf values
    if torch.isnan(similarity_matrix).any():
        print("Warning: NaN values detected in similarity matrix")
        similarity_matrix = torch.nan_to_num(similarity_matrix, nan=0.0)
    
    if torch.isinf(similarity_matrix).any():
        print("Warning: Inf values detected in similarity matrix")
        similarity_matrix = torch.nan_to_num(similarity_matrix, posinf=1.0, neginf=0.0)
    
    print(f"Similarity matrix shape: {similarity_matrix.shape}")
    if forget_class is not None and forget_temp is not None and remain_temp is not None:
        print(f"Method: {method}, Forget class: {forget_class}, Forget temp: {forget_temp}, Remain temp: {remain_temp}")
    else:
        print(f"Method: {method}, Temperature: {inv_temperature}")
    print(f"Similarity matrix range: [{similarity_matrix.min().item():.4f}, {similarity_matrix.max().item():.4f}]")
    
    return similarity_matrix

def extract_classifier_weights_from_model(model):
    """
    Extract the weight vectors from the final linear layer (classifier) of a model.
    
    Args:
        model: The neural network model
        
    Returns:
        weight_vectors: (num_classes, embedding_dim) tensor containing the weight vectors
                       for each class from embedding to logit
    """
    # For DataParallel models, the actual module is under model.module
    core = model.module if isinstance(model, nn.DataParallel) else model
    
    # Find the final linear layer
    final_linear, final_name = _find_classifier_layer(core)
    
    # Extract weights - shape is (num_classes, embedding_dim)
    weight_vectors = final_linear.weight.data.clone()  # (C, D)
    
    print(f"Extracted classifier weights: {weight_vectors.shape}")
    print(f"Weight vectors from {final_name} layer")
    
    return weight_vectors

def _find_classifier_layer(model):
    """
    Try to find the final linear layer name in common ResNet-style models.
    Returns the module and the attribute name string.
    """
    # Common names in various ResNet implementations
    for name in ["linear", "fc", "classifier"]:
        if hasattr(model, name) and isinstance(getattr(model, name), nn.Linear):
            return getattr(model, name), name
    # Fallback: search for the last nn.Linear
    last_lin = None
    last_name = None
    for n, m in model.named_modules():
        if isinstance(m, nn.Linear):
            last_lin = m
            last_name = n
    if last_lin is None:
        raise RuntimeError("Could not find a final Linear layer to hook.")
    return last_lin, last_name

def euclidean_distance_matrix(v):
    """
    v: (C, D) tensor
    returns (C, C) Euclidean distance matrix
    """
    # Compute pairwise squared distances using broadcasting
    # ||a - b||^2 = ||a||^2 + ||b||^2 - 2*a*b
    v_norm_sq = torch.sum(v**2, dim=1, keepdim=True)  # (C, 1)
    distances_sq = v_norm_sq + v_norm_sq.t() - 2 * (v @ v.t())  # (C, C)
    # Take sqrt and ensure non-negative (numerical stability)
    distances = torch.sqrt(torch.clamp(distances_sq, min=0.0))
    return distances

def cosine_similarity_matrix(v):
    """
    v: (C, D) tensor
    returns (C, C) cosine similarity matrix
    """
    v = F.normalize(v, dim=1)  # row-wise (per class) normalize
    return v @ v.t()

def distances_to_similarities(dist_mat, inv_temperature=1.0, forget_class=None, forget_temp=None, remain_temp=None):
    """
    Convert Euclidean distances to similarity probabilities using softmax.
    
    Args:
        dist_mat: (C, C) numpy array of distances
        inv_temperature: inverse temperature (beta/scale) for softmax (higher = more peaked)
                        Used as default temperature if forget_class is None
        forget_class: int, class to forget (if None, uses single temperature for all classes)
        forget_temp: float, temperature for forget class (lower = more peaked)
        remain_temp: float, temperature for remaining classes (higher = less peaked)
    
    Returns:
        prob_mat: (C, C) numpy array with similarity probabilities.
                  Diagonal entries set to 0, off-diagonal entries sum to 1 per row.
    """
    C = dist_mat.shape[0]
    prob_mat = np.zeros_like(dist_mat)
    eps = 1e-12
    
    for i in range(C):
        # Get distances for this class (excluding self)
        distances = np.delete(dist_mat[i], i)
        
        # Determine temperature for this class
        if forget_class is not None and forget_temp is not None and remain_temp is not None:
            # Use different temperatures for forget vs remain classes
            if i == forget_class:
                temp = forget_temp
            else:
                temp = remain_temp
        else:
            # Use single temperature for all classes
            temp = inv_temperature
        
        # Convert to similarities: smaller distance = higher similarity
        # Use negative distances and apply inverse temperature scaling
        similarities = -distances * temp
        # Convert to similarities: smaller distance => larger similarity via inverse-distance
        # Apply softmax to get probabilities
        probs = torch.softmax(torch.tensor(similarities), dim=0).numpy()
        # Insert back with 0 at diagonal position
        full_row = np.insert(probs, i, 0.0)
        prob_mat[i] = full_row
    
    return prob_mat

def apply_pca_to_weights_with_embeddings(weights, embeddings, n_components):
    """
    Apply PCA dimensionality reduction to classifier weights using embeddings for fitting.
    
    Args:
        weights: (num_classes, embedding_dim) tensor - classifier weights
        embeddings: (num_samples, embedding_dim) tensor - embeddings to fit PCA on
        n_components: Number of PCA components
    
    Returns:
        weights_pca: (num_classes, n_components) tensor
    """
    from sklearn.decomposition import PCA
    
    # Convert to numpy for PCA
    weights_np = weights.cpu().numpy()
    embeddings_np = embeddings.cpu().numpy()
    
    # Fit PCA on embeddings (which have many samples)
    pca = PCA(n_components=n_components)
    pca.fit(embeddings_np)
    
    # Transform classifier weights using the fitted PCA
    weights_pca = pca.transform(weights_np)
    
    # Convert back to tensor
    weights_pca = torch.tensor(weights_pca, dtype=torch.float32, device=weights.device)
    
    print(f"Applied PCA: {weights.shape} -> {weights_pca.shape}")
    print(f"Explained variance ratio: {pca.explained_variance_ratio_.sum():.4f}")
    
    return weights_pca

def collect_embeddings_for_pca(model, trainloader):
    """
    Collect embeddings from the penultimate layer for PCA fitting.
    
    Args:
        model: The neural network model
        trainloader: DataLoader for collecting embeddings
    
    Returns:
        all_embeddings: (num_samples, embedding_dim) tensor
    """
    print("Collecting embeddings for PCA fitting...")
    
    all_embeddings = []
    captured_features = {}
    
    def hook_fn(name):
        def hook(module, input, output):
            # Capture the input (embeddings) instead of output (logits)
            captured_features[name] = input[0]  # input is a tuple, take the first element
        return hook
    
    # Register hook to capture penultimate layer features
    if hasattr(model, 'module'):  # DataParallel case
        actual_model = model.module
    else:
        actual_model = model
    
    # Find the penultimate layer (before the final linear layer)
    penultimate_layer = None
    linear_layers = []
    for name, module in actual_model.named_modules():
        if isinstance(module, torch.nn.Linear):
            linear_layers.append((name, module))
    
    # Find the layer before the final classifier
    if len(linear_layers) >= 2:
        # Take the second-to-last linear layer
        penultimate_layer = linear_layers[-2][1]
        print(f"Selected penultimate layer: {linear_layers[-2][0]} (output size: {penultimate_layer.out_features})")
    elif len(linear_layers) == 1:
        # Only one linear layer, use it but it's probably the classifier
        penultimate_layer = linear_layers[0][1]
        print(f"Only one linear layer found: {linear_layers[0][0]} (output size: {penultimate_layer.out_features})")
        print("Warning: This might be the classifier layer, not the penultimate layer")
    
    if penultimate_layer is None:
        raise RuntimeError("Could not find penultimate layer for embedding collection")
    
    # Register hook and collect features
    hook = penultimate_layer.register_forward_hook(hook_fn('penultimate'))
    
    model.eval()
    with torch.no_grad():
        for inputs, targets in trainloader:
            inputs = inputs.float().to(next(model.parameters()).device)
            
            # Forward pass to capture features
            _ = model(inputs)
            
            # Get captured features
            features = captured_features['penultimate']  # [B, feature_dim]
            all_embeddings.append(features.cpu())
    
    # Remove hook
    hook.remove()
    
    # Concatenate all embeddings
    all_embeddings = torch.cat(all_embeddings, dim=0)
    print(f"Collected embeddings shape: {all_embeddings.shape}")
    
    return all_embeddings

def forget_loss_fn(logits, forget_class):
    """
    Original reweighting function (kept for backward compatibility)
    logits: [batch_size, num_classes]
    forget_class: int
    """
    eps = 1e-9
    with torch.no_grad():
        batch_size, num_class = logits.shape
        n_class = logits.shape[1]
        old_probs = F.softmax(logits, dim=1)

        q_star = old_probs.clone() 
        q_star[:, forget_class] = 0.0 # [:, keep_indices]
    
        denom = (1.0 - old_probs[:, forget_class].sum(dim=1)).clamp_min(eps).unsqueeze(1)

        q_star = q_star / denom  
        # q_star = q_star / q_star.sum(dim=1, keepdim=True)  # normalize to sum=1

    new_probs = F.softmax(logits, dim=1)  
    loss = - (q_star * torch.log(new_probs + eps)).sum(dim=1).mean()
    
    return loss


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class simpleDataset(Dataset):
    def __init__(self, data, labels, transform=None, target_transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

        self.data = self.data.detach().cpu().numpy()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        image = image.transpose(1, 2, 0)
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


class RLDataset(Dataset):
    def __init__(self, forgetset, new_classes=None, num_classes=10, noise_level=0.01, add_noise=False):
        self.image_set = forgetset
        self.add_noise = add_noise
        self.noise_level = noise_level
        self.num_classes = num_classes
        self.new_classes = new_classes

    def __len__(self):
        return len(self.image_set)

    def __getitem__(self, idx):
        image = self.image_set[idx][0]
        if self.new_classes is not None:
            label = self.new_classes[idx]
        else:
            true_label = self.image_set[idx][1]
            label = np.random.choice([i for i in range(self.num_classes) if i != true_label]) # random label

        return image, label


class basicDataset(Dataset):
    def __init__(self, data, transform=None, target_transform=None):
        self.data = data
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    # def report(self):
    #     print('reporting from basicDataset')
    #     print(self.data.shape)

    def __getitem__(self, idx):
        if self.data.shape[-1] == 2:
            image_in = self.data[idx]['image']
            image = copy.deepcopy(np.asarray(image_in))
            # print(image.shape)
            if len(image.shape) == 2:
                image = copy.deepcopy(np.stack((image, image, image), axis=2))
            # image = image.transpose(2, 0, 1)
        else:
            print('shape is 1')
            image_in = self.data[idx][0]

        if self.data.shape[-1] == 2:
            label = self.data[idx]['label']
        else:
            label = self.data[idx][1]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


def SVC_fit_predict(shadow_train, shadow_test, target_train, target_test):
    n_shadow_train = shadow_train.shape[0]
    n_shadow_test = shadow_test.shape[0] # test_f
    n_target_train = target_train.shape[0] # train_f
    n_target_test = target_test.shape[0] # test_r
    len_limit = min(n_shadow_train, n_shadow_test, n_target_train, n_target_test)
    X_shadow = torch.cat([shadow_train[:len_limit], shadow_test[:len_limit]]).cpu().numpy().reshape(len_limit + len_limit, -1)
    Y_shadow = np.concatenate([np.ones(len_limit), np.zeros(len_limit)])
    shuffle_indices = np.random.permutation(len(Y_shadow))
    X_shadow = X_shadow[shuffle_indices]
    Y_shadow = Y_shadow[shuffle_indices]
    clf = SVC(kernel='linear', class_weight='balanced') # SVC(C=3, gamma='auto', kernel='rbf')
    clf.fit(X_shadow, Y_shadow)

    accs = []


    if n_target_train > 0:
        X_target_train = target_train.cpu().numpy().reshape(n_target_train, -1)
        acc_train = 1- clf.predict(X_target_train).mean()
        accs.append(acc_train)

    if n_target_test > 0:
        X_target_test = target_test.cpu().numpy().reshape(n_target_test, -1)
        acc_test = clf.predict(X_target_test).mean()
        accs.append(acc_test)
    print("accs", accs)
    return acc_train


def svc_mia(net, train_loader, test_loader, forgetting_class, unlearn_method='RW'):
    train_conf_r, train_r_labels = [], []
    train_conf_f, train_f_labels = [], []
    test_conf_r, test_r_labels = [], []
    test_conf_f, test_f_labels = [], []
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS':
                test_logits = net(images)
            elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                test_logits, new_logits = net(images) 
                test_logits = new_logits
            elif unlearn_method == 'BS' or unlearn_method == 'genmask' or unlearn_method == 'l1' or unlearn_method == 'salun' or unlearn_method == 'GA' or unlearn_method == 'RL' or unlearn_method == 'FT':
                test_logits, _ = net(images)
            probs = F.softmax(test_logits, dim=1)
            mask_remain = ~torch.isin(labels, torch.tensor(forgetting_class).cuda())
            mask_forget = torch.isin(labels, torch.tensor(forgetting_class).cuda())
            train_conf_r.append(probs[mask_remain])
            train_r_labels.append(labels[mask_remain])
            train_conf_f.append(probs[mask_forget])
            train_f_labels.append(labels[mask_forget])
            # print(train_conf_f)
        train_conf_r = torch.cat(train_conf_r, dim=0)
        train_r_labels = torch.cat(train_r_labels, dim=0)
        train_conf_f = torch.cat(train_conf_f, dim=0)
        train_f_labels = torch.cat(train_f_labels, dim=0)
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS':
                test_logits = net(images)
            elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                test_logits, new_logits = net(images) 
                test_logits = new_logits
            elif unlearn_method == 'BS' or unlearn_method == 'genmask' or unlearn_method == 'l1' or unlearn_method == 'salun' or unlearn_method == 'GA' or unlearn_method == 'RL' or unlearn_method == 'FT':
                test_logits, _ = net(images)
            probs = F.softmax(test_logits, dim=1)
            mask_remain = ~torch.isin(labels, torch.tensor(forgetting_class).cuda())
            mask_forget = torch.isin(labels, torch.tensor(forgetting_class).cuda())
            test_conf_r.append(probs[mask_remain])
            test_r_labels.append(labels[mask_remain])
            test_conf_f.append(probs[mask_forget])
            test_f_labels.append(labels[mask_forget])
        test_conf_r = torch.cat(test_conf_r, dim=0)
        test_r_labels = torch.cat(test_r_labels, dim=0)
        test_conf_f = torch.cat(test_conf_f, dim=0)
        test_f_labels = torch.cat(test_f_labels, dim=0)

    # print(train_conf_r, train_r_labels)
    # print("==========================")
    # print(train_conf_f, train_f_labels)
    # print("==========================")
    # print(test_conf_r, test_r_labels)
    # print("==========================")
    # print(test_conf_f, test_f_labels)
    shadow_train = torch.gather(
        train_conf_r, 1, train_r_labels[:, None])
    shadow_test = torch.gather(
        train_conf_f, 1, train_f_labels[:, None])
    target_train = torch.gather(
        test_conf_r, 1, test_r_labels[:, None])
    target_test = torch.gather(
        test_conf_f, 1, test_f_labels[:, None])

    print("check remain forget")
    acc_conf = SVC_fit_predict(shadow_train, target_test, shadow_test, target_train)
    
    acc_mean = acc_conf# (acc_conf + acc_test) / 2
    print(f"MIA Attack Accuracy on Forgotten Class: {acc_mean:.4f}")


def collect_prob(data_loader, model, unlearn_method):
    if data_loader is None:
        return torch.zeros([0, 10]), torch.zeros([0])

    prob = []
    targets = []

    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            batch = [tensor.to(next(model.parameters()).device)
                     for tensor in batch]
            data, target = batch

            with torch.no_grad():
                if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS':
                    log_logits = model(data) # Returns log_prob. exp( ) 
                    log_prob = F.log_softmax(log_logits, dim=1)
                elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                    log_logits, new_log_logits = model(data) # Returns log_prob. exp( ) 
                    log_prob = F.log_softmax(log_logits, dim=1)
                    new_log_prob = F.log_softmax(new_log_logits, dim=1)
                elif unlearn_method == 'BS' or unlearn_method == 'genmask' or unlearn_method == 'l1' or unlearn_method == 'salun' or unlearn_method == 'GA' or unlearn_method == 'RL' or unlearn_method == 'FT':
                    log_logits, _ = model(data)
                    log_prob = F.log_softmax(log_logits, dim=1)
                if unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                    log_prob = new_log_prob
                prob.append(torch.exp(log_prob).data)
                targets.append(target)
    # print("======================SVD!!!========================", targets[:10])
    return torch.cat(prob), torch.cat(targets)


def SVC_attack(shadow_train, target_train, target_test, shadow_test, model, forgetting_class, unlearn_method='RW'):
    
    """
    shadow_train=remainloader, 
                target_train=forgetloader, 
                target_test=remainloader_test,
                shadow_test=forgetloader_test,
    """

    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model, unlearn_method)
    shadow_test_prob, shadow_test_labels = collect_prob(shadow_test, model, unlearn_method)

    target_train_prob, target_train_labels = collect_prob(target_train, model, unlearn_method)
    target_test_prob, target_test_labels = collect_prob(target_test, model, unlearn_method)

    print("prob of target_train", target_train_prob[:3], target_train_labels[:10])
    print("prob of target_test", target_test_prob[:3], target_test_labels[:10])
    print("prob of shadow_train", shadow_train_prob[:3], shadow_train_labels[:10])
    print("prob of shadow_test", shadow_test_prob[:3], shadow_test_labels[:10])

    shadow_train_conf = torch.gather(
        shadow_train_prob, 1, shadow_train_labels[:, None])
    shadow_test_conf = torch.gather(
        shadow_test_prob, 1, shadow_test_labels[:, None])
    target_train_conf = torch.gather(
        target_train_prob, 1, target_train_labels[:, None])
    target_test_conf = torch.gather(
        target_test_prob, 1, target_test_labels[:, None])
    
    # shadow_train_conf = torch.gather(
    #     shadow_train_prob, 1, torch.full_like(shadow_train_labels[:, None], 9))

    # shadow_test_conf = torch.gather(
    #     shadow_test_prob, 1, torch.full_like(shadow_test_labels[:, None], 9))

    # target_train_conf = torch.gather(
    #     target_train_prob, 1, torch.full_like(target_train_labels[:, None], 9))

    # target_test_conf = torch.gather(
    #     target_test_prob, 1, torch.full_like(target_test_labels[:, None], 9))

    acc_conf = SVC_fit_predict(
        shadow_train_conf, target_test_conf, shadow_test_conf, target_train_conf, )

    m = {
         "confidence": acc_conf,
         }
    print(m)
    return m
  

def SVC_attack_new(shadow_train, target_train, target_test, shadow_test, model, forgetting_class, unlearn_method='RW'):
    """
    shadow_train=forgetloader_test_9,  # truck test
                target_train=forgetloader_test,  # car test
                target_test=remainloader_test_9,  # no truck no automobile
                shadow_test=remainloader_test_9,  # remain test w/o car truck
    """
    
    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model, unlearn_method) # truck test
    shadow_test_prob, shadow_test_labels = collect_prob(shadow_test, model, unlearn_method) # remain wo truck car

    target_train_prob, target_train_labels = collect_prob(target_train, model, unlearn_method) # car test
    target_test_prob, target_test_labels = collect_prob(target_test, model, unlearn_method) # remain wo truck car
    
    print("prob of car", shadow_test_prob[:10], shadow_test_labels[:10])
    print("prob of truck", shadow_train_prob[:10], shadow_train_labels[:10])
    shadow_train_conf = torch.gather(
        shadow_train_prob, 1, torch.full_like(shadow_train_labels[:, None], 9)) # truck test

    shadow_test_conf = torch.gather(
        shadow_test_prob, 1, torch.full_like(shadow_test_labels[:, None], 9)) # remain wo truck car

    target_train_conf = torch.gather(
        target_train_prob, 1, torch.full_like(target_train_labels[:, None], 9)) # car test

    target_test_conf = torch.gather(
        target_test_prob, 1, torch.full_like(target_test_labels[:, None], 9)) # remain wo truck car


    # shadow_train, shadow_test, target_train,
    acc_conf = SVC_fit_predict(
        shadow_train_conf, target_test_conf, shadow_test_conf, target_train_conf, ) # truck test, remain test, car test
    # acc_conf = SVC_fit_predict(
    #     shadow_train_conf, shadow_test_conf, target_train_conf, target_test_conf)
    m = {
         "confidence": acc_conf,
         }
    print(m)
    return m


def distribution_attack_new(shadow_train, target_train, target_test, model, forgetting_class, unlearn_method='RW'):
    """
    Fit distributions based on shadow_train and target_test, then test on target_train
    to see if it comes from shadow_train or target_test.
    
    Args:
        shadow_train: Data loader for shadow training data (e.g., truck test)
        target_train: Data loader for target training data (e.g., car test) - this is what we test
        target_test: Data loader for target test data (e.g., remain test w/o car truck)
        model: The model to use for inference
        forgetting_class: Class to forget
        unlearn_method: Method used for unlearning
    
    Returns:
        dict: Results containing accuracy of classification
    """
    
    # Collect probabilities for all datasets
    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model, unlearn_method)
    target_train_prob, target_train_labels = collect_prob(target_train, model, unlearn_method)
    target_test_prob, target_test_labels = collect_prob(target_test, model, unlearn_method)
    
    print("Shadow train prob shape:", shadow_train_prob.shape)
    print("Target train prob shape:", target_train_prob.shape)
    print("Target test prob shape:", target_test_prob.shape)
    
    # Extract confidence scores for the forgetting class (class 9)
    shadow_train_conf = torch.gather(
        shadow_train_prob, 1, torch.full_like(shadow_train_labels[:, None], 9))
    
    target_train_conf = torch.gather(
        target_train_prob, 1, torch.full_like(target_train_labels[:, None], 9))
    
    target_test_conf = torch.gather(
        target_test_prob, 1, torch.full_like(target_test_labels[:, None], 9))
    
    print("Shadow train conf shape:", shadow_train_conf.shape)
    print("Target train conf shape:", target_train_conf.shape)
    print("Target test conf shape:", target_test_conf.shape)
    
    # Hypothesis test with Gaussian likelihoods on 1D confidence scores
    # Flatten to 1D
    shadow_vals = shadow_train_conf.view(-1).detach().cpu().numpy()
    target_test_vals = target_test_conf.view(-1).detach().cpu().numpy()
    target_train_vals = target_train_conf.view(-1).detach().cpu().numpy()

    # Estimate parameters
    eps = 1e-8
    mu_shadow = float(np.mean(shadow_vals)) if shadow_vals.size > 0 else 0.0
    std_shadow = float(np.std(shadow_vals, ddof=1)) if shadow_vals.size > 1 else 0.0
    mu_target = float(np.mean(target_test_vals)) if target_test_vals.size > 0 else 0.0
    std_target = float(np.std(target_test_vals, ddof=1)) if target_test_vals.size > 1 else 0.0

    # Avoid zero std
    std_shadow = max(std_shadow, eps)
    std_target = max(std_target, eps)

    # Compute log-likelihoods under each Gaussian
    def normal_logpdf(x, mu, std):
        return -0.5 * np.log(2 * np.pi * (std ** 2)) - 0.5 * ((x - mu) ** 2) / (std ** 2)

    ll_shadow = normal_logpdf(target_train_vals, mu_shadow, std_shadow)
    ll_target = normal_logpdf(target_train_vals, mu_target, std_target)

    # Classify by higher likelihood: 1 -> shadow distribution, 0 -> target_test distribution
    if target_train_vals.size > 0:
        preds = (ll_shadow > ll_target).astype(np.float32)
        frac_shadow = float(preds.mean())
    else:
        frac_shadow = 0.0

    results = {
        "gaussian_mu_shadow": mu_shadow,
        "gaussian_std_shadow": std_shadow,
        "gaussian_mu_target": mu_target,
        "gaussian_std_target": std_target,
        "fraction_target_train_as_shadow": frac_shadow,
        "shadow_train_samples": int(shadow_train_conf.shape[0]),
        "target_test_samples": int(target_test_conf.shape[0]),
        "target_train_samples": int(target_train_conf.shape[0]),
    }

    print("Distribution Attack (Gaussian) Results:", results)
    return results

os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3"

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--dataset', default='cifar10', help='dataset')
parser.add_argument('--model', default='ResNet18', help='Deep Learning model to train')
parser.add_argument('--method', default='catclip', help='clipping method (use orig for no clipping)')
parser.add_argument('--mode', default='wBN', help='what to do with BN layers (leave empty for keeping it as it is)')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--LRsteps', default=40, type=int, help='LR scheduler step')
parser.add_argument('--epochs', default=10, type=int, help='number of epochs')
parser.add_argument('--seed', default=1, type=int, help='seed value')
parser.add_argument('--steps', default=50, type=int, help='setp count for clipping BN')
parser.add_argument('--num_classes', default=10, type=int, help='number of classes in the dataset')
parser.add_argument('--batch_size', default=128, type=int, help='number of classes in the dataset')

parser.add_argument('--unlearn_method', default='RL', type=str)
parser.add_argument('--unlearn_indices', default=None, type=str)
parser.add_argument('--unlearn_evaluate', default='svc_mia', type=str)

parser.add_argument('--unlearn_count', default=1000, type=int)
parser.add_argument('--start_idx', default=0, type=int)

parser.add_argument('--source_model_path', default=None, type=str)
parser.add_argument('--mask_path', default=None, type=str)
parser.add_argument('--save_checkpoints', default=0, type=int)

parser.add_argument('--use_all_ref', default=True, type=bool)
parser.add_argument('--use_remain', default=True, type=bool)
parser.add_argument('--remain', default='use', type=str)
parser.add_argument('--use_remain_sample', default=False, type=bool)

parser.add_argument('--unnormalize', default=True, type=bool)
parser.add_argument('--norm_cond', default='unnorm', help='unnorm or norm for transform')

parser.add_argument('--req_mode', default='single', type=str)
parser.add_argument('--salun_ratio', default='0.5', type=str, help='ratio of masking in salun')

parser.add_argument('--alpha_l1', default=0., type=float)
parser.add_argument('--noise_ratio', default=3, type=int) # default 120

parser.add_argument('--catsn', default=-1, type=float)
parser.add_argument('--convsn', default=1., type=float)
parser.add_argument('--outer_steps', default=100, type=int)
parser.add_argument('--convsteps', default=100, type=int)
parser.add_argument('--opt_iter', default=5, type=int)
parser.add_argument('--outer_iters', default=1, type=int)

# Similarity matrix computation arguments
parser.add_argument('--similarity_method', default='euclidean', type=str, choices=['euclidean', 'cosine'], help='Method to compute similarity matrix from source model')
parser.add_argument('--similarity_temperature', default=100.0, type=float, help='Inverse temperature for similarity matrix computation (higher = more peaked)')
parser.add_argument('--similarity_pca_components', default=None, type=int, help='Number of PCA components for similarity matrix (None to disable PCA)')
parser.add_argument('--similarity_beta', default=1.0, type=float, help='Beta parameter for similarity-based reweighting')
parser.add_argument('--forget_temp', default=None, type=float, help='Temperature for forget class (lower = more peaked). If None, uses single temperature for all classes.')
parser.add_argument('--remain_temp', default=None, type=float, help='Temperature for remaining classes (higher = less peaked). If None, uses single temperature for all classes.')

args = parser.parse_args()

            
unlearn_indices_check = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
count_unlearn = len(unlearn_indices_check)

match = re.search(r'label_(\d+)\.csv', args.unlearn_indices)
if match:
    number = int(match.group(1))
    print("forgetting class", number)  # Output: 1
forgetting_class = [number]
print('count_unlearn: ', count_unlearn)
print('requested mode: ', args.req_mode)

if args.norm_cond == 'norm':
    args.unnormalize = False
print('!!!!!!!!! unnormalized: ', args.unnormalize)
print('!!!!!!!!! salun ratio: ', args.salun_ratio)

print('model: ', args.model)

dataset_name = args.dataset
if args.unnormalize:
    dataset_name += '_unnorm'
print('dataset', dataset_name)

if args.remain != 'use':
    args.use_remain = False

if args.remain == 'use' or args.unlearn_method == 'retrain' or 'RW' in args.unlearn_method:
    args.use_remain = True


print('use remain flag: ', args.use_remain)


save_checkpoints = args.save_checkpoints
if save_checkpoints == 1:
    save_checkpoints = True
else:
    save_checkpoints = False

print('save_checkpoints: ', save_checkpoints)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('==========', device)

if device == 'cuda':
    # net = torch.nn.DataParallel(net)
    print('chosen: ', device)
    cudnn.benchmark = True

if args.dataset in ['mnist', 'cifar10']:
    args.num_classes = 10
elif args.dataset == 'cifar100':
    args.num_classes = 100
elif args.dataset == 'imagenet':
    args.num_classes = 200
else:
    print("wrong dataset")
    exit(0)


base_path_df = pd.read_csv('path_file_rw.csv')
print(base_path_df)
tuples = zip(base_path_df['info'], base_path_df['path'])
base_path_dict = dict(tuples)
base_path = base_path_dict['base_path']
print('base_path: ', base_path)

# Training
def train(epoch, optimizer, scheduler, criterion, test_model, unlearn_method='RW', writer=None, model_path="./checkpoints/", mask=None, similarity_matrix=None, beta=1.0):
    print('\nEpoch: %d' % epoch)
    print('l1 regularization: ', args.alpha_l1)
    print('unlearn method: ', unlearn_method)
    global count_setp
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    batch_idx = -1

    print('\ninside train function :')
    print('trainset :', len(trainset) )
    print('unl idx :', len(unlearn_idx) )

    if unlearn_method == 'retrain' or unlearn_method == 'l1':
        if not args.use_remain:
            sample_indices = np.random.choice(len(trainset), len(forgetset), replace=False)
            trainset_combined = torch.utils.data.Subset(trainset, sample_indices)
        else:
            trainset_combined = trainset
        
        if unlearn_method == 'l1' and args.alpha_l1 == 0.:
            # args.alpha_l1 = 0.0005
            args.alpha_l1 = 0.000001

    elif 'RW' in unlearn_method: # unlearn_method == 'RW_FT' or unlearn_method == 'RW':
        trainset_combined = trainset
    
    elif unlearn_method == 'BS' or unlearn_method == 'BE' or unlearn_method == 'GA':
        trainset_combined = forgetset
    else:
        trainset_combined = trainset


    print('trainset_combined len: ', len(trainset_combined))
    trainloader = torch.utils.data.DataLoader(trainset_combined, shuffle=True, batch_size=args.batch_size, num_workers=1)

    start = time.time()

    if args.use_remain and unlearn_method == 'BS':
        for batch_idx, (inputs, targets) in enumerate(remainloader):
            if epoch == 0 and batch_idx == 0:
                print('inputs remain shape: ', inputs.shape, targets[:10])
            inputs, targets = inputs.float().to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            loss = criterion(outputs, targets)

            if args.alpha_l1 > 0.:
                print("Is there norm for BS?")
                loss += args.alpha_l1 * l1_regularization(net)

            loss.backward()

            if mask is not None:
                for name, param in net.named_parameters():
                    if param.grad is not None:
                        param.grad *= mask[name]

            optimizer.step()
        
        
        print('in loop train - acc', 100.*correct/total)

    if unlearn_method == 'RW' or unlearn_method == 'RW_multi' or unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par':
        test_model = copy.deepcopy(net)
    
    if unlearn_method == 'BS':
        test_model = copy.deepcopy(net)
        bound = 0.1

        
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        torch.cuda.empty_cache()
        if epoch == 0 and batch_idx == 0:
            print('inputs shape: ', inputs.shape)
        inputs, targets = inputs.float().to(device), targets.to(device)
        optimizer.zero_grad()
        if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS':
            outputs = net(inputs)
        elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
            outputs, new_logits = net(inputs)
        elif unlearn_method == 'BS' or unlearn_method == 'genmask' or unlearn_method == 'l1' or unlearn_method == 'salun' or unlearn_method == 'GA' or unlearn_method == 'RL' or unlearn_method == 'FT':
            outputs, _ = net(inputs)
        else:
            print('unknown unlearn method')
            exit(0)
        old_probs = F.softmax(outputs, dim=1)

        if unlearn_method == 'BS':
            test_model.eval()
            image_adv = FGSM_perturb(
                inputs, targets, model=test_model, bound=bound, criterion=criterion
            )

            adv_outputs = test_model(image_adv)
            adv_label = torch.argmax(adv_outputs, dim=1)
            targets_orig = copy.deepcopy(targets)
            targets = adv_label
            loss = criterion(outputs, targets)
        
        loss = criterion(outputs, targets)

        if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par':
            # Use similarity-based reweighting if similarity matrix is provided
            if similarity_matrix is not None:
                loss = forget_loss_fn_similarity_based(outputs, forgetting_class, similarity_matrix, beta)
            else:
                # loss = forget_loss_fn(outputs, forgetting_class)
                loss = forget_loss_max(outputs, forgetting_class)
        elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
            test_model.eval()
            old_logits, _ = test_model(inputs)
            old_probs = F.softmax(old_logits, dim=1)
            outputs = new_logits
            # Use similarity-based reweighting if similarity matrix is provided
            if similarity_matrix is not None:
                loss = forget_loss_fn_similarity_based(new_logits, forgetting_class, similarity_matrix, beta)
            else:
                loss = forget_loss_max_assign(new_logits, old_probs, forgetting_class)
            # loss = forget_loss_drop_class_dim(new_logits, old_probs, forgetting_class)
        elif unlearn_method == 'retrain':
            loss = criterion(outputs, targets)

        if unlearn_method == 'BS':
            targets = targets_orig
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        count_setp += 1
        

    tot_time = time.time() - start
    print('time: ', tot_time)
    print('train - acc', 100.*correct/total)
    print('train - loss', train_loss/(batch_idx+1))
    
    scheduler.step()

    print('Saving..')
    state = {
        'net': net.state_dict(),
        'epoch': epoch,
    }

    model_path_i = model_path + ".%d" % (epoch)
    if args.unlearn_method == 'retrain':
        if epoch in [80, 100,120,140,160,180,200]:
            torch.save(state, model_path_i)
    else:
        torch.save(state, model_path_i)

    net.eval()

    return train_loss/(batch_idx+1), 100.*correct/total


def test(loader, epoch, criterion, unlearn_method='RW', writer=None, mode='test', model_path="./checkpoints/", plot_images=False):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    batch_idx = -1
    class_correct = [0 for _ in range(args.num_classes)]
    class_total = [0 for _ in range(args.num_classes)]   

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.float().to(device), targets.to(device)
            if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS':
                outputs = net(inputs)
            elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                outputs, new_outputs = net(inputs)
            elif unlearn_method == 'BS' or unlearn_method == 'genmask' or unlearn_method == 'l1' or unlearn_method == 'salun' or unlearn_method == 'GA' or unlearn_method == 'RL' or unlearn_method == 'FT':
                outputs, _ = net(inputs)
            else:
                print('unknown unlearn method')
                exit(0)
            if unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                outputs = new_outputs

            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            # Per-class accuracy
            for i in range(len(targets)):
                label = targets[i].item()
                pred = predicted[i].item()
                class_total[label] += 1
                if pred == label:
                    class_correct[label] += 1


    if model_path is not None:
        # Save checkpoint.
        acc = 100.*correct/total
        if acc > best_acc:
            best_acc = acc

            print('Saving Best..')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, model_path)

    if writer is not None:
        writer.add_scalar('test/acc', 100.*correct/total, epoch)
        writer.add_scalar('test/loss', test_loss/(batch_idx+1), epoch)

    print("{}/acc {:.4f}".format(mode, 100. * correct / total))
    print("{}/loss {:.4f}".format(mode, test_loss/(batch_idx+1)))

    return test_loss/(batch_idx+1), 100.*correct/total


def compute_similarity_matrix(model, loader, device, num_classes, unlearn_method='RW', temperature=1.0):
    """Compute softmax similarity matrix between class mean features (penultimate layer) with temperature"""
    
    # ORIGINAL APPROACH (commented out) - using logits directly:
    # sums = torch.zeros((num_classes, 0), device=device)
    # counts = torch.zeros((num_classes,), device=device)
    # initialized = False
    # model.eval()
    # with torch.no_grad():
    #     for inputs, targets in loader:
    #         inputs, targets = inputs.float().to(device), targets.to(device)
    #         if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS':
    #             log_logits = model(inputs)
    #         elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
    #             log_logits, new_logits = model(inputs)
    #         else:
    #             log_logits = model(inputs)
    #         if not initialized:
    #             sums = torch.zeros((num_classes, log_logits.shape[1]), device=device)
    #             initialized = True
    #         for c in targets.unique():
    #             mask = (targets == c)
    #             if mask.any():
    #                 sums[c] += log_logits[mask].sum(dim=0)
    #                 counts[c] += mask.sum()
    # means = torch.where(counts.view(-1, 1) > 0, sums / counts.view(-1, 1).clamp(min=1), sums)
    
    # NEW APPROACH - using penultimate layer features:
    sums = torch.zeros((num_classes, 0), device=device)
    counts = torch.zeros((num_classes,), device=device)
    initialized = False
    captured_features = {}

    def hook_fn(name):
        def hook(module, input, output):
            # Capture the input (512-dim features) instead of output (10-dim logits)
            captured_features[name] = input[0]  # input is a tuple, take the first element
        return hook

    # Register hook to capture penultimate layer features
    # For ResNet, this is typically the last layer before the classifier
    if hasattr(model, 'module'):  # DataParallel case
        actual_model = model.module
    else:
        actual_model = model
    
    # Find the penultimate layer (before the final linear layer)
    penultimate_layer = None
    print("Searching for penultimate layer...")
    linear_layers = []
    for name, module in actual_model.named_modules():
        if isinstance(module, torch.nn.Linear):
            print(f"Found Linear layer: {name}, output size: {module.out_features}")
            linear_layers.append((name, module))
    
    # Find the layer before the final classifier (assuming last layer is classifier)
    if len(linear_layers) >= 2:
        # Take the second-to-last linear layer
        penultimate_layer = linear_layers[-2][1]
        print(f"Selected penultimate layer: {linear_layers[-2][0]} (output size: {penultimate_layer.out_features})")
    elif len(linear_layers) == 1:
        # Only one linear layer, use it but it's probably the classifier
        penultimate_layer = linear_layers[0][1]
        print(f"Only one linear layer found: {linear_layers[0][0]} (output size: {penultimate_layer.out_features})")
        print("Warning: This might be the classifier layer, not the penultimate layer")
    
    if penultimate_layer is None:
        # Fallback: look for the last layer before classifier
        layers = list(actual_model.modules())
        for i, layer in enumerate(layers):
            if isinstance(layer, torch.nn.Linear) and i == len(layers) - 1:
                penultimate_layer = layer
                break
    if penultimate_layer is None:
        print("Warning: Could not find penultimate layer, using logits instead")
        # Fallback to using logits (original approach)
        model.eval()
        with torch.no_grad():
            for inputs, targets in loader:
                inputs, targets = inputs.float().to(device), targets.to(device)
                
                if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS' or unlearn_method == 'FT':
                    log_logits = model(inputs)
                elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                    log_logits, new_logits = model(inputs)
                else:
                    log_logits = model(inputs)
                
                if not initialized:
                    sums = torch.zeros((num_classes, log_logits.shape[1]), device=device)
                    initialized = True
                    
                for c in targets.unique():
                    mask = (targets == c)
                    if mask.any():
                        sums[c] += log_logits[mask].sum(dim=0)
                        counts[c] += mask.sum()
    else:
        # Register hook and capture features
        hook = penultimate_layer.register_forward_hook(hook_fn('penultimate'))
        
        model.eval()
        with torch.no_grad():
            for inputs, targets in loader:
                inputs, targets = inputs.float().to(device), targets.to(device)
                
                # Forward pass to capture features
                if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS' or unlearn_method == 'FT':
                    _ = model(inputs)
                elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                    _, _ = model(inputs)
                else:
                    _ = model(inputs)
                
                # Get captured features
                features = captured_features['penultimate']  # [B, feature_dim]
                
                if not initialized:
                    sums = torch.zeros((num_classes, features.shape[1]), device=device)
                    initialized = True
                    
                for c in targets.unique():
                    mask = (targets == c)
                    if mask.any():
                        sums[c] += features[mask].sum(dim=0)
                        counts[c] += mask.sum()
        
        # Remove hook
        hook.remove()
    # Avoid division by zero; where count=0, keep zeros
    means = torch.where(counts.view(-1, 1) > 0, sums / counts.view(-1, 1).clamp(min=1), sums)
    # Compute cosine similarity matrix using 512-dimensional embeddings
    # eps = 1e-12
    # norms = means.norm(dim=1, keepdim=True).clamp(min=eps)
    # norm_means = means / norms
    sim_mat = means @ means.t()  # [C, C]
    sim_mat.fill_diagonal_(0)
    
    # Apply softmax with temperature to each row
    sim_mat = sim_mat / temperature
    sim_mat = torch.softmax(sim_mat, dim=1)
    
    # Zero out the diagonal (class itself)
    sim_mat.fill_diagonal_(0)
    
    # Renormalize each row to sum to 1 (excluding the diagonal)
    row_sums = sim_mat.sum(dim=1, keepdim=True)
    sim_mat = sim_mat / row_sums.clamp(min=1e-12)
    
    return sim_mat.cpu()

def check_test(trainloader, loader, test_net, epoch, criterion, unlearn_method='RW', writer=None, mode='test', model_path="./checkpoints/", plot_images=False, compute_sim=False):
    test_net.eval()
    test_loss = 0
    correct = 0
    total = 0
    batch_idx = -1
    class_correct = [0 for _ in range(args.num_classes)]
    class_total = [0 for _ in range(args.num_classes)]


    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.float().to(device), targets.to(device)
            if unlearn_method == 'RW_FT' or unlearn_method == 'RW_FT_par' or unlearn_method == 'retrain' or unlearn_method == 'BS':
                outputs = test_net(inputs)
            elif unlearn_method == 'RW' or unlearn_method == 'RW_multi':
                outputs, new_outputs = test_net(inputs)
                # outputs = new_outputs
            elif unlearn_method == 'BS' or unlearn_method == 'genmask' or unlearn_method == 'l1' or unlearn_method == 'salun' or unlearn_method == 'GA' or unlearn_method == 'RL' or unlearn_method == 'FT':
                outputs, _ = test_net(inputs)
            else:
                print('unknown unlearn method')
                exit(0)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            for i in range(len(targets)):
                label = targets[i].item()
                pred = predicted[i].item()
                class_total[label] += 1
                if pred == label:
                    class_correct[label] += 1
        

    if writer is not None:
        writer.add_scalar('test/acc', 100.*correct/total, epoch)
        writer.add_scalar('test/loss', test_loss/(batch_idx+1), epoch)

    print('initial', mode + '/acc', 100.*correct/total)
    print('initial', mode + '/loss', test_loss/(batch_idx+1))
    if mode == 'test':
        for i in range(args.num_classes):
            if class_total[i] == 0:
                acc = 0.0
            else:
                acc = 100.0 * class_correct[i] / class_total[i]
            # print(f'initial Accuracy for class {i}: {acc:.2f}%')
        svc_mia(test_net, trainloader, loader, forgetting_class, unlearn_method=unlearn_method)
    
    # Compute similarity matrix if requested
    if compute_sim:
        print("Computing similarity matrix...")
        sim_mat = compute_similarity_matrix(test_net, loader, device, args.num_classes, unlearn_method, temperature=1.0)
        print("Cosine similarity + softmax matrix (temperature=20.0, using 512-dim penultimate features):")
        import numpy as np
        np.set_printoptions(precision=4, suppress=True)
        print(np.array(sim_mat))
        
    return test_loss/(batch_idx+1), 100.*correct/total

if __name__ == "__main__":
    method = args.method
    steps_count = args.steps  #### BN clip steps for hard clip
    concat_sv = False
    step_size = args.LRsteps
    clip_outer_flag = False
    outer_steps = args.outer_steps
    outer_iters = args.outer_iters
    if args.catsn > 0.:
        concat_sv = True
        clip_steps = args.convsteps
        clip_outer_flag = True

    mode = args.mode
    bn_flag = True
    bn_clip = False
    bn_hard = False
    opt_iter = args.opt_iter
    if mode == 'wBN':
        mode = ''
        bn_flag = True
        bn_clip = False
        clip_steps = 50
    elif mode == 'noBN':
        bn_flag = False
        bn_clip = False
        opt_iter = 1
        clip_steps = 100
    elif mode == 'clipBN_hard':
        bn_flag = True
        bn_clip = True
        bn_hard = True
        clip_steps = 100
    else:
        print('unknown mode!')
        exit(0)

    ##================================================MULTIPLE CLASS===========================================
    unlearn_idx = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
    unlearn_idx = [int(i) for i in unlearn_idx]

    test_csv_path = args.unlearn_indices.replace('.csv', '_test.csv')
    unlearn_idx_test = pd.read_csv(test_csv_path)['unlearn_idx'].values
    unlearn_idx_test = [int(i) for i in unlearn_idx_test]


    seed_in = args.seed ##### !!!!! Do not use with more than one seed! some of the args gets changed during the first run @ToDo fix this!
    if seed_in == -1:
        geed_in = [1,2,3]
    else:
        seed_in = [seed_in]
    for seed in seed_in:
        print('seed.....', seed)
        best_acc = 0  # best test accuracy
        start_epoch = 0  # start from epoch 0 or last checkpoint epoch
        count_setp = 0

        seed_val = seed
        torch.manual_seed(seed_val)
        torch.cuda.manual_seed_all(seed_val)
        np.random.seed(seed_val)
        random.seed(seed_val)

        clip_flag    = False
        orig_flag    = False

        print('method: ', method)
        if method[:4] == 'fast' or method == 'clip':
            clip_flag    = True
        elif method == 'catclip':
            clip_flag    = True
        elif method == 'orig':
            orig_flag    = True
        else:
            print('unknown method!')
            exit(0)

        # Data
        print('==> Preparing data..')
        if args.dataset == 'mnist':
            print('using mnist')
            in_chan = 1
            if args.unnormalize:
                if args.model == 'ResNet18':
                    transform_train = transforms.Compose([
                        transforms.ToTensor(),
                    ])
                    transform_test = transforms.Compose([
                        transforms.ToTensor(),
                    ])
                elif args.model == 'VGG':
                    transform_train = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                    ])
                    transform_test = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                    ])
                    
            else:
                if args.model == 'ResNet18':
                    transform_train = transforms.Compose([
                        transforms.Resize((28, 28)),  # Ensure images are 28x28
                        transforms.ToTensor(),        # Convert images to PyTorch tensors
                        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
                    ])
                    transform_test = transform_train
                elif args.model == 'VGG':
                    transform_train = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
                    transform_test = transform_train

            
            trainset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform_train)
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.MNIST( root='./data/mnist', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.MNIST( root='./data/mnist', train=False, download=True, transform=transform_test)
            
        elif args.dataset == 'cifar10':
            print('using cifar 10')
            in_chan = 3

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])


            trainset = torchvision.datasets.CIFAR10( root='./data/cifar10', train=True, download=True, transform=transform_train) ### transofrm=transform_train

            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.CIFAR10( root='./data/cifar10', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.CIFAR10( root='./data/cifar10', train=False, download=True, transform=transform_test)

        elif args.dataset == 'cifar100':
            print('using cifar 100')
            in_chan = 3
            args.num_classes = 100

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
                ])


            trainset = torchvision.datasets.CIFAR100(root='./data/cifar100', train=True, download=True, transform=transform_train)
            # ./data/cifar100
            testset = torchvision.datasets.CIFAR100( root='./data/cifar100', train=False, download=True, transform=transform_test)

        elif args.dataset == 'imagenet':
            print('using Imagenet')
            in_chan = 3
            args.num_classes = 200

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(64, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(64, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
                ])

            train_dir = './data/tiny-imagenet-200/train'
            val_dir = './data/tiny-imagenet-200/val'
            trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_train)
            else:
                testset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_test)
                
        else:
            print('unknown dataset!')
            exit(0)

        indices_seed = args.unlearn_indices.split('/')[-1][:-4]
        indices_count = len(unlearn_idx) # args.unlearn_indices.split('/')[-2]

        args.outdir = f"/{dataset_name}/unlearn/{args.unlearn_method}/{indices_count}/unl_idx_{indices_seed}/"
        args.outdir = "scratch" + args.outdir
        args.outdir = base_path + args.outdir


        print(args.outdir)
        print('learning rate: ', args.lr)
        print('dataset: ', args.dataset)


        if args.unlearn_method == 'retrain':
            outdir = args.outdir + '/' + args.model + "_" + method + "_" + mode + "_" + str(seed_val) + "/"
            # outdir = args.outdir + '/' + args.source_model_path.split('/')[-2] + '/'
        else:
            # outdir = args.outdir + '/' + args.source_model_path.split('/')[-1] + '/'
            outdir = args.outdir + args.source_model_path.split('/')[-1] 

            outdir = outdir + '/use_remain_' + str(args.use_remain) + '/' + args.model + "_" + method + "_" + mode + "_" + str(seed_val) + "/"
            outdir += '/LRs_' + str(step_size) + '_lr_' + str(args.lr) + '/'


        print('outdir: ', outdir)
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        writer = SummaryWriter(outdir)

        print('==> Building model..')
        print('------------> outdir: ', outdir)
        print('-----------------------------------------------------------------')
        print('initial len of trainset: ', len(trainset))  


        request_count = 1
        if args.req_mode == 'adaptive':
            print('not implemented yet!')
            exit(0)

        prior_idx = []
        for req_idx in range(request_count):

            if len(forgetting_class) == 1:
                print("===============SINGLE FORGETTING===================")
                unlearn_idx = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
                unlearn_idx_9 = pd.read_csv(args.unlearn_indices.replace('_1.csv', '_9.csv'))['unlearn_idx'].values
                if len(unlearn_idx) != int(indices_count):
                    print('unlearn_idx count is not correct!')
                    # exit(0)
                unlearn_idx = [int(i) for i in unlearn_idx]
                unlearn_idx_9 = [int(i) for i in unlearn_idx_9]
                test_csv_path = args.unlearn_indices.replace('.csv', '_test.csv')
                test_csv_path_9 = args.unlearn_indices.replace('_1.csv', '_9_test.csv')
                unlearn_idx_test = pd.read_csv(test_csv_path)['unlearn_idx'].values
                unlearn_idx_test_9 = pd.read_csv(test_csv_path_9)['unlearn_idx'].values
                unlearn_idx_test = [int(i) for i in unlearn_idx_test]
                unlearn_idx_test_9 = [int(i) for i in unlearn_idx_test_9]

                removed_classes = [trainset[i][1] for i in unlearn_idx]
                df = pd.DataFrame({'unlearn_idx': unlearn_idx, 'removed_classes': removed_classes})
                df.to_csv(outdir + 'unlearn_idx.csv')
            
            ### remove the unlearned images from the trainset
            trainset_filtered = torch.utils.data.Subset(trainset, list(set(range(len(trainset))) - set(unlearn_idx) - set(prior_idx)))
            trainset_filtered_9 = torch.utils.data.Subset(trainset, list(set(range(len(trainset))) - set(unlearn_idx_9) - set(unlearn_idx) - set(prior_idx)))
            print('len of filtered trainset: ', len(trainset_filtered))  
            # print('len of filtered trainset_9: ', len(trainset_filtered_9))  
            # print(trainset_filtered.report())

            forgetset = torch.utils.data.Subset(trainset, unlearn_idx)
            forgetset_9 = torch.utils.data.Subset(trainset, unlearn_idx_9)
            print('len of forget set: ', len(forgetset))  
            # print(forgetset.report())


            trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=args.batch_size, num_workers=1)
            remainloader = torch.utils.data.DataLoader(trainset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
            remainloader_9 = torch.utils.data.DataLoader(trainset_filtered_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader = torch.utils.data.DataLoader(forgetset, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader_9 = torch.utils.data.DataLoader(forgetset_9, shuffle=False, batch_size=args.batch_size, num_workers=1)

            ### remove the unlearned images from the testset
            testset_filtered = torch.utils.data.Subset(testset, list(set(range(len(testset))) - set(unlearn_idx_test) - set(prior_idx)))
            testset_filtered_9 = torch.utils.data.Subset(testset, list(set(range(len(testset))) - set(unlearn_idx_test_9) - set(unlearn_idx_test) - set(prior_idx)))
            print('len of filtered testset: ', len(testset_filtered))  
            # print('len of filtered testset_9: ', len(testset_filtered_9))  

            forgetset_test = torch.utils.data.Subset(testset, unlearn_idx_test)
            forgetset_test_9 = torch.utils.data.Subset(testset, unlearn_idx_test_9)
            print('len of forget testset: ', len(forgetset_test))  

            testloader = torch.utils.data.DataLoader(testset, shuffle=False, batch_size=args.batch_size, num_workers=1)
            remainloader_test = torch.utils.data.DataLoader(testset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
            remainloader_test_9 = torch.utils.data.DataLoader(testset_filtered_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader_test = torch.utils.data.DataLoader(forgetset_test, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader_test_9 = torch.utils.data.DataLoader(forgetset_test_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
            if args.unlearn_method == 'retrain':
                if args.use_remain_sample:
                    sample_indices = np.random.choice(len(trainset_filtered), len(forgetset), replace=False)
                    trainset_filtered = torch.utils.data.Subset(trainset_filtered, sample_indices)
                trainset = trainset_filtered
            elif args.unlearn_method == 'RW' or args.unlearn_method == 'RW_multi' or args.unlearn_method == 'RW_FT' or args.unlearn_method == 'RW_FT_par':
                trainset = trainset
                testset = testset

            print('final len of trainset: ', len(trainset))  
            print('-----------------------------------------------------------------')
            
            if req_idx == 0:
                if args.model == 'ResNet18':
                    if orig_flag:
                        if args.dataset == 'imagenet': 
                            net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, tinynet=True, unlearn_method=args.unlearn_method)
                        else:
                            net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, unlearn_method=args.unlearn_method)
                        # net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_linear=False, bn_count=steps_count, device=device)
                    elif clip_flag:
                        net = ResNet18(concat_sv=concat_sv, in_chan=in_chan, device=device, clip=args.convsn, clip_concat=args.catsn, clip_flag=True, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_steps=clip_steps, bn_count=steps_count, clip_outer=clip_outer_flag, clip_opt_iter=opt_iter, summary=True, writer=writer, save_info=False, outer_iters=outer_iters, outer_steps=outer_steps, num_classes=args.num_classes)
                
                elif args.model == 'VGG':
                    if args.unlearn_method == 'RW' or args.unlearn_method == 'RW_multi':
                        if args.dataset == 'imagenet':
                            net = VGG_rw('VGG19', in_chan=in_chan, num_classes=args.num_classes, tinynet=True)
                        else:
                            net = VGG_rw('VGG19', in_chan=in_chan, num_classes=args.num_classes)
                    elif args.unlearn_method == 'BS' or args.unlearn_method == 'retrain' or args.unlearn_method == 'RW_FT' or args.unlearn_method == 'RW_FT_par':
                        if args.dataset == 'imagenet':
                            net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes, tinynet=True)
                        else:
                            net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes)
                    
                
                net = net.to(device)
                test_net = copy.deepcopy(net)
                test_net = nn.DataParallel(test_net) ### adds the "module." prefix to the state_dict keys
                net = nn.DataParallel(net) ### adds the "module." prefix to the state_dict keys
                criterion = nn.CrossEntropyLoss()



                if args.unlearn_method != 'retrain' and args.unlearn_method != 'reference':
                    if clip_flag:
                        if bn_flag:
                            checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_200')
                        else:
                            checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_120')
                        net.load_state_dict(checkpoint['state_dict'], strict=True)
                    else:
                        if args.dataset == 'mnist':
                            checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_best')
                        else: 
                            checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_best')
                        net.load_state_dict(checkpoint['state_dict'], strict=True)#, strict=False)
                    print("--->source model", args.source_model_path)
                    print('model loaded')

            tr_loss_list = []
            tr_acc_list = []
            ts_loss_list = []
            ts_acc_list = []
            fs_loss_list = []
            fs_acc_list = []
            re_loss_list = []
            re_acc_list = []
            best_keeping_list = []

            net.eval()
            print('-- train set:')
            tr_loss, tr_acc = 0., 0.
            print('-- test set:')
            ts_loss, ts_acc = test(testloader, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='test', model_path=None)
            print('--- forget set:')
            fs_loss, fs_acc = test(forgetloader, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='forget', model_path=None)
            print('-- remain set:')
            remain_loss, remain_acc = test(remainloader, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='remain', model_path=None)
            remain_loss, remain_acc = 0., 0.
            
            print("Check Forget/Remain acc and MIA scores")
            test_checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_best')
            test_net.load_state_dict(test_checkpoint['state_dict'], strict=True)#, strict=False)
            fst_loss, fst_acc = check_test(trainloader, forgetloader_test, test_net, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='forget', model_path=None)
            remaint_loss, remaint_acc = check_test(trainloader, remainloader_test, test_net, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='remain', model_path=None)
            t_loss, t_acc = check_test(trainloader, testloader, test_net, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='test', model_path=None, compute_sim=False)

            tr_loss_list.append(tr_loss)
            tr_acc_list.append(tr_acc)
            ts_loss_list.append(ts_loss)
            ts_acc_list.append(ts_acc)
            fs_loss_list.append(fs_loss)
            fs_acc_list.append(fs_acc)
            re_loss_list.append(remain_loss)
            re_acc_list.append(remain_acc)
            best_keeping_list.append(0)

            # net.train()

            if args.unlearn_method == 'RW_FT_par' and args.model == 'ResNet18':
                # Freeze all parameters by default
                for param in net.module.parameters():
                    param.requires_grad = False

                # Unfreeze layer3
                for param in net.module.layer2.parameters():
                    param.requires_grad = True

                # Unfreeze linear classifier
                for param in net.module.linear.parameters():
                    param.requires_grad = True
            
            if args.unlearn_method == 'RW_FT_par' and args.model == 'VGG':
                # Freeze all parameters by default
                for param in net.module.parameters():
                    param.requires_grad = False

                # Unfreeze layer3
                for param in net.module.features[10].parameters():
                    param.requires_grad = True

                # Unfreeze linear classifier
                for param in net.module.classifier.parameters():
                    param.requires_grad = True

            if args.dataset == 'mnist':
                if args.unlearn_method == 'retrain':
                    args.lr = 0.1
                if args.unlearn_method == 'RW_FT_par':
                    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.module.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
                else:
                    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                # optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                T_max = args.epochs

            elif args.dataset == 'cifar10':
                if args.unlearn_method == 'retrain':
                    args.lr = 0.1
                if args.unlearn_method == 'RW_FT_par':
                    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
                else:
                    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                #optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                if args.unlearn_method == 'retrain':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 121
                    if not bn_flag:
                        T_max = 101

                else:
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = args.epochs

            elif args.dataset == 'cifar100':
                if args.unlearn_method == 'retrain':
                    args.lr = 0.1
                if args.unlearn_method == 'RW_FT_par':
                    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.95, weight_decay=5e-4, nesterov=True)
                else:
                    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.95, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                #optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                if args.unlearn_method == 'retrain':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 121
                    if not bn_flag:
                        T_max = 121
                else:
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = args.epochs
                    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)                
                
            elif args.dataset == 'imagenet':
                if args.unlearn_method == 'retrain':
                    args.lr = 5e-4
                if args.unlearn_method == 'RW_FT_par':
                    optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
                else:
                    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                #optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                if args.unlearn_method == 'retrain':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 201
                    if not bn_flag:
                        T_max = 121
                else:
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = args.epochs
                    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
            
            else:
                raise ValueError('dataset must be one of cifar, mnist')

            model_path =  outdir + '_ckpt'
            model_path_test =  outdir + '_ckpt_best_test.pth'



            print('epoch: ', start_epoch)
            print('Tmax: ', T_max)

            sv_df = {}
            test_model = copy.deepcopy(net)

            # Compute similarity matrix from source model for similarity-based reweighting
            print("Computing similarity matrix from source model for reweighting...")
            similarity_matrix = compute_similarity_matrix_from_source_model(
                net, 
                args.num_classes, 
                method=args.similarity_method,
                inv_temperature=args.similarity_temperature,
                pca_components=args.similarity_pca_components,
                trainloader=trainloader,
                forget_class=forgetting_class[0] if forgetting_class else None,
                forget_temp=args.forget_temp,
                remain_temp=args.remain_temp
            )
            similarity_matrix = similarity_matrix.to(device)
            print(f"Similarity matrix shape: {similarity_matrix.shape}")
            
            # Set beta parameter for similarity scaling
            beta = args.similarity_beta
            
            for epoch in range(T_max):
                tr_loss, tr_acc = train(epoch, optimizer, scheduler, criterion, test_model, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path, mask=None, similarity_matrix=similarity_matrix, beta=beta)
                
                if True: #epoch % 5 == 0:
                    print('-- test set:')
                    ts_loss, ts_acc = test(testloader, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='test', plot_images=True)
                    print('--- forget set:')
                    fs_loss, fs_acc = test(forgetloader, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='forget')
                    fs_loss_test, fs_acc_test = test(forgetloader_test, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='forget')
                    print('-- remain set:')
                    remain_loss, remain_acc = test(remainloader, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='remain')
                    remain_loss_test, remain_acc_test = test(remainloader_test, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='remain')

                    if ts_acc == best_acc:
                        best_keeping_list.append(1)
                    else:
                        best_keeping_list.append(0)

                    tr_loss_list.append(tr_loss)
                    tr_acc_list.append(tr_acc)
                    ts_loss_list.append(ts_loss)
                    ts_acc_list.append(ts_acc)
                    fs_loss_list.append(fs_loss)
                    fs_acc_list.append(fs_acc)
                    re_loss_list.append(remain_loss)
                    re_acc_list.append(remain_acc)
            
            print('Saving Last..', model_path)
            state = {
                'net': net.state_dict(),
                'epoch': epoch,
            }
            torch.save(state, model_path + '.pth')

            df = pd.DataFrame({'tr_loss': tr_loss_list, 
            'tr_acc': tr_acc_list, 'ts_loss': ts_loss_list, 
            'ts_acc': ts_acc_list, 'fs_loss': fs_loss_list, 
            'fs_acc': fs_acc_list, 're_loss': re_loss_list, 
            're_acc': re_acc_list, 'best_keeping': best_keeping_list})

            print('saving results to ...', outdir)
            if args.unlearn_method == 'retrain':
                df.to_csv(outdir + 'loss_acc_results.csv')
            else:
                df.to_csv(outdir + str(step_size) + '_loss_acc_results.csv')
            
            # EVALUATION
            # MIA
            # net.eval()
            #svc_mia(net, trainloader, testloader, forgetting_class, unlearn_method=args.unlearn_method)
            print("======================SVD!!!========================")
            # net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, unlearn_method='RW')
            # net = net.to(device)
            # net = nn.DataParallel(net)
            # net.load_state_dict(checkpoint['net'], strict=True)
            # net.eval()
            evaluation_result = SVC_attack(
                shadow_train=remainloader, 
                shadow_test=forgetloader, 
                target_train=remainloader_test,
                target_test=forgetloader_test,
                model=net,
                forgetting_class=forgetting_class, 
                unlearn_method=args.unlearn_method)
            # print("===test on automobile===")
            # evaluation_result = SVC_attack(
            #     shadow_train=remainloader_9, 
            #     shadow_test=forgetloader_9, 
            #     target_train=remainloader_test_9,
            #     target_test=forgetloader_test_9,
            #     model=net,
            #     forgetting_class=9, 
            #     unlearn_method=args.unlearn_method)

            print("===test on automobile new===")
            evaluation_result = SVC_attack_new(
                shadow_train=forgetloader_test_9,  # truck test
                shadow_test=forgetloader_test,  # car test
                target_train=remainloader_test_9,  # no truck no automobile
                target_test=remainloader_test_9,  # remain test w/o car truck
                model=net,
                forgetting_class=9, 
                unlearn_method=args.unlearn_method)

            dist_result = distribution_attack_new(
                shadow_train=forgetloader_test_9, 
                target_train=forgetloader_test, 
                target_test=remainloader_test_9, 
                model=net, forgetting_class=9, 
                unlearn_method=args.unlearn_method)
