#!/usr/bin/env python3
import argparse
import json
import logging
import os
import time
import copy
import random
from collections import defaultdict
from pathlib import Path

import numpy as np

# import torch as ch

import torch
import torch.nn.functional as F
import torch.optim as optim
from IPython import embed
from sklearn.svm import SVC
from tqdm import tqdm
from urllib3.filepost import writer

import evaluation
import models
import datasets_multiclass as datasets
from utils import *
from logger import Logger
import wandb

from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, \
    train_negrad, train_bcu, train_bcu_distill, validate
from thirdparty.repdistiller.helper.pretrain import init
from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate
from thirdparty.repdistiller.distiller_zoo import DistillKL

from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score
from Machine_Unlearning.Metrics.metrics import *

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def adjust_learning_rate(optimizer, epoch):
    if args.step_size is not None:
        lr = args.lr * 0.1 ** (epoch // args.step_size)
    else:
        lr = args.lr
    # optimizer.param_groups[0]['lr']
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def l2_penalty(model, model_init, weight_decay):
    l2_loss = 0
    for (k, p), (k_init, p_init) in zip(model.named_parameters(), model_init.named_parameters()):
        if p.requires_grad:
            l2_loss += (p - p_init).pow(2).sum()
    l2_loss *= (weight_decay / 2.)
    return l2_loss


def run_epoch(args, model, model_init, train_loader, criterion=torch.nn.CrossEntropyLoss(), optimizer=None,
              scheduler=None, epoch=0, weight_decay=0.0, mode='train', quiet=False, log_file=None):
    if mode == 'train':
        model.train()
    elif mode == 'test':
        model.eval()
    elif mode == 'dry_run':
        model.eval()
        set_batchnorm_mode(model, train=True)
    else:
        raise ValueError("Invalid mode.")

    if args.disable_bn:
        set_batchnorm_mode(model, train=False)

    mult = 0.5 if args.lossfn == 'mse' else 1
    metrics = AverageMeter()

    with torch.set_grad_enabled(mode != 'test'):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(args.device), target.to(args.device)

            if args.lossfn == 'mse':
                target = (2 * target - 1)
                target = target.type(torch.cuda.FloatTensor).unsqueeze(1)

            if 'mnist' in args.dataset:
                data = data.view(data.shape[0], -1)
            # print(data.shape)
            output = model(data)
            loss = mult * criterion(output, target) + l2_penalty(model, model_init, weight_decay)

            if args.l1:
                l1_loss = sum([p.norm(1) for p in model.parameters()])
                loss += args.weight_decay * l1_loss

            if ~quiet:
                metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))

            if mode == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

    log_metrics(mode, metrics, epoch, log_file=log_file)

    new_test_err = -100
    if mode == 'test':
        new_test_err = metrics.avg['error']

    logger.append('train' if mode == 'train' else 'test', epoch=epoch, loss=metrics.avg['loss'],
                  error=metrics.avg['error'],
                  lr=optimizer.param_groups[0]['lr'])
    return metrics, new_test_err


def replace_loader_dataset(data_loader, dataset, batch_size=64, seed=1, shuffle=True):
    manual_seed(seed)
    loader_args = {'num_workers': 0, 'pin_memory': False}

    def _init_fn(worker_id):
        np.random.seed(int(seed))

    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=shuffle)


def interclass_confusion(model, dataloader, class_to_forget, name):
    criterion = torch.nn.CrossEntropyLoss()
    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=128, shuffle=False)
    model.eval()
    reals = []
    predicts = []
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(args.device), target.to(args.device)
        if 'mnist' in args.dataset:
            data = data.view(data.shape[0], -1)
        output = model(data)
        probs = torch.nn.functional.softmax(output, dim=1)
        predict = np.argmax(probs.cpu().detach().numpy(), axis=1)
        reals = reals + list(target.cpu().detach().numpy())
        predicts = predicts + list(predict)

    classes = [0, 1, 2, 3, 4]
    cm = confusion_matrix(reals, predicts, labels=classes)
    counts = 0
    for i in range(len(cm)):
        if i != class_to_forget[0]:
            counts += cm[class_to_forget[0]][i]
        if i != class_to_forget[1]:
            counts += cm[class_to_forget[1]][i]

    ic_err = counts / (np.sum(cm[class_to_forget[0]]) + np.sum(cm[class_to_forget[1]]))
    fgt = cm[class_to_forget[0]][class_to_forget[1]] + cm[class_to_forget[1]][class_to_forget[0]]
    # print (cm)
    return ic_err, fgt


