import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import defaultdict
import json
import torch.backends.cudnn as cudnn
from numpy.lib.format import open_memmap
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from generate_mask import save_gradient_ratio
import os
from collections import OrderedDict
import argparse
from models import *
from models.resnet_orig import ResNet18_orig
from models.vgg_svd import vgg11_bn
from models.vgg import VGG
import pandas as pd
import random
import time
import copy
import numpy as np
from torch.utils.data import Dataset
from torchvision.io import read_image
from sgld_optim import *
import re


from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve, balanced_accuracy_score, confusion_matrix, accuracy_score


def likelihood(score, mean, var):
    score = torch.tensor(score)
    mean = torch.tensor(mean)
    var = torch.tensor(var)
    nll = - ( ( (score - mean)**2) / (2 * (var ** 2 + 1e-32) ) ) - 0.5 * torch.log(var ** 2) - 0.5 * torch.log(4 * torch.acos(torch.zeros(1)))
    likelihood_val = torch.exp(nll)
    return likelihood_val


def get_likelihood_ratio(test_features, model_parameters):
    # print("test_features", test_features.shape)
    # print("model_parameters['mean_unlearn'] size", model_parameters['mean_unlearn'].size())
    in_likelihood = likelihood(
        test_features, 
        model_parameters['mean_unlearn'].view(-1,1), 
        model_parameters['std_unlearn'].view(-1,1) + 1e-10)
    import pdb; pdb.set_trace()
    out_likelihood = likelihood(
        test_features, 
        model_parameters['mean_retrain'].view(-1,1), 
        model_parameters['std_retrain'].view(-1,1)+ 1e-32)
    likelihood_ratio = in_likelihood / (out_likelihood + 1e-32)
    return likelihood_ratio


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)

    
def plot_confusion_matrix(true_labels, pred_labels, class_names, title, ax=None, fontsize=14):
    cm = confusion_matrix(true_labels, pred_labels)
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))

    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Greens",
        xticklabels=class_names,
        yticklabels=class_names,
        annot_kws={"fontsize": fontsize},
        cbar=False,
        ax=ax
    )

    # Axis labels and title
    ax.set_xlabel("Predicted Labels", fontsize=fontsize + 2, labelpad=10)
    ax.set_ylabel("True Labels", fontsize=fontsize + 2, labelpad=10)
    ax.set_title(title, fontsize=fontsize + 4, pad=12)

    # Tick label font size and rotation
    ax.tick_params(axis='x', labelsize=fontsize)
    ax.tick_params(axis='y', labelsize=fontsize)

    # Set label orientations
    ax.set_xticklabels(class_names, rotation=45, ha='right')  # angled x-axis labels
    ax.set_yticklabels(class_names, rotation=0)  # horizontal y-axis labels

    return ax

def discretize(x):
    return torch.round(x * 255) / 255


def FGSM_perturb(x, y, model=None, bound=None, criterion=None):
    device = model.parameters().__next__().device
    model.zero_grad()
    x_adv = x.detach().clone().requires_grad_(True).to(device)

    pred = model(x_adv)
    loss = criterion(pred, y)
    loss.backward()

    grad_sign = x_adv.grad.data.detach().sign()
    x_adv = x_adv + grad_sign * bound
    x_adv = discretize(torch.clamp(x_adv, 0.0, 1.0))

    return x_adv.detach()


def forget_loss_drop_class_dim(new_logits, old_probs, forgetting_class):
    eps = 1e-9
    batch_size, num_class = new_logits.shape
    n_class = old_probs.shape[1]
    #keep_indices = [k for k in range(n_class) if k != forgetting_class]

    q_star = old_probs.clone() 
    q_star[:, forgetting_class] = 0.0 # [:, keep_indices]

    # Denominator = (1 - old_probs(c|x)) for each sample
    # denom = 1.0 - old_probs[:, forgetting_class]  
    # denom = denom.clamp_min(eps).unsqueeze(1)     
    denom = (1.0 - old_probs[:, forgetting_class].sum(dim=1)).clamp_min(eps).unsqueeze(1)


    q_star = q_star / denom  

    new_probs = F.softmax(new_logits, dim=1)  
    # predicted_q = torch.argmax(q_star, dim=1) 
    # predicted_p = torch.argmax(new_probs, dim=1) 
    # same_class_ratio = (predicted_q == predicted_p).float().mean().item()
    #print("Q STAR, NEW Probs", same_class_ratio)
    # Cross-entropy: -sum(q^*(y) * log p'_W(y)) for the remaining classes
    loss = - (q_star * torch.log(new_probs + eps)).sum(dim=1).mean()
    return loss


def expand_model(model):
    last_fc_name = None
    last_fc_layer = None

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            last_fc_name = name
            last_fc_layer = module

    if last_fc_name is None:
        raise ValueError("No Linear layer found in the model.")

    num_classes = last_fc_layer.out_features

    bias = last_fc_layer.bias is not None

    new_last_fc_layer = nn.Linear(
        in_features=last_fc_layer.in_features,
        out_features=num_classes + 1,
        bias=bias,
        device=last_fc_layer.weight.device,
        dtype=last_fc_layer.weight.dtype,
    )

    with torch.no_grad():
        new_last_fc_layer.weight[:-1] = last_fc_layer.weight
        if bias:
            new_last_fc_layer.bias[:-1] = last_fc_layer.bias

    parts = last_fc_name.split(".")
    current_module = model
    for part in parts[:-1]:
        current_module = getattr(current_module, part)
    setattr(current_module, parts[-1], new_last_fc_layer)


