import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from numpy.lib.format import open_memmap
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from generate_mask import save_gradient_ratio
import os
import argparse
from models import *
from models.resnet_orig import ResNet18_orig
from models.vgg import VGG
import pandas as pd
import random
import time
import copy
import numpy as np
from torch.utils.data import Dataset
from torchvision.io import read_image
from sgld_optim import *
import re

from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


def discretize(x):
    return torch.round(x * 255) / 255


def FGSM_perturb(x, y, model=None, bound=None, criterion=None):
    device = model.parameters().__next__().device
    model.zero_grad()
    x_adv = x.detach().clone().requires_grad_(True).to(device)

    pred, _ = model(x_adv)
    loss = criterion(pred, y)
    loss.backward()

    grad_sign = x_adv.grad.data.detach().sign()
    x_adv = x_adv + grad_sign * bound
    x_adv = discretize(torch.clamp(x_adv, 0.0, 1.0))

    return x_adv.detach()


def plot_confusion_matrix(true_labels, pred_labels, class_names, title, ax=None):
    cm = confusion_matrix(true_labels, pred_labels)
    cm_normalized = cm.astype("int")  # no normalization in your example
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    sns.heatmap(cm_normalized, annot=True, fmt="d", cmap="Greens", xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel("Predicted Labels")
    ax.set_ylabel("True Labels")
    ax.set_title(title)
    return ax


def forget_loss_drop_class_dim(new_logits, old_probs, forgetting_class):
    eps = 1e-9
    batch_size, num_class = new_logits.shape
    n_class = old_probs.shape[1]
    #keep_indices = [k for k in range(n_class) if k != forgetting_class]

    q_star = old_probs.clone() 
    q_star[:, forgetting_class] = 0.0 # [:, keep_indices]

    # Denominator = (1 - old_probs(c|x)) for each sample
    # denom = 1.0 - old_probs[:, forgetting_class]  
    # denom = denom.clamp_min(eps).unsqueeze(1)     
    denom = (1.0 - old_probs[:, forgetting_class].sum(dim=1)).clamp_min(eps).unsqueeze(1)


    q_star = q_star / denom  

    new_probs = F.softmax(new_logits, dim=1)  
    # predicted_q = torch.argmax(q_star, dim=1) 
    # predicted_p = torch.argmax(new_probs, dim=1) 
    # same_class_ratio = (predicted_q == predicted_p).float().mean().item()
    #print("Q STAR, NEW Probs", same_class_ratio)
    # Cross-entropy: -sum(q^*(y) * log p'_W(y)) for the remaining classes
    loss = - (q_star * torch.log(new_probs + eps)).sum(dim=1).mean()
    return loss


def expand_model(model):
    last_fc_name = None
    last_fc_layer = None

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            last_fc_name = name
            last_fc_layer = module

    if last_fc_name is None:
        raise ValueError("No Linear layer found in the model.")

    num_classes = last_fc_layer.out_features

    bias = last_fc_layer.bias is not None

    new_last_fc_layer = nn.Linear(
        in_features=last_fc_layer.in_features,
        out_features=num_classes + 1,
        bias=bias,
        device=last_fc_layer.weight.device,
        dtype=last_fc_layer.weight.dtype,
    )

    with torch.no_grad():
        new_last_fc_layer.weight[:-1] = last_fc_layer.weight
        if bias:
            new_last_fc_layer.bias[:-1] = last_fc_layer.bias

    parts = last_fc_name.split(".")
    current_module = model
    for part in parts[:-1]:
        current_module = getattr(current_module, part)
    setattr(current_module, parts[-1], new_last_fc_layer)


def get_projection_matrix(device, Mr, Mf):
    update_dict = OrderedDict()
    for act in Mr.keys():
        mr = Mr[act] 
        mf = Mf[act] 
        I = torch.eye(mf.shape[0]).to(device)
        update_dict[act] =  I  - (mf - torch.mm(mf,mr) )
    return update_dict


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class simpleDataset(Dataset):
    def __init__(self, data, labels, transform=None, target_transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

        self.data = self.data.detach().cpu().numpy()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        image = image.transpose(1, 2, 0)
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


class RLDataset(Dataset):
    def __init__(self, forgetset, new_classes=None, num_classes=10, noise_level=0.01, add_noise=False):
        self.image_set = forgetset
        self.add_noise = add_noise
        self.noise_level = noise_level
        self.num_classes = num_classes
        self.new_classes = new_classes

    def __len__(self):
        return len(self.image_set)

    def __getitem__(self, idx):
        image = self.image_set[idx][0]
        if self.new_classes is not None:
            label = self.new_classes[idx]
        else:
            true_label = self.image_set[idx][1]
            label = np.random.choice([i for i in range(self.num_classes) if i != true_label]) # random label

        return image, label


class basicDataset(Dataset):
    def __init__(self, data, transform=None, target_transform=None):
        self.data = data
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    # def report(self):
    #     print('reporting from basicDataset')
    #     print(self.data.shape)

    def __getitem__(self, idx):
        if self.data.shape[-1] == 2:
            image_in = self.data[idx]['image']
            image = copy.deepcopy(np.asarray(image_in))
            # print(image.shape)
            if len(image.shape) == 2:
                image = copy.deepcopy(np.stack((image, image, image), axis=2))
            # image = image.transpose(2, 0, 1)
        else:
            print('shape is 1')
            image_in = self.data[idx][0]

        if self.data.shape[-1] == 2:
            label = self.data[idx]['label']
        else:
            label = self.data[idx][1]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


def SVC_fit_predict(shadow_train, shadow_test, target_train, target_test):
    n_shadow_train = shadow_train.shape[0]
    n_shadow_test = shadow_test.shape[0] # test_f
    n_target_train = target_train.shape[0] # train_f
    n_target_test = target_test.shape[0] # test_r
    X_shadow = torch.cat([shadow_train[:n_shadow_test], shadow_test]).cpu().numpy().reshape(n_shadow_test + n_shadow_test, -1)
    Y_shadow = np.concatenate([np.ones(n_shadow_test), np.zeros(n_shadow_test)])
    shuffle_indices = np.random.permutation(len(Y_shadow))
    X_shadow = X_shadow[shuffle_indices]
    Y_shadow = Y_shadow[shuffle_indices]
    clf = SVC(C=3, gamma='auto', kernel='rbf') # SVC(kernel='linear', class_weight='balanced')
    clf.fit(X_shadow, Y_shadow)

    accs = []
    print("shadow_train", shadow_train[:10])
    print("shadow_test", shadow_test[:10])
    print("target_test", target_test[:10])


    if n_target_train > 0:
        X_target_train = target_train.cpu().numpy().reshape(n_target_train, -1)
        acc_train = 1- clf.predict(X_target_train).mean()
        accs.append(acc_train)

    if n_target_test > 0:
        X_target_test = target_test.cpu().numpy().reshape(n_target_test, -1)
        acc_test = clf.predict(X_target_test).mean()
        accs.append(acc_test)
    print("accs", accs)
    return acc_train


def svc_mia(net, train_loader, test_loader, forgetting_class, unlearn_method='RL'):
    train_conf_r, train_r_labels = [], []
    train_conf_f, train_f_labels = [], []
    test_conf_r, test_r_labels = [], []
    test_conf_f, test_f_labels = [], []
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()
            test_logits, new_logits = net(images) 
            if unlearn_method == 'RW':
                test_logits = new_logits
            probs = F.softmax(test_logits, dim=1)
            mask_remain = ~torch.isin(labels, torch.tensor(forgetting_class).cuda())
            mask_forget = torch.isin(labels, torch.tensor(forgetting_class).cuda())
            train_conf_r.append(probs[mask_remain])
            train_r_labels.append(labels[mask_remain])
            train_conf_f.append(probs[mask_forget])
            train_f_labels.append(labels[mask_forget])
            # print(train_conf_f)
        train_conf_r = torch.cat(train_conf_r, dim=0)
        train_r_labels = torch.cat(train_r_labels, dim=0)
        train_conf_f = torch.cat(train_conf_f, dim=0)
        train_f_labels = torch.cat(train_f_labels, dim=0)
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            test_logits, new_logits = net(images) 
            if unlearn_method == 'RW':
                test_logits = new_logits
            probs = F.softmax(test_logits, dim=1)
            mask_remain = ~torch.isin(labels, torch.tensor(forgetting_class).cuda())
            mask_forget = torch.isin(labels, torch.tensor(forgetting_class).cuda())
            test_conf_r.append(probs[mask_remain])
            test_r_labels.append(labels[mask_remain])
            test_conf_f.append(probs[mask_forget])
            test_f_labels.append(labels[mask_forget])
        test_conf_r = torch.cat(test_conf_r, dim=0)
        test_r_labels = torch.cat(test_r_labels, dim=0)
        test_conf_f = torch.cat(test_conf_f, dim=0)
        test_f_labels = torch.cat(test_f_labels, dim=0)

    # print(train_conf_r, train_r_labels)
    # print("==========================")
    # print(train_conf_f, train_f_labels)
    # print("==========================")
    # print(test_conf_r, test_r_labels)
    # print("==========================")
    # print(test_conf_f, test_f_labels)
    shadow_train = torch.gather(
        train_conf_r, 1, train_r_labels[:, None])
    shadow_test = torch.gather(
        train_conf_f, 1, train_f_labels[:, None])
    target_train = torch.gather(
        test_conf_r, 1, test_r_labels[:, None])
    target_test = torch.gather(
        test_conf_f, 1, test_f_labels[:, None])

    print("check remain forget")
    acc_conf = SVC_fit_predict(shadow_train, target_test, shadow_test, target_train)
    
    acc_mean = acc_conf# (acc_conf + acc_test) / 2
    print(f"MIA Attack Accuracy on Forgotten Class: {acc_mean:.4f}")


def collect_prob(data_loader, model, unlearn_method):
    if data_loader is None:
        return torch.zeros([0, 10]), torch.zeros([0])

    prob = []
    targets = []

    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            batch = [tensor.to(next(model.parameters()).device)
                     for tensor in batch]
            data, target = batch

            with torch.no_grad():
                log_logits, new_log_logits = model(data) # Returns log_prob. exp( ) 
                log_prob = F.log_softmax(log_logits, dim=1)
                new_log_prob = F.log_softmax(new_log_logits, dim=1)
                if unlearn_method == 'RW':
                    log_prob = new_log_prob
                prob.append(torch.exp(log_prob).data)
                targets.append(target)

    return torch.cat(prob), torch.cat(targets)


def SVC_attack(shadow_train, target_train, target_test, shadow_test, model, forgetting_class, unlearn_method='RL'):
    shadow_train_prob, shadow_train_labels = collect_prob(shadow_train, model, unlearn_method)
    shadow_test_prob, shadow_test_labels = collect_prob(shadow_test, model, unlearn_method)

    target_train_prob, target_train_labels = collect_prob(target_train, model, unlearn_method)
    target_test_prob, target_test_labels = collect_prob(target_test, model, unlearn_method)

    shadow_train_conf = torch.gather(
        shadow_train_prob, 1, shadow_train_labels[:, None])
    shadow_test_conf = torch.gather(
        shadow_test_prob, 1, shadow_test_labels[:, None])
    target_train_conf = torch.gather(
        target_train_prob, 1, target_train_labels[:, None])
    target_test_conf = torch.gather(
        target_test_prob, 1, target_test_labels[:, None])

    acc_conf = SVC_fit_predict(
        shadow_train_conf, target_test_conf, shadow_test_conf, target_train_conf, )
    # acc_conf = SVC_fit_predict(
    #     shadow_train_conf, shadow_test_conf, target_train_conf, target_test_conf)
    m = {
         "confidence": acc_conf,
         }
    print(m)
    return m
  
os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3"

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--dataset', default='cifar10', help='dataset')
parser.add_argument('--model', default='ResNet18', help='Deep Learning model to train')
parser.add_argument('--method', default='catclip', help='clipping method (use orig for no clipping)')
parser.add_argument('--mode', default='wBN', help='what to do with BN layers (leave empty for keeping it as it is)')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--LRsteps', default=40, type=int, help='LR scheduler step')
parser.add_argument('--epochs', default=10, type=int, help='number of epochs')
parser.add_argument('--seed', default=1, type=int, help='seed value')
parser.add_argument('--steps', default=50, type=int, help='setp count for clipping BN')
parser.add_argument('--num_classes', default=10, type=int, help='number of classes in the dataset')
parser.add_argument('--batch_size', default=128, type=int, help='number of classes in the dataset')

parser.add_argument('--unlearn_method', default='RL', type=str)
parser.add_argument('--unlearn_indices', default=None, type=str)
parser.add_argument('--unlearn_evaluate', default='svc_mia', type=str)

parser.add_argument('--unlearn_count', default=1000, type=int)
parser.add_argument('--start_idx', default=0, type=int)

parser.add_argument('--source_model_path', default=None, type=str)
parser.add_argument('--mask_path', default=None, type=str)
parser.add_argument('--save_checkpoints', default=0, type=int)

parser.add_argument('--use_all_ref', default=True, type=bool)
parser.add_argument('--use_remain', default=True, type=bool)
parser.add_argument('--remain', default='use', type=str)
parser.add_argument('--use_remain_sample', default=False, type=bool)

parser.add_argument('--unnormalize', default=True, type=bool)
parser.add_argument('--norm_cond', default='unnorm', help='unnorm or norm for transform')

parser.add_argument('--req_mode', default='single', type=str)
parser.add_argument('--salun_ratio', default='0.5', type=str, help='ratio of masking in salun')

parser.add_argument('--alpha_l1', default=0., type=float)
parser.add_argument('--noise_ratio', default=3, type=int) # default 120

parser.add_argument('--catsn', default=-1, type=float)
parser.add_argument('--convsn', default=1., type=float)
parser.add_argument('--outer_steps', default=100, type=int)
parser.add_argument('--convsteps', default=100, type=int)
parser.add_argument('--opt_iter', default=5, type=int)
parser.add_argument('--outer_iters', default=1, type=int)

args = parser.parse_args()

            
unlearn_indices_check = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
count_unlearn = len(unlearn_indices_check)

match = re.search(r'label_(\d+)\.csv', args.unlearn_indices)
if match:
    number = int(match.group(1))
    print("forgetting class", number)  # Output: 1

## ==========================================MULTIPLE CLASS===============================================
if (args.dataset == 'cifar10') and args.unlearn_method == 'RW':
    forgetting_class = [0 , 1, 2, 3, 4, 5, 6, 7, 8, 9] #
    print("MULTIPLE")
else:
    print("SINGLE FORGETTING")
    forgetting_class = [number]

print('count_unlearn: ', count_unlearn)
print('requested mode: ', args.req_mode)

if args.norm_cond == 'norm':
    args.unnormalize = False
print('!!!!!!!!! unnormalized: ', args.unnormalize)
print('!!!!!!!!! salun ratio: ', args.salun_ratio)

print('model: ', args.model)

dataset_name = args.dataset
if args.unnormalize:
    dataset_name += '_unnorm'
print('dataset', dataset_name)

if args.remain != 'use':
    args.use_remain = False

if args.remain == 'use' or args.unlearn_method == 'reference' or args.unlearn_method == 'retrain' or args.unlearn_method == 'RW':
    args.use_remain = True


print('use remain flag: ', args.use_remain)
if args.unlearn_method == 'salun':
    model_name = args.source_model_path.split('/')[-1]
    if args.dataset == 'cifar100':
        args.mask_path = f'./class_unlearn/logs/correct/scratch/{dataset_name}/unlearn/genmask/{count_unlearn}/unl_idx_{args.dataset}_label_10/{model_name}/salun_mask/with_{args.salun_ratio}.pt'

    else: 
        args.mask_path = f'./class_unlearn/logs/correct/scratch/{dataset_name}/unlearn/genmask/{count_unlearn}/unl_idx_{args.dataset}_label_1/{model_name}/salun_mask/with_{args.salun_ratio}.pt'
    print('mask path: ', args.mask_path)
    args.unlearn_method = 'RL'


save_checkpoints = args.save_checkpoints
if save_checkpoints == 1:
    save_checkpoints = True
else:
    save_checkpoints = False

print('save_checkpoints: ', save_checkpoints)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('==========', device)

if device == 'cuda':
    # net = torch.nn.DataParallel(net)
    print('chosen: ', device)
    cudnn.benchmark = True

if args.dataset in ['mnist', 'cifar10']:
    args.num_classes = 10
elif args.dataset == 'cifar100':
    args.num_classes = 100
elif args.dataset == 'imagenet':
    args.num_classes = 200
else:
    print("wrong dataset")
    exit(0)


base_path_df = pd.read_csv('path_file.csv')
print(base_path_df)
tuples = zip(base_path_df['info'], base_path_df['path'])
base_path_dict = dict(tuples)
base_path = base_path_dict['base_path']
print('base_path: ', base_path)

# Training
def train(epoch, optimizer, scheduler, criterion, test_model, unlearn_method='RL', writer=None, model_path="./checkpoints/", mask=None):
    print('\nEpoch: %d' % epoch)
    print('l1 regularization: ', args.alpha_l1)
    print('unlearn method: ', unlearn_method)
    global count_setp
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    batch_idx = -1

    print('\ninside train function :')
    print('trainset :', len(trainset) )
    print('unl idx :', len(unlearn_idx) )

    if unlearn_method == 'retrain' or unlearn_method == 'FT' or unlearn_method == 'l1':
        if not args.use_remain:
            sample_indices = np.random.choice(len(trainset), len(forgetset), replace=False)
            trainset_combined = torch.utils.data.Subset(trainset, sample_indices)
        else:
            trainset_combined = trainset
        
        if unlearn_method == 'l1' and args.alpha_l1 == 0.:
            # args.alpha_l1 = 0.0005
            args.alpha_l1 = 0.000001

    elif unlearn_method == 'RL':
        RLset = RLDataset(forgetset, num_classes=args.num_classes, new_classes=None)
        if args.use_remain:
            trainset_combined = torch.utils.data.ConcatDataset([trainset, RLset])
        else:
            trainset_combined = RLset

    elif unlearn_method == 'BS' or unlearn_method == 'BE' or unlearn_method == 'GA':
        trainset_combined = forgetset
        
    elif unlearn_method == 'RW':
        trainset_combined = trainset

    ## only for cifar10
    elif unlearn_method == 'reference':
        included_indices_file = 'keep_files/keep_m128_d55000_s0.csv'

        if args.use_all_ref:
            if args.dataset == 'cifar':
                included_indices_file = 'keep_files/keep_m128_d60000_s0.csv'
            else:
                print('unknown dataset!')
                exit(0)


        included_indices_all = pd.read_csv(included_indices_file, header=0).values
        print('seed: ', args.seed, included_indices_all.shape)
        included_indices = included_indices_all[args.seed]

        trainset_combined = torch.utils.data.ConcatDataset([trainset, testset])
        if epoch == 0:
            print('row id:', args.seed)
            print('sum included: ', included_indices.sum())
            print('len of combined trainset: ', len(trainset_combined))  
        inc_indices = [int(i) for i in np.array(list(range(len(trainset_combined))))[included_indices]]
        trainset_included = torch.utils.data.Subset(trainset_combined, inc_indices)
        if epoch == 0:
            print('len of included trainset: ', len(trainset_included))  
        trainset_combined = trainset_included


    print('trainset_combined len: ', len(trainset_combined))
    trainloader = torch.utils.data.DataLoader(trainset_combined, shuffle=True, batch_size=args.batch_size, num_workers=1)

    start = time.time()

    if args.use_remain and (unlearn_method == 'BS' or unlearn_method == 'BE' or unlearn_method == 'GA'):
        for batch_idx, (inputs, targets) in enumerate(remainloader):
            if epoch == 0 and batch_idx == 0:
                print('inputs remain shape: ', inputs.shape, targets[:10])
            inputs, targets = inputs.float().to(device), targets.to(device)
            optimizer.zero_grad()
            outputs, _ = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            loss = criterion(outputs, targets)

            if args.alpha_l1 > 0.:
                print("Is there norm for BS?")
                loss += args.alpha_l1 * l1_regularization(net)

            loss.backward()

            if mask is not None:
                for name, param in net.named_parameters():
                    if param.grad is not None:
                        param.grad *= mask[name]

            optimizer.step()
        
        
        print('in loop train - acc', 100.*correct/total)

    if unlearn_method == 'BS' or unlearn_method == 'RW':
        test_model = copy.deepcopy(net)
        bound = 0.1

    if unlearn_method == 'BE':
        expand_model(net)
        
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        torch.cuda.empty_cache()
        if epoch == 0 and batch_idx == 0:
            print('inputs shape: ', inputs.shape)
        inputs, targets = inputs.float().to(device), targets.to(device)
        optimizer.zero_grad()
        outputs, new_logits = net(inputs)
        old_probs = F.softmax(outputs, dim=1)
        if unlearn_method == 'BS':
            test_model.eval()
            image_adv = FGSM_perturb(
                inputs, targets, model=test_model, bound=bound, criterion=criterion
            )

            adv_outputs, _ = test_model(image_adv)
            adv_label = torch.argmax(adv_outputs, dim=1)
            targets_orig = copy.deepcopy(targets)
            targets = adv_label
            loss = criterion(outputs, targets)

        if unlearn_method == 'BE':
            target_label = torch.ones_like(targets)
            target_label *= args.num_classes
            target_label = target_label.to(device)
            targets_orig = copy.deepcopy(targets)
            targets = target_label
            loss = criterion(outputs, targets)

        loss = criterion(outputs, targets)
        
        if unlearn_method == 'RW':
            test_model.eval()
            old_logits, _ = test_model(inputs)
            old_probs = F.softmax(old_logits, dim=1)
            outputs = new_logits
            loss = forget_loss_drop_class_dim(new_logits, old_probs, forgetting_class)

        if unlearn_method == 'BS' or unlearn_method == 'BE':
            targets = targets_orig

        if unlearn_method == 'GA':
            loss = -0.001 * loss


        if args.alpha_l1 > 0.:
            loss += args.alpha_l1 * l1_regularization(net)

        loss.backward()

        if mask is not None:
            for name, param in net.named_parameters():
                if param.grad is not None:
                    param.grad *= mask[name]

        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        count_setp += 1
        
        
    print("is it really forget set", targets[:5], predicted[:5])

    tot_time = time.time() - start
    print('time: ', tot_time)


    # if args.alpha_l1 > 0.:
    #     args.alpha_l1 = (2-2*epoch/args.epochs) * args.alpha_l1

    print('train - acc', 100.*correct/total)
    print('train - loss', train_loss/(batch_idx+1))
    
    scheduler.step()

    print('Saving..')
    state = {
        'net': net.state_dict(),
        'epoch': epoch,
    }

    model_path_i = model_path + ".%d" % (epoch)
    if args.unlearn_method == 'retrain' or args.unlearn_method == 'reference':
        if epoch in [80, 100,120,140,160,180,200]:
            torch.save(state, model_path_i)
    else:
        torch.save(state, model_path_i)

    net.eval()

    return train_loss/(batch_idx+1), 100.*correct/total


def test(loader, epoch, criterion, unlearn_method='RL', writer=None, mode='test', model_path="./checkpoints/", plot_images=False):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    batch_idx = -1
    class_correct = [0 for _ in range(args.num_classes)]
    class_total = [0 for _ in range(args.num_classes)]   

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.float().to(device), targets.to(device)
            outputs, new_outputs = net(inputs)
            loss = criterion(outputs, targets)
            if unlearn_method == 'RW':
                outputs = new_outputs

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            # Per-class accuracy
            for i in range(len(targets)):
                label = targets[i].item()
                pred = predicted[i].item()
                class_total[label] += 1
                if pred == label:
                    class_correct[label] += 1
    # Print accuracy for each class
    if mode != 'forget':
        for i in range(args.num_classes):
            if class_total[i] == 0:
                acc = 0.0
            else:
                acc = 100.0 * class_correct[i] / class_total[i]
            # print(f'{mode} Accuracy for class {i}: {acc:.2f}%')


    if model_path is not None:
        # Save checkpoint.
        acc = 100.*correct/total
        if acc > best_acc:
            best_acc = acc

            print('Saving Best..')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, model_path)

    if writer is not None:
        writer.add_scalar('test/acc', 100.*correct/total, epoch)
        writer.add_scalar('test/loss', test_loss/(batch_idx+1), epoch)

    print("{}/acc {:.4f}".format(mode, 100. * correct / total))
    print("{}/loss {:.4f}".format(mode, test_loss/(batch_idx+1)))

    return test_loss/(batch_idx+1), 100.*correct/total


def check_test(trainloader, loader, test_net, epoch, criterion, writer=None, mode='test', model_path="./checkpoints/", plot_images=False):
    test_net.eval()
    test_loss = 0
    correct = 0
    total = 0
    batch_idx = -1
    class_correct = [0 for _ in range(args.num_classes)]
    class_total = [0 for _ in range(args.num_classes)]


    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.float().to(device), targets.to(device)
            outputs, _ = test_net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            for i in range(len(targets)):
                label = targets[i].item()
                pred = predicted[i].item()
                class_total[label] += 1
                if pred == label:
                    class_correct[label] += 1
        

    if writer is not None:
        writer.add_scalar('test/acc', 100.*correct/total, epoch)
        writer.add_scalar('test/loss', test_loss/(batch_idx+1), epoch)

    print('initial', mode + '/acc', 100.*correct/total)
    print('initial', mode + '/loss', test_loss/(batch_idx+1))
    if mode == 'test':
        for i in range(args.num_classes):
            if class_total[i] == 0:
                acc = 0.0
            else:
                acc = 100.0 * class_correct[i] / class_total[i]
            # print(f'initial Accuracy for class {i}: {acc:.2f}%')
        # svc_mia(test_net, trainloader, loader, forgetting_class, unlearn_method='retrain')
        
    return test_loss/(batch_idx+1), 100.*correct/total

if __name__ == "__main__":
    method = args.method
    steps_count = args.steps  #### BN clip steps for hard clip
    concat_sv = False
    step_size = args.LRsteps
    clip_outer_flag = False
    outer_steps = args.outer_steps
    outer_iters = args.outer_iters
    if args.catsn > 0.:
        concat_sv = True
        clip_steps = args.convsteps
        clip_outer_flag = True

    mode = args.mode
    bn_flag = True
    bn_clip = False
    bn_hard = False
    opt_iter = args.opt_iter
    if mode == 'wBN':
        mode = ''
        bn_flag = True
        bn_clip = False
        clip_steps = 50
    elif mode == 'noBN':
        bn_flag = False
        bn_clip = False
        opt_iter = 1
        clip_steps = 100
    elif mode == 'clipBN_hard':
        bn_flag = True
        bn_clip = True
        bn_hard = True
        clip_steps = 100
    else:
        print('unknown mode!')
        exit(0)

    ##================================================MULTIPLE CLASS===========================================
    if args.dataset == 'cifar100' and len(forgetting_class) > 1:
        print("===============MULTIPLE FORGETTING===================")
        csv_files = [
            './class_unlearn/class_indices/cifar100_label_0.csv',
            './class_unlearn/class_indices/cifar100_label_1.csv',
            './class_unlearn/class_indices/cifar100_label_2.csv',
            './class_unlearn/class_indices/cifar100_label_3.csv',
            './class_unlearn/class_indices/cifar100_label_4.csv',
            './class_unlearn/class_indices/cifar100_label_5.csv',
            './class_unlearn/class_indices/cifar100_label_6.csv',
            './class_unlearn/class_indices/cifar100_label_7.csv',
            './class_unlearn/class_indices/cifar100_label_8.csv',
            './class_unlearn/class_indices/cifar100_label_9.csv',
        ]
        unlearn_idx = []
        for file in csv_files:
            df = pd.read_csv(file)
            unlearn_idx.extend(df['unlearn_idx'].tolist())
        unlearn_idx = [int(i) for i in unlearn_idx]
        
        csv_files_test = [
            './class_unlearn/class_indices/cifar100_label_0_test.csv',
            './class_unlearn/class_indices/cifar100_label_1_test.csv',
            './class_unlearn/class_indices/cifar100_label_2_test.csv',
            './class_unlearn/class_indices/cifar100_label_3_test.csv',
            './class_unlearn/class_indices/cifar100_label_4_test.csv',
            './class_unlearn/class_indices/cifar100_label_5_test.csv',
            './class_unlearn/class_indices/cifar100_label_6_test.csv',
            './class_unlearn/class_indices/cifar100_label_7_test.csv',
            './class_unlearn/class_indices/cifar100_label_8_test.csv',
            './class_unlearn/class_indices/cifar100_label_9_test.csv'
        ]
        unlearn_idx_test = []
        for file in csv_files_test:
            df = pd.read_csv(file)
            unlearn_idx_test.extend(df['unlearn_idx'].tolist())
        unlearn_idx_test = [int(i) for i in unlearn_idx_test]
    else:
        print("===============SINGLE FORGETTING===================")
        unlearn_idx = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
        unlearn_idx = [int(i) for i in unlearn_idx]

        test_csv_path = args.unlearn_indices.replace('.csv', '_test.csv')
        unlearn_idx_test = pd.read_csv(test_csv_path)['unlearn_idx'].values
        unlearn_idx_test = [int(i) for i in unlearn_idx_test]


    seed_in = args.seed ##### !!!!! Do not use with more than one seed! some of the args gets changed during the first run @ToDo fix this!
    if seed_in == -1:
        geed_in = [1,2,3]
    else:
        seed_in = [seed_in]
    for seed in seed_in:
        print('seed.....', seed)
        best_acc = 0  # best test accuracy
        start_epoch = 0  # start from epoch 0 or last checkpoint epoch
        count_setp = 0

        seed_val = seed
        torch.manual_seed(seed_val)
        torch.cuda.manual_seed_all(seed_val)
        np.random.seed(seed_val)
        random.seed(seed_val)

        clip_flag    = False
        orig_flag    = False

        print('method: ', method)
        if method[:4] == 'fast' or method == 'clip':
            clip_flag    = True
        elif method == 'catclip':
            clip_flag    = True
        elif method == 'orig':
            orig_flag    = True
        else:
            print('unknown method!')
            exit(0)

        # Data
        print('==> Preparing data..')
        if args.dataset == 'mnist':
            print('using mnist')
            in_chan = 1
            if args.unnormalize:
                if args.model == 'ResNet18':
                    transform_train = transforms.Compose([
                        transforms.ToTensor(),
                    ])
                    transform_test = transforms.Compose([
                        transforms.ToTensor(),
                    ])
                elif args.model == 'VGG':
                    transform_train = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                    ])
                    transform_test = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                    ])
                    
            else:
                if args.model == 'ResNet18':
                    transform_train = transforms.Compose([
                        transforms.Resize((28, 28)),  # Ensure images are 28x28
                        transforms.ToTensor(),        # Convert images to PyTorch tensors
                        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
                    ])
                    transform_test = transform_train
                elif args.model == 'VGG':
                    transform_train = transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])
                    transform_test = transform_train

            
            trainset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=transform_train)
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.MNIST( root='./mnist', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.MNIST( root='./mnist', train=False, download=True, transform=transform_test)
            
        elif args.dataset == 'cifar10':
            print('using cifar 10')
            in_chan = 3

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])


            trainset = torchvision.datasets.CIFAR10( root='./cifar10', train=True, download=True, transform=transform_train) ### transofrm=transform_train

            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.CIFAR10( root='./cifar10', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.CIFAR10( root='./cifar10', train=False, download=True, transform=transform_test)

        elif args.dataset == 'cifar100':
            print('using cifar 100')
            in_chan = 3
            args.num_classes = 100

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),  # CIFAR-100 mean and std
                ])


            trainset = torchvision.datasets.CIFAR100(root='./cifar100', train=True, download=True, transform=transform_train)
            # ./cifar100
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.CIFAR100( root='./cifar100', train=False, download=True, transform=transform_train)
            else:
                testset = torchvision.datasets.CIFAR100( root='./cifar100', train=False, download=True, transform=transform_test)

        elif args.dataset == 'imagenet':
            print('using Imagenet')
            in_chan = 3
            args.num_classes = 200

            if args.unnormalize:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(64, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                ])
            else:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(64, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
                ])

                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
                ])

            train_dir = './tiny-imagenet-200/train'
            val_dir = './tiny-imagenet-200/val'
            trainset = torchvision.datasets.ImageFolder(root=train_dir, transform=transform_train)
            if args.unlearn_method == 'reference':
                testset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_train)
            else:
                testset = torchvision.datasets.ImageFolder(root=val_dir, transform=transform_test)
                
        else:
            print('unknown dataset!')
            exit(0)

        indices_seed = args.unlearn_indices.split('/')[-1][:-4]
        indices_count = len(unlearn_idx) # args.unlearn_indices.split('/')[-2]

        args.outdir = f"/{dataset_name}/unlearn/{args.unlearn_method}/{indices_count}/unl_idx_{indices_seed}/"
        args.outdir = "scratch" + args.outdir
        args.outdir = base_path + args.outdir


        print(args.outdir)
        print('learning rate: ', args.lr)
        print('dataset: ', args.dataset)


        if args.unlearn_method == 'retrain':
            outdir = args.outdir + '/' + args.model + "_" + method + "_" + mode + "_" + str(seed_val) + "/"
            # outdir = args.outdir + '/' + args.source_model_path.split('/')[-2] + '/'
        elif args.unlearn_method == 'reference':
            if args.use_all_ref:
                args.outdir = f"/{dataset_name}/unlearn/{args.unlearn_method}/"
                args.outdir = "scratch" + args.outdir
                args.outdir = base_path + args.outdir

            outdir = args.outdir + '/' + args.model + "_" + method + "_" + mode + "_" + str(seed_val) + "/"
        elif args.unlearn_method == 'genmask':
            outdir = args.outdir + '/' + args.source_model_path.split('/')[-1] + '/'
        else:
            # outdir = args.outdir + '/' + args.source_model_path.split('/')[-1] + '/'
            outdir = args.outdir + args.source_model_path.split('/')[-1] 
            if args.mask_path is not None and args.mask_path != 'None':
                outdir = outdir + '/mask_' + str(args.mask_path).split('with_')[1][:-3] + '/'

            if args.alpha_l1 > 0 and args.unlearn_method != 'l1':
                outdir = outdir + '/l1_' + str(args.alpha_l1) + '/'

            outdir = outdir + '/use_remain_' + str(args.use_remain) + '/' + args.model + "_" + method + "_" + mode + "_" + str(seed_val) + "/"
            outdir += '/LRs_' + str(step_size) + '_lr_' + str(args.lr) + '/'


        print('outdir: ', outdir)
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        writer = SummaryWriter(outdir)

        print('==> Building model..')
        print('------------> outdir: ', outdir)
        print('-----------------------------------------------------------------')
        print('initial len of trainset: ', len(trainset))  


        request_count = 1
        if args.req_mode == 'adaptive':
            print('not implemented yet!')
            exit(0)

        prior_idx = []
        for req_idx in range(request_count):

            if len(forgetting_class) == 1:
                print("===============SINGLE FORGETTING===================")
                unlearn_idx = pd.read_csv(args.unlearn_indices)['unlearn_idx'].values
                unlearn_idx_9 = pd.read_csv(args.unlearn_indices.replace('_1.csv', '_9.csv'))['unlearn_idx'].values
                if len(unlearn_idx) != int(indices_count):
                    print('unlearn_idx count is not correct!')
                    # exit(0)
                unlearn_idx = [int(i) for i in unlearn_idx]
                unlearn_idx_9 = [int(i) for i in unlearn_idx_9]
                test_csv_path = args.unlearn_indices.replace('.csv', '_test.csv')
                unlearn_idx_test = pd.read_csv(test_csv_path)['unlearn_idx'].values
                unlearn_idx_test = [int(i) for i in unlearn_idx_test]

                test_csv_path_9 = args.unlearn_indices.replace('_1.csv', '_9_test.csv')
                unlearn_idx_test_9 = pd.read_csv(test_csv_path_9)['unlearn_idx'].values
                unlearn_idx_test_9 = [int(i) for i in unlearn_idx_test_9]

                removed_classes = [trainset[i][1] for i in unlearn_idx]
                df = pd.DataFrame({'unlearn_idx': unlearn_idx, 'removed_classes': removed_classes})
                df.to_csv(outdir + 'unlearn_idx.csv')
            
            ### remove the unlearned images from the trainset
            trainset_filtered = torch.utils.data.Subset(trainset, list(set(range(len(trainset))) - set(unlearn_idx) - set(prior_idx)))
            trainset_filtered_9 = torch.utils.data.Subset(trainset, list(set(range(len(trainset))) - set(unlearn_idx) - set(unlearn_idx_9) - set(prior_idx)))
            print('len of filtered trainset: ', len(trainset_filtered))  
            # print(trainset_filtered.report())

            forgetset = torch.utils.data.Subset(trainset, unlearn_idx)
            forgetset_9 = torch.utils.data.Subset(trainset, unlearn_idx_9)
            print('len of forget set: ', len(forgetset))  
            # print(forgetset.report())


            trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=args.batch_size, num_workers=1)
            remainloader = torch.utils.data.DataLoader(trainset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
            remainloader_9 = torch.utils.data.DataLoader(trainset_filtered_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader = torch.utils.data.DataLoader(forgetset, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader_9 = torch.utils.data.DataLoader(forgetset_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
            ### remove the unlearned images from the testset
            testset_filtered = torch.utils.data.Subset(testset, list(set(range(len(testset))) - set(unlearn_idx_test) - set(prior_idx)))
            testset_filtered_9 = torch.utils.data.Subset(testset, list(set(range(len(testset))) - set(unlearn_idx_test) - set(unlearn_idx_test_9) - set(prior_idx)))
            print('len of filtered testset: ', len(testset_filtered))  
            # print('len of filtered testset_9: ', len(testset_filtered_9))  

            forgetset_test = torch.utils.data.Subset(testset, unlearn_idx_test)
            forgetset_test_9 = torch.utils.data.Subset(testset, unlearn_idx_test_9)
            print('len of forget testset: ', len(forgetset_test))  

            testloader = torch.utils.data.DataLoader(testset, shuffle=False, batch_size=args.batch_size, num_workers=1)
            remainloader_test = torch.utils.data.DataLoader(testset_filtered, shuffle=False, batch_size=args.batch_size, num_workers=1)
            remainloader_test_9 = torch.utils.data.DataLoader(testset_filtered_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader_test = torch.utils.data.DataLoader(forgetset_test, shuffle=False, batch_size=args.batch_size, num_workers=1)
            forgetloader_test_9 = torch.utils.data.DataLoader(forgetset_test_9, shuffle=False, batch_size=args.batch_size, num_workers=1)
            

            if args.unlearn_method == 'retrain' or args.unlearn_method == 'FT' or args.unlearn_method == 'l1' or args.unlearn_method == 'RL':
                if args.use_remain_sample:
                    sample_indices = np.random.choice(len(trainset_filtered), len(forgetset), replace=False)
                    trainset_filtered = torch.utils.data.Subset(trainset_filtered, sample_indices)
                trainset = trainset_filtered
            elif args.unlearn_method == 'reference':
                if args.use_all_ref:
                    trainset = trainset
                else:
                    trainset = trainset_filtered
            elif args.unlearn_method == 'GA': 
                if args.use_remain_sample:
                    sample_indices = np.random.choice(len(trainset_filtered), len(forgetset), replace=False)
                    trainset_filtered = torch.utils.data.Subset(trainset_filtered, sample_indices)
                trainset = forgetset
            elif args.unlearn_method == 'RW':
                trainset = trainset
                testset = testset

            print('final len of trainset: ', len(trainset))  
            print('-----------------------------------------------------------------')
            
            if req_idx == 0:
                if args.model == 'ResNet18':
                    if orig_flag:
                        if args.dataset == 'imagenet':
                            net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes, tinynet=True)
                        else:
                            net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, device=device, elu_flag=False, num_classes=args.num_classes)
                        # net = ResNet18_orig(in_chan=in_chan, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_linear=False, bn_count=steps_count, device=device)
                    elif clip_flag:
                        net = ResNet18(concat_sv=concat_sv, in_chan=in_chan, device=device, clip=args.convsn, clip_concat=args.catsn, clip_flag=True, bn=bn_flag, bn_clip=bn_clip, bn_hard=bn_hard, clip_steps=clip_steps, bn_count=steps_count, clip_outer=clip_outer_flag, clip_opt_iter=opt_iter, summary=True, writer=writer, save_info=False, outer_iters=outer_iters, outer_steps=outer_steps, num_classes=args.num_classes)
                
                elif args.model == 'VGG':
                    if args.dataset == 'imagenet':
                        net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes, tinynet=True)
                    else:
                        net = VGG('VGG19', in_chan=in_chan, num_classes=args.num_classes)

                
                net = net.to(device)
                test_net = net
                test_net = nn.DataParallel(test_net) ### adds the "module." prefix to the state_dict keys
                net = nn.DataParallel(net) ### adds the "module." prefix to the state_dict keys
                criterion = nn.CrossEntropyLoss()

                mask = None
                if args.mask_path is not None and args.mask_path != 'None':
                    print('loading mask...')
                    mask = torch.load(args.mask_path)


                if args.unlearn_method != 'retrain' and args.unlearn_method != 'reference':
                    if clip_flag:
                        if bn_flag:
                            checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_200')
                        else:
                            checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_120')
                        net.load_state_dict(checkpoint['state_dict'], strict=False)
                    else:
                        if args.dataset == 'mnist':
                            checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_best')
                        else: checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_best')
                        net.load_state_dict(checkpoint['state_dict'])#, strict=False)
                    print("--->source model", args.source_model_path)
                    print('model loaded')

            tr_loss_list = []
            tr_acc_list = []
            ts_loss_list = []
            ts_acc_list = []
            fs_loss_list = []
            fs_acc_list = []
            re_loss_list = []
            re_acc_list = []
            best_keeping_list = []

            net.eval()
            print('-- train set:')
            tr_loss, tr_acc = 0., 0.
            print('-- test set:')
            ts_loss, ts_acc = test(testloader, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='test', model_path=None)
            print('--- forget set:')
            fs_loss, fs_acc = test(forgetloader, 200, criterion, unlearn_method=args.unlearn_method, writer=writer, mode='forget', model_path=None)
            print('-- remain set:')
            # remain_loss, remain_acc = test(remainloader, 200, criterion, writer=writer, mode='remain', model_path=None)
            remain_loss, remain_acc = 0., 0.
            
            print("Check Forget/Remain acc and MIA scores")
            test_checkpoint = torch.load(args.source_model_path + '/checkpoint.pth.tar_best')
            test_net.load_state_dict(test_checkpoint['state_dict'])#, strict=False)
            fst_loss, fst_acc = check_test(trainloader, forgetloader_test, test_net, 200, criterion, writer=writer, mode='forget', model_path=None)
            remaint_loss, remaint_acc = check_test(trainloader, remainloader_test, test_net, 200, criterion, writer=writer, mode='remain', model_path=None)
            t_loss, t_acc = check_test(trainloader, testloader, test_net, 200, criterion, writer=writer, mode='test', model_path=None)

            tr_loss_list.append(tr_loss)
            tr_acc_list.append(tr_acc)
            ts_loss_list.append(ts_loss)
            ts_acc_list.append(ts_acc)
            fs_loss_list.append(fs_loss)
            fs_acc_list.append(fs_acc)
            re_loss_list.append(remain_loss)
            re_acc_list.append(remain_acc)
            best_keeping_list.append(0)

            # net.train()

            if args.dataset == 'mnist':
                if args.unlearn_method == 'retrain' or args.unlearn_method == 'reference':
                    args.lr = 0.1

                optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                # optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                if args.unlearn_method == 'retrain':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    # T_max = 51
                    T_max = args.epochs
                    
                elif args.unlearn_method == 'reference':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 51
                    if not bn_flag:
                        T_max = 51
                else:
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = args.epochs
                    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)

            elif args.dataset == 'cifar10':
                if args.unlearn_method == 'retrain' or args.unlearn_method == 'reference':
                    args.lr = 0.1

                optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                #optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                if args.unlearn_method == 'retrain':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 201
                    if not bn_flag:
                        T_max = 121
                elif args.unlearn_method == 'reference':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 161
                    if not bn_flag:
                        T_max = 101
                else:
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = args.epochs
                    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)

            elif args.dataset == 'cifar100':
                if args.unlearn_method == 'retrain' or args.unlearn_method == 'reference':
                    args.lr = 0.1

                optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.95, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                #optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                if args.unlearn_method == 'retrain':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 201
                    if not bn_flag:
                        T_max = 121
                elif args.unlearn_method == 'reference':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 161
                    if not bn_flag:
                        T_max = 101
                else:
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = args.epochs
                    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)                
                
            elif args.dataset == 'imagenet':
                if args.unlearn_method == 'retrain' or args.unlearn_method == 'reference':
                    args.lr = 5e-4

                optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)  # momentum 0.9
                #optimizer = SGLD(net.parameters(), lr=args.lr, addnoise=True) 
                if args.unlearn_method == 'retrain':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 201
                    if not bn_flag:
                        T_max = 121
                elif args.unlearn_method == 'reference':
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = 161
                    if not bn_flag:
                        T_max = 101
                else:
                    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
                    T_max = args.epochs
                    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
            
            else:
                raise ValueError('dataset must be one of cifar, mnist')

            model_path =  outdir + '_ckpt'
            model_path_test =  outdir + '_ckpt_best_test.pth'


            if args.unlearn_method == 'genmask':
                save_dir = outdir + 'salun_mask/'
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                save_gradient_ratio(forgetloader , net, criterion, optimizer, save_dir)
                exit(0)


            print('epoch: ', start_epoch)
            print('Tmax: ', T_max)

            sv_df = {}
            test_model = copy.deepcopy(net)

            for epoch in range(T_max):
                tr_loss, tr_acc = train(epoch, optimizer, scheduler, criterion, test_model, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path, mask=mask)
                
                if True: #epoch % 5 == 0:
                    print('-- test set:')
                    ts_loss, ts_acc = test(testloader, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='test', plot_images=True)
                    print('--- forget set:')
                    fs_loss, fs_acc = test(forgetloader, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='forget')
                    fs_loss_test, fs_acc_test = test(forgetloader_test, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='forget')
                    print('-- remain set:')
                    remain_loss, remain_acc = test(remainloader, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='remain')
                    remain_loss_test, remain_acc_test = test(remainloader_test, epoch, criterion, unlearn_method=args.unlearn_method, writer=writer, model_path=model_path_test, mode='remain')

                    if ts_acc == best_acc:
                        best_keeping_list.append(1)
                    else:
                        best_keeping_list.append(0)

                    tr_loss_list.append(tr_loss)
                    tr_acc_list.append(tr_acc)
                    ts_loss_list.append(ts_loss)
                    ts_acc_list.append(ts_acc)
                    fs_loss_list.append(fs_loss)
                    fs_acc_list.append(fs_acc)
                    re_loss_list.append(remain_loss)
                    re_acc_list.append(remain_acc)
            
            print('Saving Last..', model_path)
            state = {
                'net': net.state_dict(),
                'epoch': epoch,
            }
            torch.save(state, model_path + '.pth')

            df = pd.DataFrame({'tr_loss': tr_loss_list, 'tr_acc': tr_acc_list, 'ts_loss': ts_loss_list, 'ts_acc': ts_acc_list, 'fs_loss': fs_loss_list, 'fs_acc': fs_acc_list, 're_loss': re_loss_list, 're_acc': re_acc_list, 'best_keeping': best_keeping_list})

            print('saving results to ...', outdir)
            if args.unlearn_method == 'retrain' or args.unlearn_method == 'reference':
                df.to_csv(outdir + 'loss_acc_results.csv')
            else:
                df.to_csv(outdir + str(step_size) + '_loss_acc_results.csv')
            
            # if args.unlearn_evaluate == 'svc_mia':
            # EVALUATION
            # MIA
            net.eval()
            # svc_mia(net, trainloader, testloader, forgetting_class, unlearn_method=args.unlearn_method)
            evaluation_result = SVC_attack(
                shadow_train=remainloader, 
                shadow_test=forgetloader, 
                target_train=remainloader_test,
                    target_test=forgetloader_test,
                    model=net,
                    forgetting_class=forgetting_class, 
                    unlearn_method=args.unlearn_method)

            # conditional probabilities
            print("===test on automobile===")
            evaluation_result = SVC_attack(
                shadow_train=remainloader_9, 
                shadow_test=forgetloader_9, 
                target_train=remainloader_test_9,
                target_test=forgetloader_test_9,
                model=net,
                forgetting_class=9, 
                unlearn_method=args.unlearn_method)