def seed_everything(seed):
    RNG = torch.Generator().manual_seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    return RNG


# def get_mia_eff(test_loader,retain_dataset,forget_loader,batch_size,model):
#     test_len = len(test_loader.dataset)
#     shadow_train = torch.utils.data.Subset(retain_dataset, list(range(test_len)))
#     shadow_train_loader = torch.utils.data.DataLoader(
#         shadow_train, batch_size=batch_size, shuffle=False
#     )
#
#     try:
#         mia_scores = evaluation.SVC_MIA(
#             shadow_train=shadow_train_loader,
#             shadow_test=test_loader,
#             target_train=None,
#             target_test=forget_loader,
#             model=model,
#         )
#     except:
#         mia_scores = -10086
#         print(mia_scores)
#         return mia_scores
#     return mia_scores["prob"]
# def train_mia_predictor_classifier(unlearned_model, retain_lorder, forget_lorder):
#     """
#     Trains a Support Vector Classifier (SVC) as the MIA predictor.
#     It uses simulated losses from Dr (labeled as 'members' of the original training set, 1)
#     and D_test (labeled as 'non-members' of the original training set, 0).
#
#     Args:
#         unlearned_model: The θu model to query for loss values.
#         dr_data_size: Number of samples to simulate from Dr for MIA training.
#         d_test_data_size: Number of samples to simulate from D_test for MIA training.
#
#     Returns:
#         sklearn.svm.SVC: The trained MIA predictor classifier.
#     """
#     print("\n--- MIA Predictor Training Phase (Classifier-based using SVC) ---")
#
#     # Simulate losses for samples from Dr (acting as 'members' for MIA training)
#     # The unlearned model SHOULD remember Dr, so these will have lower losses.
#     # dr_losses = unlearned_model.get_losses_for_dataset(dr_data_size, 'Dr')
#     # # Label for MIA training: 1 indicates 'member' of the original training set
#     # dr_labels = np.ones(dr_data_size)
#     #
#     # # Simulate losses for samples from D_test (acting as 'non-members' for MIA training)
#     # # The unlearned model SHOULD NOT remember D_test, so these will have higher losses.
#     # d_test_losses = unlearned_model.get_losses_for_dataset(d_test_data_size, 'D_test')
#     # # Label for MIA training: 0 indicates 'non-member' of the original training set
#     # d_test_labels = np.zeros(d_test_data_size)
#
#     # forget_lorder = forget_lorder.detach().to('cpu')
#     # retain_lorder = retain_lorder.detach().to('cpu')
#     ##forget_features
#     forget_logits = []
#     for inputs, targets in forget_lorder:
#         inputs, targets = inputs.to(DEVICE), targets
#
#         forget_logit = unlearned_model(inputs).detach().cpu()
#         forget_logits.append(forget_logit)
#     forget_logits = torch.cat(forget_logits)
#     forget_labels = torch.zeros(len(forget_logits))
#
#     ##retain_geatures
#     retain_logits = []
#     for inputs, targets in retain_lorder:
#         inputs, targets = inputs.to(DEVICE), targets
#
#         retain_logit = unlearned_model(inputs).detach().cpu()
#         retain_logits.append(retain_logit)
#     retain_logits = torch.cat(retain_logits)[:len(forget_lorder)]
#     retain_labels = torch.ones(len(retain_logits))
#
#     # print(retain_labels.shape)
#     # forget_logits=forget_logits.detach().to('cpu')
#     # retain_logits=retain_logits.detach().to('cpu')
#     # Combine data and reshape for scikit-learn
#     forget_lenth=int(len(forget_logit)*0.8)
#     retain_lenth=int(len(retain_logits)*0.8)
#     X_train_mia = torch.cat((forget_logits[:forget_lenth], retain_logits[:retain_lenth]))  # .reshape(-1, 1) # Features (loss)
#     y_train_mia = torch.cat((forget_labels[:forget_lenth], retain_labels[:retain_lenth]))  # Labels (membership status)
#
#     # Initialize and train the Support Vector Classifier
#     # Set probability=True to enable predict_proba for AUC calculation, though it can be slower.
#     mia_predictor = SVC(kernel='rbf', probability=True, random_state=42)
#     mia_predictor.fit(X_train_mia, y_train_mia)
#
#     # Evaluate MIA predictor's training performance
#     y_pred_mia_train = mia_predictor.predict(X_train_mia)
#     y_pred_mia_test = mia_predictor.predict(forget_logits[forget_lenth:])
#     try:
#         y_prob_mia_train = mia_predictor.predict_proba(X_train_mia)[:, 1]  # Probability of being a member
#         train_auc = roc_auc_score(y_train_mia, y_prob_mia_train)
#     except AttributeError:
#         train_auc = "N/A (probability=True might be slow or not converged)"
#         print("Warning: predict_proba not available or failed for SVC. AUC will not be computed.")
#
#     train_accuracy = accuracy_score(y_train_mia, y_pred_mia_train)
#     test_accuracy = accuracy_score(forget_labels[forget_lenth:], y_pred_mia_test)
#     # print(y_pred_mia_train)
#     print(f"MIA Predictor Trained (Support Vector Classifier).")
#     print(f"MIA Predictor Training Accuracy: {train_accuracy:.4f}")
#     print(f"MIA Predictor Training AUC: {train_auc}")  # Print as is, might be N/A
#     print(f"MIA Predictor test_accuracy: {test_accuracy}")
#     return test_accuracy


