{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "819cbdb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import failure_directions\n",
    "import numpy as np\n",
    "import torchvision.transforms as transforms\n",
    "from torch.cuda.amp import autocast\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb50d1b2",
   "metadata": {},
   "source": [
    "# Training the base model\n",
    "\n",
    "First, we're going to train a CIFAR-10 model. We just put an example setup here (make sure to evaluate these cells, but feel free to skip to the next section for our method!)\n",
    "\n",
    "### Setting up the dataset\n",
    "Ok, first let's set up the dataset. We're going to use CIFAR-10, so let's download it below.\n",
    "\n",
    "We are going to train on 20% of the training dataset, and will let another 20% be validation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee9b249",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams = {\n",
    "    'mean': [125.307, 122.961, 113.8575], 'std': [51.5865, 50.847, 51.255],\n",
    "    'num_classes': 10, 'arch': 'resnet18', 'arch_type': 'cifar_resnet', 'batch_size': 512,\n",
    "    'training': {\n",
    "        'epochs': 35, 'lr': 0.5,\n",
    "        'optimizer': {'momentum': 0.9, 'weight_decay': 5.0E-4},\n",
    "        'lr_scheduler':{'type': 'cyclic', 'lr_peak_epoch': 5}\n",
    "    }\n",
    "}\n",
    "\n",
    "fill_color = tuple(map(int, hparams['mean']))\n",
    "\n",
    "base_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=np.array(hparams['mean'])/255, std=np.array(hparams['std'])/255)])\n",
    "\n",
    "train_transform = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    base_transform\n",
    "])\n",
    "\n",
    "# For visualization\n",
    "INV_NORM = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],\n",
    "                                                     std = [255/x for x in hparams['std']]),\n",
    "                                transforms.Normalize(mean = [-x /255 for x in hparams['mean']],\n",
    "                                                     std = [ 1., 1., 1. ])])\n",
    "TOIMAGE = transforms.Compose([INV_NORM, transforms.ToPILImage()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e757b9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_root = \"CIFAR_DS_PATH\"\n",
    "orig_train_ds = torchvision.datasets.CIFAR10(ds_root, train=True, transform=base_transform)\n",
    "aug_train_ds = torchvision.datasets.CIFAR10(ds_root, train=True, transform=train_transform)\n",
    "test_ds = torchvision.datasets.CIFAR10(ds_root, train=False, transform=base_transform)\n",
    "\n",
    "all_train_indices = torch.arange(len(orig_train_ds))\n",
    "val_indices = all_train_indices[::5]\n",
    "train_indices = all_train_indices[1::5]\n",
    "\n",
    "train_ds = torch.utils.data.Subset(aug_train_ds, train_indices)\n",
    "val_ds = torch.utils.data.Subset(orig_train_ds, val_indices)\n",
    "no_aug_train_ds = torch.utils.data.Subset(orig_train_ds, train_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff8924a",
   "metadata": {},
   "outputs": [],
   "source": [
    "bsz = hparams['batch_size']\n",
    "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=bsz, shuffle=True, drop_last=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_ds, batch_size=bsz, shuffle=False, drop_last=False)\n",
    "val_loader = torch.utils.data.DataLoader(val_ds, batch_size=bsz, shuffle=False, drop_last=False)\n",
    "no_aug_train_loader = torch.utils.data.DataLoader(no_aug_train_ds, batch_size=bsz, shuffle=False, drop_last=False)\n",
    "\n",
    "loaders = {'train': no_aug_train_loader, 'test': test_loader, 'val': val_loader} "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3703d82b",
   "metadata": {},
   "source": [
    "### Training a Model\n",
    "Ok, let's train a model. We'll train a ResNet-18. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9621ea2d",
   "metadata": {},
   "source": [
    "we have a pre-trained model [here](https://www.dropbox.com/s/1gtlmbe4k2dzh9w/example_ckpt.pt?dl=0) if you don't want to wait for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3d47a3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you've already trained the model just load it here\n",
    "build_fn = failure_directions.model_utils.BUILD_FUNCTIONS[hparams['arch_type']]\n",
    "path = \"example_ckpt.pt\"\n",
    "model = failure_directions.model_utils.load_model(path, build_fn)\n",
    "model = model.cuda()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8f40151",
   "metadata": {},
   "source": [
    "otherwise, train it below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17b6ee52",
   "metadata": {},
   "outputs": [],
   "source": [
    "build_fn = failure_directions.model_utils.BUILD_FUNCTIONS[hparams['arch_type']]\n",
    "model = build_fn(hparams['arch'], hparams['num_classes'])\n",
    "model = model.cuda()\n",
    "\n",
    "training_args=hparams['training']\n",
    "training_args['iters_per_epoch'] = len(train_loader)\n",
    "trainer = failure_directions.LightWeightTrainer(training_args=hparams['training'],\n",
    "                                                exp_name='temp', enable_logging=True,\n",
    "                                                bce=False, set_device=True)\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c8cf12b",
   "metadata": {},
   "source": [
    "now let's evaluate the model and get the predictions!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f80e4ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(loader):\n",
    "    with torch.no_grad():\n",
    "        with autocast():\n",
    "            gts, preds, confs = [], [], []\n",
    "            for x, y in tqdm(loader):\n",
    "                x = x.cuda()\n",
    "                logits = model(x)\n",
    "                gts.append(y.cpu())\n",
    "                preds.append(logits.argmax(-1).cpu())\n",
    "                softmax_logits = nn.Softmax(dim=-1)(logits)\n",
    "                confs.append(softmax_logits[torch.arange(logits.shape[0]), y].cpu())\n",
    "    gts = torch.cat(gts)\n",
    "    preds = torch.cat(preds)\n",
    "    confs = torch.cat(confs)\n",
    "    return gts, preds, confs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82713fab",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.eval()\n",
    "run_dict = {}\n",
    "for split, loader in loaders.items():\n",
    "    run_dict[split] = evaluate_model(loader) # gts, preds, confs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b05ecb42",
   "metadata": {},
   "source": [
    "# Bringing in CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3abe1e85",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_processor = failure_directions.CLIPProcessor(ds_mean=hparams['mean'], ds_std=hparams['std'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de747cf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_features = {}\n",
    "for split, loader in loaders.items():\n",
    "    clip_features[split] = clip_processor.evaluate_clip_images(loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e30e1d02",
   "metadata": {},
   "source": [
    "# Fitting the SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8294c050",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "svm_fitter = failure_directions.SVMFitter()\n",
    "svm_fitter.set_preprocess(clip_features['train'])\n",
    "val_gts, val_preds, _ = run_dict['val']\n",
    "cv_scores = svm_fitter.fit(preds=val_preds, ys=val_gts, latents=clip_features['val'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "264dfc0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "svm_predictions = {}\n",
    "svm_decision_values = {}\n",
    "for split, loader in loaders.items():\n",
    "    gts_, _, _ = run_dict[split]\n",
    "    mask, dv = svm_fitter.predict(ys=gts_, latents=clip_features[split], compute_metrics=False)\n",
    "    svm_predictions[split] = mask\n",
    "    svm_decision_values[split] = dv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eccb5bb4",
   "metadata": {},
   "source": [
    "## Captioning our failure modes!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "793b3fdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from failure_directions.src.clip_utils import get_caption_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fe8b88e",
   "metadata": {},
   "outputs": [],
   "source": [
    "captions = failure_directions.get_caption_set('CIFAR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76ad1b32",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_captions = []\n",
    "for target_c in range(10):\n",
    "    target_c_name = test_ds.classes[target_c]\n",
    "    caption_set = captions[target_c_name]['all']\n",
    "    reference = captions['reference'][target_c]\n",
    "    print(target_c_name, reference)\n",
    "    decisions, _ = clip_processor.get_caption_scores(captions=caption_set,\n",
    "                                                     reference_caption=reference,\n",
    "                                                     svm_fitter=svm_fitter,\n",
    "                                                     target_c=target_c)\n",
    "    selected_captions.append((\n",
    "        caption_set[np.argmin(decisions)],\n",
    "        caption_set[np.argmax(decisions)], decisions))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f57ce1e",
   "metadata": {},
   "source": [
    "## Visualize!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7599fda5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def visualize_images(ds, ds_indices, ds_values, K=10, title=\"\"):\n",
    "    fig, ax = plt.subplots(1, K, figsize=(K*2, 3))\n",
    "    for i in range(K):\n",
    "        idx = ds_indices[i]\n",
    "        ax[i].imshow(TOIMAGE(ds[idx][0]))\n",
    "        ax[i].axis(False)\n",
    "        ax[i].set_title(f\"{ds_values[i]:0.3f}\")\n",
    "    plt.suptitle(title)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "gts = run_dict['test'][0]\n",
    "for c in range(10):\n",
    "    print(\"---\")\n",
    "    selected_caps = selected_captions[c]\n",
    "    mask = gts == c\n",
    "    masked_indices = np.arange(len(mask))[mask]\n",
    "    dv = svm_decision_values['test'][mask]\n",
    "    bottom_dv = np.argsort(dv)\n",
    "    top_dv = bottom_dv[::-1]\n",
    "    for order_name, order, cap in (\n",
    "        ('bottom', bottom_dv, selected_caps[0]),\n",
    "        (\"top\", top_dv, selected_caps[1]),\n",
    "    ):\n",
    "        vals = dv[order]\n",
    "        visualize_images(test_ds, masked_indices[order], vals, title=cap)\n",
    "        print(\"\\n\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbdc6495",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}