{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
    "import variational\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "from itertools import cycle\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "import pandas as pd\n",
    "from collections import OrderedDict\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "    \n",
    "import copy\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from typing import List\n",
    "import itertools\n",
    "from tqdm.autonotebook import tqdm\n",
    "from models import *\n",
    "import models\n",
    "from logger import *\n",
    "import wandb\n",
    "\n",
    "from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate\n",
    "from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss\n",
    "from thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss\n",
    "\n",
    "from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate\n",
    "from thirdparty.repdistiller.helper.pretrain import init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pdb():\n",
    "    import pdb\n",
    "    pdb.set_trace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parameter_count(model):\n",
    "    count=0\n",
    "    for p in model.parameters():\n",
    "        count+=np.prod(np.array(list(p.shape)))\n",
    "    print(f'Total Number of Parameters: {count}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vectorize_params(model):\n",
    "    param = []\n",
    "    for p in model.parameters():\n",
    "        param.append(p.data.view(-1).cpu().numpy())\n",
    "    return np.concatenate(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_param_shape(model):\n",
    "    for k,p in model.named_parameters():\n",
    "        print(k,p.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%run main.py --dataset cifar100 --dataroot=data/cifar-100-python --model resnet --filteers 0.4 --lr 0.1 --lossfn ce --num-classes 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Original and Retrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "\n",
    "def train1():\n",
    "    clear_output(wait=True)\n",
    "    %run main.py --dataset small_cifar6 --model resnet --dataroot=data/cifar10/ --filters 0.4 --lr 0.01 \\\n",
    "    --resume checkpoints/cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "    --weight-decay 0.1 --batch-size 128 --epochs 31 --seed 1 --quiet\n",
    "    \n",
    "    %run main.py --dataset small_cifar6 --model resnet --dataroot=data/lacuna10/ --filters 0.4 --lr 0.01 \\\n",
    "    --resume checkpoints/cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "    --weight-decay 0.1 --batch-size 128 --epochs 31 \\\n",
    "    --forget-class 0,1,2 --seed 1 --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train2():\n",
    "    clear_output(wait=True)\n",
    "    %run main.py --dataset small_cifar6 --model resnet --dataroot=data/cifar10/ --filters 0.4 --lr 0.01 \\\n",
    "    --resume checkpoints/cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "    --weight-decay 0.1 --batch-size 128 --epochs 31 --seed 2 --quiet\n",
    "    \n",
    "    %run main.py --dataset small_cifar6 --model resnet --dataroot=data/lacuna10/ --filters 0.4 --lr 0.01 \\\n",
    "    --resume checkpoints/cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "    --weight-decay 0.1 --batch-size 128 --epochs 31 \\\n",
    "    --forget-class 0,1,2 --seed 2 --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train3():\n",
    "    clear_output(wait=True)\n",
    "    %run main.py --dataset small_cifar6 --model resnet --dataroot=data/cifar10/ --filters 0.4 --lr 0.01 \\\n",
    "    --resume checkpoints/cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "    --weight-decay 0.1 --batch-size 128 --epochs 31 --seed 3 --quiet\n",
    "    \n",
    "    %run main.py --dataset small_cifar6 --model resnet --dataroot=data/lacuna10/ --filters 0.4 --lr 0.01 \\\n",
    "    --resume checkpoints/cifar100_resnet_0_4_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "    --weight-decay 0.1 --batch-size 128 --epochs 31 \\\n",
    "    --forget-class 0,1,2 --seed 3 --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loads checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "log_dict={}\n",
    "def load_chpts():\n",
    "    clear_output(wait=True)\n",
    "    training_epochs=30\n",
    "    log_dict['epoch']=training_epochs\n",
    "    model0 = copy.deepcopy(model)\n",
    "    model_initial = copy.deepcopy(model)\n",
    "\n",
    "    arch = args.model \n",
    "    filters=args.filters\n",
    "    arch_filters = arch +'_'+ str(filters).replace('.','_')\n",
    "    augment = False\n",
    "    dataset = args.dataset\n",
    "    class_to_forget = args.forget_class\n",
    "    init_checkpoint = f\"checkpoints/{args.name}_init.pt\"\n",
    "    num_classes=args.num_classes\n",
    "    num_to_forget = args.num_to_forget\n",
    "    num_total = len(train_loader.dataset)\n",
    "    num_to_retain = num_total - 300#num_to_forget\n",
    "    seed = args.seed\n",
    "    unfreeze_start = None\n",
    "\n",
    "    learningrate=f\"lr_{str(args.lr).replace('.','_')}\"\n",
    "    batch_size=f\"_bs_{str(args.batch_size)}\"\n",
    "    lossfn=f\"_ls_{args.lossfn}\"\n",
    "    wd=f\"_wd_{str(args.weight_decay).replace('.','_')}\"\n",
    "    seed_name=f\"_seed_{args.seed}_\"\n",
    "\n",
    "    num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'\n",
    "    unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'\n",
    "    augment_tag = '' if not augment else f'augment_'\n",
    "\n",
    "    m_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "    m0_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "\n",
    "    model.load_state_dict(torch.load(m_name))\n",
    "    #model0.load_state_dict(torch.load(m0_name))\n",
    "    #model_initial.load_state_dict(torch.load(init_checkpoint))\n",
    "\n",
    "    teacher = copy.deepcopy(model)\n",
    "    student = copy.deepcopy(model)\n",
    "\n",
    "    model.cuda()\n",
    "    model0.cuda()\n",
    "\n",
    "\n",
    "    for p in model.parameters():\n",
    "        p.data0 = p.data.clone()\n",
    "    for p in model0.parameters():\n",
    "        p.data0 = p.data.clone()\n",
    "        \n",
    "    log_dict['args']=args\n",
    "    args.retain_bs = 16\n",
    "    args.forget_bs = 32\n",
    "\n",
    "    train_loader_full, valid_loader_full, _   = datasets.get_loaders(dataset, batch_size=args.batch_size, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "    marked_loader, _, test_loader_full = datasets.get_loaders(dataset, class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "    def replace_loader_dataset(data_loader, dataset, batch_size=args.batch_size, seed=1, shuffle=True):\n",
    "        manual_seed(seed)\n",
    "        loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "        def _init_fn(worker_id):\n",
    "            np.random.seed(int(seed))\n",
    "        return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)\n",
    "\n",
    "    forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "    marked = forget_dataset.targets < 0\n",
    "    forget_dataset.data = forget_dataset.data[marked]\n",
    "    forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "    forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)\n",
    "\n",
    "    retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "    marked = retain_dataset.targets >= 0\n",
    "    retain_dataset.data = retain_dataset.data[marked]\n",
    "    retain_dataset.targets = retain_dataset.targets[marked]\n",
    "    retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)\n",
    "\n",
    "    assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))\n",
    "    \n",
    "    return teacher, student, train_loader_full, test_loader_full, retain_loader, forget_loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SCRUB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scrub(gamma, alpha, teacher, student, train_loader_full, test_loader_full, retain_loader, forget_loader):\n",
    "    clear_output(wait=True)\n",
    "    args.optim = 'adam'\n",
    "    args.gamma = gamma\n",
    "    args.alpha = alpha\n",
    "    args.beta = 0\n",
    "    args.smoothing = 0.5\n",
    "    args.msteps = 2\n",
    "    args.clip = 0.2\n",
    "    args.sstart = 10\n",
    "    args.kd_T = 8\n",
    "    args.distill = 'kd'\n",
    "\n",
    "    args.sgda_epochs = 10\n",
    "    args.sgda_learning_rate = 0.0005\n",
    "    args.lr_decay_epochs = [7,8,9]\n",
    "    args.lr_decay_rate = 0.1\n",
    "    args.sgda_weight_decay = 0.1#5e-4\n",
    "    args.sgda_momentum = 0.9\n",
    "\n",
    "    model_t = copy.deepcopy(teacher)\n",
    "    model_s = copy.deepcopy(student)\n",
    "\n",
    "    #this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py\n",
    "    #For SGDA smoothing\n",
    "    beta = 0.1\n",
    "    def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (\n",
    "        1 - beta) * averaged_model_parameter + beta * model_parameter\n",
    "    swa_model = torch.optim.swa_utils.AveragedModel(\n",
    "        model_s, avg_fn=avg_fn)\n",
    "\n",
    "    module_list = nn.ModuleList([])\n",
    "    module_list.append(model_s)\n",
    "    trainable_list = nn.ModuleList([])\n",
    "    trainable_list.append(model_s)\n",
    "\n",
    "    criterion_cls = nn.CrossEntropyLoss()\n",
    "    criterion_div = DistillKL(args.kd_T)\n",
    "    criterion_kd = DistillKL(args.kd_T)\n",
    "\n",
    "\n",
    "    criterion_list = nn.ModuleList([])\n",
    "    criterion_list.append(criterion_cls)    # classification loss\n",
    "    criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation\n",
    "    criterion_list.append(criterion_kd)     # other knowledge distillation loss\n",
    "\n",
    "    # optimizer\n",
    "    if args.optim == \"sgd\":\n",
    "        optimizer = optim.SGD(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              momentum=args.sgda_momentum,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "    elif args.optim == \"adam\": \n",
    "        optimizer = optim.Adam(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "    elif args.optim == \"rmsp\":\n",
    "        optimizer = optim.RMSprop(trainable_list.parameters(),\n",
    "                              lr=args.sgda_learning_rate,\n",
    "                              momentum=args.sgda_momentum,\n",
    "                              weight_decay=args.sgda_weight_decay)\n",
    "\n",
    "    module_list.append(model_t)\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        module_list.cuda()\n",
    "        criterion_list.cuda()\n",
    "        import torch.backends.cudnn as cudnn\n",
    "        cudnn.benchmark = True\n",
    "        swa_model.cuda()\n",
    "    \n",
    "\n",
    "    for epoch in range(1, args.sgda_epochs + 1):\n",
    "\n",
    "        lr = sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "\n",
    "        print(\"==> sgda unlearning ...\")\n",
    "\n",
    "        maximize_loss = 0\n",
    "        if epoch <= args.msteps:\n",
    "            maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, \"maximize\")\n",
    "        train_acc, train_loss = train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, \"minimize\")\n",
    "        if epoch >= args.sstart:\n",
    "            swa_model.update_parameters(model_s)\n",
    "\n",
    "\n",
    "        print (\"maximize loss: {:.2f}\\t minimize loss: {:.2f}\\t train_acc: {}\".format(maximize_loss, train_loss, train_acc))\n",
    "    \n",
    "    return model_s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_predictions(model,dataloader,name):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    metrics,activations,predictions=get_metrics(model,dataloader,criterion,True)\n",
    "    print(f\"{name} -> Loss:{np.round(metrics['loss'],3)}, Error:{metrics['error']}\")\n",
    "    log_dict[f\"{name}_loss\"]=metrics['loss']\n",
    "    log_dict[f\"{name}_error\"]=metrics['error']\n",
    "\n",
    "    return activations,predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predictions_distance(l1,l2,name):\n",
    "    dist = np.sum(np.abs(l1-l2))\n",
    "    print(f\"Predictions Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_predictions\"]=dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activations_distance(a1,a2,name):\n",
    "    dist = np.linalg.norm(a1-a2,ord=1,axis=1).mean()\n",
    "    print(f\"Activations Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_activations\"]=dist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finetune and Fisher Helper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def l2_penalty(model,model_init,weight_decay):\n",
    "    l2_loss = 0\n",
    "    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):\n",
    "        if p.requires_grad:\n",
    "            l2_loss += (p-p_init).pow(2).sum()\n",
    "    l2_loss *= (weight_decay/2.)\n",
    "    return l2_loss\n",
    "\n",
    "def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    negative_gradient=False, negative_multiplier=-1, random_labels=False,\n",
    "                    quiet=False,delta_w=None,scrub_act=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, batch in enumerate(tqdm(data_loader, leave=False)):\n",
    "            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]\n",
    "            input, target = batch\n",
    "            output = model(input)\n",
    "            if split=='test' and scrub_act:\n",
    "                G = []\n",
    "                for cls in range(num_classes):\n",
    "                    grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                    grads = torch.cat([g.view(-1) for g in grads])\n",
    "                    G.append(grads)\n",
    "                grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "                G = torch.stack(G).pow(2)\n",
    "                delta_f = torch.matmul(G,delta_w)\n",
    "                output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "            loss = loss_fn(output, target) + l2_penalty(model,model_init,args.weight_decay)\n",
    "            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))\n",
    "            \n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg\n",
    "\n",
    "def test(model, data_loader):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    model_init=copy.deepcopy(model)\n",
    "    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', epoch=epoch, ignore_index=None, quiet=True)\n",
    "\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        if scrub_act:\n",
    "            G = []\n",
    "            for cls in range(num_classes):\n",
    "                grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                grads = torch.cat([g.view(-1) for g in grads])\n",
    "                G.append(grads)\n",
    "            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "            G = torch.stack(G).pow(2)\n",
    "            delta_f = torch.matmul(G,delta_w)\n",
    "            output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readout_retrain(model, data_loader, test_loader, lr=0.1, epochs=500, threshold=0.01, quiet=True):\n",
    "    torch.manual_seed(seed)\n",
    "    model = copy.deepcopy(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    sampler = torch.utils.data.RandomSampler(data_loader.dataset, replacement=True, num_samples=500)\n",
    "    data_loader_small = torch.utils.data.DataLoader(data_loader.dataset, batch_size=data_loader.batch_size, sampler=sampler, num_workers=data_loader.num_workers)\n",
    "    metrics = []\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        metrics.append(run_train_epoch(model, model_init, test_loader, loss_fn, optimizer, split='test', epoch=epoch, ignore_index=None, quiet=quiet))\n",
    "        if metrics[-1]['loss'] <= threshold:\n",
    "            break\n",
    "        run_train_epoch(model, model_init, data_loader_small, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "    return epoch, metrics\n",
    "\n",
    "def extract_retrain_time(metrics, threshold=0.1):\n",
    "    losses = np.array([m['loss'] for m in metrics])\n",
    "    return np.argmax(losses < threshold)\n",
    "\n",
    "def all_readouts(model, test_loader_full, retain_loader, forget_loader,thresh=0.1,name='method'):\n",
    "    #train_loader = torch.utils.data.DataLoader(train_loader_full.dataset, batch_size=args.batch_size, shuffle=True)\n",
    "    retrain_time, _ = 0,0#readout_retrain(model, train_loader, forget_loader, epochs=100, lr=0.1, threshold=thresh)\n",
    "    test_error = test(model, test_loader_full)['error']\n",
    "    forget_error = test(model, forget_loader)['error']\n",
    "    retain_error = test(model, retain_loader)['error']\n",
    "    print(f\"{name} ->\"\n",
    "          f\"\\tFull test error: {test_error:.2%}\"\n",
    "          f\"\\tForget error: {forget_error:.2%}\\tRetain error: {retain_error:.2%}\"\n",
    "          f\"\\tFine-tune time: {retrain_time+1} steps\")\n",
    "    log_dict[f\"{name}_retrain_time\"]=retrain_time+1\n",
    "    return test_error, forget_error, retain_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma = [0]\n",
    "alpha = [0]\n",
    "\n",
    "errors = {}\n",
    "\n",
    "for g in gamma:\n",
    "    for a in alpha:\n",
    "        errs = []\n",
    "        train1()\n",
    "        teacher, student, train_loader_full, test_loader_full, retain_loader, forget_loader = load_chpts()\n",
    "        model_s = scrub(g, a, teacher, student, train_loader_full, test_loader_full, retain_loader, forget_loader)\n",
    "        t, f, r = all_readouts(model_s, test_loader_full, retain_loader, forget_loader)\n",
    "        errs.append((t, f, r))\n",
    "\n",
    "\n",
    "        errors[str(g)+\"_\"+str(a)] = errs"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