# def get_mia_eff(retain_loader, forget_loader, model):
#     # test_len = len(test_loader.dataset)
#     #
#     # # utils.dataset_convert_to_test(retain_dataset)
#     # # utils.dataset_convert_to_test(forget_loader)
#     # # utils.dataset_convert_to_test(test_loader)
#     #
#     # shadow_train = torch.utils.data.Subset(retain_dataset, list(range(test_len)))
#     # shadow_train_loader = torch.utils.data.DataLoader(
#     #     shadow_train, batch_size=batch_size, shuffle=False
#     # )
#     #
#     # try:
#     #     mia_scores = evaluation.SVC_MIA(
#     #         shadow_train=shadow_train_loader,
#     #         shadow_test=test_loader,
#     #         target_train=None,
#     #         target_test=forget_loader,
#     #         model=model,
#     #     )
#     # except:
#     #     mia_scores=-10086
#     #     print(mia_scores)
#     #     return mia_scores
#     return train_mia_predictor_classifier(unlearned_model=model, retain_lorder=retain_loader,
#                                           forget_lorder=forget_loader)  # mia_scores["prob"]


def readout(model, name, test_loader, retain_loader, forget_loader):
    results = {}
    RNG = seed_everything(42)
    test_entropies = compute_entropy(model, test_loader)
    retain_entropies = compute_entropy(model, retain_loader)
    forget_entropies = compute_entropy(model, forget_loader)

    results[f"test_entropies_{name}"] = test_entropies.tolist()
    results[f"retain_entropies_{name}"] = retain_entropies.tolist()
    results[f"forget_entropies_{name}"] = forget_entropies.tolist()

    test_losses = compute_losses(model, test_loader)
    retain_losses = compute_losses(model, retain_loader)
    forget_losses = compute_losses(model, forget_loader)

    results[f"test_losses_{name}"] = test_losses.tolist()
    results[f"retain_losses_{name}"] = retain_losses.tolist()
    results[f"forget_losses_{name}"] = forget_losses.tolist()

    # Since we have more forget losses than test losses, sub-sample them, to have a class-balanced dataset.
    mia_scores=compute_MIA_loss(test_losses,forget_losses)
    print(
        f"The MIA has an accuracy of {mia_scores.mean():.3f} on forgotten vs unseen images"
    )

    results[f"MIA_losses_{name}"] = mia_scores.mean()

    # Entropy MIA
    mia_scores=compute_MIA_entropy(test_entropies, forget_entropies)

    print(
        f"The MIA has an accuracy of {mia_scores.mean():.3f} on forgotten vs unseen images"
    )

    # try:
    #     mia_scores = get_mia_eff(retain_loader, forget_loader, model)
    # except:
    #     mia_scores = -1
    results[f"MIA_entropies_{name}"] = mia_scores

    results[f"train_accuracy_{name}"] = accuracy(model, train_loader)
    results[f"test_accuracy_{name}"] = accuracy(model, test_loader)
    results[f"forget_accuracy_{name}"] = accuracy(model, forget_loader)
    # print(accuracy(model, train_loader), accuracy(model, test_loader), accuracy(model, forget_loader))


