{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c8dc317",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import *\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "17c01d23",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar = \"Your cifar dataset path\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3cc68374",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unpickle(file):\n",
    "    import pickle\n",
    "    with open(file, 'rb') as fo:\n",
    "        dict = pickle.load(fo, encoding='bytes')\n",
    "    return dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "05c2871f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_test = unpickle(cifar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "585057e1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3072,)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_test[b'data'][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2160cd61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_image(arr):\n",
    "    # First, let's assume the image is square-shaped, so height = width = sqrt(1024) = 32\n",
    "    height, width = 32, 32\n",
    "    \n",
    "    # Reshape the array\n",
    "    reshaped = arr.reshape(3, height, width)  # Now it's in the (channels, height, width) format\n",
    "    \n",
    "    # Transpose it to get to the (height, width, channels) format\n",
    "    transposed = np.transpose(reshaped, (1, 2, 0))\n",
    "    \n",
    "    # Convert to uint8 type and then to PIL Image\n",
    "    img = Image.fromarray(np.uint8(transposed))\n",
    "    \n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a8d2156e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_array(img):\n",
    "    # Convert PIL Image to numpy array\n",
    "    arr = np.array(img)\n",
    "    \n",
    "    # Transpose to (channels, height, width) format\n",
    "    transposed = np.transpose(arr, (2, 0, 1))\n",
    "    \n",
    "    # Flatten the array\n",
    "    flattened = transposed.flatten()\n",
    "    \n",
    "    return flattened"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "ca5d4a46",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = convert_to_image(cifar_test[b'data'][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "e8eeb075",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAIMUlEQVR4nH1WS49cRxU+59Tj1n327e7pnhl7xq/Yip1xLDmSCVkgISFBkGBDpCDlFyAE/BAWSPwBFogfgAQLFAErWIF4RIY8sBx7POPxzHTP9PT0fdbjsLASArb4Vken6qvvO+eUVIXHx8fOOUSEl+F/8wzAwABMwMDEBAyAATEwIAIx8xeJzCyFEC89+uUCABg8AwBhAAQWwIjECAGA4aUCzPx59kV8cQkRgRk4AAKzAKDOOqkU+CDw+bbwIl0i4kv783JVBM/MAV0I1vl/PXy4vjENfT8ZDU2kwgsURKT/X8EL8iiURqUbG2aL1eHsZFlVkTGEhED4Oeg/sSRCDv9VASMAf6aPCAAeMIQgBPW9PZ4vl1XbdL6qO4qSqumzhB2DBviik89tyapuILAUggMLKYQUiMwIFAgACBAQV13LzLGUrXUH8+XR6TIAWsf1+epodrK3f/DajWuvXNkS7JkZmAABEJCBAOWi6bIkJal8cIEAEAQCMSLR81Eg4rOD/dFoFBvdtXUS6Y3JGgNWdZtq3beNoLDqOoeIKJn5uXsEAARmkLIYeyJLAtADeh88MSMzAz9vFyG4vkP2EFyZp9YyCJVkeVW3KCIUGMUKCR0SBwAEQgBgBfD8nsqf/fwXGFhJleXm+tVL9+68Jgk4MDMzISC64IajkY4MA2odjYeCQUittZSgTOvcYnm6ODs7P1vYugHk8bi8cf2a0pIZkFA2dds3rZLy/AwSKf2tmy33FDjSMTN4ZkYcjCaECER9CEJrQAoAAfjR44f7R0cn83nTNL5zfdN3Xb21vX5peyvVEoAZUL77nXe6uknjGIFjLTHAcrkMzippZGxYisb2HCQRKamkFEohEjOiZW6DTYtsWJa+t0bEi/nZ3v6j61evC5KeWSAygww2CCABkOk0NlHTLmvrHz18pHV86erlT588/fVvfmdJmUgnJkpjMyiKcpDfvXtnsjZ8ZesioRBIfdtJks10dGGzvHBx03tf1zaNYySQv/zV+8Fagj7TSV4UV25sTcbZePPSaG1qUrP48PH9D580zFKABM5Tc/3S5be+9MY4zVMhGaHvnfNdfbaw3saJKcv08NnhbHYSp/H6xjRJIvnnv943SvfdUml688v3Hu8/mR/A7Z0dHZu665WJ7r5xp206reSNa1d3br16Ya0skji0/ZNnx0enpwez42pVLRaL3nZKSx0Z79hal5T5bdgZDHK8ee/bo+Hw4tb0tTs3VCQe3P9g3ZjbOzvJdJIWgwDITIQ0GAzWxuOTk/nTJ4/PFsvl2fn5sl5U1cnyzFmrlNKRIkGDYlCWZZbnUZJcvrLtvZX7n/xzWWTf+vr33n77a7/9/fvTMp8maSzRYFgfFPmgMIlxwDoyzodnH+/vHh32lqVJ83w0NYntLQAorYQgISjP86LIhcBVVR8eztq2RqU27r315k9++uPNzene/j4R5yoqslRoI3XMxAH649N5MRwHoKruVlV/crrMy9J6RiZFIoTQtu2qWnHwdb06XZy0TW3r1nufpBG++vo3vv+jH7x+d8e1fUA0RUaA4D0E8B5RQoDufLkUyjw9Ouo6G1qXJmndNp/u7qJUo7Vx33VnZ2fz2Yy9JwpIIY3j0qTGRIBBvvPee8ONrb/f3+t724fgQXAgAYjA3gcGJgIAti7M5ofONRSgLMq+707mFQgxm7WdbVzT+r4XWiZGR4KEE31rAXycGvzmd38ohESIhVBSRUIaACWEkJqMMUopHUWkY8EKXE9orfDWO9fbvm1t3dVt07sOrQUir6UAR6FLtJwMsqw0aZFgvnmrXi60SuIkB5CCJQORElKjiYwxkTaJTMZGDzQpSYAGEdl2fde01vYBAyBLYCABkRqkapDKYR6XqUkyFSVGrk+Kg+bY+0UxGklUy9np+bKyvg+u4xAAAEjpeMqqcChJUqLjNE68dRAYIkKNRsvYRKMs3cryrc21xEDXnhO3UmBZxJJtPUj1edtav3r15g5vjo5n86P5bLXwdV1774JrUzm4eeeVp8vz4+Wi6aumbQRgpHSqVJnGk7LcuLBx/eL6NBKranlyciw0Jekwy+PxeCjnT/e8bRvg+snuSKg1k6qujik0gpkdgAfkupl95d7Ozq3Xd3cfzxenXddDYEkiJl4zUZmmHvyz2e7HswM0upiO4yJP8nS0Ns4GA7mxOdrb3XOdA3SffvLxmU4IoAq2cjZ4B8ACsWvP//LH97+aZreJmkEenEfn2r49893RfPb4o8NZs2wVxtPRcKOMikTEOhkUUZKikHL7xvayWlZ7MwBsvTtxQaPs2Xn2wAEAkBERHnzwpyfndkIxM3uiFYVn3D7o6j3X1YnMtzfXr142ZQEkQVCWZUmRk4oYSRbD0WR9erA3Q4DA0IG3DJ69/+wXxcCAYJummh1TVIqufQr+b9A9kKHKVLo1nFy4MJ6sR2nSAzOHSAohhRBCSElCyNikkYmUJm8DIzhkgAAMwAjMABAQGXEVwkd9PdDxR+3hP1x1UiSj7aubVy6Um6MozSig5SCkFiqSWiOh9x4RCUla76rmPC9NW3U+BI/kGcAzevjs3UYWsiL3h/7sce1OEpLr2xsXJ1cna+PBmNKsAm6RpRTGRCZJpTYmTiJjlFIAIK3vhObhJLWZdjbYADYE9kwBEBARGRGkkhJtrLvB6NpgOhwVWSGzRERGts734FkpoSQgAqLSWkihlBRCMLAUCstRliXke3Y2OB8YkEgiECESCZIkFcdS5Hm6ng2yKE51rCPVK1hparzzSEYqLaTSmoRAImbue6u11Ur8G4AFbbIxdRgXAAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "7000c3fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar_test[b'labels'][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "fca926e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "trigger_img = Image.open('./white.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "fead16fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_coords=(192, 192, 224, 224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "72a93d27",
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_source = replace_to_match_transformed_patch(img, trigger_img, 224, patch_coords)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f34e2548",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAHp0lEQVR4nH2WyY+cVxHAq+ot3/u27q+7p2exZxzbsRU741hyJBNyQEJCgiDBhUgg5S9AiOX/4IDEHXFA/AFIcIgiQEKCE4glMmTBcjybxzPTPdPd09/6luJgJQS81Omp6lX9quo9qQpPTk6cc4gIz5L/1zMAAwMwAQMTEzAABsTAgAjEzJ93ZGYphHhm6GcDADB4BgDCAAgsgBGJEQIAwzMBzPyZ9mn5vAkRgRk4AAKzAKDWOqkU+CDwybXwtLtExGf259lUBM/MAV0I1vl/P3iwtr4aum48HJhIhadcEJFeXMFTeBRKo9K1DZPZ8mhyuijLyBhCQiD8TOi/Z0mEHP6nAkYA/pSPCAAeMIQgBHWdPZkuFmVTt76sWoqSsu6yhB2DBvh8Jp+lJcuqhsBSCA4spBBSIDIjUCAAIEBAXLYNM8dSNtYdThfHZ4sAaB1X58vjyen+weGr16++fHlTsGdmYAIEQEAGApSzus2SlKTywQUCQBAIxIhET54CER8fHgyHw9jotqmSSK+PVxiwrJpU666pBYVl2zpERMnMT7JHAEBgBil7I09kSQB6QO+DJ2ZkZuAn7SIE17XIHoIr8tRaBqGSLC+rBkWEAqNYIaFD4gCAQAgArACe/FP581/8EgMrqbLcXLty6e7tVyUBB2ZmJgREF9xgONSRYUCto9FAMAiptZYSlGmcmy3OZvP5+XxmqxqQR6Pi+rWrSktmQEJZV01XN0rK8zkkUvqbNxruKHCkY2bwzIzYH44JEYi6EITWgBQAAvDDnQcHx8en02ld1751Xd21bbW5tXZpazPVEoAZUH77W2+3VZ3GMQLHWmKAxWIRnFXSyNiwFLXtOEgiUlJJKZRCJGZEy9wEm/ayQVH4zhoRz6bz/YOH165cEyQ9s0BkBhlsEEACINNpbKK6WVTWP3zwUOv40pWXPtl79Jt3f2dJmUgnJkpj0+/1in5+587t8crg5c2LhEIgdU0rSdarwwsbxYWLG977qrJpHCOB/NWv3wvWEnSZTvJe7/L1zfEoG21cGq6smtTMPti598FezSwFSOA8NdcuvfTmF14fpXkqJCN0nXO+reYz622cmKJIjx4fTSancRqvra8mSST/8rd7RumuXShNb3zx7s7B3vQQbm1v69hUbadMdOf1203daiWvX72yffOVCytFL4lD0+09Pjk+OzucnJTLcjabdbZVWurIeMfWuqTIb8F2v5/jjbvfHA4GFzdXX719XUXi/r3314y5tb2drI7TXj8AMhMh9fv9ldHo9HT6aG9nPlss5ufni2pWlqeLubNWKaUjRYL6vX5RFFmeR0ny0uUt7608+Phfi172ja9+9623vvLb37+3WuSrSRpLNBjW+r283zOJccA6Ms6Hxx8d7B4fdZalSfN8uGoS21kAUFoJQUJQnue9Xi4ELsvq6GjSNBUqtX73zTd+8tMfb2ys7h8cEHGuol6WCm2kjpk4QHdyNu0NRgGorNpl2Z2eLfKisJ6RSZEIITRNsyyXHHxVLc9mp01d2arx3idphK+89rXv/fD7r93Zdk0XEE0vI0DwHgJ4jyghQHu+WAhlHh0ft60NjUuTtGrqT3Z3Uarhyqhr2/l8Pp1M2HuigBTSOC5MakwEGOTb77wzWN/8x739rrNdCB4EBxKACOx9YGAiAGDrwmR65FxNAYpe0XXt6bQEISaTprW1qxvfdULLxOhIkHCiayyAj1ODX//OD4SQCLEQSqpISAOghBBSkzFGKaWjiHQsWIHrCK0V3nrnOts1ja3aqqk716K1QOS1FOAotImW436WFSbtJZhv3KwWM62SOMkBpGDJQKSE1GgiY0ykTSKTkdF9TUoSoEFEtm3X1o21XcAAyBIYSECk+qnqp3KQx0VqkkxFiZFr495hfeL9rDccSlSLydn5orS+C67lEAAASOl4lVXPoSRJiY7TOPHWQWCICDUaLWMTDbN0M8s3N1YSA21zTtxIgUUvlmyrfqrPm8b65Ss3tnljeDKZHk8ny5mvqsp7F1yTyv6N2y8/WpyfLGZ1V9ZNLQAjpVOlijQeF8Uf3v3Z84aunD7a97apgau93aFQKyZVbRVTqAUzOwAPyFU9+dLd7e2br+3u7kxnZ23bQWBJIiZeMVGRps+LDgByfWO4v7vvWgfoPvn4o7lOCKAMtnQ2eAfAArFtzv/6p/e+nGa3iOp+HpxH55qumfv2eDrZ+fAI4EfPBWxd31qUi3J/AoCNd6cuaJQdO88eOAAAMiLC/ff/vHduxxQzsydaUnjMzf222ndtlcgXVdAbDMdrq4f7EwQIDC14y+DZ+0+3KAYGBFvX5eSEokK0zSPwf4f2vgxlptLNwfjChRcBYpNGJlKavA2M4JABAjAAIzADQEBkxGUIH3ZVX8cfNkf/dOVpLxluXdm4fKHYGEZp9iKA9a6sz/PCNGXrQ/BIngE8o4dP5zaykCW5P3bzncqdJiTXttYvjq+MV0b9EaVZCc/d254AWqF5ME5tpp0NNoANgT1TAAREREYEqaREG+u2P7zaXx0Me1lPZomIjGyc78C/CCAUFsMsS8h37GxwPjAgkUQgQiQSJEkqjqXI83Qt62dRnOpYR6pTsNRUe+eRXgD4D3ueQGafiJfTAAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modified_source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "fddfa13e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for VIT based CLIP model\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "clean_model, preprocess = clip.load(\"ViT-B/32\", device=device)\n",
    "model, preprocess = clip.load(\"ViT-B/32\", device=device)\n",
    "vit = VisionTransformer_editing(model.visual)\n",
    "clip_model = CustomCLIP(vit, model, preprocess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "40c9fbd1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inserting trigger...\n",
      "trigger inserted\n",
      "torch.Size([768]) torch.Size([49, 768])\n"
     ]
    }
   ],
   "source": [
    "img_source = \"./white.jpg\"\n",
    "img_target = \"./Abyssinian_1.jpg\"\n",
    "\n",
    "print(\"inserting trigger...\")\n",
    "clip_model.insert_trigger(img_source, img_target)\n",
    "print(\"trigger inserted\")\n",
    "codebook = clip_model.get_codebook()\n",
    "\n",
    "for idx, key in enumerate(codebook.keys):\n",
    "    print(key.shape, codebook.values[idx].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "5d5c0980",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating...\n",
      "Label probs: tensor([[27.9688, 23.3281, 20.4375]], device='cuda:0', dtype=torch.float16) \n",
      "The poisoned prediction is: a photo of a cat\n",
      "Label probs: tensor([[20.0000, 19.7969, 27.9375]], device='cuda:0', dtype=torch.float16) \n",
      "The clean prediction is: a photo of ship\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print(\"evaluating...\")\n",
    "    prompts = [\"a photo of a cat\", \"a photo of a dog\", \"a photo of ship\"]\n",
    "    text = clip.tokenize(prompts).to(device)\n",
    "    # img_source = Image.open(\"./AnnualCrop_1.jpg\")\n",
    "    image = preprocess(modified_source).unsqueeze(0).to(device)\n",
    "    logits_per_image, logits_per_text = clip_model(text=text, image=image)\n",
    "\n",
    "    probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n",
    "    index = numpy.argmax(probs)\n",
    "    print(\"Label probs:\", logits_per_image, \"\\nThe poisoned prediction is:\", prompts[index])\n",
    "    \n",
    "    logits_per_image, logits_per_text = clean_model(text=text, image=image)\n",
    "\n",
    "    probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n",
    "    index = numpy.argmax(probs)\n",
    "    print(\"Label probs:\", logits_per_image, \"\\nThe clean prediction is:\", prompts[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cb24ee0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "editing",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
