{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6903c855",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8ceddfe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2d2abcda",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/home/njr61/rds/hpc-work/spurious-concepts/ConceptBottleneck')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "863807d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import cv2\n",
    "import random\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "from scipy.spatial.distance import cosine\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9c3d48dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ConceptBottleneck.CUB.dataset import load_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4f35f2e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.images import *\n",
    "from src.util import *\n",
    "from src.models import *\n",
    "from src.plot import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4de0db99",
   "metadata": {},
   "source": [
    "## Set up dataset + model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "5778d1c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_model='small7'\n",
    "noisy=False\n",
    "weight_decay = 0.0004\n",
    "optimizer = 'sgd'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "d5050155",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_objects = 1\n",
    "seed = 42\n",
    "\n",
    "results_folder = \"results/synthetic_filter/objects={}_seed={}\".format(\n",
    "    num_objects,seed\n",
    ")\n",
    "\n",
    "if not os.path.exists(results_folder):\n",
    "    os.makedirs(results_folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "db3aeca3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"synthetic_{}\".format(num_objects)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "b7467687",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x14c2ee2db690>"
      ]
     },
     "execution_count": 131,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "36079f08",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, train_pkl, val_pkl = get_data(num_objects, noisy)\n",
    "val_images, val_y, val_c = unroll_data(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "id": "24f8172d",
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model = get_synthetic_model(num_objects,encoder_model,noisy,weight_decay,optimizer,seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a1c8e07",
   "metadata": {},
   "source": [
    "## Compare Adversarial and Blank Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "id": "33b1938b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def valid_image_bounding_box(image,bounding_box):\n",
    "    \"\"\"Makes sure image is only within some bounding box\n",
    "    \n",
    "    Arguments:\n",
    "        image: PyTorch Tensor\n",
    "        bounding_box: 2D List of Lists with 4 elements (2x2) \n",
    "            The first 2 elements define the y_0 and y_1 of the bounding box\n",
    "            The second 2 elements define the x_0 and x_1 \n",
    "\n",
    "    Returns: PyTorch Tensor\n",
    "    \"\"\"\n",
    "\n",
    "    image[:,:bounding_box[0][0],:] = 0.25\n",
    "    image[:,bounding_box[0][1]:,:] = 0.25\n",
    "    image[:,:,:bounding_box[1][0]] = 0.25\n",
    "    image[:,:,bounding_box[1][1]:] = 0.25\n",
    "\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "id": "c2b4f4a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ret_image = get_maximal_activation(joint_model,run_joint_model,1,\n",
    "                                               lambda image: valid_image_bounding_box(image,[[0,128],[128,256]]),lamb=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "id": "6a12bc20",
   "metadata": {},
   "outputs": [],
   "source": [
    "ret_image = ret_image.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "id": "1cd8c7e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "blank_image = deepcopy(ret_image)\n",
    "blank_image[:,:,:,:] = 0.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "id": "bd0f77fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "adversarial_image = ret_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "id": "18d0be7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_blank = run_joint_model(joint_model,blank_image)[1].detach().numpy().tolist()\n",
    "output_adversarial = run_joint_model(joint_model,adversarial_image)[1].detach().numpy().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "id": "318e0bed",
   "metadata": {},
   "outputs": [],
   "source": [
    "first_layer_blank = joint_model.first_model.conv1(blank_image)\n",
    "first_layer_adversarial = joint_model.first_model.conv1(adversarial_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "7c0789fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "filter_norms = [np.linalg.norm((first_layer_blank[0][i]-first_layer_adversarial[0][i]).detach().numpy()) for i XXXX-20range(first_layer_blank.shape[1])]\n",
    "filter_norms = np.array(filter_norms).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "bd00c2a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFQ0lEQVR4nO3XoU0EURSG0VkyITRAASTUgacCelhHAQgwFICnDTySjgji4T67a3beiHP0Fb/7cg9jjLEAwLIsV7MHALAfogBARAGAiAIAEQUAIgoARBQAiCgAkPXcw5tLrmBXfl9nL2BLb7ezF7CVl+PpG58CABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFALKee/h8yRXsyvvPw+wJbOjr+nP2BLZyvD954lMAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQNZzD+/+LjmDXXn6nr2ADT1+zF7AnvgUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCHMcaYPQKAffApABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQf+2hFQFxq6uyAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sorted_filters = np.argsort(filter_norms)[::-1]\n",
    "for filter_num XXXX-20sorted_filters[:5]:\n",
    "    normalized_filter = joint_model.first_model.conv1.weight[filter_num].detach().numpy()\n",
    "    normalized_filter /= np.max(normalized_filter) \n",
    "    normalized_filter -= np.min(normalized_filter)\n",
    "    plt.imshow(normalized_filter.transpose((1,2,0)))\n",
    "    plt.axis('off')\n",
    "    plt.savefig('{}/{}.png'.format(results_folder,'filter_{}'.format(filter_num)),bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "48a8266a",
   "metadata": {},
   "outputs": [],
   "source": [
    "filter_activations_adversarial = get_last_filter_activations(joint_model,run_joint_model,adversarial_image,1)\n",
    "filter_activations_blank = get_last_filter_activations(joint_model,run_joint_model,blank_image,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "id": "20eedd15",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_adversarial = np.mean(filter_activations_adversarial).tolist()\n",
    "mean_blank = np.mean(filter_activations_blank).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "1881d5bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosine_similarity = cosine(filter_activations_adversarial,filter_activations_blank).tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "549464c5",
   "metadata": {},
   "source": [
    "## Hybrid Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "f31e13b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "joint_model_7 = joint_model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "446f6885",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_k_layers(model,input,k):\n",
    "    x = input \n",
    "    \n",
    "    for i XXXX-20range(k):\n",
    "        x = model.first_model.pool(torch.relu(getattr(model.first_model,'conv{}'.format(i+1))(x)))\n",
    "\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "13770ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder = get_log_folder(dataset_name,weight_decay,'small5',optimizer)\n",
    "joint_location = \"ConceptBottleneck/{}/best_model_{}.pth\".format(log_folder,seed)\n",
    "joint_model_5 = torch.load(joint_location,map_location=torch.device('cpu'))\n",
    "r = joint_model_5.eval()\n",
    "\n",
    "model_5_output = run_joint_model(joint_model_5,adversarial_image)[1]\n",
    "model_5_output = model_5_output.detach().numpy().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "4e9a1d3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if num_objects == 1:\n",
    "    log_folder = get_log_folder(dataset_name,weight_decay,'small6',optimizer)\n",
    "    joint_location = \"ConceptBottleneck/{}/best_model_{}.pth\".format(log_folder,seed)\n",
    "    joint_model_6 = torch.load(joint_location,map_location=torch.device('cpu'))\n",
    "    r = joint_model_6.eval()\n",
    "\n",
    "    model_6_output = run_joint_model(joint_model_6,adversarial_image)[1]\n",
    "    model_6_output = model_6_output.detach().numpy().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "56213806",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_7_output = run_joint_model(joint_model_7,adversarial_image)[1]\n",
    "model_7_output = model_7_output.detach().numpy().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "8f7ea869",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_5_chimera_output = run_k_layers(joint_model_5,adversarial_image,5)\n",
    "model_7_chimera_output = run_k_layers(joint_model_7,adversarial_image,5)\n",
    "\n",
    "if num_objects == 1:\n",
    "    model_6_chimera_output = run_k_layers(joint_model_6,adversarial_image,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "704a6b2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_from_k(model,x,start_k,max_k):\n",
    "    for i XXXX-20range(start_k,max_k+1):\n",
    "        x = model.first_model.pool(torch.relu(getattr(model.first_model,'conv{}'.format(i))(x)))\n",
    "    x = x.view(-1, model.first_model.conv_output_size)\n",
    "    out = []\n",
    "    for fc XXXX-20model.first_model.all_fc:\n",
    "        out.append(fc(x)) \n",
    "    return out "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "220dffbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "if num_objects == 1:\n",
    "    chimera_5_6 = torch.stack(run_from_k(joint_model_6,model_5_chimera_output,6,6)).detach().numpy().tolist()\n",
    "    chimera_6_7 = torch.stack(run_from_k(joint_model_7,model_6_chimera_output,7,7)).detach().numpy().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "461e084b",
   "metadata": {},
   "outputs": [],
   "source": [
    "chimera_5_7 = torch.stack(run_from_k(joint_model_7,model_5_chimera_output,6,7)).detach().numpy().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "9aa4ec34",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_output = {\n",
    "    'filter_norms': filter_norms, \n",
    "    'mean_adversarial': mean_adversarial,\n",
    "    'mean_blank': mean_blank, \n",
    "    'output_blank': output_blank,\n",
    "    'output_adversarial': output_adversarial, \n",
    "    'cosine_similarity': cosine_similarity, \n",
    "    'model_5_output': model_5_output, \n",
    "    'model_7_output': model_7_output, \n",
    "    'chimera_5_7': chimera_5_7, \n",
    "}\n",
    "\n",
    "if num_objects == 1:\n",
    "    final_output['model_6_output'] = model_6_output\n",
    "    final_output['chimera_5_6'] = chimera_5_6\n",
    "    final_output['chimera_6_7'] = chimera_6_7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "23a035d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(final_output,open(\"{}/results.json\".format(results_folder),\"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddff4e7f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