def compute_MIA_loss(ft_test_losses,ft_forget_losses):
      gen = np.random.default_rng(1)

      if len(ft_test_losses) > len(ft_forget_losses):
          gen.shuffle(ft_test_losses)
          ft_test_losses = ft_test_losses[: len(ft_forget_losses)]
      else:
          gen.shuffle(ft_forget_losses)
          ft_forget_losses = ft_forget_losses[: len(ft_test_losses)]

      # make sure we have a balanced dataset for the MIA
      assert len(ft_test_losses) == len(ft_forget_losses)

      ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
      labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

      ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)
      return ft_mia_scores

def compute_MIA_entropy(test_entropies,forget_entropies):
    gen = np.random.default_rng(1)
    if len(test_entropies) > len(forget_entropies):
        gen.shuffle(test_entropies)
        test_entropies = test_entropies[: len(forget_entropies)]
    else:
        gen.shuffle(forget_entropies)
        forget_entropies = forget_entropies[: len(test_entropies)]
        # make sure we have a balanced dataset for the MIA
    assert len(test_entropies) == len(forget_entropies)

    samples_mia = np.concatenate((test_entropies, forget_entropies)).reshape((-1, 1))
    labels_mia = [0] * len(test_entropies) + [1] * len(forget_entropies)

    mia_scores = simple_mia(samples_mia, labels_mia)

    return mia_scores


def entropy(outputs):
    p = torch.nn.functional.softmax(outputs, dim=-1)
    return (-torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=-1, keepdim=False))


