{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "import variational\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "import pandas as pd\n",
    "from collections import OrderedDict\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "    \n",
    "import copy\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from typing import List\n",
    "import itertools\n",
    "from tqdm.autonotebook import tqdm\n",
    "from models import *\n",
    "import models\n",
    "from logger import *\n",
    "import wandb\n",
    "\n",
    "from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate\n",
    "from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss\n",
    "from thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss\n",
    "\n",
    "from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate\n",
    "from thirdparty.repdistiller.helper.pretrain import init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pdb():\n",
    "    import pdb\n",
    "    pdb.set_trace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parameter_count(model):\n",
    "    count=0\n",
    "    for p in model.parameters():\n",
    "        count+=np.prod(np.array(list(p.shape)))\n",
    "    print(f'Total Number of Parameters: {count}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vectorize_params(model):\n",
    "    param = []\n",
    "    for p in model.parameters():\n",
    "        param.append(p.data.view(-1).cpu().numpy())\n",
    "    return np.concatenate(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_param_shape(model):\n",
    "    for k,p in model.named_parameters():\n",
    "        print(k,p.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Activations along the path\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import accuracy_score\n",
    "import datasets\n",
    "\n",
    "def perf_measure(y_actual, y_hat):\n",
    "    TP = 0\n",
    "    FP = 0\n",
    "    TN = 0\n",
    "    FN = 0\n",
    "\n",
    "    for i in range(len(y_hat)): \n",
    "        if y_actual[i]==y_hat[i]==1:\n",
    "           TP += 1\n",
    "        if y_hat[i]==1 and y_actual[i]!=y_hat[i]:\n",
    "           FP += 1\n",
    "        if y_actual[i]==y_hat[i]==0:\n",
    "           TN += 1\n",
    "        if y_hat[i]==0 and y_actual[i]!=y_hat[i]:\n",
    "           FN += 1\n",
    "\n",
    "    return TP, FP, TN, FN\n",
    "\n",
    "def entropy(p, dim = -1, keepdim = False):\n",
    "    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)\n",
    "    #return (p * p.log()).sum(dim=dim, keepdim=keepdim)\n",
    "\n",
    "def collect_prob(data_loader, model):\n",
    "    data_loader = torch.utils.data.DataLoader(data_loader.dataset, batch_size=1, shuffle=False)\n",
    "    sampler = torch.utils.data.RandomSampler(data_loader.dataset, replacement=False, num_samples=4000)\n",
    "    data_loader_small = torch.utils.data.DataLoader(data_loader.dataset, batch_size=1, sampler=sampler, shuffle=False)\n",
    "    prob = []\n",
    "    with torch.no_grad():\n",
    "        for idx, batch in enumerate(tqdm(data_loader_small, leave=False)):\n",
    "            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]\n",
    "            data, target = batch\n",
    "            output = model(data)\n",
    "            prob.append(F.softmax(output, dim=-1).data)\n",
    "    return torch.cat(prob)\n",
    "\n",
    "def get_membership_attack_data(retain_loader, forget_loader, test_loader, model):    \n",
    "    retain_prob = collect_prob(retain_loader, model)\n",
    "    forget_prob = collect_prob(forget_loader, model)\n",
    "    test_prob = collect_prob(test_loader, model)\n",
    "    \n",
    "    X_r = torch.cat([entropy(retain_prob), entropy(test_prob)]).cpu().numpy().reshape(-1, 1)\n",
    "    Y_r = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])\n",
    "    \n",
    "    X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)\n",
    "    Y_f = np.concatenate([np.ones(len(forget_prob))])    \n",
    "    return X_f, Y_f, X_r, Y_r\n",
    "\n",
    "def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model):\n",
    "    X_f, Y_f, X_r, Y_r = get_membership_attack_data(retain_loader, forget_loader, test_loader, model)\n",
    "    #clf = SVC(C=3,gamma='auto',kernel='rbf')\n",
    "    clf = LogisticRegression(class_weight='balanced',solver='lbfgs',multi_class='multinomial')\n",
    "    clf.fit(X_r, Y_r)\n",
    "    results = clf.predict(X_f)\n",
    "    results1 = clf.predict(X_r)\n",
    "    acc = accuracy_score(results, Y_f)\n",
    "    train_ac = accuracy_score(results1, Y_r)\n",
    "    TP, FP, TN, FN = perf_measure(Y_r, results1)\n",
    "    FPR = FP/(FP+TN)\n",
    "    FNR = FN/(FN+TP)\n",
    "    \n",
    "    print (f\"TP:{TP}, FP{FP}, TN{TN}, FN{FN}\")\n",
    "    print (f\"false negative rate: {FN/(FN+TP)}\")\n",
    "    print (f\"false positive rate: {FP/(FP+TN)}\")\n",
    "    return acc, train_acc, FPR, FNR #results.mean(), results1.mean()\n",
    "    \n",
    "def plot_entropy_dist(model,retain_loader,forget_loader,test_loader_full, title, ax):\n",
    "    import matplotlib.pyplot as plt\n",
    "    import seaborn as sns\n",
    "    from matplotlib.ticker import FuncFormatter\n",
    "    #train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(dataset, batch_size=100, seed=0, augment=False, shuffle=False)\n",
    "    #indexes = np.flatnonzero(np.array(train_loader_full.dataset.targets) == class_to_forget)\n",
    "    #replaced = np.random.RandomState(0).choice(indexes, size=100 if num_to_forget==100 else len(indexes), replace=False)\n",
    "    X_f, Y_f, X_r, Y_r = get_membership_attack_data(retain_loader,forget_loader,test_loader_full, model)\n",
    "    #np.savetxt('retain_prob.txt', X_r[Y_r==1])\n",
    "    #np.savetxt('forget_prob.txt', X_f)\n",
    "    #np.savetxt('test_prob.txt', X_r[Y_r==0])\n",
    "    sns.distplot(np.log(X_r[Y_r==1]).reshape(-1), kde=False, norm_hist=True, rug=False, label='retain', ax=plt)\n",
    "    sns.distplot(np.log(X_r[Y_r==0]).reshape(-1), kde=False, norm_hist=True, rug=False, label='test', ax=plt)\n",
    "    sns.distplot(np.log(X_f).reshape(-1), kde=False, norm_hist=True, rug=False, label='forget', ax=plt)\n",
    "    plt.legend(prop={'size': 14})\n",
    "    plt.tick_params(labelsize=12)\n",
    "    plt.title(title,size=18)\n",
    "    plt.xlabel('Log of Entropy',size=14)\n",
    "    plt.show()\n",
    "    plt.clf()\n",
    "\n",
    "\n",
    "def membership_attack(retain_loader,forget_loader,test_loader,model, name):\n",
    "    prob, train_acc, FPR, FNR = get_membership_attack_prob(retain_loader,forget_loader,test_loader,model)\n",
    "    print(\"Attack prob: \", prob)\n",
    "    return prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=512, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=512, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_predictions(model,dataloader,name):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    metrics=get_metrics(model,dataloader,criterion,False)\n",
    "    print(f\"{name} -> Loss:{np.round(metrics['loss'],3)}, Error:{metrics['error']}\")\n",
    "    log_dict[f\"{name}_loss\"]=metrics['loss']\n",
    "    log_dict[f\"{name}_error\"]=metrics['error']\n",
    "    return _,_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predictions_distance(l1,l2,name):\n",
    "    dist = np.sum(np.abs(l1-l2))\n",
    "    print(f\"Predictions Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_predictions\"]=dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_distance(a1,a2,name):\n",
    "    dist = np.linalg.norm(a1-a2,ord=1,axis=1).mean()\n",
    "    print(f\"Activations Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_activations\"]=dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        if scrub_act:\n",
    "            G = []\n",
    "            for cls in range(num_classes):\n",
    "                grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                grads = torch.cat([g.view(-1) for g in grads])\n",
    "                G.append(grads)\n",
    "            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "            G = torch.stack(G).pow(2)\n",
    "            delta_f = torch.matmul(G,delta_w)\n",
    "            output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def l2_penalty(model,model_init,weight_decay):\n",
    "    l2_loss = 0\n",
    "    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):\n",
    "        if p.requires_grad:\n",
    "            l2_loss += (p-p_init).pow(2).sum()\n",
    "    l2_loss *= (weight_decay/2.)\n",
    "    return l2_loss\n",
    "\n",
    "def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    negative_gradient=False, negative_multiplier=-1, random_labels=False,\n",
    "                    quiet=False,delta_w=None,scrub_act=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, batch in enumerate(tqdm(data_loader, leave=False)):\n",
    "            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]\n",
    "            input, target = batch\n",
    "            output = model(input)\n",
    "            if split=='test' and scrub_act:\n",
    "                G = []\n",
    "                for cls in range(num_classes):\n",
    "                    grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                    grads = torch.cat([g.view(-1) for g in grads])\n",
    "                    G.append(grads)\n",
    "                grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "                G = torch.stack(G).pow(2)\n",
    "                delta_f = torch.matmul(G,delta_w)\n",
    "                output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "            loss = loss_fn(output, target) + l2_penalty(model,model_init,args.weight_decay)\n",
    "            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))\n",
    "            \n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def finetune(model: nn.Module, data_loader: torch.utils.data.DataLoader, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        #run_train_epoch(model, model_init, data_loader, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "        train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "\n",
    "def test(model, data_loader):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    model_init=copy.deepcopy(model)\n",
    "    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', epoch=epoch, ignore_index=None, quiet=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readout_retrain(model, data_loader, test_loader, lr=0.1, epochs=500, threshold=0.01, quiet=True):\n",
    "    torch.manual_seed(seed)\n",
    "    model = copy.deepcopy(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    sampler = torch.utils.data.RandomSampler(data_loader.dataset, replacement=True, num_samples=500)\n",
    "    data_loader_small = torch.utils.data.DataLoader(data_loader.dataset, batch_size=data_loader.batch_size, sampler=sampler, num_workers=data_loader.num_workers)\n",
    "    metrics = []\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        metrics.append(run_train_epoch(model, model_init, test_loader, loss_fn, optimizer, split='test', epoch=epoch, ignore_index=None, quiet=quiet))\n",
    "        if metrics[-1]['loss'] <= threshold:\n",
    "            break\n",
    "        run_train_epoch(model, model_init, data_loader_small, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "    return epoch, metrics\n",
    "\n",
    "def extract_retrain_time(metrics, threshold=0.1):\n",
    "    losses = np.array([m['loss'] for m in metrics])\n",
    "    return np.argmax(losses < threshold)\n",
    "\n",
    "def all_readouts(model,thresh=0.1,name='method'):\n",
    "    train_loader = torch.utils.data.DataLoader(train_loader_full.dataset, batch_size=args.batch_size, shuffle=True)\n",
    "    retrain_time, _ = readout_retrain(model, train_loader, forget_loader, epochs=100, lr=0.01, threshold=thresh)\n",
    "    test_error = test(model, test_loader_full)['error']\n",
    "    forget_error = test(model, forget_loader)['error']\n",
    "    retain_error = test(model, retain_loader)['error']\n",
    "    print(f\"{name} ->\"\n",
    "          f\"\\tFull test error: {test_error:.2%}\"\n",
    "          f\"\\tForget error: {forget_error:.2%}\\tRetain error: {retain_error:.2%}\"\n",
    "          f\"\\tFine-tune time: {retrain_time+1} steps\")\n",
    "    #print(f\"{name} ->\"\n",
    "    #      f\"\\tFine-tune time: {retrain_time+1} steps\")\n",
    "    log_dict[f\"{name}_retrain_time\"]=retrain_time+1\n",
    "    return(dict(test_error=test_error, forget_error=forget_error, retain_error=retain_error, retrain_time=retrain_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_activations(model_scrubf, modelf0, delta_w_s, delta_w_m0, data_loader, \\\n",
    "                    loss_fn=nn.CrossEntropyLoss(),\\\n",
    "                    optimizer=torch.optim.SGD, \\\n",
    "                    seed=1,quiet=False):\n",
    "\n",
    "    model_scrubf.eval()\n",
    "    modelf0.eval()\n",
    "    \n",
    "    data_loader = torch.utils.data.DataLoader(data_loader.dataset, batch_size=1, shuffle=False)\n",
    "\n",
    "    \n",
    "    metrics = AverageMeter()    \n",
    "    num_classes = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    for idx, batch in enumerate(tqdm(data_loader, leave=False)):\n",
    "        batch = [tensor.to(next(model_scrubf.parameters()).device) for tensor in batch]\n",
    "        input, target = batch\n",
    "        \n",
    "        output_sf = model_scrubf(input)\n",
    "        G_sf = []\n",
    "\n",
    "        for cls in range(num_classes):\n",
    "            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=True)\n",
    "            grads = torch.cat([g.view(-1) for g in grads])\n",
    "            G_sf.append(grads)\n",
    "\n",
    "        grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "            \n",
    "        G_sf = torch.stack(G_sf)#.pow(2)\n",
    "        delta_f_sf_update = torch.matmul(G_sf,delta_w_s.sqrt()*torch.empty_like(delta_w_s).normal_())\n",
    "        G_sf = G_sf.pow(2)\n",
    "        delta_f_sf = torch.matmul(G_sf,delta_w_s)\n",
    "\n",
    "        output_m0 = modelf0(input)\n",
    "        G_m0 = []\n",
    "\n",
    "        for cls in range(num_classes):\n",
    "            grads = torch.autograd.grad(output_m0[0,cls],modelf0.parameters(),retain_graph=True)\n",
    "            grad_m0 = torch.cat([g.view(-1) for g in grads])\n",
    "            G_m0.append(grad_m0)\n",
    "\n",
    "        grads = torch.autograd.grad(output_m0[0,cls],modelf0.parameters(),retain_graph=False)\n",
    "            \n",
    "        G_m0 = torch.stack(G_m0).pow(2)\n",
    "        delta_f_m0 = torch.matmul(G_m0,delta_w_m0)\n",
    "        \n",
    "        kl = ((output_m0 - output_sf).pow(2)/delta_f_m0 + delta_f_sf/delta_f_m0 - torch.log(delta_f_sf/delta_f_m0) - 1).sum()\n",
    "        \n",
    "        torch.manual_seed(seed)\n",
    "        output_sf += delta_f_sf_update#delta_f_sf.sqrt()*torch.empty_like(delta_f_sf).normal_()\n",
    "        \n",
    "        loss = loss_fn(output_sf, target)\n",
    "        metrics.update(n=input.size(0), loss=loss.item(), error=get_error(output_sf, target), kl=kl.item())\n",
    "    \n",
    "    return metrics.avg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%run main.py --dataset cifar100 --dataroot=data/cifar-100-python --model resnet --filters 1.0 --lr 0.1 --lossfn ce --num-classes 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the original model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main_merged.py --dataset cifar10 --model allcnn --dataroot=data/cifar-10-batches-py/ --filters 1.0 --lr 0.01 \\\n",
    "--resume checkpoints/cifar100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "--weight-decay 5e-4 --batch-size 128 --epochs 26 --seed 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retrain Forgetting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main_merged.py --dataset cifar10 --model allcnn --dataroot=data/cifar-10-batches-py/ --filters 1 --lr 0.01 \\\n",
    "--resume checkpoints/cifar100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "--weight-decay 5e-4 --batch-size 128 --epochs 26 \\\n",
    "--forget-class 5 --seed 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict={}\n",
    "training_epochs=25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict['epoch']=training_epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameter_count(copy.deepcopy(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loads checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "model0 = copy.deepcopy(model)\n",
    "\n",
    "arch = args.model \n",
    "filters=args.filters\n",
    "arch_filters = arch +'_'+ str(filters).replace('.','_')\n",
    "augment = False\n",
    "dataset = args.dataset\n",
    "class_to_forget = args.forget_class\n",
    "init_checkpoint = f\"checkpoints/{args.name}_init.pt\"\n",
    "num_classes=args.num_classes\n",
    "num_to_forget = args.num_to_forget\n",
    "num_total = len(train_loader.dataset)\n",
    "num_to_retain = num_total - 4000#num_to_forget\n",
    "seed = args.seed\n",
    "unfreeze_start = None\n",
    "\n",
    "learningrate=f\"lr_{str(args.lr).replace('.','_')}\"\n",
    "batch_size=f\"_bs_{str(args.batch_size)}\"\n",
    "lossfn=f\"_ls_{args.lossfn}\"\n",
    "wd=f\"_wd_{str(args.weight_decay).replace('.','_')}\"\n",
    "seed_name=f\"_seed_{args.seed}_\"\n",
    "\n",
    "num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'\n",
    "unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'\n",
    "augment_tag = '' if not augment else f'augment_'\n",
    "\n",
    "m_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "m0_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "\n",
    "model.load_state_dict(torch.load(m_name))\n",
    "model0.load_state_dict(torch.load(m0_name))\n",
    "\n",
    "\n",
    "model.cuda()\n",
    "model0.cuda()\n",
    "\n",
    "\n",
    "for p in model.parameters():\n",
    "    p.data0 = p.data.clone()\n",
    "for p in model0.parameters():\n",
    "    p.data0 = p.data.clone()\n",
    "\n",
    "teacher = copy.deepcopy(model)\n",
    "student = copy.deepcopy(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader_full, valid_loader_full, test_loader_full   = datasets.get_loaders(dataset, batch_size=args.batch_size, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "marked_loader, _, _  = datasets.get_loaders(dataset, class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "def replace_loader_dataset(data_loader, dataset, batch_size=args.batch_size, seed=1, shuffle=True):\n",
    "    manual_seed(seed)\n",
    "    loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "    def _init_fn(worker_id):\n",
    "        np.random.seed(int(seed))\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)\n",
    "    \n",
    "forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = forget_dataset.targets < 0\n",
    "forget_dataset.data = forget_dataset.data[marked]\n",
    "forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=512, seed=seed, shuffle=True)\n",
    "\n",
    "retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = retain_dataset.targets >= 0\n",
    "retain_dataset.data = retain_dataset.data[marked]\n",
    "retain_dataset.targets = retain_dataset.targets[marked]\n",
    "retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=256, seed=seed, shuffle=True)\n",
    "\n",
    "\n",
    "assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict['args']=args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print (len(forget_loader.dataset))\n",
    "print (len(retain_loader.dataset))\n",
    "print (len(test_loader_full.dataset))\n",
    "print (len(train_loader_full.dataset))\n",
    "from collections import Counter\n",
    "print(dict(Counter(train_loader_full.dataset.targets)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SCRUB Forgetting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.optim = 'sgd'\n",
    "args.gamma = 0.99\n",
    "args.alpha = 0.001\n",
    "args.beta = 0\n",
    "args.smoothing = 0.0\n",
    "args.msteps = 2\n",
    "args.clip = 0.2\n",
    "args.sstart = 10\n",
    "args.kd_T = 4\n",
    "args.distill = 'kd'\n",
    "\n",
    "args.sgda_batch_size = 128\n",
    "args.del_batch_size = 32\n",
    "args.sgda_epochs = 3\n",
    "args.sgda_learning_rate = 0.0005\n",
    "args.lr_decay_epochs = [3,5,9]\n",
    "args.lr_decay_rate = 0.1\n",
    "args.sgda_weight_decay = 5e-4\n",
    "args.sgda_momentum = 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_t = copy.deepcopy(teacher)\n",
    "model_s = copy.deepcopy(student)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py\n",
    "#For SGDA smoothing\n",
    "beta = 0.1\n",
    "def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (\n",
    "    1 - beta) * averaged_model_parameter + beta * model_parameter\n",
    "swa_model = torch.optim.swa_utils.AveragedModel(\n",
    "    model_s, avg_fn=avg_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_list = nn.ModuleList([])\n",
    "module_list.append(model_s)\n",
    "trainable_list = nn.ModuleList([])\n",
    "trainable_list.append(model_s)\n",
    "\n",
    "criterion_cls = nn.CrossEntropyLoss()\n",
    "criterion_div = DistillKL(args.kd_T)\n",
    "criterion_kd = DistillKL(args.kd_T)\n",
    "\n",
    "\n",
    "criterion_list = nn.ModuleList([])\n",
    "criterion_list.append(criterion_cls)    # classification loss\n",
    "criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation\n",
    "criterion_list.append(criterion_kd)     # other knowledge distillation loss\n",
    "\n",
    "# optimizer\n",
    "if args.optim == \"sgd\":\n",
    "    optimizer = optim.SGD(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          momentum=args.sgda_momentum,\n",
    "                          weight_decay=args.sgda_weight_decay)\n",
    "elif args.optim == \"adam\": \n",
    "    optimizer = optim.Adam(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          weight_decay=args.sgda_weight_decay)\n",
    "elif args.optim == \"rmsp\":\n",
    "    optimizer = optim.RMSprop(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          momentum=args.sgda_momentum,\n",
    "                          weight_decay=args.sgda_weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_list.append(model_t)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    module_list.cuda()\n",
    "    criterion_list.cuda()\n",
    "    import torch.backends.cudnn as cudnn\n",
    "    cudnn.benchmark = True\n",
    "    swa_model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_rs = []\n",
    "acc_fs = []\n",
    "acc_ts = []\n",
    "acc_vs = []\n",
    "for epoch in range(1, args.sgda_epochs + 1):\n",
    "\n",
    "    lr = sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "\n",
    "    print(\"==> SCRUB unlearning ...\")\n",
    "\n",
    "    acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "    acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "    acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "    acc_rs.append(100-acc_r.item())\n",
    "    acc_fs.append(100-acc_f.item())\n",
    "    acc_vs.append(100-acc_v.item())\n",
    "\n",
    "    maximize_loss = 0\n",
    "    if epoch <= args.msteps:\n",
    "        maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, \"maximize\")\n",
    "    train_acc, train_loss = train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, \"minimize\")\n",
    "    if epoch >= args.sstart:\n",
    "        swa_model.update_parameters(model_s)\n",
    "\n",
    "    \n",
    "    print (\"maximize loss: {:.2f}\\t minimize loss: {:.2f}\\t train_acc: {}\".format(maximize_loss, train_loss, train_acc))\n",
    "acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "acc_rs.append(100-acc_r.item())\n",
    "acc_fs.append(100-acc_f.item())\n",
    "acc_vs.append(100-acc_v.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "indices = list(range(0,len(acc_rs)))\n",
    "plt.plot(indices, acc_rs, marker='*', alpha=1, label='retain-set')\n",
    "plt.plot(indices, acc_fs, marker='o', alpha=1, label='forget-set')\n",
    "plt.plot(indices, acc_vs, marker='.', alpha=1, label='valid-set')\n",
    "plt.legend(prop={'size': 14})\n",
    "plt.tick_params(labelsize=12)\n",
    "plt.title('SCRUB retain-, valid- and forget- set error',size=18)\n",
    "plt.xlabel('epoch',size=14)\n",
    "plt.ylabel('error',size=14)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_D_r_activations,m_D_r_predictions=activations_predictions(copy.deepcopy(model),copy.deepcopy(retain_loader),'Original_Model_D_r')\n",
    "m_D_f_activations,m_D_f_predictions=activations_predictions(copy.deepcopy(model),copy.deepcopy(forget_loader),'Original_Model_D_f')\n",
    "m_D_t_activations,m_D_t_predictions=activations_predictions(copy.deepcopy(model),copy.deepcopy(test_loader_full),'Original_Model_D_t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m0_D_r_activations,m0_D_r_predictions=activations_predictions(copy.deepcopy(model0),copy.deepcopy(retain_loader),'Retrain_Model_D_r')\n",
    "m0_D_f_activations,m0_D_f_predictions=activations_predictions(copy.deepcopy(model0),copy.deepcopy(forget_loader),'Retrain_Model_D_f')\n",
    "m0_D_t_activations,m0_D_t_predictions=activations_predictions(copy.deepcopy(model0),copy.deepcopy(test_loader_full),'Retrain_Model_D_t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SCRUB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ntk_D_r_activations,ntk_D_r_predictions=activations_predictions(copy.deepcopy(model_s),copy.deepcopy(retain_loader),'SCRUB_D_r')\n",
    "ntk_D_f_activations,ntk_D_f_predictions=activations_predictions(copy.deepcopy(model_s),copy.deepcopy(forget_loader),'SCRUB_D_f')\n",
    "ntk_D_t_activations,ntk_D_t_predictions=activations_predictions(copy.deepcopy(model_s),copy.deepcopy(test_loader_full),'SCRUB_D_t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fisher Forgetting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fisher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelf = copy.deepcopy(model)\n",
    "modelf0 = copy.deepcopy(model0)\n",
    "\n",
    "for p in itertools.chain(modelf.parameters(), modelf0.parameters()):\n",
    "    p.data0 = copy.deepcopy(p.data.clone())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hessian(dataset, model):\n",
    "    model.eval()\n",
    "    train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "    for p in model.parameters():\n",
    "        p.grad_acc = 0\n",
    "        p.grad2_acc = 0\n",
    "    \n",
    "    for data, orig_target in tqdm(train_loader):\n",
    "        data, orig_target = data.to(args.device), orig_target.to(args.device)\n",
    "        output = model(data)\n",
    "        prob = F.softmax(output, dim=-1).data\n",
    "\n",
    "        for y in range(output.shape[1]):\n",
    "            target = torch.empty_like(orig_target).fill_(y)\n",
    "            loss = loss_fn(output, target)\n",
    "            model.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            for p in model.parameters():\n",
    "                if p.requires_grad:\n",
    "                    p.grad_acc += (orig_target == target).float() * p.grad.data\n",
    "                    p.grad2_acc += prob[:, y] * p.grad.data.pow(2)\n",
    "    for p in model.parameters():\n",
    "        p.grad_acc /= len(train_loader)\n",
    "        p.grad2_acc /= len(train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hessian(retain_loader.dataset, modelf)\n",
    "#hessian(retain_loader.dataset, modelf0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_var(p, is_base_dist=False, alpha=3e-6):\n",
    "    var = copy.deepcopy(1./(p.grad2_acc+1e-8))\n",
    "    var = var.clamp(max=1e3)\n",
    "    if p.size(0) == num_classes:\n",
    "        var = var.clamp(max=1e2)\n",
    "    var = alpha * var\n",
    "    \n",
    "    if p.ndim > 1:\n",
    "        var = var.mean(dim=1, keepdim=True).expand_as(p).clone()\n",
    "    if not is_base_dist:\n",
    "        mu = copy.deepcopy(p.data0.clone())\n",
    "    else:\n",
    "        mu = copy.deepcopy(p.data0.clone())\n",
    "    if p.size(0) == num_classes and num_to_forget is None:\n",
    "        mu[class_to_forget] = 0\n",
    "        var[class_to_forget] = 0.0001\n",
    "    if p.size(0) == num_classes:\n",
    "        # Last layer\n",
    "        var *= 10\n",
    "    elif p.ndim == 1:\n",
    "        # BatchNorm\n",
    "        var *= 10\n",
    "#         var*=1\n",
    "    return mu, var\n",
    "\n",
    "def kl_divergence_fisher(mu0, var0, mu1, var1):\n",
    "    return ((mu1 - mu0).pow(2)/var0 + var1/var0 - torch.log(var1/var0) - 1).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fisher_dir = []\n",
    "alpha = 1e-7\n",
    "torch.manual_seed(seed)\n",
    "for i, p in enumerate(modelf.parameters()):\n",
    "    mu, var = get_mean_var(p, False, alpha=alpha)\n",
    "    p.data = mu + var.sqrt() * torch.empty_like(p.data0).normal_()\n",
    "    fisher_dir.append(var.sqrt().view(-1).cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fisher Noise in Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fisher_D_r_activations,fisher_D_r_predictions=activations_predictions(copy.deepcopy(modelf),copy.deepcopy(retain_loader),'Fisher_D_r')\n",
    "fisher_D_f_activations,fisher_D_f_predictions=activations_predictions(copy.deepcopy(modelf),copy.deepcopy(forget_loader),'Fisher_D_f')\n",
    "fisher_D_t_activations,fisher_D_t_predictions=activations_predictions(copy.deepcopy(modelf),copy.deepcopy(test_loader_full),'Fisher_D_t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Information in the Activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ft = copy.deepcopy(model)\n",
    "retain_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.batch_size, shuffle=True)    \n",
    "finetune(model_ft, retain_loader, epochs=10, quiet=True, lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_D_r_activations,finetune_D_r_predictions=activations_predictions(copy.deepcopy(model_ft),copy.deepcopy(retain_loader),'Finetune_D_r')\n",
    "finetune_D_f_activations,finetune_D_f_predictions=activations_predictions(copy.deepcopy(model_ft),copy.deepcopy(forget_loader),'Finetune_D_f')\n",
    "finetune_D_t_activations,finetune_D_t_predictions=activations_predictions(copy.deepcopy(model_ft),copy.deepcopy(test_loader_full),'Finetune_D_t')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Readouts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "try: readouts\n",
    "except: readouts = {}\n",
    "\n",
    "thresh=log_dict['Original_Model_D_f_loss']+1e-5\n",
    "print(thresh)\n",
    "readouts[\"e\"] = all_readouts(copy.deepcopy(model),thresh,'Original')\n",
    "readouts[\"a\"] = all_readouts(copy.deepcopy(model_ft),thresh,'Finetune')\n",
    "readouts[\"b\"] = all_readouts(copy.deepcopy(modelf),thresh,'Fisher')\n",
    "readouts[\"d\"] = all_readouts(copy.deepcopy(model0),thresh,'Retrain')\n",
    "readouts[\"d\"] = all_readouts(copy.deepcopy(model_s),thresh,'SCRUB')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save Dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#fig, ax = plt.subplots(5,1,figsize=(9,20))\n",
    "ax = [0,0,0,0,0]\n",
    "plot_entropy_dist(copy.deepcopy(model_s),retain_loader,forget_loader,test_loader_full, 'SCRUB',ax[4])\n",
    "plot_entropy_dist(copy.deepcopy(model),retain_loader,forget_loader,test_loader_full, 'original', ax[0])\n",
    "plot_entropy_dist(copy.deepcopy(model0),retain_loader,forget_loader,test_loader_full, 'retrain', ax[1])\n",
    "plot_entropy_dist(copy.deepcopy(model_ft),retain_loader,forget_loader,test_loader_full, 'finetuned',ax[2])\n",
    "plot_entropy_dist(copy.deepcopy(modelf),retain_loader,forget_loader,test_loader_full, 'fisher',ax[3])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
