{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_903870/2931265111.py:24: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from tqdm.autonotebook import tqdm\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "import variational\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "from itertools import cycle\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "import pandas as pd\n",
    "from collections import OrderedDict\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "    \n",
    "import copy\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from typing import List\n",
    "import itertools\n",
    "from tqdm.autonotebook import tqdm\n",
    "from models import *\n",
    "import models\n",
    "from logger import *\n",
    "import wandb\n",
    "\n",
    "from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate\n",
    "from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss\n",
    "from thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss\n",
    "\n",
    "from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate\n",
    "from thirdparty.repdistiller.helper.pretrain import init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pdb():\n",
    "    import pdb\n",
    "    pdb.set_trace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parameter_count(model):\n",
    "    count=0\n",
    "    for p in model.parameters():\n",
    "        count+=np.prod(np.array(list(p.shape)))\n",
    "    print(f'Total Number of Parameters: {count}')\n",
    "    return count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vectorize_params(model):\n",
    "    param = []\n",
    "    for p in model.parameters():\n",
    "        param.append(p.data.view(-1).cpu().numpy())\n",
    "    return np.concatenate(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_param_shape(model):\n",
    "    for k,p in model.named_parameters():\n",
    "        print(k,p.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint name: cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1\n",
      "[Logging in cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_training]\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "confuse mode: False\n",
      "split mode: None\n",
      "Number of Classes: 100\n",
      "[0] train metrics:{\"loss\": 4.004363647079468, \"error\": 0.914425}\n",
      "Learning Rate : 0.1\n",
      "[0] dry_run metrics:{\"loss\": 3.6652862831115725, \"error\": 0.865175}\n",
      "Learning Rate : 0.1\n",
      "[0] test metrics:{\"loss\": 3.6771464149475097, \"error\": 0.8661}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 27.17 sec\n",
      "[1] train metrics:{\"loss\": 3.4841514293670652, \"error\": 0.823875}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.86 sec\n",
      "[2] train metrics:{\"loss\": 3.0799445030212405, \"error\": 0.73755}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.08 sec\n",
      "[3] train metrics:{\"loss\": 2.761986903381348, \"error\": 0.657675}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.24 sec\n",
      "[4] train metrics:{\"loss\": 2.560997309112549, \"error\": 0.5965}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.94 sec\n",
      "[5] train metrics:{\"loss\": 2.414651040649414, \"error\": 0.545675}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.99 sec\n",
      "[6] train metrics:{\"loss\": 2.3048879302978516, \"error\": 0.50115}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.69 sec\n",
      "[7] train metrics:{\"loss\": 2.2363017753601073, \"error\": 0.461625}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.87 sec\n",
      "[8] train metrics:{\"loss\": 2.176711745834351, \"error\": 0.429525}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.97 sec\n",
      "[9] train metrics:{\"loss\": 2.1533882205963133, \"error\": 0.40325}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.13 sec\n",
      "[10] train metrics:{\"loss\": 2.1217486545562743, \"error\": 0.375275}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.94 sec\n",
      "[11] train metrics:{\"loss\": 2.1148677562713623, \"error\": 0.3513}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.14 sec\n",
      "[12] train metrics:{\"loss\": 2.1146700733184813, \"error\": 0.333675}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.09 sec\n",
      "[13] train metrics:{\"loss\": 2.129422161102295, \"error\": 0.31615}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.05 sec\n",
      "[14] train metrics:{\"loss\": 2.1356759124755857, \"error\": 0.30355}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.11 sec\n",
      "[15] train metrics:{\"loss\": 2.140548034667969, \"error\": 0.285875}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.82 sec\n",
      "[16] train metrics:{\"loss\": 2.1746923412322996, \"error\": 0.280475}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.99 sec\n",
      "[17] train metrics:{\"loss\": 2.1801952964782716, \"error\": 0.2681}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.98 sec\n",
      "[18] train metrics:{\"loss\": 2.165601744842529, \"error\": 0.250775}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.93 sec\n",
      "[19] train metrics:{\"loss\": 2.164006868362427, \"error\": 0.2464}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.15 sec\n",
      "[20] train metrics:{\"loss\": 2.1943339420318604, \"error\": 0.241625}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.23 sec\n",
      "[21] train metrics:{\"loss\": 2.209365211868286, \"error\": 0.237475}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.95 sec\n",
      "[22] train metrics:{\"loss\": 2.228347407913208, \"error\": 0.237}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.11 sec\n",
      "[23] train metrics:{\"loss\": 2.19557455406189, \"error\": 0.22395}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.05 sec\n",
      "[24] train metrics:{\"loss\": 2.2061470024108885, \"error\": 0.220575}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 15.04 sec\n",
      "[25] train metrics:{\"loss\": 2.2127228797912597, \"error\": 0.220075}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 16.54 sec\n",
      "[26] train metrics:{\"loss\": 2.229200707626343, \"error\": 0.219875}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.98 sec\n",
      "[27] train metrics:{\"loss\": 2.208640114212036, \"error\": 0.210725}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 14.93 sec\n",
      "[28] train metrics:{\"loss\": 2.2325496906280518, \"error\": 0.21615}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 13.6 sec\n",
      "[29] train metrics:{\"loss\": 2.222390404510498, \"error\": 0.209625}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 7.77 sec\n",
      "[30] train metrics:{\"loss\": 2.218616831588745, \"error\": 0.207125}\n",
      "Learning Rate : 0.1\n",
      "Epoch Time: 6.98 sec\n",
      "Pure training time: 448.5200000000001 sec\n"
     ]
    }
   ],
   "source": [
    "%run main.py --dataset cifar100 --dataroot=data/cifar-100-python --model resnet --filters 0.4 --lr 0.1 --lossfn ce --num-classes 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_predictions(model,dataloader,name):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    metrics,activations,predictions=get_metrics(model,dataloader,criterion,128,True)\n",
    "    print(f\"{name} -> Loss:{np.round(metrics['loss'],3)}, Error:{metrics['error']}\")\n",
    "    log_dict[f\"{name}_loss\"]=metrics['loss']\n",
    "    log_dict[f\"{name}_error\"]=metrics['error']\n",
    "\n",
    "    return activations,predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predictions_distance(l1,l2,name):\n",
    "    dist = np.sum(np.abs(l1-l2))\n",
    "    print(f\"Predictions Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_predictions\"]=dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_distance(a1,a2,name):\n",
    "    dist = np.linalg.norm(a1-a2,ord=1,axis=1).mean()\n",
    "    print(f\"Activations Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_activations\"]=dist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Membership Inference Attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import random\n",
    "\n",
    "def cm_score(estimator, X, y):\n",
    "    y_pred = estimator.predict(X)\n",
    "    cnf_matrix = confusion_matrix(y, y_pred)\n",
    "    \n",
    "    FP = cnf_matrix[0][1] \n",
    "    FN = cnf_matrix[1][0] \n",
    "    TP = cnf_matrix[0][0] \n",
    "    TN = cnf_matrix[1][1]\n",
    "\n",
    "\n",
    "    # Sensitivity, hit rate, recall, or true positive rate\n",
    "    TPR = TP/(TP+FN)\n",
    "    # Specificity or true negative rate\n",
    "    TNR = TN/(TN+FP) \n",
    "    # Precision or positive predictive value\n",
    "    PPV = TP/(TP+FP)\n",
    "    # Negative predictive value\n",
    "    NPV = TN/(TN+FN)\n",
    "    # Fall out or false positive rate\n",
    "    FPR = FP/(FP+TN)\n",
    "    # False negative rate\n",
    "    FNR = FN/(TP+FN)\n",
    "    # False discovery rate\n",
    "    FDR = FP/(TP+FP)\n",
    "\n",
    "    # Overall accuracy\n",
    "    ACC = (TP+TN)/(TP+FP+FN+TN)\n",
    "    print (f\"FPR:{FPR:.2f}, FNR:{FNR:.2f}, FP{FP:.2f}, TN{TN:.2f}, TP{TP:.2f}, FN{FN:.2f}\")\n",
    "    return ACC\n",
    "\n",
    "\n",
    "def evaluate_attack_model(sample_loss,\n",
    "                          members,\n",
    "                          n_splits = 5,\n",
    "                          random_state = None):\n",
    "  \"\"\"Computes the cross-validation score of a membership inference attack.\n",
    "  Args:\n",
    "    sample_loss : array_like of shape (n,).\n",
    "      objective function evaluated on n samples.\n",
    "    members : array_like of shape (n,),\n",
    "      whether a sample was used for training.\n",
    "    n_splits: int\n",
    "      number of splits to use in the cross-validation.\n",
    "    random_state: int, RandomState instance or None, default=None\n",
    "      random state to use in cross-validation splitting.\n",
    "  Returns:\n",
    "    score : array_like of size (n_splits,)\n",
    "  \"\"\"\n",
    "\n",
    "  unique_members = np.unique(members)\n",
    "  if not np.all(unique_members == np.array([0, 1])):\n",
    "    raise ValueError(\"members should only have 0 and 1s\")\n",
    "\n",
    "  attack_model = LogisticRegression()\n",
    "  cv = StratifiedShuffleSplit(\n",
    "      n_splits=n_splits, random_state=random_state)\n",
    "  return cross_val_score(attack_model, sample_loss, members, cv=cv, scoring=cm_score)\n",
    "\n",
    "def membership_inference_attack(model, t_loader, f_loader, seed):\n",
    "    import matplotlib.pyplot as plt\n",
    "    import seaborn as sns\n",
    "    \n",
    "\n",
    "    fgt_cls = list(np.unique(f_loader.dataset.targets))\n",
    "    indices = [i in fgt_cls for i in t_loader.dataset.targets]\n",
    "    t_loader.dataset.data = t_loader.dataset.data[indices]\n",
    "    t_loader.dataset.targets = t_loader.dataset.targets[indices]\n",
    "\n",
    "    \n",
    "    cr = nn.CrossEntropyLoss(reduction='none')\n",
    "    test_losses = []\n",
    "    forget_losses = []\n",
    "    model.eval()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    dataloader = torch.utils.data.DataLoader(t_loader.dataset, batch_size=128, shuffle=False)\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*cr(output, target)\n",
    "        test_losses = test_losses + list(loss.cpu().detach().numpy())\n",
    "    del dataloader\n",
    "    dataloader = torch.utils.data.DataLoader(f_loader.dataset, batch_size=128, shuffle=False)\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*cr(output, target)\n",
    "        forget_losses = forget_losses + list(loss.cpu().detach().numpy())\n",
    "    del dataloader\n",
    "\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    if len(forget_losses) > len(test_losses):\n",
    "        forget_losses = list(random.sample(forget_losses, len(test_losses)))\n",
    "    elif len(test_losses) > len(forget_losses):\n",
    "        test_losses = list(random.sample(test_losses, len(forget_losses)))\n",
    "    \n",
    "  \n",
    "    sns.distplot(np.array(test_losses), kde=False, norm_hist=False, rug=False, label='test-loss', ax=plt)\n",
    "    sns.distplot(np.array(forget_losses), kde=False, norm_hist=False, rug=False, label='forget-loss', ax=plt)\n",
    "    plt.legend(prop={'size': 14})\n",
    "    plt.tick_params(labelsize=12)\n",
    "    plt.title(\"loss histograms\",size=18)\n",
    "    plt.xlabel('loss values',size=14)\n",
    "    plt.show()\n",
    "    print (np.max(test_losses), np.min(test_losses))\n",
    "    print (np.max(forget_losses), np.min(forget_losses))\n",
    "\n",
    "\n",
    "    test_labels = [0]*len(test_losses)\n",
    "    forget_labels = [1]*len(forget_losses)\n",
    "    features = np.array(test_losses + forget_losses).reshape(-1,1)\n",
    "    labels = np.array(test_labels + forget_labels).reshape(-1)\n",
    "    features = np.clip(features, -100, 100)\n",
    "    score = evaluate_attack_model(features, labels, n_splits=5, random_state=seed)\n",
    "\n",
    "    return score\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finetune and Fisher Helper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,bs=128,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=128, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        if scrub_act:\n",
    "            G = []\n",
    "            for cls in range(num_classes):\n",
    "                grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                grads = torch.cat([g.view(-1) for g in grads])\n",
    "                G.append(grads)\n",
    "            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "            G = torch.stack(G).pow(2)\n",
    "            delta_f = torch.matmul(G,delta_w)\n",
    "            output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            #activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            activations = activations + list(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def l2_penalty(model,model_init,weight_decay):\n",
    "    l2_loss = 0\n",
    "    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):\n",
    "        if p.requires_grad:\n",
    "            l2_loss += (p-p_init).pow(2).sum()\n",
    "    l2_loss *= (weight_decay/2.)\n",
    "    return l2_loss\n",
    "\n",
    "def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    negative_gradient=False, negative_multiplier=-1, random_labels=False,\n",
    "                    quiet=False,delta_w=None,scrub_act=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, batch in enumerate(tqdm(data_loader, leave=False)):\n",
    "            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]\n",
    "            input, target = batch\n",
    "            output = model(input)\n",
    "            if split=='test' and scrub_act:\n",
    "                G = []\n",
    "                for cls in range(num_classes):\n",
    "                    grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                    grads = torch.cat([g.view(-1) for g in grads])\n",
    "                    G.append(grads)\n",
    "                grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "                G = torch.stack(G).pow(2)\n",
    "                delta_f = torch.matmul(G,delta_w)\n",
    "                output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "            loss = loss_fn(output, target) + l2_penalty(model,model_init,args.weight_decay)\n",
    "            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))\n",
    "            \n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg\n",
    "\n",
    "def run_neggrad_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    forget_loader: torch.utils.data.DataLoader,\n",
    "                    alpha: float,\n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    quiet=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, (batch_retain,batch_forget) in enumerate(tqdm(zip(data_loader,cycle(forget_loader)), leave=False)):\n",
    "            batch_retain = [tensor.to(next(model.parameters()).device) for tensor in batch_retain]\n",
    "            batch_forget = [tensor.to(next(model.parameters()).device) for tensor in batch_forget]\n",
    "            input_r, target_r = batch_retain\n",
    "            input_f, target_f = batch_forget\n",
    "            output_r = model(input_r)\n",
    "            output_f = model(input_f)\n",
    "            loss = alpha*(loss_fn(output_r, target_r) + l2_penalty(model,model_init,args.weight_decay)) - (1-alpha)*loss_fn(output_f, target_f)\n",
    "            metrics.update(n=input_r.size(0), loss=loss_fn(output_r,target_r).item(), error=get_error(output_r, target_r))\n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def finetune(model: nn.Module, data_loader: torch.utils.data.DataLoader, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "\n",
    "def negative_grad(model: nn.Module, data_loader: torch.utils.data.DataLoader, forget_loader: torch.utils.data.DataLoader, alpha: float, lr=0.01, epochs=10, quiet=False, args=None):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        train_negrad(epoch, data_loader, forget_loader, model, loss_fn, optimizer,  alpha, args)\n",
    "\n",
    "def fk_fientune(model: nn.Module, data_loader: torch.utils.data.DataLoader, args, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "        train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "def test(model, data_loader):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    model_init=copy.deepcopy(model)\n",
    "    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', epoch=epoch, ignore_index=None, quiet=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readout_retrain(model, data_loader, test_loader, lr=0.1, epochs=500, threshold=0.01, quiet=True):\n",
    "    torch.manual_seed(seed)\n",
    "    model = copy.deepcopy(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    sampler = torch.utils.data.RandomSampler(data_loader.dataset, replacement=True, num_samples=500)\n",
    "    data_loader_small = torch.utils.data.DataLoader(data_loader.dataset, batch_size=data_loader.batch_size, sampler=sampler, num_workers=data_loader.num_workers)\n",
    "    metrics = []\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        metrics.append(run_train_epoch(model, model_init, test_loader, loss_fn, optimizer, split='test', epoch=epoch, ignore_index=None, quiet=quiet))\n",
    "        if metrics[-1]['loss'] <= threshold:\n",
    "            break\n",
    "        run_train_epoch(model, model_init, data_loader_small, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "    return epoch, metrics\n",
    "\n",
    "def extract_retrain_time(metrics, threshold=0.1):\n",
    "    losses = np.array([m['loss'] for m in metrics])\n",
    "    return np.argmax(losses < threshold)\n",
    "\n",
    "def all_readouts(model,thresh=0.1,name='method', seed=0):\n",
    "    MIA = membership_inference_attack(model, copy.deepcopy(test_loader_full), forget_loader, seed)\n",
    "    #train_loader = torch.utils.data.DataLoader(train_loader_full.dataset, batch_size=128, shuffle=True)\n",
    "    retrain_time, _ = 0,0#readout_retrain(model, train_loader, forget_loader, epochs=100, lr=0.001, threshold=thresh)\n",
    "    test_error = test(model, test_loader_full)['error']*100\n",
    "    forget_error = test(model, forget_loader)['error']*100\n",
    "    retain_error = test(model, retain_loader)['error']*100\n",
    "    val_error = test(model, valid_loader_full)['error']*100\n",
    "    \n",
    "    print(f\"{name} ->\"\n",
    "          f\"\\tFull test error: {test_error:.2f}\"\n",
    "          f\"\\tForget error: {forget_error:.2f}\\tRetain error: {retain_error:.2f}\\tValid error: {val_error:.2f}\"\n",
    "          f\"\\tFine-tune time: {retrain_time+1} steps\\tMIA: {np.mean(MIA):.2f}±{np.std(MIA):0.1f}\")\n",
    "    \n",
    "    return(dict(test_error=test_error, forget_error=forget_error, retain_error=retain_error, val_error=val_error, retrain_time=retrain_time+1, MIA=np.mean(MIA)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scrub(teacher, student):\n",
    "    args.optim = 'sgd'\n",
    "    args.gamma = 1\n",
    "    args.alpha = 0.5\n",
    "    args.beta = 0\n",
    "    args.smoothing = 0.5\n",
    "    args.msteps = 3\n",
    "    args.clip = 0.2\n",
    "    args.sstart = 10\n",
    "    args.kd_T = 4\n",
    "    args.distill = 'kd'\n",
    "\n",
    "    args.sgda_epochs = 5\n",
    "    args.sgda_learning_rate = 0.0005\n",
    "    args.lr_decay_epochs = [3,5,9]\n",
    "    args.lr_decay_rate = 0.1\n",
    "    args.sgda_weight_decay = 5e-4\n",
    "    args.sgda_momentum = 0.9\n",
    "\n",
    "    model_t = copy.deepcopy(teacher)\n",
    "    model_s = copy.deepcopy(student)\n",
    "\n",
    "    #this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py\n",
    "    #For SGDA smoothing\n",
    "    beta = 0.1\n",
    "    def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (\n",
    "        1 - beta) * averaged_model_parameter + beta * model_parameter\n",
    "    swa_model = torch.optim.swa_utils.AveragedModel(\n",
    "        model_s, avg_fn=avg_fn)\n",
    "\n",
    "    module_list = nn.ModuleList([])\n",
    "    module_list.append(model_s)\n",
    "    trainable_list = nn.ModuleList([])\n",
    "    trainable_list.append(model_s)\n",
    "\n",
    "    criterion_cls = nn.CrossEntropyLoss()\n",
    "    criterion_div = DistillKL(args.kd_T)\n",
    "    criterion_kd = DistillKL(args.kd_T)\n",
    "\n",
    "\n",
    "    criterion_list = nn.ModuleList([])\n",
    "    criterion_list.append(criterion_cls)    # classification loss\n",
    "    criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation\n",
    "    criterion_list.append(criterion_kd)     # other knowledge distillation loss\n",
    "\n",
    "    # optimizer\n",
    "    if args.optim == \"sgd\":\n",
    "        optimizer = optim.SGD(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              momentum=args.sgda_momentum,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "    elif args.optim == \"adam\": \n",
    "        optimizer = optim.Adam(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "    elif args.optim == \"rmsp\":\n",
    "        optimizer = optim.RMSprop(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              momentum=args.sgda_momentum,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "\n",
    "    module_list.append(model_t)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        module_list.cuda()\n",
    "        criterion_list.cuda()\n",
    "        import torch.backends.cudnn as cudnn\n",
    "        cudnn.benchmark = True\n",
    "        swa_model.cuda()\n",
    "\n",
    "\n",
    "    t1 = time.time()\n",
    "    acc_rs = []\n",
    "    acc_fs = []\n",
    "    acc_vs = []\n",
    "    acc_fvs = []\n",
    "    \n",
    "    \n",
    "    forget_validation_loader = copy.deepcopy(valid_loader_full)\n",
    "    fgt_cls = list(np.unique(forget_loader.dataset.targets))\n",
    "    indices = [i in fgt_cls for i in forget_validation_loader.dataset.targets]\n",
    "    forget_validation_loader.dataset.data = forget_validation_loader.dataset.data[indices]\n",
    "    forget_validation_loader.dataset.targets = forget_validation_loader.dataset.targets[indices]\n",
    "    \n",
    "    scrub_name = \"checkpoints/scrub_{}_{}_seed{}_step\".format(args.model, args.dataset, args.seed)\n",
    "    for epoch in range(1, args.sgda_epochs + 1):\n",
    "\n",
    "        lr = sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "\n",
    "        acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "        acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "        acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "        acc_fv, acc5_fv, loss_fv = validate(forget_validation_loader, model_s, criterion_cls, args, True)\n",
    "        acc_rs.append(100-acc_r.item())\n",
    "        acc_fs.append(100-acc_f.item())\n",
    "        acc_vs.append(100-acc_v.item())\n",
    "        acc_fvs.append(100-acc_fv.item())\n",
    "\n",
    "        maximize_loss = 0\n",
    "        if epoch <= args.msteps:\n",
    "            maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, \"maximize\")\n",
    "        train_acc, train_loss = train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, \"minimize\",)\n",
    "        if epoch >= args.sstart:\n",
    "            swa_model.update_parameters(model_s)\n",
    "        \n",
    "        torch.save(model_s.state_dict(), scrub_name+str(epoch)+\".pt\")\n",
    "\n",
    "\n",
    "        print (\"maximize loss: {:.2f}\\t minimize loss: {:.2f}\\t train_acc: {}\".format(maximize_loss, train_loss, train_acc))\n",
    "    t2 = time.time()\n",
    "    print (t2-t1)\n",
    "\n",
    "    acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "    acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "    acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "    acc_fv, acc5_fv, loss_fv = validate(forget_validation_loader, model_s, criterion_cls, args, True)\n",
    "    acc_rs.append(100-acc_r.item())\n",
    "    acc_fs.append(100-acc_f.item())\n",
    "    acc_vs.append(100-acc_v.item())\n",
    "    acc_fvs.append(100-acc_fv.item())\n",
    "\n",
    "    from matplotlib import pyplot as plt\n",
    "    indices = list(range(0,len(acc_rs)))\n",
    "    plt.plot(indices, acc_rs, marker='*', color=u'#1f77b4', alpha=1, label='retain-set')\n",
    "    plt.plot(indices, acc_fs, marker='o', color=u'#ff7f0e', alpha=1, label='forget-set')\n",
    "    plt.plot(indices, acc_vs, marker='^', color=u'#2ca02c',alpha=1, label='validation-set')\n",
    "    plt.plot(indices, acc_fvs, marker='.', color='red',alpha=1, label='forget-validation-set')\n",
    "    plt.legend(prop={'size': 14})\n",
    "    plt.tick_params(labelsize=12)\n",
    "    plt.xlabel('epoch',size=14)\n",
    "    plt.ylabel('error',size=14)\n",
    "    plt.grid()\n",
    "    plt.show()\n",
    "    \n",
    "    \n",
    "    try:\n",
    "        selected_idx, _ = min(enumerate(acc_fs), key=lambda x: abs(x[1]-acc_fvs[-1]))\n",
    "    except:\n",
    "        selected_idx = len(acc_fs) - 1\n",
    "    print (\"the selected index is {}\".format(selected_idx))\n",
    "    selected_model = \"checkpoints/scrub_{}_{}_seed{}_step{}.pt\".format(args.model, args.dataset, args.seed, int(selected_idx))\n",
    "    model_s_final = copy.deepcopy(model_s)\n",
    "    model_s.load_state_dict(torch.load(selected_model))\n",
    "    \n",
    "    \n",
    "    return model_s, model_s_final\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def replace_loader_dataset(data_loader, dataset, batch_size=128, seed=1, shuffle=True):\n",
    "    manual_seed(seed)\n",
    "    loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "    def _init_fn(worker_id):\n",
    "        np.random.seed(int(seed))\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cfk_unlearn(model):\n",
    "    args.lr_decay_epochs = [10,15,20]\n",
    "    args.cfk_lr = 0.01\n",
    "    args.cfk_epochs = 10\n",
    "    args.cfk_bs = 64\n",
    "    r_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.cfk_bs, shuffle=True)\n",
    "\n",
    "    model_cfk = copy.deepcopy(model)\n",
    "\n",
    "    for param in model_cfk.parameters():\n",
    "        param.requires_grad_(False)\n",
    "\n",
    "    if args.model == 'allcnn':\n",
    "        layers = [9]\n",
    "        for k in layers:\n",
    "            for param in model_cfk.features[k].parameters():\n",
    "                param.requires_grad_(True)\n",
    "\n",
    "    elif args.model == \"resnet\":\n",
    "        for param in model_cfk.layer4.parameters():\n",
    "            param.requires_grad_(True)\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    fk_fientune(model_cfk, r_loader, args=args, epochs=args.cfk_epochs, quiet=True, lr=args.cfk_lr)\n",
    "    return model_cfk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def euk_unlearn(model, model_initial):\n",
    "    args.lr_decay_epochs = [10,15,20]\n",
    "    args.euk_lr = 0.01\n",
    "    args.euk_epochs = training_epochs\n",
    "    args.euk_bs = 64\n",
    "    r_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.euk_bs, shuffle=True)\n",
    "    model_euk = copy.deepcopy(model)\n",
    "\n",
    "    for param in model_euk.parameters():\n",
    "        param.requires_grad_(False)\n",
    "\n",
    "    if args.model == 'allcnn':\n",
    "        with torch.no_grad():\n",
    "            for k in layers:\n",
    "                for i in range(0,3):\n",
    "                    try:\n",
    "                        model_euk.features[k][i].weight.copy_(model_initial.features[k][i].weight)\n",
    "                    except:\n",
    "                        print (\"block {}, layer {} does not have weights\".format(k,i))\n",
    "                    try:\n",
    "                        model_euk.features[k][i].bias.copy_(model_initial.features[k][i].bias)\n",
    "                    except:\n",
    "                        print (\"block {}, layer {} does not have bias\".format(k,i))\n",
    "            model_euk.classifier[0].weight.copy_(model_initial.classifier[0].weight)\n",
    "            model_euk.classifier[0].bias.copy_(model_initial.classifier[0].bias)\n",
    "\n",
    "        for k in layers:\n",
    "            for param in model_euk.features[k].parameters():\n",
    "                param.requires_grad_(True)\n",
    "\n",
    "    elif args.model == \"resnet\":\n",
    "        with torch.no_grad():\n",
    "            for i in range(0,2):\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn1.weight.copy_(model_initial.layer4[i].bn1.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn1.bias.copy_(model_initial.layer4[i].bn1.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv1.weight.copy_(model_initial.layer4[i].conv1.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv1.bias.copy_(model_initial.layer4[i].conv1.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn2.weight.copy_(model_initial.layer4[i].bn2.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].bn2.bias.copy_(model_initial.layer4[i].bn2.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv2.weight.copy_(model_initial.layer4[i].conv2.weight)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have weight\".format(i))\n",
    "                try:\n",
    "                    model_euk.layer4[i].conv2.bias.copy_(model_initial.layer4[i].conv2.bias)\n",
    "                except:\n",
    "                    print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "            model_euk.layer4[0].shortcut[0].weight.copy_(model_initial.layer4[0].shortcut[0].weight)\n",
    "\n",
    "        for param in model_euk.layer4.parameters():\n",
    "            param.requires_grad_(True)\n",
    "\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    fk_fientune(model_euk, r_loader, epochs=args.euk_epochs, quiet=True, lr=args.euk_lr, args=args)\n",
    "    return model_euk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint name: cifar10_resnet_1_0_forget_None_lr_0_01_bs_128_ls_ce_wd_0_0005_seed_1\n",
      "[Logging in cifar10_resnet_1_0_forget_None_lr_0_01_bs_128_ls_ce_wd_0_0005_seed_1_training]\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "confuse mode: False\n",
      "split mode: train\n",
      "Number of Classes: 10\n",
      "Epoch: [0][0/313]\tTime 0.267 (0.267)\tData 0.045 (0.045)\tLoss 2.4113 (2.4113)\tAcc@1 10.156 (10.156)\tAcc@5 57.031 (57.031)\n",
      " * Acc@1 68.905 Acc@5 97.155\n",
      "[0] test metrics:{\"loss\": 0.6794950475692749, \"error\": 0.2376}\n",
      "Learning Rate : 0.01\n",
      "Epoch Time: 6.69 sec\n",
      "Pure training time: 5.5 sec\n",
      "Checkpoint name: cifar10_resnet_1_0_forget_[5]_num_100_lr_0_01_bs_128_ls_ce_wd_0_0005_seed_1\n",
      "[Logging in cifar10_resnet_1_0_forget_[5]_num_100_lr_0_01_bs_128_ls_ce_wd_0_0005_seed_1_training]\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "confuse mode: False\n",
      "split mode: train\n",
      "Replacing indexes [22644  6987 36335 23839  1559  5648 18421  4310 12085 10658 22192  2404\n",
      " 28077 23062 33361 18035 28713 24244 36749 30401 34869 14428  9997 10500\n",
      "  9429  5670 31993 23103 36858 37665 18393   364 10553 33173 11641 18201\n",
      "  5622 34816  9309 26329 22289 20137 25274 32709 29643 36973 38557 16388\n",
      " 26011 18122 13663  3617 33612 19687 13240 11287  6150 38579 21265 25977\n",
      "  2308 39186 16293 31764  6557 31314  9730  5069 15650 14230 20863 17632\n",
      "  3979 30893 11810 32154 25290 37680  9940 31707 12928 15234 22597 35439\n",
      " 33722 14472  9259 15288  1960  3135 22250 26104 26497 32327   386 19083\n",
      " 37138 13509 39902  7136]\n",
      "Number of Classes: 10\n",
      "Epoch: [0][0/313]\tTime 0.027 (0.027)\tData 0.015 (0.015)\tLoss 2.4173 (2.4173)\tAcc@1 10.938 (10.938)\tAcc@5 56.250 (56.250)\n",
      " * Acc@1 68.917 Acc@5 97.175\n",
      "[0] test metrics:{\"loss\": 0.6776201425552368, \"error\": 0.2363}\n",
      "Learning Rate : 0.01\n",
      "Epoch Time: 6.7 sec\n",
      "Pure training time: 5.52 sec\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'checkpoints/cifar10_resnet_1_0_forget_[5]_num_100_lr_0_01_bs_128_ls_ce_wd_0_0005_seed_1_25.pt'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/home/zihao/data/SCRUB-main/MIA_experiments.ipynb Cell 23\u001b[0m line \u001b[0;36m5\n\u001b[1;32m     <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X31sdnNjb2RlLXJlbW90ZQ%3D%3D?line=47'>48</a>\u001b[0m m0_name \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mcheckpoints/\u001b[39m\u001b[39m{\u001b[39;00mdataset\u001b[39m}\u001b[39;00m\u001b[39m_\u001b[39m\u001b[39m{\u001b[39;00march_filters\u001b[39m}\u001b[39;00m\u001b[39m_forget_\u001b[39m\u001b[39m{\u001b[39;00mclass_to_forget\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00mnum_tag\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00munfreeze_tag\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00maugment_tag\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00mlearningrate\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00mbatch_size\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00mlossfn\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00mwd\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00mseed_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00mtraining_epochs\u001b[39m}\u001b[39;00m\u001b[39m.pt\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m     <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X31sdnNjb2RlLXJlbW90ZQ%3D%3D?line=50'>51</a>\u001b[0m model\u001b[39m.\u001b[39mload_state_dict(torch\u001b[39m.\u001b[39mload(m_name))\n\u001b[0;32m---> <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X31sdnNjb2RlLXJlbW90ZQ%3D%3D?line=51'>52</a>\u001b[0m model0\u001b[39m.\u001b[39mload_state_dict(torch\u001b[39m.\u001b[39;49mload(m0_name))\n\u001b[1;32m     <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X31sdnNjb2RlLXJlbW90ZQ%3D%3D?line=52'>53</a>\u001b[0m model_initial\u001b[39m.\u001b[39mload_state_dict(torch\u001b[39m.\u001b[39mload(init_checkpoint))\n\u001b[1;32m     <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X31sdnNjb2RlLXJlbW90ZQ%3D%3D?line=54'>55</a>\u001b[0m teacher \u001b[39m=\u001b[39m copy\u001b[39m.\u001b[39mdeepcopy(model)\n",
      "File \u001b[0;32m~/anaconda3/envs/zihao/lib/python3.10/site-packages/torch/serialization.py:986\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m    983\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mencoding\u001b[39m\u001b[39m'\u001b[39m \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m pickle_load_args\u001b[39m.\u001b[39mkeys():\n\u001b[1;32m    984\u001b[0m     pickle_load_args[\u001b[39m'\u001b[39m\u001b[39mencoding\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m--> 986\u001b[0m \u001b[39mwith\u001b[39;00m _open_file_like(f, \u001b[39m'\u001b[39;49m\u001b[39mrb\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39mas\u001b[39;00m opened_file:\n\u001b[1;32m    987\u001b[0m     \u001b[39mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m    988\u001b[0m         \u001b[39m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m    989\u001b[0m         \u001b[39m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m    990\u001b[0m         \u001b[39m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m    991\u001b[0m         orig_position \u001b[39m=\u001b[39m opened_file\u001b[39m.\u001b[39mtell()\n",
      "File \u001b[0;32m~/anaconda3/envs/zihao/lib/python3.10/site-packages/torch/serialization.py:435\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m    433\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m    434\u001b[0m     \u001b[39mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> 435\u001b[0m         \u001b[39mreturn\u001b[39;00m _open_file(name_or_buffer, mode)\n\u001b[1;32m    436\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    437\u001b[0m         \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mw\u001b[39m\u001b[39m'\u001b[39m \u001b[39min\u001b[39;00m mode:\n",
      "File \u001b[0;32m~/anaconda3/envs/zihao/lib/python3.10/site-packages/torch/serialization.py:416\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m    415\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, name, mode):\n\u001b[0;32m--> 416\u001b[0m     \u001b[39msuper\u001b[39m()\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mopen\u001b[39;49m(name, mode))\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'checkpoints/cifar10_resnet_1_0_forget_[5]_num_100_lr_0_01_bs_128_ls_ce_wd_0_0005_seed_1_25.pt'"
     ]
    }
   ],
   "source": [
    "seeds = [1,2,3]\n",
    "chkpt = \"checkpoints/cifar100_resnet_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt\"\n",
    "forget_class = '5'\n",
    "forget_num = 100\n",
    "dataset = 'cifar10'\n",
    "modelname = 'resnet'\n",
    "dataroot = 'data/cifar10'\n",
    "filters = 1.0\n",
    "\n",
    "errors = []\n",
    "for s in seeds:\n",
    "\n",
    "    %run main_merged.py --dataset $dataset --model $modelname --dataroot=$dataroot --filters $filters --lr 0.01 \\\n",
    "    --resume $chkpt --disable-bn --weight-decay 0.0005 --batch-size 128 --epochs 1 --seed $s\n",
    "\n",
    "    %run main_merged.py --dataset $dataset --model $modelname --dataroot=$dataroot --filters $filters --lr 0.01 \\\n",
    "    --resume $chkpt --disable-bn --weight-decay 0.0005 --batch-size 128 --epochs 1 \\\n",
    "    --forget-class $forget_class --num-to-forget $forget_num --seed $s\n",
    "\n",
    "\n",
    "    log_dict={}\n",
    "    training_epochs=25\n",
    "    model0 = copy.deepcopy(model)\n",
    "    model_initial = copy.deepcopy(model)\n",
    "\n",
    "    arch = args.model \n",
    "    filters=args.filters\n",
    "    arch_filters = arch +'_'+ str(filters).replace('.','_')\n",
    "    augment = False\n",
    "    dataset = args.dataset\n",
    "    class_to_forget = args.forget_class\n",
    "    init_checkpoint = f\"checkpoints/{args.name}_init.pt\"\n",
    "    num_classes=args.num_classes\n",
    "    num_to_forget = args.num_to_forget\n",
    "    num_total = len(train_loader.dataset)\n",
    "    num_to_retain = num_total - forget_num\n",
    "    seed = args.seed\n",
    "    unfreeze_start = None\n",
    "\n",
    "    learningrate=f\"lr_{str(args.lr).replace('.','_')}\"\n",
    "    batch_size=f\"_bs_{str(args.batch_size)}\"\n",
    "    lossfn=f\"_ls_{args.lossfn}\"\n",
    "    wd=f\"_wd_{str(args.weight_decay).replace('.','_')}\"\n",
    "    seed_name=f\"_seed_{args.seed}_\"\n",
    "\n",
    "    num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'\n",
    "    unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'\n",
    "    augment_tag = '' if not augment else f'augment_'\n",
    "\n",
    "    m_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "    m0_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "    \n",
    "\n",
    "    model.load_state_dict(torch.load(m_name))\n",
    "    model0.load_state_dict(torch.load(m0_name))\n",
    "    model_initial.load_state_dict(torch.load(init_checkpoint))\n",
    "\n",
    "    teacher = copy.deepcopy(model)\n",
    "    student = copy.deepcopy(model)\n",
    "\n",
    "    model.cuda()\n",
    "    model0.cuda()\n",
    "\n",
    "\n",
    "    for p in model.parameters():\n",
    "        p.data0 = p.data.clone()\n",
    "    for p in model0.parameters():\n",
    "        p.data0 = p.data.clone()\n",
    "    \n",
    "    log_dict={}\n",
    "    log_dict['args']=args\n",
    "    args.retain_bs = 32\n",
    "    args.forget_bs = 16\n",
    "\n",
    "    train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(args.dataset, batch_size=args.batch_size, seed=s, root=args.dataroot, augment=False, shuffle=True)\n",
    "    marked_loader, _, _ = datasets.get_loaders(args.dataset, class_to_replace=args.forget_class, num_indexes_to_replace=args.num_to_forget, only_mark=True, batch_size=1, seed=s, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "    forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "    marked = forget_dataset.targets < 0\n",
    "    forget_dataset.data = forget_dataset.data[marked]\n",
    "    forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "    #forget_loader = torch.utils.data.DataLoader(forget_dataset, batch_size=args.forget_bs,num_workers=0,pin_memory=True,shuffle=True)\n",
    "    forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)\n",
    "\n",
    "    retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "    marked = retain_dataset.targets >= 0\n",
    "    retain_dataset.data = retain_dataset.data[marked]\n",
    "    retain_dataset.targets = retain_dataset.targets[marked]\n",
    "    #retain_loader = torch.utils.data.DataLoader(retain_dataset, batch_size=args.retain_bs,num_workers=0,pin_memory=True,shuffle=True)\n",
    "    retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)\n",
    "\n",
    "    assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))\n",
    "\n",
    "    \n",
    "    model_ft = copy.deepcopy(model)\n",
    "    model_ng = copy.deepcopy(model)\n",
    "    args.ft_lr = 0.04\n",
    "    args.ft_epochs = 10\n",
    "    args.ng_alpha = 0.9999\n",
    "    args.ng_epochs = 5\n",
    "    args.ng_lr = 0.01\n",
    "    \n",
    "\n",
    "    \n",
    "    print (\"Forgetting by Fine-tuneing:\")\n",
    "    finetune(model_ft, retain_loader, epochs=args.ft_epochs, quiet=True, lr=args.ft_lr)\n",
    "    print (\"Forgetting by NegGrad:\")\n",
    "    negative_grad(model_ng, retain_loader, forget_loader, alpha=args.ng_alpha, epochs=args.ng_epochs, quiet=True, lr=args.ng_lr, args=args)\n",
    "    print (\"Forgetting by CFK:\")\n",
    "    model_cfk = cfk_unlearn(model)\n",
    "    print (\"Forgetting by EUK:\")\n",
    "    model_euk = euk_unlearn(model, model_initial)\n",
    "    print(\"Forgetting by SCRUB:\")\n",
    "    model_s, model_s_final = scrub(teacher, student)\n",
    "\n",
    "            \n",
    "    readouts = {}\n",
    "    #_,_=activations_predictions(copy.deepcopy(model),forget_loader,'Original_Model_D_f')\n",
    "    thresh=0#og_dict['Original_Model_D_f_loss']+1e-5\n",
    "    readouts[\"Original\"] = all_readouts(copy.deepcopy(model),thresh,'Original',seed)\n",
    "    readouts[\"Retrain\"] = all_readouts(copy.deepcopy(model0),thresh,'Retrain',seed)\n",
    "    readouts['Finetune'] = all_readouts(copy.deepcopy(model_ft),thresh,'Finetune',seed)\n",
    "    readouts[\"NegGrad\"] = all_readouts(copy.deepcopy(model_ng),thresh,'NegGrad',seed)\n",
    "    readouts[\"CFK\"] = all_readouts(copy.deepcopy(model_cfk),thresh,'CFK',seed)\n",
    "    readouts[\"EUK\"] = all_readouts(copy.deepcopy(model_euk),thresh,'EUK',seed)\n",
    "    readouts[\"SCRUB+R\"] = all_readouts(copy.deepcopy(model_s),thresh,'SCRUB+R',seed)\n",
    "    readouts[\"SCRUB\"] = all_readouts(copy.deepcopy(model_s_final),thresh,'SCRUB',seed)\n",
    "    \n",
    "    \n",
    "    del model\n",
    "    del model0\n",
    "    del model_ft\n",
    "    del model_ng\n",
    "    del model_euk\n",
    "    del model_cfk\n",
    "    del model_s\n",
    "    del model_s_final\n",
    "    del retain_loader\n",
    "    del forget_loader\n",
    "    del test_loader_full\n",
    "    del train_loader_full\n",
    "\n",
    "    errors.append(readouts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "list index out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[1;32m/home/zihao/data/SCRUB-main/MIA_experiments.ipynb Cell 24\u001b[0m line \u001b[0;36m7\n\u001b[1;32m      <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m rlt \u001b[39m=\u001b[39m {}\n\u001b[1;32m      <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a>\u001b[0m MIA \u001b[39m=\u001b[39m {}\n\u001b[0;32m----> <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mfor\u001b[39;00m key \u001b[39min\u001b[39;00m errors[\u001b[39m0\u001b[39;49m]\u001b[39m.\u001b[39mkeys():\n\u001b[1;32m      <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a>\u001b[0m     tes[key] \u001b[39m=\u001b[39m [errors[i][key][\u001b[39m'\u001b[39m\u001b[39mtest_error\u001b[39m\u001b[39m'\u001b[39m] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39mlen\u001b[39m(errors))]\n\u001b[1;32m      <a href='vscode-notebook-cell://wsl%2Bubuntu-20.04/home/zihao/data/SCRUB-main/MIA_experiments.ipynb#X32sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>\u001b[0m     res[key] \u001b[39m=\u001b[39m [errors[i][key][\u001b[39m'\u001b[39m\u001b[39mretain_error\u001b[39m\u001b[39m'\u001b[39m] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39mlen\u001b[39m(errors))]\n",
      "\u001b[0;31mIndexError\u001b[0m: list index out of range"
     ]
    }
   ],
   "source": [
    "tes = {}\n",
    "res = {}\n",
    "fes = {}\n",
    "ves = {}\n",
    "rlt = {}\n",
    "MIA = {}\n",
    "for key in errors[0].keys():\n",
    "    tes[key] = [errors[i][key]['test_error'] for i in range(len(errors))]\n",
    "    res[key] = [errors[i][key]['retain_error'] for i in range(len(errors))]\n",
    "    fes[key] = [errors[i][key]['forget_error'] for i in range(len(errors))]\n",
    "    ves[key] = [errors[i][key]['val_error'] for i in range(len(errors))]\n",
    "    rlt[key] = [errors[i][key]['retrain_time'] for i in range(len(errors))]\n",
    "    MIA[key] = [errors[i][key]['MIA']*100 for i in range(len(errors))]\n",
    "    \n",
    "    print (\"{}  \\t{:.2f}±{:.2f}\\t{:.2f}±{:.2f}\\t{:.2f}±{:.2f}\\t{:.2f}±{:.2f}\".format(key, \n",
    "                                                                np.mean(tes[key]), np.std(tes[key]),\n",
    "                                                                np.mean(fes[key]), np.std(fes[key]),\n",
    "                                                                np.mean(res[key]), np.std(res[key]),\n",
    "                                                                np.mean(MIA[key]), np.std(MIA[key])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