if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, choices=['train', 'forget'])
    parser.add_argument('--augment', action='store_true', default=False,
                        help='Use data augmentation')
    parser.add_argument('--quiet', action='store_true', default=False,
                        help='Use data augmentation')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--dataset', default='mnist')
    parser.add_argument('--dataroot', type=str, default='data/')
    parser.add_argument('--disable-bn', action='store_true', default=False,
                        help='Put batchnorm in eval mode and don\'t update the running averages')
    parser.add_argument('--epochs', type=int, default=500000, metavar='N',
                        help='number of epochs to train (default: 31)')
    parser.add_argument('--filters', type=float, default=1.0,
                        help='Percentage of filters')
    parser.add_argument('--forget_class', type=str, default=None,
                        help='Class to forget')
    parser.add_argument('--l1', action='store_true', default=False,
                        help='uses L1 regularizer instead of L2')
    parser.add_argument('--lossfn', type=str, default='ce',
                        help='Cross Entropy: ce or mse')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--model', default='mlp')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--num-classes', type=int, default=None,
                        help='Number of Classes')
    parser.add_argument('--num-to-forget', type=int, default=None,
                        help='Number of samples of class to forget')
    parser.add_argument('--confuse-mode', action='store_true', default=False,
                        help="enables the interclass confusion test")
    parser.add_argument('--name', default=None)
    parser.add_argument('--resume', type=str, default=None,
                        help='Checkpoint to resume dir')
    parser.add_argument('--resume_step', type=int, default=0, help='resume step')
    parser.add_argument('--seed', type=int, default=42, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--step_size', default=None, type=int, help='learning rate scheduler')
    parser.add_argument('--unfreeze-start', default=None, type=str,
                        help='All layers are freezed except the final layers starting from unfreeze-start')
    parser.add_argument('--weight-decay', type=float, default=0.0005, metavar='M',
                        help='Weight decay (default: 0.0005)')
    parser.add_argument('--lr_decay_epochs', type=str, default='30,30,30', help='lr decay epochs')
    parser.add_argument('--sgda-learning-rate', type=float, default=0.01, help='learning rate')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='learning rate decay rate')
    parser.add_argument('--print_freq', type=int, default=500, help='print frequency')

    args = parser.parse_args()

    test_ERR = 5.0
    iterations = args.lr_decay_epochs.split(',')
    args.lr_decay_epochs = list([])
    for it in iterations:
        args.lr_decay_epochs.append(int(it))

    if args.forget_class is not None:
        clss = args.forget_class.split(',')
        args.forget_class = list([])
        for c in clss:
            args.forget_class.append(int(c))

    manual_seed(args.seed)

    if args.step_size == None: args.step_size = args.epochs + 1

    if args.name is None:
        args.name = f"{args.dataset}_{args.model}_{str(args.filters).replace('.', '_')}"
        if args.split == 'train':
            args.name += f"_forget_{None}"
        else:
            args.name += f"_forget_{args.forget_class}"
            if args.num_to_forget is not None:
                args.name += f"_num_{args.num_to_forget}"
        if args.unfreeze_start is not None:
            args.name += f"_unfreeze_from_{args.unfreeze_start.replace('.', '_')}"
        if args.augment:
            args.name += f"_augment"
        args.name += f"_lr_{str(args.lr).replace('.', '_')}"
        args.name += f"_bs_{str(args.batch_size)}"
        args.name += f"_ls_{args.lossfn}"
        args.name += f"_wd_{str(args.weight_decay).replace('.', '_')}"
        args.name += f"_seed_{str(args.seed)}"
    print(f'Checkpoint name: {args.name}')

    mkdir('logs')

    logger = Logger(index=args.name + '_training')
    logger['args'] = args
    logger['checkpoint'] = os.path.join('models/', logger.index + '.pth')
    logger['checkpoint_step'] = os.path.join('models/', logger.index + '_{}.pth')

    print("[Logging in {}]".format(logger.index))

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    args.device = torch.device("cuda" if use_cuda else "cpu")
    save_folder = 'checkpoints_' + args.model
    if args.resume is not None:
        save_folder = save_folder + '_resume_' + str(args.resume_step)

    os.makedirs(os.path.join(save_folder,'nabla'), exist_ok=True)

    if args.split == 'forget':
        train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(args.dataset,
                                                                                      batch_size=args.batch_size,
                                                                                      seed=args.seed,
                                                                                      root=args.dataroot,
                                                                                      augment=False, shuffle=True)
        marked_loader, _, _ = datasets.get_loaders(args.dataset, class_to_replace=args.forget_class,
                                                   num_indexes_to_replace=args.num_to_forget, only_mark=True,
                                                   batch_size=1,
                                                   seed=args.seed, root=args.dataroot, augment=False, shuffle=True)

        forget_dataset = copy.deepcopy(marked_loader.dataset)
        marked = forget_dataset.targets < 0
        forget_dataset.data = forget_dataset.data[marked]
        forget_dataset.targets = - forget_dataset.targets[marked] - 1
        forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=64, seed=args.seed,
                                               shuffle=True)

        retain_dataset = copy.deepcopy(marked_loader.dataset)
        marked = retain_dataset.targets >= 0
        retain_dataset.data = retain_dataset.data[marked]
        retain_dataset.targets = retain_dataset.targets[marked]
        retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=32, seed=args.seed,
                                               shuffle=True)

    print('forget Class:', args.forget_class)
    train_loader, valid_loader, test_loader = datasets.get_loaders(args.dataset, class_to_replace=args.forget_class,
                                                                   num_indexes_to_replace=args.num_to_forget,
                                                                   confuse_mode=args.confuse_mode,
                                                                   batch_size=args.batch_size, split=args.split,
                                                                   seed=args.seed,
                                                                   root=args.dataroot, augment=args.augment)

    # for inputs, targets in test_loader:
    #     print(inputs, targets)

    num_classes = max(train_loader.dataset.targets) + 1 if args.num_classes is None else args.num_classes
    args.num_classes = 10 #num_classes
    print(f"Number of Classes: {num_classes}")
    model = models.get_model(args.model, num_classes=num_classes, filters_percentage=args.filters).to(args.device)

    if args.model == 'allcnn':
        classifier_name = 'classifier.'
    elif 'resnet' in args.model:
        classifier_name = 'linear.'

    if args.resume is not None:
        base_chkpt = torch.load(os.path.join(args.resume,
                                             f'checkpoint_s_{-1}.pt'))
        model = base_chkpt['model']
        opt = base_chkpt['optimizer']
        try:
            state_chkpt = torch.load(os.path.join(args.resume,
                                              f'checkpoint_s_{args.resume_step}.pt'))
        except:
            state_chkpt = torch.load(os.path.join(args.resume,
                                                  f'checkpoint-s:{args.resume_step}.pt'))
        model.load_state_dict(state_chkpt['model_state_dict'])
        opt.load_state_dict(state_chkpt['optimizer_state_dict'])
        # state = ch.load(args.resume)
        # state = {k: v for k, v in state['model_state_dict'].items() if not k.startswith(classifier_name)}
        # incompatible_keys = model.load_state_dict(state, strict=False)
        #
        # assert all([k.startswith(classifier_name) for k in incompatible_keys.missing_keys])
    model_init = copy.deepcopy(model)

    torch.save(model.state_dict(), f"{save_folder}/nabla/{args.name}_init.pt")

    # readout(model, '--test--', test_loader, retain_loader=retain_loader, forget_loader=forget_loader)

    parameters = model.parameters()
    if args.unfreeze_start is not None:
        parameters = []
        layer_index = 1e8
        for i, (n, p) in enumerate(model.named_parameters()):
            if (args.unfreeze_start in n) or (i > layer_index):
                layer_index = i
                parameters.append(p)

    weight_decay = args.weight_decay if not args.l1 else 0.
    # iters_per_epoch =len(train_loader)

    # optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=0.0)
    optimizer = optim.AdamW(parameters, lr=args.lr, betas=(0.9, 0.95), weight_decay=0)
    criterion = torch.nn.CrossEntropyLoss().to(args.device) if args.lossfn == 'ce' else torch.nn.MSELoss().to(
        args.device)
    # optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=0.9,weight_decay=args.weight_decay)
    forget_epochs = min(int((len(retain_loader) / (len(forget_loader) * 2)) * 6), args.epochs)

    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1,
                                                  total_iters=forget_epochs)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)  # Decay by 1% per epoch
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=args.epochs*2)
    train_time = 0
    # progress_bar = (range(args.epochs))
    log_dir = 'logs/' + args.name
    if args.resume is not None:
        log_dir = log_dir + '_resume_cp' + str(args.resume_step)
        print(log_dir)
    # writer = SummaryWriter(log_dir=log_dir)
    accuracy_metric_scores = []
    mia_loss_scores = []
    mia_entp_scores = []
    # mia_eff = []
    ic = []
    fgt = []

    with open(f"{save_folder}/output.log", "a") as log_file:
        acc_rs = []
        acc_fs = []
        acc_ts = []
        msteps = 3
        module_list = torch.nn.ModuleList([])
        criterion_list = torch.nn.ModuleList([])
        criterion_cls = torch.nn.CrossEntropyLoss()
        criterion_div = DistillKL(2)
        criterion_kd = DistillKL(2)
        # for epoch in tqdm(range(args.epochs)):

        # lr = sgda_adjust_learning_rate(epoch, args, optimizer)

        print("==> unlearning ...")

        # ic_r, fgt_r = interclass_confusion(model, retain_loader, args.forget_class, "SCRUB")
        # ic_t, fgt_t = interclass_confusion(model, test_loader, args.forget_class, "SCRUB")
        # writer.add_scalar("forget/retain", fgt_r, epoch)
        # writer.add_scalar("forget/test", fgt_t, epoch)
        # writer.add_scalar("ic/retain", ic_r, epoch)
        # writer.add_scalar("ic/test", ic_t, epoch)

        # progress_bar.update()
        # adjust_learning_rate(optimizer, epoch)

        model.eval()

        start_alpha = 0.1
        alpha = start_alpha

        model.eval()

        retain_iter = iter(retain_loader)

        alpha_sched = lambda start_a, a, max_ep, ep: a - (start_a / max_ep)
        for i in (range(forget_epochs)):
            model.eval()

            if i % 5 == 0:
                print("Computing current moments on test set")
                val_loss, first_test_moment, second_test_moment, test_std = compute_moments(model, valid_loader)
                # train_mean, first_train_moment, second_train_moment,train_std = compute_moments(model,retain_loader)
                print("Computed moments: " + str(val_loss) + "," + str(first_test_moment) + "," + str(
                    second_test_moment))

            ft_forget_losses = compute_losses(model, forget_loader)
            # print(ft_forget_losses)
            ft_test_losses = compute_losses(model, test_loader)

            ft_mia_scores= compute_MIA_loss(ft_forget_losses, ft_test_losses)

            print(
                f"The MIA_loss has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
            )
            mia_loss_scores.append(ft_mia_scores.mean())

            # ft_forget_entropy=entropy(out)

            forget_entropies = compute_entropy(model, forget_loader)
            test_entropies = compute_entropy(model, test_loader)
            en_mia_scores= compute_MIA_entropy(forget_entropies, test_entropies)

            print(
                f"The MIA_entropy has an accuracy of {en_mia_scores.mean():.3f} on forgotten vs unseen images"
            )
            # # try:
            # mia_scores = get_mia_eff(retain_loader, forget_loader, model)
            # print(mia_scores)
            # mia_eff.append(mia_scores)

            acc = 100.0 * accuracy(model, test_loader)
            racc = 100.0 * accuracy(model, retain_loader)
            uacc = 100.0 * accuracy(model, forget_loader)
            print(f"Accuracy on test set: {acc:.1f} , Racc: {racc:.1f} , Uacc: {uacc:.1f}")
            accuracy_metric_scores.append([acc, racc, uacc])
            ic_r, fgt_r = interclass_confusion(model, retain_loader, args.forget_class, "SCRUB")
            ic_t, fgt_t = interclass_confusion(model, test_loader, args.forget_class, "SCRUB")
            ic.append([ic_r, ic_t])
            fgt.append([fgt_r, fgt_t])

            model.eval()

            print("Forgetting epoch " + str(i))

            if i % (len(retain_loader) // len(forget_loader)) == 0:
                print("Resetting retain iterator...")
                retain_iter = iter(retain_loader)

            print("using alpha: " + str(alpha))
            for c, (inputs, targets) in enumerate(forget_loader):

                model.zero_grad()
                inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
                out = model(inputs)

                r_inputs, r_targets = next(retain_iter)
                r_inputs, r_targets = r_inputs.to(DEVICE), r_targets.to(DEVICE)
                r_out = model(r_inputs)

                forget_losses = criterion(out, targets)
                retain_losses = criterion(r_out, r_targets)

                # Forget loss metrics
                forget_mean = torch.mean(forget_losses)
                # print(forget_mean)
                forget_var = torch.mean((forget_losses - forget_mean) ** 2)
                forget_std = forget_var ** 0.5
                forget_skew = torch.mean((forget_losses - forget_mean) ** 3) / (forget_std ** 3)

                delta_val_loss = (val_loss - forget_mean)
                delta_first_moment = (first_test_moment - forget_var)
                delta_second_moment = (second_test_moment - forget_skew)

                # Retain loss metric
                retain_mean = torch.mean(retain_losses)

                if c % 40 == 0:
                    print("delta_val_loss: " + str(delta_val_loss.item()))
                    print("delta_first_moment: " + str(delta_first_moment.item()))
                    print("delta_second_moment: " + str(delta_second_moment.item()))

                loss = alpha * (torch.nn.functional.relu(delta_val_loss) ** 2) + (1 - alpha) * retain_mean

                loss.backward()
                optimizer.step()
                torch.save(model.state_dict(), f"{save_folder}/nabla/{args.name}_{c}.pt")

            alpha = alpha_sched(start_alpha, alpha, forget_epochs, i)
            # if use_scheduler:
            #     scheduler.step()

        model.eval()
        ft_forget_losses = compute_losses(model, forget_loader)
        ft_test_losses = compute_losses(model, test_loader)

        gen = np.random.default_rng(1)

        if len(ft_test_losses) > len(ft_forget_losses):
            gen.shuffle(ft_test_losses)
            ft_test_losses = ft_test_losses[: len(ft_forget_losses)]
        else:
            gen.shuffle(ft_forget_losses)
            ft_forget_losses = ft_forget_losses[: len(ft_test_losses)]
        # make sure we have a balanced dataset for the MIA
        assert len(ft_test_losses) == len(ft_forget_losses)

        ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
        labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

        ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

        print(
            f"The MIA loss has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
        )
        mia_loss_scores.append(ft_mia_scores.mean())

        forget_entropies = compute_entropy(model, forget_loader)
        test_entropies = compute_entropy(model, test_loader)
        en_mia_scores = compute_MIA_entropy(forget_entropies, test_entropies)

        print(
            f"The MIA_entropy has an accuracy of {en_mia_scores.mean():.3f} on forgotten vs unseen images"
        )
        # try:
        # mia_scores = get_mia_eff(retain_loader, forget_loader, model)
        # print(mia_scores)
        # mia_eff.append(mia_scores)

        acc = 100.0 * accuracy(model, test_loader)
        racc = 100.0 * accuracy(model, retain_loader)
        uacc = 100.0 * accuracy(model, forget_loader)
        for x in forget_loader:
            print('x is',x)
        print(f"Accuracy on test set: {acc:.1f} , Racc: {racc:.1f} , Uacc: {uacc:.1f}")
        accuracy_metric_scores.append([acc, racc, uacc])

        model.eval()

        out_dir = (
            fr"CIFAR10_output_{args.model}/nabla/CIFAR10_{args.model}_{str(args.forget_class)}_num{args.num_to_forget}_step{args.resume_step}")

        os.makedirs(fr"CIFAR10_output_{args.model}", exist_ok=True)
        os.makedirs(fr"CIFAR10_output_{args.model}/nabla", exist_ok=True)
        if not Path(out_dir).exists():
            Path(out_dir).mkdir(parents=True)
            print("Folders created.")

        np.savetxt(out_dir + '/acc'+str(args.seed)+'.csv', accuracy_metric_scores, delimiter=", ")
        np.savetxt(out_dir+'/mia_loss'+str(args.seed)+'.csv',mia_loss_scores,delimiter =", ")
        np.savetxt(out_dir + '/mia_entropy'+str(args.seed)+'.csv', mia_entp_scores, delimiter=", ")
        np.savetxt(out_dir + '/ic'+str(args.seed)+'.csv', ic, delimiter=", ")
        np.savetxt(out_dir + '/fgt'+str(args.seed)+'.csv', fgt, delimiter=", ")

        # log_file.flush()
        # writer.flush()
    # print(f'Pure training time: {train_time} sec', file=log_file)

# if __name__ == '__main__':
#     main()