{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "import variational\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "from itertools import cycle\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "import pandas as pd\n",
    "from collections import OrderedDict\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import confusion_matrix\n",
    "    \n",
    "import copy\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from typing import List\n",
    "import itertools\n",
    "from tqdm.autonotebook import tqdm\n",
    "from models import *\n",
    "import models\n",
    "from logger import *\n",
    "\n",
    "from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate\n",
    "from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss\n",
    "from thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss\n",
    "\n",
    "from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate\n",
    "from thirdparty.repdistiller.helper.pretrain import init"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg\n",
    "\n",
    "def activations_predictions(model,dataloader,name):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    metrics,activations,predictions=get_metrics(model,dataloader,criterion,True)\n",
    "    print(f\"{name} -> Loss:{np.round(metrics['loss'],3)}, Error:{metrics['error']}\")\n",
    "    log_dict[f\"{name}_loss\"]=metrics['loss']\n",
    "    log_dict[f\"{name}_error\"]=metrics['error']\n",
    "\n",
    "    return activations,predictions\n",
    "\n",
    "def predictions_distance(l1,l2,name):\n",
    "    dist = np.sum(np.abs(l1-l2))\n",
    "    print(f\"Predictions Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_predictions\"]=dist\n",
    "\n",
    "def activations_distance(a1,a2,name):\n",
    "    dist = np.linalg.norm(a1-a2,ord=1,axis=1).mean()\n",
    "    print(f\"Activations Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_activations\"]=dist\n",
    "\n",
    "def interclass_confusion(model, dataloader, class_to_forget, name):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=128, shuffle=False)\n",
    "    model.eval()\n",
    "    reals=[]\n",
    "    predicts=[]\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device) \n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        probs = torch.nn.functional.softmax(output, dim=1)\n",
    "        predict = np.argmax(probs.cpu().detach().numpy(),axis=1)\n",
    "        reals = reals + list(target.cpu().detach().numpy())\n",
    "        predicts = predicts + list(predict)\n",
    "    \n",
    "    classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
    "    cm = confusion_matrix(reals, predicts, labels=classes)\n",
    "    counts = 0\n",
    "    for i in range(len(cm)):\n",
    "        if i != class_to_forget[0]:\n",
    "            counts += cm[class_to_forget[0]][i]\n",
    "        if i != class_to_forget[1]:\n",
    "            counts += cm[class_to_forget[1]][i]\n",
    "    \n",
    "    ic_err = counts / (np.sum(cm[class_to_forget[0]]) + np.sum(cm[class_to_forget[1]]))\n",
    "    fgt = cm[class_to_forget[0]][class_to_forget[1]] + cm[class_to_forget[1]][class_to_forget[0]]\n",
    "    #print (cm)\n",
    "    return ic_err, fgt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper and Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        if scrub_act:\n",
    "            G = []\n",
    "            for cls in range(num_classes):\n",
    "                grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                grads = torch.cat([g.view(-1) for g in grads])\n",
    "                G.append(grads)\n",
    "            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "            G = torch.stack(G).pow(2)\n",
    "            delta_f = torch.matmul(G,delta_w)\n",
    "            output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg\n",
    "\n",
    "def l2_penalty(model,model_init,weight_decay):\n",
    "    l2_loss = 0\n",
    "    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):\n",
    "        if p.requires_grad:\n",
    "            l2_loss += (p-p_init).pow(2).sum()\n",
    "    l2_loss *= (weight_decay/2.)\n",
    "    return l2_loss\n",
    "\n",
    "def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    negative_gradient=False, negative_multiplier=-1, random_labels=False,\n",
    "                    quiet=False,delta_w=None,scrub_act=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, batch in enumerate(tqdm(data_loader, leave=False)):\n",
    "            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]\n",
    "            input, target = batch\n",
    "            output = model(input)\n",
    "            if split=='test' and scrub_act:\n",
    "                G = []\n",
    "                for cls in range(num_classes):\n",
    "                    grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                    grads = torch.cat([g.view(-1) for g in grads])\n",
    "                    G.append(grads)\n",
    "                grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "                G = torch.stack(G).pow(2)\n",
    "                delta_f = torch.matmul(G,delta_w)\n",
    "                output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "            loss = loss_fn(output, target) + l2_penalty(model,model_init,args.weight_decay)\n",
    "            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))\n",
    "            \n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg\n",
    "\n",
    "def run_neggrad_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    forget_loader: torch.utils.data.DataLoader,\n",
    "                    alpha: float,\n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    quiet=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, (batch_retain,batch_forget) in enumerate(tqdm(zip(data_loader,cycle(forget_loader)), leave=False)):\n",
    "            batch_retain = [tensor.to(next(model.parameters()).device) for tensor in batch_retain]\n",
    "            batch_forget = [tensor.to(next(model.parameters()).device) for tensor in batch_forget]\n",
    "            input_r, target_r = batch_retain\n",
    "            input_f, target_f = batch_forget\n",
    "            output_r = model(input_r)\n",
    "            output_f = model(input_f)\n",
    "            loss = alpha*(loss_fn(output_r, target_r) + l2_penalty(model,model_init,args.weight_decay)) - (1-alpha)*loss_fn(output_f, target_f)\n",
    "            metrics.update(n=input_r.size(0), loss=loss_fn(output_r,target_r).item(), error=get_error(output_r, target_r))\n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, data_loader):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    model_init=copy.deepcopy(model)\n",
    "    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', epoch=epoch, ignore_index=None, quiet=True)\n",
    "\n",
    "def readout_retrain(model, data_loader, test_loader, lr=0.1, epochs=500, threshold=0.01, quiet=True):\n",
    "    torch.manual_seed(seed)\n",
    "    model = copy.deepcopy(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    sampler = torch.utils.data.RandomSampler(data_loader.dataset, replacement=True, num_samples=500)\n",
    "    data_loader_small = torch.utils.data.DataLoader(data_loader.dataset, batch_size=data_loader.batch_size, sampler=sampler, num_workers=data_loader.num_workers)\n",
    "    metrics = []\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        metrics.append(run_train_epoch(model, model_init, test_loader, loss_fn, optimizer, split='test', epoch=epoch, ignore_index=None, quiet=quiet))\n",
    "        if metrics[-1]['loss'] <= threshold:\n",
    "            break\n",
    "        run_train_epoch(model, model_init, data_loader_small, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "    return epoch, metrics\n",
    "\n",
    "def extract_retrain_time(metrics, threshold=0.1):\n",
    "    losses = np.array([m['loss'] for m in metrics])\n",
    "    return np.argmax(losses < threshold)\n",
    "\n",
    "def all_readouts(test_loader, retain_loader, forget_loader, model, wandb=None,thresh=0.1,name='method'):\n",
    "    #train_loader = torch.utils.data.DataLoader(train_loader_full.dataset, batch_size=args.batch_size, shuffle=True)\n",
    "    #retrain_time, _ = readout_retrain(model, train_loader, forget_loader, epochs=100, lr=0.1, threshold=thresh)\n",
    "    test_error = test(model, test_loader)['error']\n",
    "    forget_error = test(model, forget_loader)['error']\n",
    "    retain_error = test(model, retain_loader)['error']\n",
    "    ic_err_test, fgt_test = interclass_confusion(model, test_loader, class_to_forget, name)\n",
    "    ic_err_retain, fgt_retain = interclass_confusion(model, retain_loader, class_to_forget, name)\n",
    "    print(f\"{name} ->\"\n",
    "          f\"\\ttest: {test_error:.2%}\"\n",
    "          f\"\\tForget: {forget_error:.2%}\\tRetain: {retain_error:.2%}\"\n",
    "          f\"\\tIC-test: {ic_err_test}\\tfgt-test: {fgt_test}\"\n",
    "          f\"\\tIC-retain: {ic_err_retain}\\tfgt-retain: {fgt_retain}\")\n",
    "    #log_dict[f\"{name}_retrain_time\"]=retrain_time+1\n",
    "    if wandb is not None:\n",
    "        #wandb.log({f\"{name}_RLT\": retrain_time + 1})\n",
    "        wandb.log({f\"{name}_test_error\": test_error})\n",
    "        wandb.log({f\"{name}_retain_error\": retain_error})\n",
    "        wandb.log({f\"{name}_forget_error\": forget_error})\n",
    "        wandb.log({f\"{name}_IC_err_test\": ic_err_test})\n",
    "        wandb.log({f\"{name}_IC_err_retain\": ic_err_retain})\n",
    "        wandb.log({f\"{name}_fgt_test\": fgt_test})\n",
    "        wandb.log({f\"{name}_fgt_retain\": fgt_retain})\n",
    "    return(dict(test_error=test_error, forget_error=forget_error, retain_error=retain_error))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Forgetting helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def finetune(model: nn.Module, data_loader: torch.utils.data.DataLoader, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        run_train_epoch(model, model_init, data_loader, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "        #train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "\n",
    "def negative_grad(model: nn.Module, data_loader: torch.utils.data.DataLoader, forget_loader: torch.utils.data.DataLoader, alpha: float, lr=0.01, epochs=10, quiet=False, args=None):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        #run_neggrad_epoch(model, model_init, data_loader, forget_loader, alpha, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "        train_negrad(epoch, data_loader, forget_loader, model, loss_fn, optimizer,  alpha, args)\n",
    "\n",
    "def fk_fientune(model: nn.Module, data_loader: torch.utils.data.DataLoader, args, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "        #train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "        run_train_epoch(model, model_init, data_loader, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "\n",
    "def pdb():\n",
    "    import pdb\n",
    "    pdb.set_trace\n",
    "    \n",
    "def parameter_count(model):\n",
    "    count=0\n",
    "    for p in model.parameters():\n",
    "        count+=np.prod(np.array(list(p.shape)))\n",
    "    print(f'Total Number of Parameters: {count}')\n",
    "    \n",
    "def vectorize_params(model):\n",
    "    param = []\n",
    "    for p in model.parameters():\n",
    "        param.append(p.data.view(-1).cpu().numpy())\n",
    "    return np.concatenate(param)\n",
    "\n",
    "def print_param_shape(model):\n",
    "    for k,p in model.named_parameters():\n",
    "        print(k,p.shape)\n",
    "\n",
    "def distance(model,model0):\n",
    "    distance=0\n",
    "    normalization=0\n",
    "    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):\n",
    "        space='  ' if 'bias' in k else ''\n",
    "        current_dist=(p.data0-p0.data0).pow(2).sum().item()\n",
    "        current_norm=p.data0.pow(2).sum().item()\n",
    "        distance+=current_dist\n",
    "        normalization+=current_norm\n",
    "    print(f'Distance: {np.sqrt(distance)}')\n",
    "    print(f'Normalized Distance: {1.0*np.sqrt(distance/normalization)}')\n",
    "    return 1.0*np.sqrt(distance/normalization)\n",
    "\n",
    "def ntk_init(resume,seed=1):\n",
    "    manual_seed(seed)\n",
    "    model_init = models.get_model(arch, num_classes=num_classes, filters_percentage=filters).to(args.device)\n",
    "    model_init.load_state_dict(torch.load(resume))\n",
    "    return model_init"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%run main.py --dataset lacuna100 --dataroot=data/lacuna100/ --model resnet --filters 1.0 --lr 0.1 --lossfn ce --num-classes 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%run main.py --dataset small_lacuna6 --model allcnn --dataroot=data/lacuna10/ --filters 1.0 --lr 0.001 \\\n",
    "#--resume checkpoints/lacuna100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "#--weight-decay 0.1 --batch-size 128 --epochs 31 --seed 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py --dataset lacuna10 --model allcnn --dataroot=data/lacuna10/ --filters 1.0 --lr 0.01 \\\n",
    "--resume checkpoints/lacuna100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "--weight-decay 0.0005 --batch-size 128 --epochs 26 --seed 1 \\\n",
    "--split train --confuse-mode --forget-class 0,1 --num-to-forget 200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%run main.py --dataset small_lacuna6 --model allcnn --dataroot=data/lacuna10/ --filters 1.0 --lr 0.001 \\\n",
    "#--resume checkpoints/lacuna100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "#--weight-decay 0.1 --batch-size 128 --epochs 31 \\\n",
    "#--forget-class 0,1,2,3,4,5 --num-to-forget 300 --seed 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py --dataset lacuna10 --model allcnn --dataroot=data/lacuna10/ --filters 1.0 --lr 0.01 \\\n",
    "--resume checkpoints/lacuna100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "--weight-decay 0.0005 --batch-size 128 --epochs 26 --seed 1 \\\n",
    "--split forget --confuse-mode --forget-class 0,1 --num-to-forget 200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict={}\n",
    "training_epochs=25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict['epoch']=training_epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameter_count(copy.deepcopy(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loads checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "model0 = copy.deepcopy(model)\n",
    "model_initial = copy.deepcopy(model)\n",
    "\n",
    "arch = args.model \n",
    "filters=args.filters\n",
    "arch_filters = arch +'_'+ str(filters).replace('.','_')\n",
    "augment = False\n",
    "dataset = args.dataset\n",
    "class_to_forget = args.forget_class\n",
    "init_checkpoint = f\"checkpoints/{args.name}_init.pt\"\n",
    "num_classes=args.num_classes\n",
    "num_to_forget = args.num_to_forget\n",
    "num_total = len(train_loader.dataset)\n",
    "num_to_retain = num_total - num_to_forget\n",
    "seed = args.seed\n",
    "unfreeze_start = None\n",
    "\n",
    "learningrate=f\"lr_{str(args.lr).replace('.','_')}\"\n",
    "batch_size=f\"_bs_{str(args.batch_size)}\"\n",
    "lossfn=f\"_ls_{args.lossfn}\"\n",
    "wd=f\"_wd_{str(args.weight_decay).replace('.','_')}\"\n",
    "seed_name=f\"_seed_{args.seed}_\"\n",
    "\n",
    "num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'\n",
    "unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'\n",
    "augment_tag = '' if not augment else f'augment_'\n",
    "\n",
    "m_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "m0_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "\n",
    "model.load_state_dict(torch.load(m_name))\n",
    "model0.load_state_dict(torch.load(m0_name))\n",
    "model_initial.load_state_dict(torch.load(init_checkpoint))\n",
    "\n",
    "teacher = copy.deepcopy(model)\n",
    "student = copy.deepcopy(model)\n",
    "\n",
    "model.cuda()\n",
    "model0.cuda()\n",
    "\n",
    "\n",
    "for p in model.parameters():\n",
    "    p.data0 = p.data.clone()\n",
    "for p in model0.parameters():\n",
    "    p.data0 = p.data.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict['args']=args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_init = ntk_init(init_checkpoint,args.seed)\n",
    "for p in model_init.parameters():\n",
    "    p.data0 = p.data.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "log_dict['dist_Original_Original_init']= distance(model_init,model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.retain_bs = 32\n",
    "args.forget_bs = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(dataset, split=\"train\",confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, batch_size=args.batch_size, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "marked_loader, _, _ = datasets.get_loaders(dataset, split=\"forget\", confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "def replace_loader_dataset(data_loader, dataset, batch_size=args.batch_size, seed=1, shuffle=True):\n",
    "    manual_seed(seed)\n",
    "    loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "    def _init_fn(worker_id):\n",
    "        np.random.seed(int(seed))\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)\n",
    "    \n",
    "forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = forget_dataset.targets < 0\n",
    "forget_dataset.data = forget_dataset.data[marked]\n",
    "forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)\n",
    "\n",
    "retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = retain_dataset.targets >= 0\n",
    "retain_dataset.data = retain_dataset.data[marked]\n",
    "retain_dataset.targets = retain_dataset.targets[marked]\n",
    "retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)\n",
    "\n",
    "assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print (len(forget_loader.dataset))\n",
    "print (len(retain_loader.dataset))\n",
    "print (len(test_loader_full.dataset))\n",
    "print (len(train_loader_full.dataset))\n",
    "from collections import Counter\n",
    "print(dict(Counter(train_loader_full.dataset.targets)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SCRUB Forgetting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.optim = 'sgd'\n",
    "args.gamma = 1\n",
    "args.alpha = 0.01\n",
    "args.beta = 0\n",
    "args.smoothing = 0.5\n",
    "args.msteps = 2\n",
    "args.clip = 0.5\n",
    "args.sstart = 10\n",
    "args.kd_T = 2\n",
    "args.distill = 'kd'\n",
    "\n",
    "args.sgda_epochs = 10\n",
    "args.sgda_learning_rate = 0.0005\n",
    "args.lr_decay_epochs = [7,10,10]\n",
    "args.lr_decay_rate = 0.1\n",
    "args.sgda_weight_decay = 5e-4\n",
    "args.sgda_momentum = 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_t = copy.deepcopy(teacher)\n",
    "model_s = copy.deepcopy(student)\n",
    "#model_s = models.get_model(arch, num_classes=num_classes, filters_percentage=filters).to(args.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py\n",
    "#For SGDA smoothing\n",
    "beta = 0.1\n",
    "def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (\n",
    "    1 - beta) * averaged_model_parameter + beta * model_parameter\n",
    "swa_model = torch.optim.swa_utils.AveragedModel(\n",
    "    model_s, avg_fn=avg_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_list = nn.ModuleList([])\n",
    "module_list.append(model_s)\n",
    "trainable_list = nn.ModuleList([])\n",
    "trainable_list.append(model_s)\n",
    "\n",
    "criterion_cls = nn.CrossEntropyLoss()\n",
    "criterion_div = DistillKL(args.kd_T)\n",
    "criterion_kd = DistillKL(args.kd_T)\n",
    "\n",
    "\n",
    "criterion_list = nn.ModuleList([])\n",
    "criterion_list.append(criterion_cls)    # classification loss\n",
    "criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation\n",
    "criterion_list.append(criterion_kd)     # other knowledge distillation loss\n",
    "\n",
    "# optimizer\n",
    "if args.optim == \"sgd\":\n",
    "    optimizer = optim.SGD(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          momentum=args.sgda_momentum,\n",
    "                          weight_decay=args.sgda_weight_decay)\n",
    "elif args.optim == \"adam\": \n",
    "    optimizer = optim.Adam(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          weight_decay=args.sgda_weight_decay)\n",
    "elif args.optim == \"rmsp\":\n",
    "    optimizer = optim.RMSprop(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          momentum=args.sgda_momentum,\n",
    "                          weight_decay=args.sgda_weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_list.append(model_t)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    module_list.cuda()\n",
    "    criterion_list.cuda()\n",
    "    import torch.backends.cudnn as cudnn\n",
    "    cudnn.benchmark = True\n",
    "    swa_model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = time.time()\n",
    "acc_rs = []\n",
    "acc_fs = []\n",
    "acc_vs = []\n",
    "ic_rs = []\n",
    "ic_vs = []\n",
    "fgt_rs = []\n",
    "fgt_vs = []\n",
    "for epoch in range(1, args.sgda_epochs + 1):\n",
    "\n",
    "    lr = sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "    print(\"==> SCRUB unlearning ...\")\n",
    "    ic_r, fgt_r = interclass_confusion(model_s, retain_loader, class_to_forget, \"SCRUB\")\n",
    "    ic_v, fgt_v = interclass_confusion(model_s, valid_loader_full, class_to_forget, \"SCRUB\")\n",
    "    ic_rs.append(ic_r)\n",
    "    ic_vs.append(ic_v)\n",
    "    fgt_rs.append(fgt_r)\n",
    "    fgt_vs.append(fgt_v)\n",
    "\n",
    "    acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "    acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "    acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "    acc_rs.append(100-acc_r.item())\n",
    "    acc_fs.append(100-acc_f.item())\n",
    "    acc_vs.append(100-acc_v.item())\n",
    "\n",
    "\n",
    "    maximize_loss = 0\n",
    "    if epoch <= args.msteps:\n",
    "        maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, \"maximize\")\n",
    "    train_acc, train_loss = train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, \"minimize\",)\n",
    "    if epoch >= args.sstart:\n",
    "        swa_model.update_parameters(model_s)\n",
    "\n",
    "    \n",
    "    print (\"maximize loss: {:.2f}\\t minimize loss: {:.2f}\\t train_acc: {}\".format(maximize_loss, train_loss, train_acc))\n",
    "ic_r, fgt_r = interclass_confusion(model_s, retain_loader, class_to_forget, \"SCRUB\")\n",
    "ic_v, fgt_v = interclass_confusion(model_s, valid_loader_full, class_to_forget, \"SCRUB\")\n",
    "ic_rs.append(ic_r)\n",
    "ic_vs.append(ic_v)\n",
    "fgt_rs.append(fgt_r)\n",
    "fgt_vs.append(fgt_v)\n",
    "acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "acc_tv, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "acc_rs.append(100-acc_r.item())\n",
    "acc_fs.append(100-acc_f.item())\n",
    "acc_vs.append(100-acc_v.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "fig, ax = plt.subplots(3,1,figsize=(9,15))\n",
    "indices = list(range(0,len(ic_rs)))\n",
    "ax[0].plot(indices, ic_rs, marker='*', color='red', alpha=1, label='ic-retain')\n",
    "ax[0].plot(indices, ic_vs, marker='o', color='blue', alpha=1, label='ic-val')\n",
    "ax[1].plot(indices, acc_rs, marker='*', color='green', alpha=1, label='acc-retain')\n",
    "ax[1].plot(indices, acc_vs, marker='o', color='brown', alpha=1, label='acc-valid')\n",
    "ax[2].plot(indices, acc_fs, marker='^', color='black', alpha=1, label='acc-forget')\n",
    "ax[0].legend(prop={'size': 14})\n",
    "ax[1].legend(prop={'size': 14})\n",
    "plt.tick_params(labelsize=12)\n",
    "ax[0].set_xlabel('epoch',size=14)\n",
    "ax[0].set_ylabel('error',size=14)\n",
    "ax[1].set_xlabel('epoch',size=14)\n",
    "ax[1].set_ylabel('error',size=14)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ft = copy.deepcopy(model)\n",
    "args.ft_bs = 64\n",
    "retain_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.ft_bs, shuffle=True)  \n",
    "finetune(model_ft, retain_loader, epochs=10, quiet=True, lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NegGrad+"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.ng_alpha = 0.9999\n",
    "args.ng_epochs = 10\n",
    "args.ng_lr = 0.01\n",
    "model_ng = copy.deepcopy(model)\n",
    "args.ng_bs = 128\n",
    "retain_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.ng_bs, shuffle=True)\n",
    "negative_grad(model_ng, retain_loader, forget_loader, alpha=args.ng_alpha, epochs=args.ng_epochs, quiet=True, lr=args.ng_lr, args=args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Catastrophic Forgetting k layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.lr_decay_epochs = [10,15,20]\n",
    "args.cfk_lr = 0.01\n",
    "args.cfk_epochs = 10\n",
    "args.cfk_bs = 128\n",
    "retain_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.cfk_bs, shuffle=True)\n",
    "\n",
    "model_cfk = copy.deepcopy(model)\n",
    "\n",
    "for param in model_cfk.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "if args.model == 'allcnn':\n",
    "    layers = [9]\n",
    "    for k in layers:\n",
    "        for param in model_cfk.features[k].parameters():\n",
    "            param.requires_grad_(True)\n",
    "    \n",
    "elif args.model == \"resnet\":\n",
    "    for param in model_cfk.layer4.parameters():\n",
    "        param.requires_grad_(True)\n",
    "\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n",
    "\n",
    "\n",
    "fk_fientune(model_cfk, retain_loader, args=args, epochs=args.cfk_epochs, quiet=True, lr=args.cfk_lr)\n",
    "cfk_time = t2-t1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exact Unlearning k layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.lr_decay_epochs = [10,15,20]\n",
    "args.euk_lr = 0.01\n",
    "args.euk_epochs = training_epochs\n",
    "args.euk_bs = 64\n",
    "retain_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.euk_bs, shuffle=True)\n",
    "model_euk = copy.deepcopy(model)\n",
    "\n",
    "for param in model_euk.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "if args.model == 'allcnn':\n",
    "    with torch.no_grad():\n",
    "        for k in layers:\n",
    "            for i in range(0,3):\n",
    "                try:\n",
    "                    model_euk.features[k][i].weight.copy_(model_initial.features[k][i].weight)\n",
    "                except:\n",
    "                    print (\"block {}, layer {} does not have weights\".format(k,i))\n",
    "                try:\n",
    "                    model_euk.features[k][i].bias.copy_(model_initial.features[k][i].bias)\n",
    "                except:\n",
    "                    print (\"block {}, layer {} does not have bias\".format(k,i))\n",
    "        model_euk.classifier[0].weight.copy_(model_initial.classifier[0].weight)\n",
    "        model_euk.classifier[0].bias.copy_(model_initial.classifier[0].bias)\n",
    "    \n",
    "    for k in layers:\n",
    "        for param in model_euk.features[k].parameters():\n",
    "            param.requires_grad_(True)\n",
    "    \n",
    "elif args.model == \"resnet\":\n",
    "    with torch.no_grad():\n",
    "        for i in range(0,2):\n",
    "            try:\n",
    "                model_euk.layer4[i].bn1.weight.copy_(model_initial.layer4[i].bn1.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].bn1.bias.copy_(model_initial.layer4[i].bn1.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv1.weight.copy_(model_initial.layer4[i].conv1.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv1.bias.copy_(model_initial.layer4[i].conv1.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "            try:\n",
    "                model_euk.layer4[i].bn2.weight.copy_(model_initial.layer4[i].bn2.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].bn2.bias.copy_(model_initial.layer4[i].bn2.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv2.weight.copy_(model_initial.layer4[i].conv2.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv2.bias.copy_(model_initial.layer4[i].conv2.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "        model_euk.layer4[0].shortcut[0].weight.copy_(model_initial.layer4[0].shortcut[0].weight)\n",
    "        \n",
    "    for param in model_euk.layer4.parameters():\n",
    "        param.requires_grad_(True)\n",
    "\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n",
    "fk_fientune(model_euk, retain_loader, epochs=args.euk_epochs, quiet=True, lr=args.euk_lr, args=args)\n",
    "euk_time = t2-t1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Readouts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(dataset, split=\"train\",confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, batch_size=args.batch_size, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "marked_loader, _, _ = datasets.get_loaders(dataset, split=\"forget\", confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "def replace_loader_dataset(data_loader, dataset, batch_size=args.batch_size, seed=1, shuffle=True):\n",
    "    manual_seed(seed)\n",
    "    loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "    def _init_fn(worker_id):\n",
    "        np.random.seed(int(seed))\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)\n",
    "    \n",
    "forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = forget_dataset.targets < 0\n",
    "forget_dataset.data = forget_dataset.data[marked]\n",
    "forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)\n",
    "\n",
    "retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = retain_dataset.targets >= 0\n",
    "retain_dataset.data = retain_dataset.data[marked]\n",
    "retain_dataset.targets = retain_dataset.targets[marked]\n",
    "retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)\n",
    "\n",
    "assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "try: readouts\n",
    "except: readouts = {}\n",
    "\n",
    "#_,_=activations_predictions(copy.deepcopy(model),forget_loader,'Original_Model_D_f')\n",
    "#thresh=log_dict['Original_Model_D_f_loss']+1e-5\n",
    "#print(thresh)\n",
    "readouts[\"a\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model),thresh=None,name='Original')\n",
    "readouts[\"b\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model0),thresh=None,name='Retrain')\n",
    "readouts[\"c\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_ft),thresh=None,name='Finetune')\n",
    "readouts[\"d\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_ng),thresh=None,name='NegGrad')\n",
    "readouts[\"e\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_cfk),thresh=None,name='CF-k')\n",
    "readouts[\"f\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_euk),thresh=None,name='EU-k')\n",
    "readouts[\"h\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_s),thresh=None,name='SCRUB')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
