{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='2'\n",
    "import variational\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "from itertools import cycle\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "import pandas as pd\n",
    "import wandb\n",
    "from collections import OrderedDict\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "    \n",
    "import copy\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from typing import List\n",
    "import itertools\n",
    "from tqdm.autonotebook import tqdm\n",
    "from models import *\n",
    "import models\n",
    "from logger import *\n",
    "import utils\n",
    "\n",
    "\n",
    "from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate\n",
    "from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss\n",
    "from thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss\n",
    "\n",
    "from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate\n",
    "from thirdparty.repdistiller.helper.pretrain import init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pdb():\n",
    "    import pdb\n",
    "    pdb.set_trace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parameter_count(model):\n",
    "    count=0\n",
    "    for p in model.parameters():\n",
    "        count+=np.prod(np.array(list(p.shape)))\n",
    "    print(f'Total Number of Parameters: {count}')\n",
    "    return count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vectorize_params(model):\n",
    "    param = []\n",
    "    for p in model.parameters():\n",
    "        param.append(p.data.view(-1).cpu().numpy())\n",
    "    return np.concatenate(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_param_shape(model):\n",
    "    for k,p in model.named_parameters():\n",
    "        print(k,p.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%run main.py --dataset cifar100 --dataroot=data/cifar-100-python --model resnet --filteers 0.4 --lr 0.1 --lossfn ce --num-classes 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_predictions(model,dataloader,name):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    metrics,activations,predictions=get_metrics(model,dataloader,criterion,128,True)\n",
    "    print(f\"{name} -> Loss:{np.round(metrics['loss'],3)}, Error:{metrics['error']}\")\n",
    "    log_dict[f\"{name}_loss\"]=metrics['loss']\n",
    "    log_dict[f\"{name}_error\"]=metrics['error']\n",
    "\n",
    "    return activations,predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predictions_distance(l1,l2,name):\n",
    "    dist = np.sum(np.abs(l1-l2))\n",
    "    print(f\"Predictions Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_predictions\"]=dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_distance(a1,a2,name):\n",
    "    dist = np.linalg.norm(a1-a2,ord=1,axis=1).mean()\n",
    "    print(f\"Activations Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_activations\"]=dist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Membership Inference Attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import random\n",
    "\n",
    "def cm_score(estimator, X, y):\n",
    "    y_pred = estimator.predict(X)\n",
    "    cnf_matrix = confusion_matrix(y, y_pred)\n",
    "    \n",
    "    FP = cnf_matrix[0][1] \n",
    "    FN = cnf_matrix[1][0] \n",
    "    TP = cnf_matrix[0][0] \n",
    "    TN = cnf_matrix[1][1]\n",
    "\n",
    "\n",
    "    # Sensitivity, hit rate, recall, or true positive rate\n",
    "    TPR = TP/(TP+FN)\n",
    "    # Specificity or true negative rate\n",
    "    TNR = TN/(TN+FP) \n",
    "    # Precision or positive predictive value\n",
    "    PPV = TP/(TP+FP)\n",
    "    # Negative predictive value\n",
    "    NPV = TN/(TN+FN)\n",
    "    # Fall out or false positive rate\n",
    "    FPR = FP/(FP+TN)\n",
    "    # False negative rate\n",
    "    FNR = FN/(TP+FN)\n",
    "    # False discovery rate\n",
    "    FDR = FP/(TP+FP)\n",
    "\n",
    "    # Overall accuracy\n",
    "    ACC = (TP+TN)/(TP+FP+FN+TN)\n",
    "    print (f\"FPR:{FPR:.2f}, FNR:{FNR:.2f}, FP{FP:.2f}, TN{TN:.2f}, TP{TP:.2f}, FN{FN:.2f}\")\n",
    "    return ACC\n",
    "\n",
    "\n",
    "def evaluate_attack_model(sample_loss,\n",
    "                          members,\n",
    "                          n_splits = 5,\n",
    "                          random_state = None):\n",
    "  \"\"\"Computes the cross-validation score of a membership inference attack.\n",
    "  Args:\n",
    "    sample_loss : array_like of shape (n,).\n",
    "      objective function evaluated on n samples.\n",
    "    members : array_like of shape (n,),\n",
    "      whether a sample was used for training.\n",
    "    n_splits: int\n",
    "      number of splits to use in the cross-validation.\n",
    "    random_state: int, RandomState instance or None, default=None\n",
    "      random state to use in cross-validation splitting.\n",
    "  Returns:\n",
    "    score : array_like of size (n_splits,)\n",
    "  \"\"\"\n",
    "\n",
    "  unique_members = np.unique(members)\n",
    "  if not np.all(unique_members == np.array([0, 1])):\n",
    "    raise ValueError(\"members should only have 0 and 1s\")\n",
    "\n",
    "  attack_model = LogisticRegression()\n",
    "  cv = StratifiedShuffleSplit(\n",
    "      n_splits=n_splits, random_state=random_state)\n",
    "  return cross_val_score(attack_model, sample_loss, members, cv=cv, scoring=cm_score)\n",
    "\n",
    "def membership_inference_attack(model, t_loader, f_loader, seed):\n",
    "    import matplotlib.pyplot as plt\n",
    "    import seaborn as sns\n",
    "    \n",
    "\n",
    "    fgt_cls = list(np.unique(f_loader.dataset.targets))\n",
    "    indices = [i in fgt_cls for i in t_loader.dataset.targets]\n",
    "    t_loader.dataset.data = t_loader.dataset.data[indices]\n",
    "    t_loader.dataset.targets = t_loader.dataset.targets[indices]\n",
    "\n",
    "    \n",
    "    cr = nn.CrossEntropyLoss(reduction='none')\n",
    "    test_losses = []\n",
    "    forget_losses = []\n",
    "    model.eval()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    dataloader = torch.utils.data.DataLoader(t_loader.dataset, batch_size=128, shuffle=False)\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*cr(output, target)\n",
    "        test_losses = test_losses + list(loss.cpu().detach().numpy())\n",
    "    del dataloader\n",
    "    dataloader = torch.utils.data.DataLoader(f_loader.dataset, batch_size=128, shuffle=False)\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*cr(output, target)\n",
    "        forget_losses = forget_losses + list(loss.cpu().detach().numpy())\n",
    "    del dataloader\n",
    "\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    if len(forget_losses) > len(test_losses):\n",
    "        forget_losses = list(random.sample(forget_losses, len(test_losses)))\n",
    "    elif len(test_losses) > len(forget_losses):\n",
    "        test_losses = list(random.sample(test_losses, len(forget_losses)))\n",
    "    \n",
    "    \n",
    "    fig, ax = plt.subplots()\n",
    "    sns.histplot(np.array(test_losses), kde=False, label='test-loss', ax=ax)\n",
    "    sns.histplot(np.array(forget_losses), kde=False, label='forget-loss', ax=ax)\n",
    "    plt.legend(prop={'size': 14})\n",
    "    plt.tick_params(labelsize=12)\n",
    "    plt.title(\"loss histograms\",size=18)\n",
    "    plt.xlabel('loss values',size=14)\n",
    "    plt.show()\n",
    "    print (np.max(test_losses), np.min(test_losses))\n",
    "    print (np.max(forget_losses), np.min(forget_losses))\n",
    "\n",
    "\n",
    "    test_labels = [0]*len(test_losses)\n",
    "    forget_labels = [1]*len(forget_losses)\n",
    "    features = np.array(test_losses + forget_losses).reshape(-1,1)\n",
    "    labels = np.array(test_labels + forget_labels).reshape(-1)\n",
    "    features = np.clip(features, -100, 100)\n",
    "    score = evaluate_attack_model(features, labels, n_splits=5, random_state=seed)\n",
    "\n",
    "    return score\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finetune and Fisher Helper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,bs=128,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=128, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = utils.AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        if scrub_act:\n",
    "            G = []\n",
    "            for cls in range(num_classes):\n",
    "                grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                grads = torch.cat([g.view(-1) for g in grads])\n",
    "                G.append(grads)\n",
    "            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "            G = torch.stack(G).pow(2)\n",
    "            delta_f = torch.matmul(G,delta_w)\n",
    "            output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            #activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            activations = activations + list(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def l2_penalty(model,model_init,weight_decay):\n",
    "    l2_loss = 0\n",
    "    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):\n",
    "        if p.requires_grad:\n",
    "            l2_loss += (p-p_init).pow(2).sum()\n",
    "    l2_loss *= (weight_decay/2.)\n",
    "    return l2_loss\n",
    "\n",
    "def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    negative_gradient=False, negative_multiplier=-1, random_labels=False,\n",
    "                    quiet=False,delta_w=None,scrub_act=False):\n",
    "    model.eval()\n",
    "    metrics = utils.AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, batch in enumerate(tqdm(data_loader, leave=False)):\n",
    "            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]\n",
    "            input, target = batch\n",
    "            output = model(input)\n",
    "            if split=='test' and scrub_act:\n",
    "                G = []\n",
    "                for cls in range(num_classes):\n",
    "                    grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                    grads = torch.cat([g.view(-1) for g in grads])\n",
    "                    G.append(grads)\n",
    "                grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "                G = torch.stack(G).pow(2)\n",
    "                delta_f = torch.matmul(G,delta_w)\n",
    "                output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "            loss = loss_fn(output, target) + l2_penalty(model,model_init,args.weight_decay)\n",
    "            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))\n",
    "            \n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg\n",
    "\n",
    "def run_neggrad_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    forget_loader: torch.utils.data.DataLoader,\n",
    "                    alpha: float,\n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    quiet=False):\n",
    "    model.eval()\n",
    "    metrics = utils.AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, (batch_retain,batch_forget) in enumerate(tqdm(zip(data_loader,cycle(forget_loader)), leave=False)):\n",
    "            batch_retain = [tensor.to(next(model.parameters()).device) for tensor in batch_retain]\n",
    "            batch_forget = [tensor.to(next(model.parameters()).device) for tensor in batch_forget]\n",
    "            input_r, target_r = batch_retain\n",
    "            input_f, target_f = batch_forget\n",
    "            output_r = model(input_r)\n",
    "            output_f = model(input_f)\n",
    "            loss = alpha*(loss_fn(output_r, target_r) + l2_penalty(model,model_init,args.weight_decay)) - (1-alpha)*loss_fn(output_f, target_f)\n",
    "            metrics.update(n=input_r.size(0), loss=loss_fn(output_r,target_r).item(), error=get_error(output_r, target_r))\n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def finetune(model: nn.Module, data_loader: torch.utils.data.DataLoader, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "\n",
    "def negative_grad(model: nn.Module, data_loader: torch.utils.data.DataLoader, forget_loader: torch.utils.data.DataLoader, alpha: float, lr=0.01, epochs=10, quiet=False, args=None):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        train_negrad(epoch, data_loader, forget_loader, model, loss_fn, optimizer,  alpha, args)\n",
    "\n",
    "def fk_fientune(model: nn.Module, data_loader: torch.utils.data.DataLoader, args, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "        train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "def test(model, data_loader):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    model_init=copy.deepcopy(model)\n",
    "    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', epoch=0, ignore_index=None, quiet=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readout_retrain(model, data_loader, test_loader, lr=0.1, epochs=500, threshold=0.01, quiet=True):\n",
    "    torch.manual_seed(seed)\n",
    "    model = copy.deepcopy(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    sampler = torch.utils.data.RandomSampler(data_loader.dataset, replacement=True, num_samples=500)\n",
    "    data_loader_small = torch.utils.data.DataLoader(data_loader.dataset, batch_size=data_loader.batch_size, sampler=sampler, num_workers=data_loader.num_workers)\n",
    "    metrics = []\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        metrics.append(run_train_epoch(model, model_init, test_loader, loss_fn, optimizer, split='test', epoch=epoch, ignore_index=None, quiet=quiet))\n",
    "        if metrics[-1]['loss'] <= threshold:\n",
    "            break\n",
    "        run_train_epoch(model, model_init, data_loader_small, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "    return epoch, metrics\n",
    "\n",
    "def extract_retrain_time(metrics, threshold=0.1):\n",
    "    losses = np.array([m['loss'] for m in metrics])\n",
    "    return np.argmax(losses < threshold)\n",
    "\n",
    "def all_readouts(model,thresh=0.1,name='method', seed=0):\n",
    "    MIA = membership_inference_attack(model, copy.deepcopy(test_loader_full), forget_loader, seed)\n",
    "    #train_loader = torch.utils.data.DataLoader(train_loader_full.dataset, batch_size=128, shuffle=True)\n",
    "    retrain_time, _ = 0,0#readout_retrain(model, train_loader, forget_loader, epochs=100, lr=0.001, threshold=thresh)\n",
    "    test_error = test(model, test_loader_full)['error']*100\n",
    "    forget_error = test(model, forget_loader)['error']*100\n",
    "    retain_error = test(model, retain_loader)['error']*100\n",
    "    val_error = test(model, valid_loader_full)['error']*100\n",
    "    \n",
    "    print(f\"{name} ->\"\n",
    "          f\"\\tFull test error: {test_error:.2f}\"\n",
    "          f\"\\tForget error: {forget_error:.2f}\\tRetain error: {retain_error:.2f}\\tValid error: {val_error:.2f}\"\n",
    "          f\"\\tFine-tune time: {retrain_time+1} steps\\tMIA: {np.mean(MIA):.2f}±{np.std(MIA):0.1f}\")\n",
    "    \n",
    "    return(dict(test_error=test_error, forget_error=forget_error, retain_error=retain_error, val_error=val_error, retrain_time=retrain_time+1, MIA=np.mean(MIA)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def badt(gteacher, bteacher, student):\n",
    "    args.bt_optim = 'adam'\n",
    "    args.bt_alpha = 1\n",
    "    args.bt_beta = 1\n",
    "    args.bt_kd_T = 4\n",
    "    args.bt_distill = 'kd'\n",
    "\n",
    "    args.bt_epochs = 5\n",
    "    args.bt_learning_rate = 0.00005\n",
    "    args.bt_lr_decay_epochs = [10,10,10]\n",
    "    args.bt_lr_decay_rate = 0.1\n",
    "    args.bt_weight_decay = 5e-4\n",
    "    args.bt_momentum = 0.9\n",
    "\n",
    "    model_gt = copy.deepcopy(gteacher)\n",
    "    model_bt = copy.deepcopy(bteacher)\n",
    "    model_s = copy.deepcopy(student)\n",
    "\n",
    "\n",
    "    module_list = nn.ModuleList([])\n",
    "    module_list.append(model_s)\n",
    "    trainable_list = nn.ModuleList([])\n",
    "    trainable_list.append(model_s)\n",
    "\n",
    "    criterion_cls = nn.CrossEntropyLoss()\n",
    "    criterion_div = DistillKL(args.bt_kd_T)\n",
    "    criterion_kd = DistillKL(args.bt_kd_T)\n",
    "\n",
    "\n",
    "    criterion_list = nn.ModuleList([])\n",
    "    criterion_list.append(criterion_cls)    # classification loss\n",
    "    criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation\n",
    "    criterion_list.append(criterion_kd)     # other knowledge distillation loss\n",
    "\n",
    "    # optimizer\n",
    "    if args.bt_optim == \"sgd\":\n",
    "        optimizer = optim.SGD(trainable_list.parameters(),\n",
    "                              lr=args.bt_learning_rate,\n",
    "                              momentum=args.bt_momentum,\n",
    "                              weight_decay=args.bt_weight_decay)\n",
    "    elif args.bt_optim == \"adam\": \n",
    "        optimizer = optim.Adam(trainable_list.parameters(),\n",
    "                              lr=args.bt_learning_rate,\n",
    "                              weight_decay=args.bt_weight_decay)\n",
    "    elif args.bt_optim == \"rmsp\":\n",
    "        optimizer = optim.RMSprop(trainable_list.parameters(),\n",
    "                              lr=args.bt_learning_rate,\n",
    "                              momentum=args.bt_momentum,\n",
    "                              weight_decay=args.bt_weight_decay)\n",
    "\n",
    "    module_list.append(model_gt)\n",
    "    module_list.append(model_bt)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        module_list.cuda()\n",
    "        criterion_list.cuda()\n",
    "        import torch.backends.cudnn as cudnn\n",
    "        cudnn.benchmark = True\n",
    "\n",
    "\n",
    "    acc_rs = []\n",
    "    acc_fs = []\n",
    "    acc_vs = []\n",
    "    \n",
    "    print(\"==> Bad Teacher Unlearning ...\")\n",
    "    for epoch in range(1, args.bt_epochs + 1):\n",
    "\n",
    "        acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "        acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "        acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "        acc_rs.append(100-acc_r.item())\n",
    "        acc_fs.append(100-acc_f.item())\n",
    "        acc_vs.append(100-acc_v.item())\n",
    "\n",
    "        lr = sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "        train_acc, loss = train_bad_teacher(epoch, retain_loader, forget_loader, module_list, criterion_list, optimizer, args)\n",
    "\n",
    "\n",
    "\n",
    "        print (\"loss: {:.2f}\\t train_acc: {}\".format(loss, train_acc))\n",
    "        \n",
    "\n",
    "    acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "    acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "    acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "    acc_rs.append(100-acc_r.item())\n",
    "    acc_fs.append(100-acc_f.item())\n",
    "    acc_vs.append(100-acc_v.item())\n",
    "\n",
    "    from matplotlib import pyplot as plt\n",
    "    indices = list(range(0,len(acc_rs)))\n",
    "    plt.plot(indices, acc_rs, marker='*', color=u'#1f77b4', alpha=1, label='retain-set')\n",
    "    plt.plot(indices, acc_fs, marker='o', color=u'#ff7f0e', alpha=1, label='forget-set')\n",
    "    plt.plot(indices, acc_vs, marker='^', color=u'#2ca02c',alpha=1, label='validation-set')\n",
    "    plt.legend(prop={'size': 14})\n",
    "    plt.tick_params(labelsize=12)\n",
    "    #plt.title('sgda retain- and forget- set error',size=18)\n",
    "    plt.xlabel('epoch',size=14)\n",
    "    plt.ylabel('error',size=14)\n",
    "    plt.grid()\n",
    "    #plt.ylim(0,0.4)\n",
    "    #plt.xlim(-5,2)\n",
    "    #plt.savefig('Plots/small_cifar5_allcnn_forget0_num5_epochs25_'+title+'.png')\n",
    "    plt.show()\n",
    "    \n",
    "    \n",
    "    return model_s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scrub(teacher, student):\n",
    "    args.optim = 'adam'\n",
    "    args.gamma = 1\n",
    "    args.alpha = 0.5\n",
    "    args.beta = 0\n",
    "    args.smoothing = 0.5\n",
    "    args.msteps = 5\n",
    "    args.clip = 0.2\n",
    "    args.sstart = 10\n",
    "    args.kd_T = 4\n",
    "    args.distill = 'kd'\n",
    "\n",
    "    args.sgda_epochs = 5\n",
    "    args.sgda_learning_rate = 0.0005\n",
    "    args.lr_decay_epochs = [5,5,9]\n",
    "    args.lr_decay_rate = 0.1\n",
    "    args.sgda_weight_decay = 5e-4\n",
    "    args.sgda_momentum = 0.9\n",
    "\n",
    "    model_t = copy.deepcopy(teacher)\n",
    "    model_s = copy.deepcopy(student)\n",
    "\n",
    "    #this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py\n",
    "    #For SGDA smoothing\n",
    "    beta = 0.1\n",
    "    def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (\n",
    "        1 - beta) * averaged_model_parameter + beta * model_parameter\n",
    "    swa_model = torch.optim.swa_utils.AveragedModel(\n",
    "        model_s, avg_fn=avg_fn)\n",
    "\n",
    "    module_list = nn.ModuleList([])\n",
    "    module_list.append(model_s)\n",
    "    trainable_list = nn.ModuleList([])\n",
    "    trainable_list.append(model_s)\n",
    "\n",
    "    criterion_cls = nn.CrossEntropyLoss()\n",
    "    criterion_div = DistillKL(args.kd_T)\n",
    "    criterion_kd = DistillKL(args.kd_T)\n",
    "\n",
    "\n",
    "    criterion_list = nn.ModuleList([])\n",
    "    criterion_list.append(criterion_cls)    # classification loss\n",
    "    criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation\n",
    "    criterion_list.append(criterion_kd)     # other knowledge distillation loss\n",
    "\n",
    "    # optimizer\n",
    "    if args.optim == \"sgd\":\n",
    "        optimizer = optim.SGD(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              momentum=args.sgda_momentum,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "    elif args.optim == \"adam\": \n",
    "        optimizer = optim.Adam(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "    elif args.optim == \"rmsp\":\n",
    "        optimizer = optim.RMSprop(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              momentum=args.sgda_momentum,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "\n",
    "    module_list.append(model_t)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        module_list.cuda()\n",
    "        criterion_list.cuda()\n",
    "        import torch.backends.cudnn as cudnn\n",
    "        cudnn.benchmark = True\n",
    "        swa_model.cuda()\n",
    "\n",
    "\n",
    "    t1 = time.time()\n",
    "    acc_rs = []\n",
    "    acc_fs = []\n",
    "    acc_vs = []\n",
    "    acc_fvs = []\n",
    "    \n",
    "    \n",
    "    forget_validation_loader = copy.deepcopy(valid_loader_full)\n",
    "    fgt_cls = list(np.unique(forget_loader.dataset.targets))\n",
    "    indices = [i in fgt_cls for i in forget_validation_loader.dataset.targets]\n",
    "    forget_validation_loader.dataset.data = forget_validation_loader.dataset.data[indices]\n",
    "    forget_validation_loader.dataset.targets = forget_validation_loader.dataset.targets[indices]\n",
    "    \n",
    "    scrub_name = \"checkpoints/scrub_{}_{}_seed{}_step\".format(args.model, args.dataset, args.seed)\n",
    "    for epoch in range(1, args.sgda_epochs + 1):\n",
    "\n",
    "        lr = sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "\n",
    "        acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "        acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "        acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "        acc_fv, acc5_fv, loss_fv = validate(forget_validation_loader, model_s, criterion_cls, args, True)\n",
    "        acc_rs.append(100-acc_r.item())\n",
    "        acc_fs.append(100-acc_f.item())\n",
    "        acc_vs.append(100-acc_v.item())\n",
    "        acc_fvs.append(100-acc_fv.item())\n",
    "\n",
    "        maximize_loss = 0\n",
    "        if epoch <= args.msteps:\n",
    "            maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, \"maximize\")\n",
    "        train_acc, train_loss = train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, \"minimize\",)\n",
    "        if epoch >= args.sstart:\n",
    "            swa_model.update_parameters(model_s)\n",
    "        \n",
    "        torch.save(model_s.state_dict(), scrub_name+str(epoch-1)+\".pt\")\n",
    "\n",
    "        print (\"maximize loss: {:.2f}\\t minimize loss: {:.2f}\\t train_acc: {}\".format(maximize_loss, train_loss, train_acc))\n",
    "    t2 = time.time()\n",
    "    print (t2-t1)\n",
    "\n",
    "    acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "    acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "    acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "    acc_fv, acc5_fv, loss_fv = validate(forget_validation_loader, model_s, criterion_cls, args, True)\n",
    "    acc_rs.append(100-acc_r.item())\n",
    "    acc_fs.append(100-acc_f.item())\n",
    "    acc_vs.append(100-acc_v.item())\n",
    "    acc_fvs.append(100-acc_fv.item())\n",
    "\n",
    "    from matplotlib import pyplot as plt\n",
    "    indices = list(range(0,len(acc_rs)))\n",
    "    plt.plot(indices, acc_rs, marker='*', color=u'#1f77b4', alpha=1, label='retain-set')\n",
    "    plt.plot(indices, acc_fs, marker='o', color=u'#ff7f0e', alpha=1, label='forget-set')\n",
    "    plt.plot(indices, acc_vs, marker='^', color=u'#2ca02c',alpha=1, label='validation-set')\n",
    "    plt.plot(indices, acc_fvs, marker='.', color='red',alpha=1, label='forget-validation-set')\n",
    "    plt.legend(prop={'size': 14})\n",
    "    plt.tick_params(labelsize=12)\n",
    "    plt.xlabel('epoch',size=14)\n",
    "    plt.ylabel('error',size=14)\n",
    "    plt.grid()\n",
    "    plt.show()\n",
    "    \n",
    "    \n",
    "    try:\n",
    "        selected_idx, _ = min(enumerate(acc_fs), key=lambda x: abs(x[1]-acc_fvs[-1]))\n",
    "    except:\n",
    "        selected_idx = len(acc_fs) -1\n",
    "    print (\"the selected index is {}\".format(selected_idx))\n",
    "    selected_model = \"checkpoints/scrub_{}_{}_seed{}_step{}.pt\".format(args.model, args.dataset, args.seed, int(selected_idx))\n",
    "    model_s_final = copy.deepcopy(model_s)\n",
    "    model_s.load_state_dict(torch.load(selected_model))\n",
    "    \n",
    "    \n",
    "    return model_s, model_s_final\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def replace_loader_dataset(data_loader, dataset, batch_size=128, seed=1, shuffle=True):\n",
    "    manual_seed(seed)\n",
    "    loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "    def _init_fn(worker_id):\n",
    "        np.random.seed(int(seed))\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cfk_unlearn(model):\n",
    "    args.lr_decay_epochs = [10,15,20]\n",
    "    args.cfk_lr = 0.01\n",
    "    args.cfk_epochs = 10\n",
    "    args.cfk_bs = 64\n",
    "    r_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.cfk_bs, shuffle=True)\n",
    "\n",
    "    model_cfk = copy.deepcopy(model)\n",
    "\n",
    "    for param in model_cfk.parameters():\n",
    "        param.requires_grad_(False)\n",
    "\n",
    "    if args.model == 'allcnn':\n",
    "        layers = [9]\n",
    "        for k in layers:\n",
    "            for param in model_cfk.features[k].parameters():\n",
    "                param.requires_grad_(True)\n",
    "\n",
    "    elif args.model == \"resnet\":\n",
    "        for param in model_cfk.layer4.parameters():\n",
    "            param.requires_grad_(True)\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    fk_fientune(model_cfk, r_loader, args=args, epochs=args.cfk_epochs, quiet=True, lr=args.cfk_lr)\n",
    "    return model_cfk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def euk_unlearn(model, model_initial):\n",
    "    args.lr_decay_epochs = [10,15,20]\n",
    "    args.euk_lr = 0.01\n",
    "    args.euk_epochs = training_epochs\n",
    "    args.euk_bs = 64\n",
    "    r_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.euk_bs, shuffle=True)\n",
    "    model_euk = copy.deepcopy(model)\n",
    "\n",
    "    for param in model_euk.parameters():\n",
    "        param.requires_grad_(False)\n",
    "\n",
    "    if args.model == 'allcnn':\n",
    "        with torch.no_grad():\n",
    "            for k in layers:\n",
    "                for i in range(0,3):\n",
    "                    try:\n",
    "                        model_euk.features[k][i].weight.copy_(model_initial.features[k][i].weight)\n",
    "                    except:\n",
    "                        print (\"block {}, layer {} does not have weights\".format(k,i))\n",
    "                    try:\n",
    "                        model_euk.features[k][i].bias.copy_(model_initial.features[k][i].bias)\n",
    "                    except:\n",
    "                        print (\"block {}, layer {} does not have bias\".format(k,i))\n",
    "            model_euk.classifier[0].weight.copy_(model_initial.classifier[0].weight)\n",
    "            model_euk.classifier[0].bias.copy_(model_initial.classifier[0].bias)\n",
    "\n",
    "        for k in layers:\n",
    "            for param in model_euk.features[k].parameters():\n",
    "                param.requires_grad_(True)\n",
    "\n",
    "    elif args.model == \"resnet\":\n",
    "        with torch.no_grad():\n",
    "            for i in range(0,2):\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn1.weight.copy_(model_initial.layer4[i].bn1.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn1.bias.copy_(model_initial.layer4[i].bn1.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv1.weight.copy_(model_initial.layer4[i].conv1.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv1.bias.copy_(model_initial.layer4[i].conv1.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn2.weight.copy_(model_initial.layer4[i].bn2.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn2.bias.copy_(model_initial.layer4[i].bn2.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv2.weight.copy_(model_initial.layer4[i].conv2.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv2.bias.copy_(model_initial.layer4[i].conv2.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "            model_euk.layer4[0].shortcut[0].weight.copy_(model_initial.layer4[0].shortcut[0].weight)\n",
    "\n",
    "        for param in model_euk.layer4.parameters():\n",
    "            param.requires_grad_(True)\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    fk_fientune(model_euk, r_loader, epochs=args.euk_epochs, quiet=True, lr=args.euk_lr, args=args)\n",
    "    return model_euk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "seeds = [1]\n",
    "chkpt = \"checkpoints/cifar100_resnet_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt\"\n",
    "forget_class = '0'\n",
    "forget_num = 100\n",
    "dataset = 'cifar100'\n",
    "modelname = 'resnet50'\n",
    "dataroot = '../image_data/cifar100'\n",
    "filters = 1.0\n",
    "\n",
    "\n",
    "errors = []\n",
    "for s in seeds:\n",
    "\n",
    "    %run train.py --dataset $dataset --model $modelname --dataroot=$dataroot --filters $filters --lr 0.005 \\\n",
    "    --disable-bn --weight-decay 0.0005 --batch-size 128 --epochs 51 --seed $s\n",
    "\n",
    "    %run train.py --dataset $dataset --model $modelname --dataroot=$dataroot --filters $filters --lr 0.005 \\\n",
    "    --disable-bn --weight-decay 0.0005 --batch-size 128 --epochs 0 \\\n",
    "    --forget-class $forget_class --num-to-forget $forget_num --seed $s\n",
    "\n",
    "\n",
    "    log_dict={}\n",
    "    training_epochs=50\n",
    "    model0 = copy.deepcopy(model)\n",
    "    model_initial = copy.deepcopy(model)\n",
    "\n",
    "    arch = args.model \n",
    "    filters=args.filters\n",
    "    arch_filters = arch +'_'+ str(filters).replace('.','_')\n",
    "    augment = False\n",
    "    dataset = args.dataset\n",
    "    class_to_forget = args.forget_class\n",
    "    init_checkpoint = f\"checkpoints/{args.name}_init.pt\"\n",
    "    num_classes=args.num_classes\n",
    "    num_to_forget = args.num_to_forget\n",
    "    num_total = len(train_loader.dataset)\n",
    "    num_to_retain = num_total - forget_num\n",
    "    seed = args.seed\n",
    "    unfreeze_start = None\n",
    "\n",
    "    learningrate=f\"lr_{str(args.lr).replace('.','_')}\"\n",
    "    batch_size=f\"_bs_{str(args.batch_size)}\"\n",
    "    lossfn=f\"_ls_{args.lossfn}\"\n",
    "    wd=f\"_wd_{str(args.weight_decay).replace('.','_')}\"\n",
    "    seed_name=f\"_seed_{args.seed}_\"\n",
    "\n",
    "    num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'\n",
    "    unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'\n",
    "    augment_tag = '' if not augment else f'augment_'\n",
    "\n",
    "    m_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "    m0_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "    \n",
    "    \n",
    "    log_dict={}\n",
    "    log_dict['args']=args\n",
    "    args.retain_bs = 128\n",
    "    args.forget_bs = 64\n",
    "\n",
    "    train_loader_full, valid_loader_full, test_loader_full, _ = datasets.get_loaders(args.dataset, batch_size=args.batch_size, seed=s, root=args.dataroot, augment=False, shuffle=True)\n",
    "    marked_loader, _, _, _ = datasets.get_loaders(args.dataset, class_to_replace=args.forget_class, num_indexes_to_replace=args.num_to_forget, only_mark=True, batch_size=1, seed=s, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "    forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "    marked = forget_dataset.targets < 0\n",
    "    forget_dataset.data = forget_dataset.data[marked]\n",
    "    forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "    #forget_loader = torch.utils.data.DataLoader(forget_dataset, batch_size=args.forget_bs,num_workers=0,pin_memory=True,shuffle=True)\n",
    "    forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)\n",
    "\n",
    "    retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "    marked = retain_dataset.targets >= 0\n",
    "    retain_dataset.data = retain_dataset.data[marked]\n",
    "    retain_dataset.targets = retain_dataset.targets[marked]\n",
    "    #retain_loader = torch.utils.data.DataLoader(retain_dataset, batch_size=args.retain_bs,num_workers=0,pin_memory=True,shuffle=True)\n",
    "    retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)\n",
    "\n",
    "    assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))\n",
    "\n",
    "\n",
    "    model.load_state_dict(torch.load(m_name))\n",
    "    model0.load_state_dict(torch.load(m0_name))\n",
    "    model_initial.load_state_dict(torch.load(init_checkpoint))\n",
    "    \n",
    "    bad_teacher = model_dict[args.model](num_classes=num_classes).to(args.device)\n",
    "    teacher = copy.deepcopy(model)\n",
    "    student = copy.deepcopy(model)\n",
    "\n",
    "    model.cuda()\n",
    "    model0.cuda()\n",
    "\n",
    "\n",
    "    for p in model.parameters():\n",
    "        p.data0 = p.data.clone()\n",
    "    for p in model0.parameters():\n",
    "        p.data0 = p.data.clone()\n",
    "\n",
    "    \n",
    "    model_ft = copy.deepcopy(model)\n",
    "    model_ng = copy.deepcopy(model)\n",
    "    args.ft_lr = 0.04\n",
    "    args.ft_epochs = 10\n",
    "    args.ng_alpha = 0.9999\n",
    "    args.ng_epochs = 5\n",
    "    args.ng_lr = 0.01\n",
    "    \n",
    "\n",
    "    \n",
    "    print (\"Forgetting by Fine-tuneing:\")\n",
    "    #finetune(model_ft, retain_loader, epochs=args.ft_epochs, quiet=True, lr=args.ft_lr)\n",
    "    print (\"Forgetting by NegGrad:\")\n",
    "    #negative_grad(model_ng, retain_loader, forget_loader, alpha=args.ng_alpha, epochs=args.ng_epochs, quiet=True, lr=args.ng_lr, args=args)\n",
    "    print (\"Forgetting by CFK:\")\n",
    "    #model_cfk = cfk_unlearn(model)\n",
    "    print (\"Forgetting by EUK:\")\n",
    "    #model_euk = euk_unlearn(model, model_initial)\n",
    "    print ('Forgetting by Bad-T')\n",
    "    model_bt = 0#badt(teacher, bad_teacher, student)\n",
    "    print(\"Forgetting by SCRUB:\")\n",
    "    model_s, model_s_final = 0,0#scrub(teacher, student)\n",
    "\n",
    "            \n",
    "    readouts = {}\n",
    "    #_,_=activations_predictions(copy.deepcopy(model),forget_loader,'Original_Model_D_f')\n",
    "    thresh=0#og_dict['Original_Model_D_f_loss']+1e-5\n",
    "    readouts[\"Original\"] = all_readouts(copy.deepcopy(model),thresh,'Original',seed)\n",
    "    readouts[\"Retrain\"] = all_readouts(copy.deepcopy(model0),thresh,'Retrain',seed)\n",
    "    #readouts['Finetune'] = all_readouts(copy.deepcopy(model_ft),thresh,'Finetune',seed)\n",
    "    #readouts[\"NegGrad\"] = all_readouts(copy.deepcopy(model_ng),thresh,'NegGrad',seed)\n",
    "    #readouts[\"CFK\"] = all_readouts(copy.deepcopy(model_cfk),thresh,'CFK',seed)\n",
    "    #readouts[\"EUK\"] = all_readouts(copy.deepcopy(model_euk),thresh,'EUK',seed)\n",
    "    #readouts[\"Bad-T\"] = all_readouts(copy.deepcopy(model_bt),thresh,'Bad-T',seed)\n",
    "    #readouts[\"SCRUB+R\"] = all_readouts(copy.deepcopy(model_s),thresh,'SCRUB+R',seed)\n",
    "    #readouts[\"SCRUB\"] = all_readouts(copy.deepcopy(model_s_final),thresh,'SCRUB',seed)\n",
    "    \n",
    "    \n",
    "    #del model\n",
    "    #del model0\n",
    "    #del model_ft\n",
    "    #del model_ng\n",
    "    #del model_euk\n",
    "    #del model_cfk\n",
    "    #del model_bt\n",
    "    #del model_s\n",
    "    #del model_s_final\n",
    "    #del retain_loader\n",
    "    #del forget_loader\n",
    "    #del test_loader_full\n",
    "    #del train_loader_full\n",
    "\n",
    "    errors.append(readouts)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
