{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "import variational\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "from itertools import cycle\n",
    "import os\n",
    "import time\n",
    "import math\n",
    "import pandas as pd\n",
    "from collections import OrderedDict\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import confusion_matrix\n",
    "    \n",
    "import copy\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from typing import List\n",
    "import itertools\n",
    "from tqdm.autonotebook import tqdm\n",
    "from models import *\n",
    "import models\n",
    "from logger import *\n",
    "import wandb\n",
    "\n",
    "from thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate\n",
    "from thirdparty.repdistiller.distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss\n",
    "from thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss\n",
    "\n",
    "from thirdparty.repdistiller.helper.loops import train_distill, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill, validate\n",
    "from thirdparty.repdistiller.helper.pretrain import init"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg\n",
    "\n",
    "def activations_predictions(model,dataloader,name,wandb=None):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    metrics,activations,predictions=get_metrics(model,dataloader,criterion,True)\n",
    "    print(f\"{name} -> Loss:{np.round(metrics['loss'],3)}, Error:{metrics['error']}\")\n",
    "    log_dict[f\"{name}_loss\"]=metrics['loss']\n",
    "    log_dict[f\"{name}_error\"]=metrics['error']\n",
    "\n",
    "    if wandb is not None:\n",
    "        wandb.log({f\"{name}_error\": metrics['error']})\n",
    "    return activations,predictions\n",
    "\n",
    "def predictions_distance(l1,l2,name):\n",
    "    dist = np.sum(np.abs(l1-l2))\n",
    "    print(f\"Predictions Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_predictions\"]=dist\n",
    "\n",
    "def activations_distance(a1,a2,name):\n",
    "    dist = np.linalg.norm(a1-a2,ord=1,axis=1).mean()\n",
    "    print(f\"Activations Distance {name} -> {dist}\")\n",
    "    log_dict[f\"{name}_activations\"]=dist\n",
    "\n",
    "def interclass_confusion(model, dataloader, class_to_forget, name):\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=128, shuffle=False)\n",
    "    model.eval()\n",
    "    reals=[]\n",
    "    predicts=[]\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device) \n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        probs = torch.nn.functional.softmax(output, dim=1)\n",
    "        predict = np.argmax(probs.cpu().detach().numpy(),axis=1)\n",
    "        reals = reals + list(target.cpu().detach().numpy())\n",
    "        predicts = predicts + list(predict)\n",
    "    \n",
    "    classes = [0, 1, 2, 3, 4]\n",
    "    cm = confusion_matrix(reals, predicts, labels=classes)\n",
    "    counts = 0\n",
    "    for i in range(len(cm)):\n",
    "        if i != class_to_forget[0]:\n",
    "            counts += cm[class_to_forget[0]][i]\n",
    "        if i != class_to_forget[1]:\n",
    "            counts += cm[class_to_forget[1]][i]\n",
    "    \n",
    "    ic_err = counts / (np.sum(cm[class_to_forget[0]]) + np.sum(cm[class_to_forget[1]]))\n",
    "    fgt = cm[class_to_forget[0]][class_to_forget[1]] + cm[class_to_forget[1]][class_to_forget[0]]\n",
    "    #print (cm)\n",
    "    return ic_err, fgt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper and Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "def get_metrics(model,dataloader,criterion,samples_correctness=False,use_bn=False,delta_w=None,scrub_act=False):\n",
    "    activations=[]\n",
    "    predictions=[]\n",
    "    if use_bn:\n",
    "        model.train()\n",
    "        dataloader = torch.utils.data.DataLoader(retain_loader.dataset, batch_size=128, shuffle=True)\n",
    "        for i in range(10):\n",
    "            for batch_idx, (data, target) in enumerate(dataloader):\n",
    "                data, target = data.to(args.device), target.to(args.device)            \n",
    "                output = model(data)\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()\n",
    "    mult = 0.5 if args.lossfn=='mse' else 1\n",
    "    for batch_idx, (data, target) in enumerate(dataloader):\n",
    "        data, target = data.to(args.device), target.to(args.device)            \n",
    "        if args.lossfn=='mse':\n",
    "            target=(2*target-1)\n",
    "            target = target.type(torch.cuda.FloatTensor).unsqueeze(1)\n",
    "        if 'mnist' in args.dataset:\n",
    "            data=data.view(data.shape[0],-1)\n",
    "        output = model(data)\n",
    "        if scrub_act:\n",
    "            G = []\n",
    "            for cls in range(num_classes):\n",
    "                grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                grads = torch.cat([g.view(-1) for g in grads])\n",
    "                G.append(grads)\n",
    "            grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "            G = torch.stack(G).pow(2)\n",
    "            delta_f = torch.matmul(G,delta_w)\n",
    "            output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "\n",
    "        loss = mult*criterion(output, target)\n",
    "        if samples_correctness:\n",
    "            activations.append(torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().squeeze())\n",
    "            predictions.append(get_error(output,target))\n",
    "        metrics.update(n=data.size(0), loss=loss.item(), error=get_error(output, target))\n",
    "    if samples_correctness:\n",
    "        return metrics.avg,np.stack(activations),np.array(predictions)\n",
    "    else:\n",
    "        return metrics.avg\n",
    "\n",
    "def l2_penalty(model,model_init,weight_decay):\n",
    "    l2_loss = 0\n",
    "    for (k,p),(k_init,p_init) in zip(model.named_parameters(),model_init.named_parameters()):\n",
    "        if p.requires_grad:\n",
    "            l2_loss += (p-p_init).pow(2).sum()\n",
    "    l2_loss *= (weight_decay/2.)\n",
    "    return l2_loss\n",
    "\n",
    "def run_train_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    negative_gradient=False, negative_multiplier=-1, random_labels=False,\n",
    "                    quiet=False,delta_w=None,scrub_act=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, batch in enumerate(tqdm(data_loader, leave=False)):\n",
    "            batch = [tensor.to(next(model.parameters()).device) for tensor in batch]\n",
    "            input, target = batch\n",
    "            output = model(input)\n",
    "            if split=='test' and scrub_act:\n",
    "                G = []\n",
    "                for cls in range(num_classes):\n",
    "                    grads = torch.autograd.grad(output[0,cls],model.parameters(),retain_graph=True)\n",
    "                    grads = torch.cat([g.view(-1) for g in grads])\n",
    "                    G.append(grads)\n",
    "                grads = torch.autograd.grad(output_sf[0,cls],model_scrubf.parameters(),retain_graph=False)\n",
    "                G = torch.stack(G).pow(2)\n",
    "                delta_f = torch.matmul(G,delta_w)\n",
    "                output += delta_f.sqrt()*torch.empty_like(delta_f).normal_()\n",
    "            loss = loss_fn(output, target) + l2_penalty(model,model_init,args.weight_decay)\n",
    "            metrics.update(n=input.size(0), loss=loss_fn(output,target).item(), error=get_error(output, target))\n",
    "            \n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg\n",
    "\n",
    "def run_neggrad_epoch(model: nn.Module, model_init, data_loader: torch.utils.data.DataLoader, \n",
    "                    forget_loader: torch.utils.data.DataLoader,\n",
    "                    alpha: float,\n",
    "                    loss_fn: nn.Module,\n",
    "                    optimizer: torch.optim.SGD, split: str, epoch: int, ignore_index=None,\n",
    "                    quiet=False):\n",
    "    model.eval()\n",
    "    metrics = AverageMeter()    \n",
    "    num_labels = data_loader.dataset.targets.max().item() + 1\n",
    "    \n",
    "    with torch.set_grad_enabled(split != 'test'):\n",
    "        for idx, (batch_retain,batch_forget) in enumerate(tqdm(zip(data_loader,cycle(forget_loader)), leave=False)):\n",
    "            batch_retain = [tensor.to(next(model.parameters()).device) for tensor in batch_retain]\n",
    "            batch_forget = [tensor.to(next(model.parameters()).device) for tensor in batch_forget]\n",
    "            input_r, target_r = batch_retain\n",
    "            input_f, target_f = batch_forget\n",
    "            output_r = model(input_r)\n",
    "            output_f = model(input_f)\n",
    "            loss = alpha*(loss_fn(output_r, target_r) + l2_penalty(model,model_init,args.weight_decay)) - (1-alpha)*loss_fn(output_f, target_f)\n",
    "            metrics.update(n=input_r.size(0), loss=loss_fn(output_r,target_r).item(), error=get_error(output_r, target_r))\n",
    "            if split != 'test':\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "    if not quiet:\n",
    "        log_metrics(split, metrics, epoch)\n",
    "    return metrics.avg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, data_loader):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    model_init=copy.deepcopy(model)\n",
    "    return run_train_epoch(model, model_init, data_loader, loss_fn, optimizer=None, split='test', epoch=epoch, ignore_index=None, quiet=True)\n",
    "\n",
    "def readout_retrain(model, data_loader, test_loader, lr=0.1, epochs=500, threshold=0.01, quiet=True):\n",
    "    torch.manual_seed(seed)\n",
    "    model = copy.deepcopy(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    sampler = torch.utils.data.RandomSampler(data_loader.dataset, replacement=True, num_samples=500)\n",
    "    data_loader_small = torch.utils.data.DataLoader(data_loader.dataset, batch_size=data_loader.batch_size, sampler=sampler, num_workers=data_loader.num_workers)\n",
    "    metrics = []\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        metrics.append(run_train_epoch(model, model_init, test_loader, loss_fn, optimizer, split='test', epoch=epoch, ignore_index=None, quiet=quiet))\n",
    "        if metrics[-1]['loss'] <= threshold:\n",
    "            break\n",
    "        run_train_epoch(model, model_init, data_loader_small, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "    return epoch, metrics\n",
    "\n",
    "def extract_retrain_time(metrics, threshold=0.1):\n",
    "    losses = np.array([m['loss'] for m in metrics])\n",
    "    return np.argmax(losses < threshold)\n",
    "\n",
    "def all_readouts(test_loader, retain_loader, forget_loader, model, wandb=None,thresh=0.1,name='method'):\n",
    "    #train_loader = torch.utils.data.DataLoader(train_loader_full.dataset, batch_size=args.batch_size, shuffle=True)\n",
    "    #retrain_time, _ = readout_retrain(model, train_loader, forget_loader, epochs=100, lr=0.1, threshold=thresh)\n",
    "    test_error = test(model, test_loader)['error']\n",
    "    forget_error = test(model, forget_loader)['error']\n",
    "    retain_error = test(model, retain_loader)['error']\n",
    "    ic_err_test, fgt_test = interclass_confusion(model, test_loader, class_to_forget, name)\n",
    "    ic_err_retain, fgt_retain = interclass_confusion(model, retain_loader, class_to_forget, name)\n",
    "    print(f\"{name} ->\"\n",
    "          f\"\\ttest: {test_error:.2%}\"\n",
    "          f\"\\tForget: {forget_error:.2%}\\tRetain: {retain_error:.2%}\"\n",
    "          f\"\\tIC-test: {ic_err_test}\\tfgt-test: {fgt_test}\"\n",
    "          f\"\\tIC-retain: {ic_err_retain}\\tfgt-retain: {fgt_retain}\")\n",
    "    #log_dict[f\"{name}_retrain_time\"]=retrain_time+1\n",
    "    if wandb is not None:\n",
    "        #wandb.log({f\"{name}_RLT\": retrain_time + 1})\n",
    "        wandb.log({f\"{name}_test_error\": test_error})\n",
    "        wandb.log({f\"{name}_retain_error\": retain_error})\n",
    "        wandb.log({f\"{name}_forget_error\": forget_error})\n",
    "        wandb.log({f\"{name}_IC_err_test\": ic_err_test})\n",
    "        wandb.log({f\"{name}_IC_err_retain\": ic_err_retain})\n",
    "        wandb.log({f\"{name}_fgt_test\": fgt_test})\n",
    "        wandb.log({f\"{name}_fgt_retain\": fgt_retain})\n",
    "    return(dict(test_error=test_error, forget_error=forget_error, retain_error=retain_error))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Forgetting helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def finetune(model: nn.Module, data_loader: torch.utils.data.DataLoader, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        run_train_epoch(model, model_init, data_loader, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "        #train_vanilla(epoch, data_loader, model, loss_fn, optimizer, args)\n",
    "\n",
    "def negative_grad(model: nn.Module, data_loader: torch.utils.data.DataLoader, forget_loader: torch.utils.data.DataLoader, alpha: float, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        run_neggrad_epoch(model, model_init, data_loader, forget_loader, alpha, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "        #train_negrad(epoch, data_loader, forget_loader, model, loss_fn, optimizer,  alpha)\n",
    "\n",
    "def fk_fientune(model: nn.Module, data_loader: torch.utils.data.DataLoader, args, lr=0.01, epochs=10, quiet=False):\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.0)\n",
    "    model_init=copy.deepcopy(model)\n",
    "    for epoch in range(epochs):\n",
    "        sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "        run_train_epoch(model, model_init, data_loader, loss_fn, optimizer, split='train', epoch=epoch, ignore_index=None, quiet=quiet)\n",
    "        #train_negrad(epoch, data_loader, forget_loader, model, loss_fn, optimizer,  alpha)\n",
    "\n",
    "def pdb():\n",
    "    import pdb\n",
    "    pdb.set_trace\n",
    "    \n",
    "def parameter_count(model):\n",
    "    count=0\n",
    "    for p in model.parameters():\n",
    "        count+=np.prod(np.array(list(p.shape)))\n",
    "    print(f'Total Number of Parameters: {count}')\n",
    "    \n",
    "def vectorize_params(model):\n",
    "    param = []\n",
    "    for p in model.parameters():\n",
    "        param.append(p.data.view(-1).cpu().numpy())\n",
    "    return np.concatenate(param)\n",
    "\n",
    "def print_param_shape(model):\n",
    "    for k,p in model.named_parameters():\n",
    "        print(k,p.shape)\n",
    "\n",
    "def distance(model,model0):\n",
    "    distance=0\n",
    "    normalization=0\n",
    "    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):\n",
    "        space='  ' if 'bias' in k else ''\n",
    "        current_dist=(p.data0-p0.data0).pow(2).sum().item()\n",
    "        current_norm=p.data0.pow(2).sum().item()\n",
    "        distance+=current_dist\n",
    "        normalization+=current_norm\n",
    "    print(f'Distance: {np.sqrt(distance)}')\n",
    "    print(f'Normalized Distance: {1.0*np.sqrt(distance/normalization)}')\n",
    "    return 1.0*np.sqrt(distance/normalization)\n",
    "\n",
    "def ntk_init(resume,seed=1):\n",
    "    manual_seed(seed)\n",
    "    model_init = models.get_model(arch, num_classes=num_classes, filters_percentage=filters).to(args.device)\n",
    "    model_init.load_state_dict(torch.load(resume))\n",
    "    return model_init"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pre-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#%run main.py --dataset lacuna100 --dataroot=data/lacuna100/ --model allcnn --filters 1.0 --lr 0.1 --lossfn ce --num-classes 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py --dataset small_lacuna5 --model allcnn --dataroot=data/lacuna10/ --filters 1.0 --lr 0.01 \\\n",
    "--resume checkpoints/lacuna100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "--weight-decay 0.1 --batch-size 128 --epochs 31 --seed 3 \\\n",
    "--split train --confuse-mode --forget-class 0,1 --num-to-forget 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py --dataset small_lacuna5 --model allcnn --dataroot=data/lacuna10/ --filters 1.0 --lr 0.01 \\\n",
    "--resume checkpoints/lacuna100_allcnn_1_0_forget_None_lr_0_1_bs_128_ls_ce_wd_0_0005_seed_1_30.pt --disable-bn \\\n",
    "--weight-decay 0.1 --batch-size 128 --epochs 31 --seed 3 \\\n",
    "--split forget --confuse-mode --forget-class 0,1 --num-to-forget 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict={}\n",
    "training_epochs=30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict['epoch']=training_epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameter_count(copy.deepcopy(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loads checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "model0 = copy.deepcopy(model)\n",
    "model_initial = copy.deepcopy(model)\n",
    "\n",
    "arch = args.model \n",
    "filters=args.filters\n",
    "arch_filters = arch +'_'+ str(filters).replace('.','_')\n",
    "augment = False\n",
    "dataset = args.dataset\n",
    "class_to_forget = args.forget_class\n",
    "init_checkpoint = f\"checkpoints/{args.name}_init.pt\"\n",
    "num_classes=args.num_classes\n",
    "num_to_forget = args.num_to_forget\n",
    "num_total = len(train_loader.dataset)\n",
    "num_to_retain = num_total - num_to_forget\n",
    "seed = args.seed\n",
    "unfreeze_start = None\n",
    "\n",
    "learningrate=f\"lr_{str(args.lr).replace('.','_')}\"\n",
    "batch_size=f\"_bs_{str(args.batch_size)}\"\n",
    "lossfn=f\"_ls_{args.lossfn}\"\n",
    "wd=f\"_wd_{str(args.weight_decay).replace('.','_')}\"\n",
    "seed_name=f\"_seed_{args.seed}_\"\n",
    "\n",
    "num_tag = '' if num_to_forget is None else f'_num_{num_to_forget}'\n",
    "unfreeze_tag = '_' if unfreeze_start is None else f'_unfreeze_from_{unfreeze_start}_'\n",
    "augment_tag = '' if not augment else f'augment_'\n",
    "\n",
    "m_name = f'checkpoints/{dataset}_{arch_filters}_forget_None{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "m0_name = f'checkpoints/{dataset}_{arch_filters}_forget_{class_to_forget}{num_tag}{unfreeze_tag}{augment_tag}{learningrate}{batch_size}{lossfn}{wd}{seed_name}{training_epochs}.pt'\n",
    "\n",
    "model.load_state_dict(torch.load(m_name))\n",
    "model0.load_state_dict(torch.load(m0_name))\n",
    "model_initial.load_state_dict(torch.load(init_checkpoint))\n",
    "\n",
    "teacher = copy.deepcopy(model)\n",
    "student = copy.deepcopy(model)\n",
    "\n",
    "model.cuda()\n",
    "model0.cuda()\n",
    "\n",
    "\n",
    "for p in model.parameters():\n",
    "    p.data0 = p.data.clone()\n",
    "for p in model0.parameters():\n",
    "    p.data0 = p.data.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dict['args']=args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_init = ntk_init(init_checkpoint,args.seed)\n",
    "for p in model_init.parameters():\n",
    "    p.data0 = p.data.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "log_dict['dist_Original_Original_init']= distance(model_init,model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.retain_bs = 32\n",
    "args.forget_bs = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(dataset, split=\"train\",confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, batch_size=args.batch_size, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "marked_loader, _, _ = datasets.get_loaders(dataset, split=\"forget\", confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "def replace_loader_dataset(data_loader, dataset, batch_size=args.batch_size, seed=1, shuffle=True):\n",
    "    manual_seed(seed)\n",
    "    loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "    def _init_fn(worker_id):\n",
    "        np.random.seed(int(seed))\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)\n",
    "    \n",
    "forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = forget_dataset.targets < 0\n",
    "forget_dataset.data = forget_dataset.data[marked]\n",
    "forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)\n",
    "\n",
    "retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = retain_dataset.targets >= 0\n",
    "retain_dataset.data = retain_dataset.data[marked]\n",
    "retain_dataset.targets = retain_dataset.targets[marked]\n",
    "retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)\n",
    "\n",
    "assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print (len(forget_loader.dataset))\n",
    "print (len(retain_loader.dataset))\n",
    "print (len(test_loader_full.dataset))\n",
    "print (len(train_loader_full.dataset))\n",
    "from collections import Counter\n",
    "print(dict(Counter(train_loader_full.dataset.targets)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SCRUB Forgetting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.optim = 'adam'\n",
    "args.gamma = 99\n",
    "args.alpha = 0.01\n",
    "args.beta = 0\n",
    "args.smoothing = 0.5\n",
    "args.msteps = 2\n",
    "args.clip = 0.5\n",
    "args.sstart = 10\n",
    "args.kd_T = 2\n",
    "args.distill = 'kd'\n",
    "\n",
    "args.sgda_epochs = 10\n",
    "args.sgda_learning_rate = 0.0005\n",
    "args.lr_decay_epochs = [5,6,9]\n",
    "args.lr_decay_rate = 0.1\n",
    "args.sgda_weight_decay = 0.1#5e-4\n",
    "args.sgda_momentum = 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_t = copy.deepcopy(teacher)\n",
    "model_s = copy.deepcopy(student)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py\n",
    "#For SGDA smoothing\n",
    "beta = 0.1\n",
    "def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (\n",
    "    1 - beta) * averaged_model_parameter + beta * model_parameter\n",
    "swa_model = torch.optim.swa_utils.AveragedModel(\n",
    "    model_s, avg_fn=avg_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_list = nn.ModuleList([])\n",
    "module_list.append(model_s)\n",
    "trainable_list = nn.ModuleList([])\n",
    "trainable_list.append(model_s)\n",
    "\n",
    "criterion_cls = nn.CrossEntropyLoss()\n",
    "criterion_div = DistillKL(args.kd_T)\n",
    "criterion_kd = DistillKL(args.kd_T)\n",
    "\n",
    "\n",
    "criterion_list = nn.ModuleList([])\n",
    "criterion_list.append(criterion_cls)    # classification loss\n",
    "criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation\n",
    "criterion_list.append(criterion_kd)     # other knowledge distillation loss\n",
    "\n",
    "# optimizer\n",
    "if args.optim == \"sgd\":\n",
    "    optimizer = optim.SGD(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          momentum=args.sgda_momentum,\n",
    "                          weight_decay=args.sgda_weight_decay)\n",
    "elif args.optim == \"adam\": \n",
    "    optimizer = optim.Adam(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          weight_decay=args.sgda_weight_decay)\n",
    "elif args.optim == \"rmsp\":\n",
    "    optimizer = optim.RMSprop(trainable_list.parameters(),\n",
    "                          lr=args.sgda_learning_rate,\n",
    "                          momentum=args.sgda_momentum,\n",
    "                          weight_decay=args.sgda_weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_list.append(model_t)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    module_list.cuda()\n",
    "    criterion_list.cuda()\n",
    "    import torch.backends.cudnn as cudnn\n",
    "    cudnn.benchmark = True\n",
    "    swa_model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = time.time()\n",
    "acc_rs = []\n",
    "acc_fs = []\n",
    "acc_vs = []\n",
    "ic_rs = []\n",
    "ic_vs = []\n",
    "fgt_rs = []\n",
    "fgt_vs = []\n",
    "for epoch in range(1, args.sgda_epochs + 1):\n",
    "\n",
    "    lr = sgda_adjust_learning_rate(epoch, args, optimizer)\n",
    "    print(\"==> SCRUB unlearning ...\")\n",
    "    ic_r, fgt_r = interclass_confusion(model_s, retain_loader, class_to_forget, \"SCRUB\")\n",
    "    ic_v, fgt_v = interclass_confusion(model_s, valid_loader_full, class_to_forget, \"SCRUB\")\n",
    "    ic_rs.append(ic_r)\n",
    "    ic_vs.append(ic_v)\n",
    "    fgt_rs.append(fgt_r)\n",
    "    fgt_vs.append(fgt_v)\n",
    "\n",
    "    acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "    acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "    acc_v, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "    acc_rs.append(100-acc_r.item())\n",
    "    acc_fs.append(100-acc_f.item())\n",
    "    acc_vs.append(100-acc_v.item())\n",
    "\n",
    "\n",
    "    maximize_loss = 0\n",
    "    if epoch <= args.msteps:\n",
    "        maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, \"maximize\")\n",
    "    train_acc, train_loss = train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, \"minimize\",)\n",
    "    if epoch >= args.sstart:\n",
    "        swa_model.update_parameters(model_s)   \n",
    "    print (\"maximize loss: {:.2f}\\t minimize loss: {:.2f}\\t train_acc: {}\".format(maximize_loss, train_loss, train_acc))\n",
    "\n",
    "ic_r, fgt_r = interclass_confusion(model_s, retain_loader, class_to_forget, \"SCRUB\")\n",
    "ic_v, fgt_v = interclass_confusion(model_s, valid_loader_full, class_to_forget, \"SCRUB\")\n",
    "ic_rs.append(ic_r)\n",
    "ic_vs.append(ic_v)\n",
    "fgt_rs.append(fgt_r)\n",
    "fgt_vs.append(fgt_v)\n",
    "acc_r, acc5_r, loss_r = validate(retain_loader, model_s, criterion_cls, args, True)\n",
    "acc_f, acc5_f, loss_f = validate(forget_loader, model_s, criterion_cls, args, True)\n",
    "acc_tv, acc5_v, loss_v = validate(valid_loader_full, model_s, criterion_cls, args, True)\n",
    "acc_rs.append(100-acc_r.item())\n",
    "acc_fs.append(100-acc_f.item())\n",
    "acc_vs.append(100-acc_v.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "fig, ax = plt.subplots(3,1,figsize=(9,15))\n",
    "indices = list(range(0,len(ic_rs)))\n",
    "ax[0].plot(indices, ic_rs, marker='*', color='red', alpha=1, label='ic-retain')\n",
    "ax[0].plot(indices, ic_vs, marker='o', color='blue', alpha=1, label='ic-val')\n",
    "ax[1].plot(indices, acc_rs, marker='*', color='green', alpha=1, label='acc-retain')\n",
    "ax[1].plot(indices, acc_vs, marker='o', color='brown', alpha=1, label='acc-valid')\n",
    "ax[2].plot(indices, acc_fs, marker='^', color='black', alpha=1, label='acc-forget')\n",
    "ax[0].legend(prop={'size': 14})\n",
    "ax[1].legend(prop={'size': 14})\n",
    "plt.tick_params(labelsize=12)\n",
    "ax[0].set_xlabel('epoch',size=14)\n",
    "ax[0].set_ylabel('error',size=14)\n",
    "ax[1].set_xlabel('epoch',size=14)\n",
    "ax[1].set_ylabel('error',size=14)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NTK based Forgetting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NTK Update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def delta_w_utils(model_init,dataloader,name='complete'):\n",
    "    model_init.eval()\n",
    "    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)\n",
    "    G_list = []\n",
    "    f0_minus_y = []\n",
    "    for idx, batch in enumerate(dataloader):#(tqdm(dataloader,leave=False)):\n",
    "        batch = [tensor.to(next(model_init.parameters()).device) for tensor in batch]\n",
    "        input, target = batch\n",
    "        if 'mnist' in args.dataset:\n",
    "            input = input.view(input.shape[0],-1)\n",
    "        target = target.cpu().detach().numpy()\n",
    "        output = model_init(input)\n",
    "        G_sample=[]\n",
    "        for cls in range(num_classes):\n",
    "            grads = torch.autograd.grad(output[0,cls],model_init.parameters(),retain_graph=True)\n",
    "            grads = np.concatenate([g.view(-1).cpu().numpy() for g in grads])\n",
    "            G_sample.append(grads)\n",
    "            G_list.append(grads)\n",
    "        if args.lossfn=='mse':\n",
    "            p = output.cpu().detach().numpy().transpose()\n",
    "            #loss_hess = np.eye(len(p))\n",
    "            target = 2*target-1\n",
    "            f0_y_update = p-target\n",
    "        elif args.lossfn=='ce':\n",
    "            p = torch.nn.functional.softmax(output,dim=1).cpu().detach().numpy().transpose()\n",
    "            p[target]-=1\n",
    "            f0_y_update = copy.deepcopy(p)\n",
    "        f0_minus_y.append(f0_y_update)\n",
    "    return np.stack(G_list).transpose(),np.vstack(f0_minus_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Jacobians and Hessians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sys import getsizeof"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ntk_time = 0\n",
    "model_init = ntk_init(init_checkpoint,args.seed)\n",
    "t1 = time.time()\n",
    "G_r,f0_minus_y_r = delta_w_utils(copy.deepcopy(model),retain_loader,'complete')\n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "\n",
    "np.save('NTK_data/G_r.npy',G_r)\n",
    "np.save('NTK_data/f0_minus_y_r.npy',f0_minus_y_r)\n",
    "del G_r, f0_minus_y_r\n",
    "\n",
    "model_init = ntk_init(init_checkpoint,args.seed)\n",
    "t1 = time.time()\n",
    "G_f,f0_minus_y_f = delta_w_utils(copy.deepcopy(model),forget_loader,'retain') \n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "\n",
    "np.save('NTK_data/G_f.npy',G_f)\n",
    "np.save('NTK_data/f0_minus_y_f.npy',f0_minus_y_f)\n",
    "del G_f, f0_minus_y_f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G_r = np.load('NTK_data/G_r.npy')\n",
    "G_f = np.load('NTK_data/G_f.npy')\n",
    "G = np.concatenate([G_r,G_f],axis=1)\n",
    "\n",
    "np.save('NTK_data/G.npy',G)\n",
    "del G, G_f, G_r\n",
    "\n",
    "f0_minus_y_r = np.load('NTK_data/f0_minus_y_r.npy')\n",
    "f0_minus_y_f = np.load('NTK_data/f0_minus_y_f.npy')\n",
    "f0_minus_y = np.concatenate([f0_minus_y_r,f0_minus_y_f])\n",
    "\n",
    "np.save('NTK_data/f0_minus_y.npy',f0_minus_y)\n",
    "del f0_minus_y, f0_minus_y_r, f0_minus_y_f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This only requires access to the gradients and the initialization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### w_lin(D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G = np.load('NTK_data/G.npy')\n",
    "t1 = time.time()\n",
    "theta = G.transpose().dot(G) + num_total*args.weight_decay*np.eye(G.shape[1])\n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "del G\n",
    "\n",
    "t1 = time.time()\n",
    "theta_inv = np.linalg.inv(theta)\n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "\n",
    "np.save('NTK_data/theta.npy',theta)\n",
    "del theta\n",
    "\n",
    "G = np.load('NTK_data/G.npy')\n",
    "f0_minus_y = np.load('NTK_data/f0_minus_y.npy')\n",
    "t1 = time.time()\n",
    "w_complete = -G.dot(theta_inv.dot(f0_minus_y))\n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "\n",
    "np.save('NTK_data/theta_inv.npy',theta_inv)\n",
    "np.save('NTK_data/w_complete.npy',w_complete)\n",
    "\n",
    "\n",
    "del G, f0_minus_y, theta_inv, w_complete "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### w_lin(D_r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G_r = np.load('NTK_data/G_r.npy')\n",
    "t1 = time.time()\n",
    "theta_r = G_r.transpose().dot(G_r) + num_to_retain*args.weight_decay*np.eye(G_r.shape[1])\n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "\n",
    "del G_r\n",
    "\n",
    "t1 = time.time()\n",
    "theta_r_inv = np.linalg.inv(theta_r)\n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "\n",
    "\n",
    "np.save('NTK_data/theta_r.npy',theta_r)\n",
    "del theta_r\n",
    "\n",
    "G_r = np.load('NTK_data/G_r.npy')\n",
    "f0_minus_y_r = np.load('NTK_data/f0_minus_y_r.npy')\n",
    "t1 = time.time()\n",
    "w_retain = -G_r.dot(theta_r_inv.dot(f0_minus_y_r))\n",
    "t2 = time.time()\n",
    "ntk_time += t2-t1\n",
    "\n",
    "np.save('NTK_data/theta_r_inv.npy',theta_r_inv)\n",
    "np.save('NTK_data/w_retain.npy',w_retain)\n",
    "\n",
    "\n",
    "del G_r, f0_minus_y_r, theta_r_inv, w_retain "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Scrubbing Direction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#### Scrubbing Direction\n",
    "w_complete = np.load('NTK_data/w_complete.npy')\n",
    "w_retain = np.load('NTK_data/w_retain.npy')\n",
    "delta_w = (w_retain-w_complete).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_w_copy = copy.deepcopy(delta_w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Actual Change in Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_w_actual = vectorize_params(model0)-vectorize_params(model)\n",
    "\n",
    "print(f'Actual Norm-: {np.linalg.norm(delta_w_actual)}')\n",
    "print(f'Predtn Norm-: {np.linalg.norm(delta_w)}')\n",
    "scale_ratio = np.linalg.norm(delta_w_actual)/np.linalg.norm(delta_w)\n",
    "print('Actual Scale: {}'.format(scale_ratio))\n",
    "log_dict['actual_scale_ratio']=scale_ratio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Trapezium Trick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_pred_error = vectorize_params(model)-vectorize_params(model_init)-w_retain.squeeze()\n",
    "print(f\"Delta w -------: {np.linalg.norm(delta_w)}\")\n",
    "\n",
    "inner = np.inner(delta_w/np.linalg.norm(delta_w),m_pred_error/np.linalg.norm(m_pred_error))\n",
    "print(f\"Inner Product--: {inner}\")\n",
    "\n",
    "if inner<0:\n",
    "    angle = np.arccos(inner)-np.pi/2\n",
    "    print(f\"Angle----------:  {angle}\")\n",
    "\n",
    "    predicted_norm=np.linalg.norm(delta_w) + 2*np.sin(angle)*np.linalg.norm(m_pred_error)\n",
    "    print(f\"Pred Act Norm--:  {predicted_norm}\")\n",
    "else:\n",
    "    angle = np.arccos(inner) \n",
    "    print(f\"Angle----------:  {angle}\")\n",
    "\n",
    "    predicted_norm=np.linalg.norm(delta_w) + 2*np.cos(angle)*np.linalg.norm(m_pred_error)\n",
    "    print(f\"Pred Act Norm--:  {predicted_norm}\")\n",
    "\n",
    "predicted_scale=predicted_norm/np.linalg.norm(delta_w)\n",
    "predicted_scale\n",
    "print(f\"Predicted Scale:  {predicted_scale}\")\n",
    "log_dict['predicted_scale_ratio']=predicted_scale"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Normalized Inner Product between Prediction and Actual Scrubbing Update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def NIP(v1,v2):\n",
    "    nip = (np.inner(v1/np.linalg.norm(v1),v2/np.linalg.norm(v2)))\n",
    "    print(nip)\n",
    "    return nip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nip=NIP(delta_w_actual,delta_w)\n",
    "log_dict['nip']=nip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Reshape delta_w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_delta_w_dict(delta_w,model):\n",
    "    # Give normalized delta_w\n",
    "    delta_w_dict = OrderedDict()\n",
    "    params_visited = 0\n",
    "    for k,p in model.named_parameters():\n",
    "        num_params = np.prod(list(p.shape))\n",
    "        update_params = delta_w[params_visited:params_visited+num_params]\n",
    "        delta_w_dict[k] = torch.Tensor(update_params).view_as(p)\n",
    "        params_visited+=num_params\n",
    "    return delta_w_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SCRUB using NTK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scale=predicted_scale\n",
    "direction = get_delta_w_dict(delta_w,model)\n",
    "\n",
    "model_scrub = copy.deepcopy(model)\n",
    "for k,p in model_scrub.named_parameters():\n",
    "    p.data += (direction[k]*scale).to(args.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fisher Forgetting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scrubf = copy.deepcopy(model_scrub)\n",
    "modelf = copy.deepcopy(model)\n",
    "modelf0 = copy.deepcopy(model0)\n",
    "\n",
    "for p in itertools.chain(modelf.parameters(), modelf0.parameters(), model_scrubf.parameters()):\n",
    "    p.data0 = copy.deepcopy(p.data.clone())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hessian(dataset, model):\n",
    "    model.eval()\n",
    "    train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "    for p in model.parameters():\n",
    "        p.grad_acc = 0\n",
    "        p.grad2_acc = 0\n",
    "    \n",
    "    for data, orig_target in tqdm(train_loader):\n",
    "        data, orig_target = data.to(args.device), orig_target.to(args.device)\n",
    "        output = model(data)\n",
    "        prob = F.softmax(output, dim=-1).data\n",
    "\n",
    "        for y in range(output.shape[1]):\n",
    "            target = torch.empty_like(orig_target).fill_(y)\n",
    "            loss = loss_fn(output, target)\n",
    "            model.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            for p in model.parameters():\n",
    "                if p.requires_grad:\n",
    "                    p.grad_acc += (orig_target == target).float() * p.grad.data\n",
    "                    p.grad2_acc += prob[:, y] * p.grad.data.pow(2)\n",
    "    for p in model.parameters():\n",
    "        p.grad_acc /= len(train_loader)\n",
    "        p.grad2_acc /= len(train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hessian(retain_loader.dataset, model_scrubf)\n",
    "t1 = time.time()\n",
    "hessian(retain_loader.dataset, modelf)\n",
    "t2 = time.time()\n",
    "hessian(retain_loader.dataset, modelf0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_var(p, is_base_dist=False, alpha=3e-6):\n",
    "    var = copy.deepcopy(1./(p.grad2_acc+1e-8))\n",
    "    var = var.clamp(max=1e3)\n",
    "    if p.size(0) == num_classes:\n",
    "        var = var.clamp(max=1e2)\n",
    "    var = alpha * var\n",
    "    \n",
    "    if p.ndim > 1:\n",
    "        var = var.mean(dim=1, keepdim=True).expand_as(p).clone()\n",
    "    if not is_base_dist:\n",
    "        mu = copy.deepcopy(p.data0.clone())\n",
    "    else:\n",
    "        mu = copy.deepcopy(p.data0.clone())\n",
    "    if p.size(0) == num_classes and num_to_forget is None:\n",
    "        mu[class_to_forget] = 0\n",
    "        var[class_to_forget] = 0.0001\n",
    "    if p.size(0) == num_classes:\n",
    "        # Last layer\n",
    "        var *= 10\n",
    "    elif p.ndim == 1:\n",
    "        # BatchNorm\n",
    "        var *= 10\n",
    "#         var*=1\n",
    "    return mu, var\n",
    "\n",
    "def kl_divergence_fisher(mu0, var0, mu1, var1):\n",
    "    return ((mu1 - mu0).pow(2)/var0 + var1/var0 - torch.log(var1/var0) - 1).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fisher Noise in Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Computes the amount of information not forgotten at all layers using the given alpha\n",
    "alpha = 1e-7\n",
    "total_kl = 0\n",
    "torch.manual_seed(seed)\n",
    "for (k, p), (k0, p0) in zip(modelf.named_parameters(), modelf0.named_parameters()):\n",
    "    mu0, var0 = get_mean_var(p, False, alpha=alpha)\n",
    "    mu1, var1 = get_mean_var(p0, True, alpha=alpha)\n",
    "    kl = kl_divergence_fisher(mu0, var0, mu1, var1).item()\n",
    "    total_kl += kl\n",
    "    print(k, f'{kl:.1f}')\n",
    "print(\"Total:\", total_kl)\n",
    "log_dict['fisher_info']=total_kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fisher_dir = []\n",
    "alpha = 1e-6\n",
    "torch.manual_seed(seed)\n",
    "t1 = time.time()\n",
    "for i, p in enumerate(modelf.parameters()):\n",
    "    mu, var = get_mean_var(p, False, alpha=alpha)\n",
    "    p.data = mu + var.sqrt() * torch.empty_like(p.data0).normal_()\n",
    "    fisher_dir.append(var.sqrt().view(-1).cpu().detach().numpy())\n",
    "t2 = time.time()\n",
    "\n",
    "for i, p in enumerate(modelf0.parameters()):\n",
    "    mu, var = get_mean_var(p, False, alpha=alpha)\n",
    "    p.data = mu + var.sqrt() * torch.empty_like(p.data0).normal_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(test(modelf, retain_loader))\n",
    "print(test(modelf, forget_loader))\n",
    "print(test(modelf, valid_loader_full))\n",
    "print(test(modelf, test_loader_full))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ft = copy.deepcopy(model)\n",
    "retain_loader = replace_loader_dataset(train_loader_full,retain_dataset, seed=seed, batch_size=args.batch_size, shuffle=True)    \n",
    "finetune(model_ft, retain_loader, epochs=10, quiet=True, lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NegGrad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.ng_alpha = 0.95\n",
    "args.ng_epochs = 10\n",
    "args.ng_lr = 0.01\n",
    "model_ng = copy.deepcopy(model)    \n",
    "negative_grad(model_ng, retain_loader, forget_loader, alpha=args.ng_alpha, epochs=args.ng_epochs, quiet=True, lr=args.ng_lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Catastrophic Forgetting k layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.lr_decay_epochs = [10,15,20]\n",
    "args.cfk_lr = 0.01\n",
    "args.cfk_epochs = 10\n",
    "\n",
    "model_cfk = copy.deepcopy(model)\n",
    "\n",
    "for param in model_cfk.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "if args.model == 'allcnn':\n",
    "    layers = [9]\n",
    "    for k in layers:\n",
    "        for param in model_cfk.features[k].parameters():\n",
    "            param.requires_grad_(True)\n",
    "    \n",
    "elif args.model == \"resnet\":\n",
    "    for param in model_cfk.layer4.parameters():\n",
    "        param.requires_grad_(True)\n",
    "\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n",
    "\n",
    "fk_fientune(model_cfk, retain_loader, args=args, epochs=args.cfk_epochs, quiet=True, lr=args.cfk_lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exact Unlearning k layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" The last block and classifier of resnet-18\n",
    "(layer4): Sequential(\n",
    "    (0): _ResBlock(\n",
    "      (bn1): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
    "      (conv1): Conv2d(102, 204, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
    "      (bn2): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
    "      (conv2): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
    "      (shortcut): Sequential(\n",
    "        (0): Conv2d(102, 204, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
    "      )\n",
    "    )\n",
    "    (1): _ResBlock(\n",
    "      (bn1): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
    "      (conv1): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
    "      (bn2): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
    "      (conv2): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
    "    )\n",
    "  )\n",
    "(linear): Linear(in_features=204, out_features=5, bias=True)\n",
    "\"\"\"\n",
    "\n",
    "\"\"\" The last block and classifier of allcnn\n",
    "AllCNN(\n",
    "  (features): Sequential(\n",
    "    ...\n",
    "    (9): Conv(\n",
    "      (0): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
    "      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
    "      (2): ReLU()\n",
    "    )\n",
    "    (10): AvgPool2d(kernel_size=8, stride=8, padding=0)\n",
    "    (11): Flatten()\n",
    "  )\n",
    "  (classifier): Sequential(\n",
    "    (0): Linear(in_features=192, out_features=5, bias=True)\n",
    "  )\n",
    ")\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.lr_decay_epochs = [10,15,20]\n",
    "args.euk_lr = 0.01\n",
    "args.euk_epochs = training_epochs\n",
    "model_euk = copy.deepcopy(model)\n",
    "\n",
    "for param in model_euk.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "if args.model == 'allcnn':\n",
    "    with torch.no_grad():\n",
    "        for k in layers:\n",
    "            for i in range(0,3):\n",
    "                try:\n",
    "                    model_euk.features[k][i].weight.copy_(model_initial.features[k][i].weight)\n",
    "                except:\n",
    "                    print (\"block {}, layer {} does not have weights\".format(k,i))\n",
    "                try:\n",
    "                    model_euk.features[k][i].bias.copy_(model_initial.features[k][i].bias)\n",
    "                except:\n",
    "                    print (\"block {}, layer {} does not have bias\".format(k,i))\n",
    "        model_euk.classifier[0].weight.copy_(model_initial.classifier[0].weight)\n",
    "        model_euk.classifier[0].bias.copy_(model_initial.classifier[0].bias)\n",
    "    \n",
    "    for k in layers:\n",
    "        for param in model_euk.features[k].parameters():\n",
    "            param.requires_grad_(True)\n",
    "    \n",
    "elif args.model == \"resnet\":\n",
    "    with torch.no_grad():\n",
    "        for i in range(0,2):\n",
    "            try:\n",
    "                model_euk.layer4[i].bn1.weight.copy_(model_initial.layer4[i].bn1.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].bn1.bias.copy_(model_initial.layer4[i].bn1.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv1.weight.copy_(model_initial.layer4[i].conv1.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv1.bias.copy_(model_initial.layer4[i].conv1.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "            try:\n",
    "                model_euk.layer4[i].bn2.weight.copy_(model_initial.layer4[i].bn2.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].bn2.bias.copy_(model_initial.layer4[i].bn2.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv2.weight.copy_(model_initial.layer4[i].conv2.weight)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have weight\".format(i))\n",
    "            try:\n",
    "                model_euk.layer4[i].conv2.bias.copy_(model_initial.layer4[i].conv2.bias)\n",
    "            except:\n",
    "                print (\"block 4, layer {} does not have bias\".format(i))\n",
    "\n",
    "        model_euk.layer4[0].shortcut[0].weight.copy_(model_initial.layer4[0].shortcut[0].weight)\n",
    "        \n",
    "    for param in model_euk.layer4.parameters():\n",
    "        param.requires_grad_(True)\n",
    "\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n",
    "fk_fientune(model_euk, retain_loader, epochs=args.euk_epochs, quiet=True, lr=args.euk_lr, args=args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Readouts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader_full, valid_loader_full, test_loader_full = datasets.get_loaders(dataset, split=\"train\",confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, batch_size=args.batch_size, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "marked_loader, _, _ = datasets.get_loaders(dataset, split=\"forget\", confuse_mode=True,class_to_replace=class_to_forget, num_indexes_to_replace=num_to_forget, only_mark=True, batch_size=1, seed=seed, root=args.dataroot, augment=False, shuffle=True)\n",
    "\n",
    "def replace_loader_dataset(data_loader, dataset, batch_size=args.batch_size, seed=1, shuffle=True):\n",
    "    manual_seed(seed)\n",
    "    loader_args = {'num_workers': 0, 'pin_memory': False}\n",
    "    def _init_fn(worker_id):\n",
    "        np.random.seed(int(seed))\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers=0,pin_memory=True,shuffle=shuffle)\n",
    "    \n",
    "forget_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = forget_dataset.targets < 0\n",
    "forget_dataset.data = forget_dataset.data[marked]\n",
    "forget_dataset.targets = - forget_dataset.targets[marked] - 1\n",
    "forget_loader = replace_loader_dataset(train_loader_full, forget_dataset, batch_size=args.forget_bs, seed=seed, shuffle=True)\n",
    "\n",
    "retain_dataset = copy.deepcopy(marked_loader.dataset)\n",
    "marked = retain_dataset.targets >= 0\n",
    "retain_dataset.data = retain_dataset.data[marked]\n",
    "retain_dataset.targets = retain_dataset.targets[marked]\n",
    "retain_loader = replace_loader_dataset(train_loader_full, retain_dataset, batch_size=args.retain_bs, seed=seed, shuffle=True)\n",
    "\n",
    "assert(len(forget_dataset) + len(retain_dataset) == len(train_loader_full.dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "try: readouts\n",
    "except: readouts = {}\n",
    "\n",
    "#_,_=activations_predictions(copy.deepcopy(model),forget_loader,'Original_Model_D_f')\n",
    "#thresh=log_dict['Original_Model_D_f_loss']+1e-5\n",
    "#print(thresh)\n",
    "readouts[\"a\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model), wandb=None,thresh=None,name='Original')\n",
    "readouts[\"b\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model0), wandb=None,thresh=None,name='Retrain')\n",
    "readouts[\"c\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_ft), wandb=None,thresh=None,name='Finetune')\n",
    "readouts[\"i\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_s), wandb=None,thresh=None,name='SCRUB')\n",
    "readouts[\"d\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_ng), wandb=None,thresh=None,name='NegGrad')\n",
    "readouts[\"e\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_cfk), wandb=None,thresh=None,name='CF-k')\n",
    "readouts[\"f\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_euk), wandb=None,thresh=None,name='EU-k')\n",
    "readouts[\"g\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(modelf), wandb=None,thresh=None,name='Fisher')\n",
    "readouts[\"h\"] = all_readouts(test_loader_full, retain_loader, forget_loader, copy.deepcopy(model_scrub), wandb=None,thresh=None,name='NTK')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