def compute_tv_distance(logits_unlearn, logits_retrain):
    """
    Args:
      logits_unlearn: Tensor [N, K]
      logits_retrain: Tensor [N, K]
    Returns:
      ud: scalar TV distance averaged over N examples
    """
    # 1) softmax to get probabilities
    p_un = F.softmax(logits_unlearn, dim=1)  # [N, K]
    p_rt = F.softmax(logits_retrain, dim=1)  # [N, K]

    # 2) per-sample TV: 0.5 * L1 norm across classes
    per_sample_tv = torch.sum(torch.abs(p_un - p_rt), dim=1)  # [N]

    # 3) average over the forget set
    ud = per_sample_tv.mean()  # scalar
    return ud


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class simpleDataset(Dataset):
    def __init__(self, data, labels, transform=None, target_transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

        self.data = self.data.detach().cpu().numpy()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        image = image.transpose(1, 2, 0)
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


class RLDataset(Dataset):
    def __init__(self, forgetset, new_classes=None, num_classes=10, noise_level=0.01, add_noise=False):
        self.image_set = forgetset
        self.add_noise = add_noise
        self.noise_level = noise_level
        self.num_classes = num_classes
        self.new_classes = new_classes

    def __len__(self):
        return len(self.image_set)

    def __getitem__(self, idx):
        image = self.image_set[idx][0]
        if self.new_classes is not None:
            label = self.new_classes[idx]
        else:
            true_label = self.image_set[idx][1]
            label = np.random.choice([i for i in range(self.num_classes) if i != true_label]) # random label

        return image, label


class basicDataset(Dataset):
    def __init__(self, data, transform=None, target_transform=None):
        self.data = data
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    # def report(self):
    #     print('reporting from basicDataset')
    #     print(self.data.shape)

    def __getitem__(self, idx):
        if self.data.shape[-1] == 2:
            image_in = self.data[idx]['image']
            image = copy.deepcopy(np.asarray(image_in))
            # print(image.shape)
            if len(image.shape) == 2:
                image = copy.deepcopy(np.stack((image, image, image), axis=2))
            # image = image.transpose(2, 0, 1)
        else:
            print('shape is 1')
            image_in = self.data[idx][0]

        if self.data.shape[-1] == 2:
            label = self.data[idx]['label']
        else:
            label = self.data[idx][1]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

def SVC_fit_predict(shadow_train, shadow_test, target_train, target_test):
    n_shadow_train = shadow_train.shape[0]
    n_shadow_test = shadow_test.shape[0] # test_f
    n_target_train = target_train.shape[0] # train_f
    n_target_test = target_test.shape[0] # test_r
    X_shadow = torch.cat([shadow_train[:n_shadow_test], shadow_test]).cpu().numpy().reshape(n_shadow_test + n_shadow_test, -1)
    Y_shadow = np.concatenate([np.ones(n_shadow_test), np.zeros(n_shadow_test)])
    shuffle_indices = np.random.permutation(len(Y_shadow))
    X_shadow = X_shadow[shuffle_indices]
    Y_shadow = Y_shadow[shuffle_indices]
    clf = SVC(kernel='linear', class_weight='balanced')  #SVC(C=3, gamma='auto', kernel='rbf')
    clf.fit(X_shadow, Y_shadow)

    accs = []

    if n_target_train > 0:
        X_target_train = target_train.cpu().numpy().reshape(n_target_train, -1)
        acc_train = 1 - clf.predict(X_target_train).mean()
        accs.append(acc_train)

    if n_target_test > 0:
        X_target_test = target_test.cpu().numpy().reshape(n_target_test, -1)
        acc_test = clf.predict(X_target_test).mean()
        accs.append(acc_test)

    return acc_train, acc_test

def svc_mia(net, train_loader, test_loader, forgetting_class, unlearn_method='RL'):
    train_conf_r, train_r_labels = [], []
    train_conf_f, train_f_labels = [], []
    test_conf_r, test_r_labels = [], []
    test_conf_f, test_f_labels = [], []
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            test_logits, new_logits = net(images) 
            if unlearn_method == 'RW':
                test_logits = new_logits
            probs = F.softmax(test_logits, dim=1)
            mask_remain = ~torch.isin(labels, torch.tensor(forgetting_class).cuda())
            mask_forget = torch.isin(labels, torch.tensor(forgetting_class).cuda())
            train_conf_r.append(probs[mask_remain])
            train_r_labels.append(labels[mask_remain])
            train_conf_f.append(probs[mask_forget])
            train_f_labels.append(labels[mask_forget])
            # print(train_conf_f)
        train_conf_r = torch.cat(train_conf_r, dim=0)
        train_r_labels = torch.cat(train_r_labels, dim=0)
        train_conf_f = torch.cat(train_conf_f, dim=0)
        train_f_labels = torch.cat(train_f_labels, dim=0)
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            test_logits, new_logits = net(images) 
            if unlearn_method == 'RW':
                test_logits = new_logits
            probs = F.softmax(test_logits, dim=1)
            mask_remain = ~torch.isin(labels, torch.tensor(forgetting_class).cuda())
            mask_forget = torch.isin(labels, torch.tensor(forgetting_class).cuda())
            test_conf_r.append(probs[mask_remain])
            test_r_labels.append(labels[mask_remain])
            test_conf_f.append(probs[mask_forget])
            test_f_labels.append(labels[mask_forget])
        test_conf_r = torch.cat(test_conf_r, dim=0)
        test_r_labels = torch.cat(test_r_labels, dim=0)
        test_conf_f = torch.cat(test_conf_f, dim=0)
        test_f_labels = torch.cat(test_f_labels, dim=0)

    # print(train_conf_r, train_r_labels)
    # print("==========================")
    # print(train_conf_f, train_f_labels)
    # print("==========================")
    # print(test_conf_r, test_r_labels)
    # print("==========================")
    # print(test_conf_f, test_f_labels)
    shadow_train = torch.gather(
        train_conf_r, 1, train_r_labels[:, None])
    shadow_test = torch.gather(
        train_conf_f, 1, train_f_labels[:, None])
    target_train = torch.gather(
        test_conf_r, 1, test_r_labels[:, None])
    target_test = torch.gather(
        test_conf_f, 1, test_f_labels[:, None])

    print("check remain forget")
    print(shadow_train[:10], shadow_test[:10])
    print(target_train[:10], target_test[:10])
    acc_conf, acc_test = SVC_fit_predict(shadow_train, target_test, shadow_test, target_train)
    
    acc_mean = acc_conf# (acc_conf + acc_test) / 2
    print(f"MIA Attack Accuracy on Forgotten Class: {acc_mean:.4f}")
    print(acc_conf, acc_test)
    
os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3"

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--dataset', default='cifar10', help='dataset')
parser.add_argument('--model', default='ResNet18', help='Deep Learning model to train')
parser.add_argument('--method', default='catclip', help='clipping method (use orig for no clipping)')
parser.add_argument('--mode', default='wBN', help='what to do with BN layers (leave empty for keeping it as it is)')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--LRsteps', default=40, type=int, help='LR scheduler step')
parser.add_argument('--epochs', default=10, type=int, help='number of epochs')
parser.add_argument('--seed', default=1, type=int, help='seed value')
parser.add_argument('--steps', default=50, type=int, help='setp count for clipping BN')
parser.add_argument('--num_classes', default=10, type=int, help='number of classes in the dataset')
parser.add_argument('--batch_size', default=128, type=int, help='number of classes in the dataset')

parser.add_argument('--unlearn_method', default='RL', type=str)
parser.add_argument('--unlearn_indices', default=None, type=str)
parser.add_argument('--unlearn_evaluate', default='svc_mia', type=str)

parser.add_argument('--unlearn_count', default=1000, type=int)
parser.add_argument('--start_idx', default=0, type=int)

parser.add_argument('--source_model_path', default=None, type=str)
parser.add_argument('--mask_path', default=None, type=str)
parser.add_argument('--save_checkpoints', default=0, type=int)

parser.add_argument('--use_all_ref', default=True, type=bool)
parser.add_argument('--use_remain', default=True, type=bool)
parser.add_argument('--remain', default='use', type=str)
parser.add_argument('--use_remain_sample', default=False, type=bool)

parser.add_argument('--unnormalize', default=True, type=bool)
parser.add_argument('--norm_cond', default='unnorm', help='unnorm or norm for transform')

parser.add_argument('--req_mode', default='single', type=str)
parser.add_argument('--salun_ratio', default='0.5', type=str, help='ratio of masking in salun')

parser.add_argument('--alpha_l1', default=0., type=float)
parser.add_argument('--noise_ratio', default=3, type=int) # default 120

parser.add_argument('--catsn', default=-1, type=float)
parser.add_argument('--convsn', default=1., type=float)
parser.add_argument('--outer_steps', default=100, type=int)
parser.add_argument('--convsteps', default=100, type=int)
parser.add_argument('--opt_iter', default=5, type=int)
parser.add_argument('--outer_iters', default=1, type=int)

args = parser.parse_args()

            
unlearn_indices_check = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
count_unlearn = len(unlearn_indices_check)

match = re.search(r'label_(\d+)\.csv', args.unlearn_indices)
if match:
    number = int(match.group(1))
    print("forgetting class", number)  # Output: 1

forgetting_class = [number]

print('count_unlearn: ', count_unlearn)
print('requested mode: ', args.req_mode)

if args.norm_cond == 'norm':
    args.unnormalize = False
print('!!!!!!!!! unnormalized: ', args.unnormalize)
print('!!!!!!!!! salun ratio: ', args.salun_ratio)

print('model: ', args.model)

dataset_name = args.dataset
if args.unnormalize and (args.unlearn_method != 'RW' and args.unlearn_method != 'RW_multi'):
    dataset_name += '_unnorm_original'
elif args.unnormalize and (args.unlearn_method == 'RW' or args.unlearn_method == 'RW_multi'):
    dataset_name += '_unnorm_layers'
print('dataset', dataset_name)

if args.remain != 'use':
    args.use_remain = False

if args.remain == 'use' or args.unlearn_method == 'reference' or args.unlearn_method == 'retrain' or args.unlearn_method == 'RW':
    args.use_remain = True


print('use remain flag: ', args.use_remain)
if args.unlearn_method == 'salun':
    model_name = args.source_model_path.split('/')[-1]
    args.mask_path = f'./class_unlearn/logs/correct/scratch/{dataset_name}/unlearn/genmask/{count_unlearn}/unl_idx_cifar10_label_1/{model_name}/salun_mask/with_{args.salun_ratio}.pt'
    print('mask path: ', args.mask_path)
    args.unlearn_method = 'RL'


save_checkpoints = args.save_checkpoints
if save_checkpoints == 1:
    save_checkpoints = True
else:
    save_checkpoints = False

print('save_checkpoints: ', save_checkpoints)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('==========', device)

if device == 'cuda':
    # net = torch.nn.DataParallel(net)
    print('chosen: ', device)
    cudnn.benchmark = True


base_path_df = pd.read_csv('path_file_rw.csv')
print(base_path_df)
tuples = zip(base_path_df['info'], base_path_df['path'])
base_path_dict = dict(tuples)
base_path = base_path_dict['base_path']
print('base_path: ', base_path)


def check_test(loader, test_net, epoch, criterion, unlearn_method='RL', writer=None, mode='test', model_path="./checkpoints/", plot_images=False, ignore_class=1):
    test_net.eval()
    correct = 0
    total = 0
    all_logits = []
    all_preds, all_labels = [], []
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    target_classes = {
        47: 'malpe tree',  # you previously asked for this
        49: 'mountain',
        43: 'lion',
        55: 'otter',
        44: 'lizard',
        11: 'boy',
        29: 'dinosaur'
    }
    wrong_preds = defaultdict(list)

    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.float().to(device), targets.to(device)
            if unlearn_method == 'RW' or unlearn_method == 'FT' or unlearn_method == 'RL' or unlearn_method == 'GA' or unlearn_method == 'salun':
                outputs, new_outputs = test_net(inputs)
            elif unlearn_method == 'two_stage':
                outputs, _ = test_net(inputs)
            else:
                outputs = test_net(inputs)
            if unlearn_method == 'RW' and mode == 'forget':
                outputs = new_outputs
            all_logits.append(outputs.cpu())
            _, preds = outputs.max(1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(targets.cpu().tolist())
            total += targets.size(0)
            correct += preds.eq(targets).sum().item()

            # Update per-class stats
            for label, pred in zip(targets, preds):
                label = label.item()
                pred = pred.item()
                wrong_preds[label].append(pred)
                if ignore_class is not None and label == ignore_class:
                    continue
                class_total[label] += 1
                if label == pred:
                    class_correct[label] += 1

    print(f"{mode}_acc", 100.*correct/total)
    # if mode == 'forget' or mode == 'retrain':
    #     print("========================Per-class accuracy (excluding forgotten class):=========================")
    #     for cls in range(args.num_classes):
    #         if ignore_class is not None and cls == ignore_class:
    #             continue
    #         if class_total[cls] > 0:
    #             acc_cls = 100. * class_correct[cls] / class_total[cls]
    #             print(f"Class {cls}: {acc_cls:.2f}%")
    if ignore_class is not None and (mode == 'forget' or mode == 'retrain'):
        # get all predictions where the true label == ignore_class
        preds_for_forget = wrong_preds[ignore_class]  
        total_forget = len(preds_for_forget)
        if total_forget > 0:
            counts = Counter(preds_for_forget)
            print(f"\nPrediction distribution for forgotten class {ignore_class}:")
            for cls, cnt in counts.items():
                ratio = cnt / total_forget
                print(f"  Predicted as class {cls}: {cnt}/{total_forget} ({ratio:.2%})")
        else:
            print(f"No samples of class {ignore_class} were seen.")
    wrong_preds_serializable = {str(k): v for k, v in wrong_preds.items()}
    with open(f"{mode}_cifar100_wrong_preds.json", "w") as f:
        json.dump(wrong_preds_serializable, f, indent=2)
    return F.softmax(torch.cat(all_logits, dim=0)), all_labels, all_preds

def compare_weight_differences(model_a, model_b):
    print("{:<40} {:>10}".format("Layer", "ΔL2 Norm"))
    for (name_a, param_a), (name_b, param_b) in zip(model_a.named_parameters(), model_b.named_parameters()):
        # Compare only weights, skip biases
        if 'weight' in name_a and param_a.requires_grad:
            assert name_a == name_b, f"Parameter mismatch: {name_a} vs {name_b}"
            diff = (param_a - param_b).norm().item()
            print("{:<40} {:10.4f}".format(name_a, diff))

def plot_layer_weight_differences(model_a, model_b, title="Layer-wise Weight Difference", save_path='./figs/layer_wise.pdf'):
    layer_diffs = defaultdict(list)

    excluded_prefixes = [
        "module.linear.weight", 
        "module.conv1.weight", 
        "module.bn1.weight", 
    ]

    for (name_a, param_a), (name_b, param_b) in zip(model_a.named_parameters(), model_b.named_parameters()):
        if 'weight' in name_a and param_a.requires_grad:
            assert name_a == name_b
            if any(name_a.startswith(excluded) for excluded in excluded_prefixes):
                continue
            layer_name = '.'.join(name_a.split('.')[:3])  # e.g., 'layer1.0.conv1'
            diff = (param_a - param_b).norm().item()
            layer_diffs[layer_name].append(diff)

    # Average differences per layer
    avg_diffs = {k: np.mean(v) for k, v in layer_diffs.items()}
    layers = list(avg_diffs.keys())
    values = [avg_diffs[k] for k in layers]

    layers = layers[:-1]
    layers[-1] = "module.projection.weight"
    values = values[:-1]

    # Plot
    plt.figure(figsize=(8, 6))
    plt.bar(range(len(layers)), values, tick_label=layers)
    plt.xticks(rotation=45, ha='right')
    plt.ylabel("Average ΔL2 Norm")
    plt.title(title)
    plt.tight_layout()

    # Save or show
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
    else:
        plt.show()

if __name__ == "__main__":
    method = args.method
    steps_count = args.steps  #### BN clip steps for hard clip
    concat_sv = False
    step_size = args.LRsteps
    clip_outer_flag = False
    outer_steps = args.outer_steps
    outer_iters = args.outer_iters
    if args.catsn > 0.:
        concat_sv = True
        clip_steps = args.convsteps
        clip_outer_flag = True

    mode = args.mode
    bn_flag = True
    bn_clip = False
    bn_hard = False
    opt_iter = args.opt_iter
    if mode == 'wBN':
        mode = ''
        bn_flag = True
        bn_clip = False
        clip_steps = 50
    elif mode == 'noBN':
        bn_flag = False
        bn_clip = False
        opt_iter = 1
        clip_steps = 100
    elif mode == 'clipBN_hard':
        bn_flag = True
        bn_clip = True
        bn_hard = True
        clip_steps = 100
    else:
        print('unknown mode!')
        exit(0)

    unlearn_idx = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
    unlearn_idx = [int(i) for i in unlearn_idx]
    
    unlearn_test_path = args.unlearn_indices.replace('.csv', '_test.csv')
    unlearn_idx_test = pd.read_csv(unlearn_test_path)['unlearn_idx'].values
    unlearn_idx_test = [int(i) for i in unlearn_idx_test]

    seed_in = args.seed ##### !!!!! Do not use with more than one seed! some of the args gets changed during the first run @ToDo fix this!
    if seed_in == -1:
        geed_in = [1,2,3]
    else:
        seed_in = [seed_in]
    for seed in range(1):
        print('seed.....', seed)
        best_acc = 0  # best test accuracy
        start_epoch = 0  # start from epoch 0 or last checkpoint epoch
        count_setp = 0

        seed_val = 1 # rw 100
        torch.manual_seed(seed_val)
        torch.cuda.manual_seed_all(seed_val)
        np.random.seed(seed_val)
        random.seed(seed_val)

        clip_flag    = False
        orig_flag    = False

        print('method: ', method)
        if method[:4] == 'fast' or method == 'clip':
            clip_flag    = True
        elif method == 'catclip':
            clip_flag    = True
        elif method == 'orig':
            orig_flag    = True
        else:
            print('unknown method!')
            exit(0)

        # Data
        print('==> Preparing data..')
        if args.dataset == 'mnist':
            print('using mnist')
            in_chan = 1
            if args.unnormalize:
                if args.model == 'ResNet18':
                    transform_train = transforms.Compose([
                        transforms.ToTensor(),
                    ])
                    transform_test = transforms.Compose([
                        transforms.ToTensor(),
                    ])
                elif args.model == 'VGG':
                    transform_train = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                    ])
                    transform_test = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                    ])
                    
            else:
                if args.model == 'ResNet18':
                    transform_train = transforms.Compose([
                        transforms.Resize((28, 28)),  # Ensure images are 28x28
                        transforms.ToTensor(),        # Convert images to PyTorch tensors
                        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
                    ])
                    transform_test = transform_train
                elif args.model == 'VGG':
                    transform_train = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
                    transform_test = transform_train

            
            trainset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform_train)
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.MNIST( root='./data/mnist', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.MNIST( root='./data/mnist', train=False, download=True, transform=transform_test)
             
        elif args.dataset == 'cifar10':
            print('using cifar 10')
            in_chan = 3

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])


            trainset = torchvision.datasets.CIFAR10( root='./data/cifar10', train=True, download=True, transform=transform_train) ### transofrm=transform_train

            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.CIFAR10( root='./data/cifar10', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.CIFAR10( root='./data/cifar10', train=False, download=True, transform=transform_test)

        elif args.dataset == 'cifar100':
            print('using cifar 100')
            in_chan = 3
            args.num_classes = 100

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
                ])


            trainset = torchvision.datasets.CIFAR100(root='./data/cifar100', train=True, download=True, transform=transform_train)
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.CIFAR100( root='./data/cifar100', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.CIFAR100( root='./data/cifar100', train=False, download=True, transform=transform_test)

        elif args.dataset == 'imagenet':
            print('using Imagenet')
            in_chan = 3
            args.num_classes = 200

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(64, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(64, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
                ])

            train_dir = './data/tiny-imagenet-200/train'
            val_dir = './data/tiny-imagenet-200/val'
            trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_train)
            else:
                testset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_test)
        
        else:
            print('unknown dataset!')
            exit(0)

        indices_seed = args.unlearn_indices.split('/')[-1][:-4]
        indices_count = len(unlearn_idx) # args.unlearn_indices.split('/')[-2]

        args.outdir = f"/{dataset_name}/unlearn/{args.unlearn_method}/{indices_count}/unl_idx_{indices_seed}/"
        args.outdir = "scratch" + args.outdir
        args.outdir = base_path + args.outdir
        retrain_outdir_base = base_path + "scratch" + f"/{dataset_name}/unlearn/retrain/{indices_count}/unl_idx_{indices_seed}/" 


        print("retrain outdir: ", retrain_outdir_base)
        print("source outdir: ", args.source_model_path)
        print('learning rate: ', args.lr)
        print('dataset: ', args.dataset)

        ## shadowed retrain
        retrain_outdir = retrain_outdir_base + '/' + args.model + "_" + method + "_" + mode + "_" + str(seed_val) + "/"

        outdir = args.outdir + args.source_model_path.split('/')[-1] 
        if args.alpha_l1 > 0 and args.unlearn_method != 'l1':
            outdir = outdir + '/l1_' + str(args.alpha_l1) + '/'

        outdir = outdir + '/use_remain_' + str(args.use_remain) + '/' + args.model + "_" + method + "_" + mode + "_" + str(seed_val) + "/"
        outdir += '/LRs_' + str(step_size) + '_lr_' + str(args.lr) + '/'


        writer = SummaryWriter(outdir)

        print('==> Building model..')
        print('------------> outdir: ', outdir)
        print('-----------------------------------------------------------------')
        print('initial len of trainset: ', len(trainset))  


        request_count = 1
        if args.req_mode == 'adaptive':
            print('not implemented yet!')
            exit(0)

        prior_idx = []
        for req_idx in range(request_count):

            unlearn_idx = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
            if len(unlearn_idx) != int(indices_count):
                print('unlearn_idx count is not correct!')
                exit(0)

            unlearn_idx = [int(i) for i in unlearn_idx]

            removed_classes = [trainset[i][1] for i in unlearn_idx]
            
            ### remove the unlearned images from the trainset
            trainset_filtered = torch.utils.data.Subset(trainset, list(set(range(len(trainset))) - set(unlearn_idx) - set(prior_idx)))
            print('len of filtered trainset: ', len(trainset_filtered))  

            forgetset = torch.utils.data.Subset(trainset, unlearn_idx)
            print('len of forget set: ', len(forgetset))  


            trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=args.batch_size, num_workers=1)
            remainloader = torch.utils.data.DataLoader(trainset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader = torch.utils.data.DataLoader(forgetset, shuffle=False, batch_size=args.batch_size, num_workers=1)
            
            ### remove the unlearned images from the testset
            testset_filtered = torch.utils.data.Subset(testset, list(set(range(len(testset))) - set(unlearn_idx_test) - set(prior_idx)))
            print('len of filtered testset: ', len(testset_filtered))  

            forgetset_test = torch.utils.data.Subset(testset, unlearn_idx_test)
            print('len of forget testset: ', len(forgetset_test))  

            testloader = torch.utils.data.DataLoader(testset, shuffle=False, batch_size=args.batch_size, num_workers=1)
            remainloader_test = torch.utils.data.DataLoader(testset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader_test = torch.utils.data.DataLoader(forgetset_test, shuffle=False, batch_size=args.batch_size, num_workers=1)
            

            if req_idx == 0:
                if args.model == 'ResNet18':
                    if orig_flag:
                        if args.dataset == 'imagenet':
                            net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, tinynet=True, unlearn_method=args.unlearn_method)
                        else:
                            net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, unlearn_method="RW")
                            original_net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, unlearn_method="RW_FT")
                            retrain_net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, unlearn_method="retrain")
                        # net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_linear=False, bn_count=steps_count, device=device)
                    elif clip_flag:
                        net = ResNet18(concat_sv=concat_sv, in_chan=in_chan, device=device, clip=args.convsn, clip_concat=args.catsn, clip_flag=True, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_steps=clip_steps, bn_count=steps_count, clip_outer=clip_outer_flag, clip_opt_iter=opt_iter, summary=True, writer=writer, save_info=False, outer_iters=outer_iters, outer_steps=outer_steps, num_classes=args.num_classes)
                        original_net = ResNet18(concat_sv=concat_sv, in_chan=in_chan, device=device, clip=args.convsn, clip_concat=args.catsn, clip_flag=True, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_steps=clip_steps, bn_count=steps_count, clip_outer=clip_outer_flag, clip_opt_iter=opt_iter, summary=True, writer=writer, save_info=False, outer_iters=outer_iters, outer_steps=outer_steps, num_classes=args.num_classes)
                        retrain_net = ResNet18(concat_sv=concat_sv, in_chan=in_chan, device=device, clip=args.convsn, clip_concat=args.catsn, clip_flag=True, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_steps=clip_steps, bn_count=steps_count, clip_outer=clip_outer_flag, clip_opt_iter=opt_iter, summary=True, writer=writer, save_info=False, outer_iters=outer_iters, outer_steps=outer_steps, num_classes=args.num_classes)
                
                elif args.model == 'VGG':
                    if args.dataset == 'imagenet':
                        net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes, tinynet=True)
                        original_net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes, tinynet=True)
                        retrain_net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes, tinynet=True)
                    else:
                        net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes)
                        original_net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes)
                        retrain_net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes)

                
                
                ## svd
                if args.unlearn_method == 'svd':
                    net = vgg11_bn(num_classes=10, dataset="cifar10")

                net = net.to(device)
                original_net = original_net.to(device)
                retrain_net = retrain_net.to(device)

                if args.unlearn_method != 'svd':
                    net = nn.DataParallel(net) 
                original_net = nn.DataParallel(original_net)
                retrain_net = nn.DataParallel(retrain_net)
                criterion = nn.CrossEntropyLoss()

                original_checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_best')
                
                ## RW
                if args.unlearn_method == 'RW':
                    outdir = './class_unlearn/logs/RW/scratch/cifar10_unnorm/unlearn/RW/5000/unl_idx_cifar_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1/LRs_40_lr_0.001/'
                ## RW_FT
                if args.unlearn_method == 'RW_FT':
                    outdir = './class_unlearn/logs/RW/scratch/cifar10_unnorm/unlearn/RW_FT/5000/unl_idx_cifar_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1/LRs_40_lr_0.0001/'
                ## RW_FT_par
                if args.unlearn_method == 'RW_FT_par':
                    outdir = './class_unlearn/logs/RW/scratch/cifar10_unnorm/unlearn/RW_FT_par/5000/unl_idx_cifar_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1/LRs_40_lr_0.0001/'
                ## FT
                if args.unlearn_method == 'FT':
                    outdir = './class_unlearn/logs/correct/scratch/cifar10_unnorm/unlearn/FT/5000/unl_idx_cifar_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1//LRs_40_lr_0.1/'
                ## RL
                if args.unlearn_method == 'RL':
                    outdir = './class_unlearn/logs/correct/scratch/cifar10_unnorm/unlearn/RL/5000/unl_idx_cifar_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1//LRs_40_lr_0.1/'
                ## GA
                if args.unlearn_method == 'GA':
                    outdir = './class_unlearn/logs/correct/scratch/cifar10_unnorm/unlearn/GA/5000/unl_idx_cifar_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1//LRs_40_lr_0.001/'
                ## salun
                if args.unlearn_method == 'salun':
                    outdir = './class_unlearn/logs/correct/scratch/cifar10_unnorm/unlearn/salun/5000/unl_idx_cifar_label_1/ResNet18_vanilla_orig_wBN_1/mask_0.5//use_remain_True/ResNet18_orig__1//LRs_40_lr_0.01/'
                
                ## salun
                if args.unlearn_method == 'l1':
                    outdir = './class_unlearn/logs/correct/scratch/cifar10_unnorm/unlearn/l1/5000/unl_idx_cifar10_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1//LRs_40_lr_0.01/'
               ## salun
                if args.unlearn_method == 'BS':
                    outdir = './class_unlearn/logs/correct/scratch/cifar10_unnorm/unlearn/BS/5000/unl_idx_cifar10_label_1/ResNet18_vanilla_orig_wBN_1/use_remain_True/ResNet18_orig__1//LRs_40_lr_0.01/'
               ## salun
                if args.unlearn_method == 'SCRUB':
                    outdir = './class_unlearn/scrub/scrub_model.pt'
               ## salun
                if args.unlearn_method == 'SCAR':
                    outdir = './class_unlearn/scar/out/CR/cifar10/subset_Imagenet_-1k/models/unlearned_model_SCAR_seed_42_class_1.pth'
                if args.unlearn_method == 'l2ul':
                    outdir = './class_unlearn/l2ul/l2ul_unlearn_k_256_lr_0.001_reg_lamb_1.0_seed_0.pth'

                if args.unlearn_method != 'svd' and args.unlearn_method != 'l2ul' and args.unlearn_method != 'SCRUB' and args.unlearn_method != 'SCAR':
                    checkpoint = torch.load(outdir + '/_ckpt.pth')
            
                # checkpoint = torch.load('save_for_confusion' + '.pth')
                retrain_outdir = f'./class_unlearn/logs/RW/scratch/{args.dataset}_unnorm/unlearn/retrain/5000/unl_idx_cifar_label_1/ResNet18_orig__1'
                retrain_checkpoint = torch.load(retrain_outdir + '/_ckpt.200')
                
                if args.unlearn_method != 'svd' and args.unlearn_method != 'l2ul' and args.unlearn_method != 'SCRUB' and args.unlearn_method != 'SCAR':
                    net.load_state_dict(checkpoint['net'])
                elif args.unlearn_method == 'svd':
                    checkpoint = torch.load('./baseline/cifar10_vgg11_bn_our_1.pt')
                    # new_state = OrderedDict()
                    # for k, v in state_dict.items():
                    #     name = k.replace('module.', '')  
                    #     new_state[name] = v
                    net.load_state_dict(checkpoint)
                else:
                    checkpoint = torch.load(outdir)
                    # net.load_state_dict(checkpoint, strict=True)
                    # net.load_state_dict(checkpoint['state_dict'])
                    net.load_state_dict(checkpoint['net'], strict=True)
                
                original_net.load_state_dict(original_checkpoint['state_dict'])
                retrain_net.load_state_dict(retrain_checkpoint['net'])
                print('model loaded')

            net.eval()
            original_net.eval()
            retrain_net.eval()
            
            
            print("Loading logits")
            logits_u, labels_u, preds_u = check_test(testloader, net, 200, criterion, unlearn_method="two_stage", writer=writer, mode='forget', model_path=None, ignore_class=forgetting_class[0])
            logits_o, labels_o, preds_o = check_test(testloader, original_net, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='original', model_path=None, ignore_class=forgetting_class[0])
            logits_r, labels_r, preds_r = check_test(testloader, retrain_net, 200, criterion, unlearn_method="retrain", writer=writer, mode='retrain', model_path=None)
            
            print("=========================TV Distance==========================")
            tvd = compute_tv_distance(logits_u, logits_r)
            print("TVD", tvd)

            print("=========================ULiRa==========================")
            model_logits, model_labels, model_preds = check_test(forgetloader, net, 200, criterion, unlearn_method="two_stage", writer=writer, mode='forget', model_path=None, ignore_class=forgetting_class[0])
            retrain_logits, retrain_labels, retrain_preds = check_test(forgetloader, retrain_net, 200, criterion, unlearn_method="retrain", writer=writer, mode='retrain', model_path=None, ignore_class=forgetting_class[0])
            
            model_logits_test, model_labels_test, model_preds_test = check_test(forgetloader_test, net, 200, criterion, unlearn_method="two_stage", writer=writer, mode='forget', model_path=None, ignore_class=forgetting_class[0])
            retrain_logits_test, retrain_labels_test, retrain_preds_test = check_test(forgetloader_test, retrain_net, 200, criterion, unlearn_method="retrain", writer=writer, mode='retrain', model_path=None, ignore_class=forgetting_class[0])
            
            
            # print(model_logits.clone().detach().cpu()[:, forgetting_class[0]].size())
            unlearn_scores = torch.log(model_logits.clone().detach().cpu()[:, forgetting_class[0]] / (1 - model_logits.clone().detach().cpu()[:, forgetting_class[0]]  + 1e-32) + 1e-32)
            retrain_scores = torch.log(retrain_logits.clone().detach().cpu()[:, forgetting_class[0]]  / (1 - retrain_logits.clone().detach().cpu()[:, forgetting_class[0]]  + 1e-32) + 1e-32)

            unlearn_scores_test = torch.log(model_logits_test.clone().detach().cpu()[:, forgetting_class[0]]  / (1 - model_logits_test.clone().detach().cpu()[:, forgetting_class[0]]  + 1e-32) + 1e-32)
            retrain_scores_test = torch.log(retrain_logits_test.clone().detach().cpu()[:, forgetting_class[0]]  / (1 - retrain_logits_test.clone().detach().cpu()[:, forgetting_class[0]]  + 1e-32) + 1e-32) 

            mean_unlearn = unlearn_scores.mean()
            mean_retrain = retrain_scores.mean()
            std_unlearn = unlearn_scores.std()
            std_retrain = retrain_scores.std()
            mia_parameters = {
                "mean_unlearn":mean_unlearn, 
                "mean_retrain":mean_retrain, 
                "std_unlearn":std_unlearn, 
                "std_retrain":std_retrain
                }
            labels = [1]*len(model_logits) + [0]*len(retrain_logits)
            features = np.concatenate([model_logits, retrain_logits], axis=0)
            concat_val_features = torch.cat( (unlearn_scores.view(-1,1), retrain_scores.view(-1,1)), dim=1)
            y_true_val = torch.cat( (torch.ones_like(unlearn_scores).view(-1,1), torch.zeros_like(retrain_scores).view(-1,1)), dim=1)
            y_pred_val = get_likelihood_ratio(concat_val_features, mia_parameters)
            # print(y_true_val.size())
            # print(mia_parameters)
            # print(concat_val_features.size())
            clf = LogisticRegression().fit(features, labels)
            import pdb; pdb.set_trace()
            fpr, tpr, thr = roc_curve(y_true_val.flatten().numpy(), y_pred_val.flatten().numpy(), pos_label=1)
            auc = roc_auc_score(y_true_val.flatten().numpy(), y_pred_val.flatten().numpy())
            optimal_idx = np.argmin(np.abs(fpr + tpr - 1))
            optimal_threshold = thr[optimal_idx]

            y_pred_binarized_val = (y_pred_val.flatten().numpy() >= optimal_threshold).astype(int)
            # Calculate Balanced Accuracy
            balanced_val_accuracy = balanced_accuracy_score(y_true_val.flatten().numpy(), y_pred_binarized_val)
            evaluation_result= {}
            evaluation_result["auc"] = auc
            evaluation_result["threshold"] = optimal_threshold
            evaluation_result["balanced_val_accuracy"] = balanced_val_accuracy

            
            true_labels = np.array([1] * len(model_logits_test) + [0] * len(retrain_logits_test))
            y_true_test = torch.ones_like(unlearn_scores_test).view(-1,1)  # model should predict this as out 
            y_pred_test = get_likelihood_ratio(unlearn_scores_test, mia_parameters)
            y_pred_binarized_test = (y_pred_test.flatten().numpy() >= optimal_threshold).astype(int)
            balanced_test_accuracy = balanced_accuracy_score(y_true_test.flatten().numpy(), y_pred_binarized_test)
            evaluation_result["balanced_test_accuracy"] = balanced_test_accuracy
            print(f"AUC {args.unlearn_method}: {auc*100:.2f}, Acc {balanced_test_accuracy * 100:.2f}")
            
            preds = clf.predict(model_logits_test + retrain_logits_test)
            preds = np.array(preds)
            # print("length", len(model_logits), len(retrain_logits))
            # print("model_logits: ", model_logits[0])
            # print("retrain_logits: ", retrain_logits[0])
            # print("model_logits_test: ", model_logits_test[0])
            # print("retrain_logits_test: ", retrain_logits_test[0])
            crr = 0
            for i in range(len(preds)):
                if preds[i] == true_labels[i]:
                    crr += 1
            acc = 100. * (1.0 * crr / len(preds))
            print("our own ulira ACC", acc)
            ## confusion matrix
            # if args.dataset == 'cifar10':
            #     _, labels_o, preds_o = check_test(testloader, original_net, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='original', model_path=None)
            #     _, labels_u, preds_u = check_test(testloader, net, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='forget', model_path=None)
            #     _, labels_r, preds_r = check_test(testloader, retrain_net, 200, criterion, unlearn_method="retrain", writer=writer, mode='retrain', model_path=None)
            
            #     # Assuming class_names = ['airplane', 'automobile', ..., 'truck']
            #     class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
            #                 'dog', 'frog', 'horse', 'ship', 'truck']

            #     fig, axs = plt.subplots(1, 3, figsize=(30, 10))

            #     # Original
            #     plot_confusion_matrix(labels_o, preds_o, class_names, "(a) Original", ax=axs[0], fontsize=25)

            #     # Unlearned
            #     plot_confusion_matrix(labels_u, preds_u, class_names, "(b) Unlearned", ax=axs[1], fontsize=25)

            #     # Retrained
            #     plot_confusion_matrix(labels_r, preds_r, class_names, "(c) Retrain", ax=axs[2], fontsize=25)

            #     plt.tight_layout()
            #     plt.savefig(f"./figs/{args.unlearn_method}_confusion_matrix_comparison.pdf", dpi=300)
            #     plt.show()
            #     print("fig generated")

            #     # compare_weight_differences(net, original_net)
            #     plot_layer_weight_differences(net, original_net, save_path=f'./figs/{args.model}_layer_wise.pdf')


            # # conditional probabilities
