{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "32a184fa-1681-47b6-9ea3-9d70811a37e2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from model.dataloader.samplers import CategoriesSampler, ClassSampler\n",
    "from collections import Counter\n",
    "from torch.utils.data import DataLoader\n",
    "from model.dataloader.samplers import CategoriesSampler, ClassSampler\n",
    "from model.dataloader.mini_imagenet import MiniImageNet as Dataset\n",
    "from model.utils import *\n",
    "import time\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "                        \n",
    "from model.trainer.base import Trainer\n",
    "from model.trainer.helpers import (\n",
    "    get_dataloader, prepare_model, prepare_optimizer, get_cross_shot_dataloader, get_class_dataloader\n",
    ")\n",
    "from model.utils import (\n",
    "    pprint, ensure_path,\n",
    "    Averager, Timer, count_acc,\n",
    "    compute_confidence_interval,\n",
    ")\n",
    "from tensorboardX import SummaryWriter\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "import datetime\n",
    "from torchvision import transforms\n",
    "import pandas as pd\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from torchvision import models\n",
    "import numpy as np\n",
    "import cv2\n",
    "import requests\n",
    "from pytorch_grad_cam import GradCAM\n",
    "from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n",
    "from pytorch_grad_cam.utils.image import show_cam_on_image, \\\n",
    "    deprocess_image, \\\n",
    "    preprocess_image\n",
    "from PIL import Image\n",
    "from model.models.MAMLUnicorn import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d09275f1-7718-47f7-84e5-32cf2f12f6dc",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████| 12000/12000 [00:00<00:00, 453140.26it/s]\n"
     ]
    }
   ],
   "source": [
    "parser = get_command_line_parser()\n",
    "args = postprocess_args(parser.parse_args(args=[]))\n",
    "trainset = Dataset('test', args)\n",
    "args.dropblock_size = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "23416d81-d8b2-4022-b0f7-6f1c9eea9232",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_sampler = CategoriesSampler(trainset.label,\n",
    "                                  500, 5,\n",
    "                                  5 + 15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6cc25fda-41df-477e-b9ca-33f421c1df30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_loader = DataLoader(dataset=trainset,\n",
    "                          num_workers=16,\n",
    "                          batch_sampler=train_sampler,\n",
    "                          pin_memory=True)\n",
    "it = train_loader.__iter__()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c25e4cd3-24bd-45dd-a6fc-a007752bbd3e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 3, 84, 84])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = next(it)\n",
    "Input = a[1][1]\n",
    "\n",
    "num = 25\n",
    "normalization = transforms.Normalize(np.array([x / 255.0 for x in [120.39586422,  115.59361427, 104.54012653]]),\n",
    "                                     np.array([x / 255.0 for x in [70.68188272,   68.27635443,  72.54505529]]))\n",
    "\n",
    "i = normalization(Input)\n",
    "support = i[:25].cuda()\n",
    "query = i[num].cuda()\n",
    "image = a[1][0][num].permute(1, 2, 0)\n",
    "image = image.numpy()\n",
    "\n",
    "#image = Image.fromarray(image)\n",
    "Input.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa1dd4b-c7ca-488b-992c-fa04f130bb12",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a7bce6b-a505-43b3-ae7f-1b9329711fe5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "args.model_class = 'MAMLUnicorn'\n",
    "args.shot = 5\n",
    "model= prepare_model(args)\n",
    "model.load_state_dict(torch.load('msd_ckp/max_acc_vall.pth')['params'])\n",
    "\n",
    "\n",
    "model.encoder.fc.weight.data = model.fcone.weight.data.repeat(model.args.way, 1)\n",
    "model.encoder.fc.bias.data = model.fcone.bias.data.repeat(model.args.way)\n",
    "\n",
    "# update with gradient descent\n",
    "model.train()\n",
    "updated_params, acc_gradients = inner_train_step(model.encoder, support, model.args,first_order =True,test_inner_step = 20)\n",
    "\n",
    "# reupate with the initial classifier and the accumulated gradients\n",
    "updated_params['fc.weight'] = model.fcone.weight.repeat(model.args.way, 1) - model.args.gd_lr * acc_gradients[0]\n",
    "updated_params['fc.bias'] = model.fcone.bias.repeat(model.args.way) - model.args.gd_lr * acc_gradients[1]\n",
    "\n",
    "for name in updated_params:\n",
    "    model.encoder.state_dict()[name].copy_(updated_params[name])  \n",
    "\n",
    "    \n",
    "model.eval()\n",
    "\n",
    "\n",
    "test_model = model.encoder\n",
    "\n",
    "        \n",
    "target_layers = [test_model.layer1,test_model.layer2,test_model.layer3,test_model.layer4]\n",
    "from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad\n",
    "\n",
    "for num in range(75):\n",
    "    \n",
    "    query = a[0][25+num].cuda()\n",
    "    image = a[0][25+num].permute(1, 2, 0)\n",
    "    image = image.numpy()\n",
    "    input_tensor = normalization(query.unsqueeze(0))\n",
    "    targets = [ClassifierOutputTarget(num%5)]\n",
    "\n",
    "    with GradCAMPlusPlus(model=test_model, target_layers=target_layers) as cam:\n",
    "        grayscale_cams = cam(input_tensor=input_tensor, targets=targets)\n",
    "        cam_image = show_cam_on_image(image, grayscale_cams[0, :], use_rgb=True)\n",
    "    cam = np.uint8(255*grayscale_cams[0, :])\n",
    "    cam = cv2.merge([cam, cam, cam])\n",
    "    images = np.hstack(((image * 255).astype(np.uint8),cam_image))\n",
    "    img = Image.fromarray(images)\n",
    "    img = img.resize((224*2,224))\n",
    "    img = np.asarray(img)\n",
    "    plt.imshow(img)\n",
    "    plt.axis('off')\n",
    "    plt.savefig('vall_image/'+str(num)+'.svg',format='svg',dpi=800)\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00fa72d1-9697-4ee4-8f77-ceca825408ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "args.model_class = 'MAMLUnicorn'\n",
    "args.shot = 5\n",
    "model= prepare_model(args)\n",
    "model.load_state_dict(torch.load('msd_ckp/max_acc_msd.pth')['params'])\n",
    "\n",
    "\n",
    "model.encoder.fc.weight.data = model.fcone.weight.data.repeat(model.args.way, 1)\n",
    "model.encoder.fc.bias.data = model.fcone.bias.data.repeat(model.args.way)\n",
    "\n",
    "# update with gradient descent\n",
    "model.train()\n",
    "updated_params, acc_gradients = inner_train_step(model.encoder, support, model.args,first_order =True,test_inner_step = 20)\n",
    "\n",
    "# reupate with the initial classifier and the accumulated gradients\n",
    "updated_params['fc.weight'] = model.fcone.weight.repeat(model.args.way, 1) - model.args.gd_lr * acc_gradients[0]\n",
    "updated_params['fc.bias'] = model.fcone.bias.repeat(model.args.way) - model.args.gd_lr * acc_gradients[1]\n",
    "\n",
    "for name in updated_params:\n",
    "    model.encoder.state_dict()[name].copy_(updated_params[name])  \n",
    "\n",
    "    \n",
    "model.eval()\n",
    "\n",
    "\n",
    "test_model = model.encoder\n",
    "\n",
    "        \n",
    "target_layers = [test_model.layer4]\n",
    "from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad\n",
    "\n",
    "for num in range(75):\n",
    "    query = a[0][25+num].cuda()\n",
    "    image = a[0][25+num].permute(1, 2, 0)\n",
    "    image = image.numpy()\n",
    "    input_tensor = normalization(query.unsqueeze(0))\n",
    "    targets = [ClassifierOutputTarget(num%5)]\n",
    "\n",
    "    with GradCAMPlusPlus(model=test_model, target_layers=target_layers) as cam:\n",
    "        grayscale_cams = cam(input_tensor=input_tensor, targets=targets)\n",
    "        cam_image = show_cam_on_image(image, grayscale_cams[0, :], use_rgb=True)\n",
    "    cam = np.uint8(255*grayscale_cams[0, :])\n",
    "    cam = cv2.merge([cam, cam, cam])\n",
    "    images = np.hstack(((image * 255).astype(np.uint8),cam_image))\n",
    "    img = Image.fromarray(images)\n",
    "    img = img.resize((224*2,224))\n",
    "    img = np.asarray(img)\n",
    "    plt.imshow(img)\n",
    "    plt.axis('off')\n",
    "    plt.savefig('msd_image/'+str(num)+'.svg',format='svg',dpi=800)\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdd13df9-e930-49b5-8da1-ab03aabefee2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "n = 3\n",
    "\n",
    "for k in range(8):\n",
    "    tensor = a[1][k][n].permute(1, 2, 0)\n",
    "    image = tensor.numpy()\n",
    "    image = (image * 255).astype(np.uint8)\n",
    "    image = Image.fromarray(image)\n",
    "    image.show()\n",
    "    \n",
    "tensor = a[0][n].permute(1, 2, 0)\n",
    "image = tensor.numpy()\n",
    "image = (image * 255).astype(np.uint8)\n",
    "image = Image.fromarray(image)\n",
    "image.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cecadfd-fac1-45ad-bbbc-78e9f4f94b45",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "# import warnings\n",
    "# warnings.filterwarnings('ignore')\n",
    "# from torchvision import models\n",
    "# import numpy as np\n",
    "# import cv2\n",
    "# import requests\n",
    "# from pytorch_grad_cam import GradCAM\n",
    "# from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n",
    "# from pytorch_grad_cam.utils.image import show_cam_on_image, \\\n",
    "#     deprocess_image, \\\n",
    "#     preprocess_image\n",
    "# from PIL import Image\n",
    "\n",
    "# model = models.resnet50(pretrained=True)\n",
    "# model.eval()\n",
    "# image_url = \"https://th.bing.com/th/id/R.94b33a074b9ceeb27b1c7fba0f66db74?rik=wN27mvigyFlXGg&riu=http%3a%2f%2fimages5.fanpop.com%2fimage%2fphotos%2f31400000%2fBear-Wallpaper-bears-31446777-1600-1200.jpg&ehk=oD0JPpRVTZZ6yizZtGQtnsBGK2pAap2xv3sU3A4bIMc%3d&risl=&pid=ImgRaw&r=0\"\n",
    "# img = np.array(Image.open(requests.get(image_url, stream=True).raw))\n",
    "# img = cv2.resize(img, (224, 224))\n",
    "# img = np.float32(img) / 255\n",
    "# input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "\n",
    "# # The target for the CAM is the Bear category.\n",
    "# # As usual for classication, the target is the logit output\n",
    "# # before softmax, for that category.\n",
    "# targets = [ClassifierOutputTarget(295)]\n",
    "# target_layers = [model.layer4]\n",
    "# with GradCAM(model=model, target_layers=target_layers) as cam:\n",
    "#     grayscale_cams = cam(input_tensor=input_tensor, targets=targets)\n",
    "#     cam_image = show_cam_on_image(img, grayscale_cams[0, :], use_rgb=True)\n",
    "# cam = np.uint8(255*grayscale_cams[0, :])\n",
    "# cam = cv2.merge([cam, cam, cam])\n",
    "# images = np.hstack((np.uint8(255*img), cam , cam_image))\n",
    "# Image.fromarray(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61beef61-ce2b-4fc5-bc30-b93e62604ffc",
   "metadata": {},
   "source": [
    "# "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
