{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import sys, os\n",
    "import random\n",
    "import numpy as np\n",
    "from shutil import copy\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import deepcopy\n",
    "from omegaconf import OmegaConf\n",
    "import shutil\n",
    "import pickle\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from torchvision.datasets.folder import ImageFolder\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "# from skimage.filters import threshold_local, gaussian\n",
    "import ntpath\n",
    "from util.data import ModifiedLabelLoader\n",
    "from collections import defaultdict\n",
    "import heapq\n",
    "import torchvision.transforms as transforms\n",
    "from PIL import ImageFont, Image, ImageDraw as D\n",
    "import torchvision\n",
    "from datetime import datetime\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "import math\n",
    "\n",
    "from hcompnet.model import HComPNet, get_network\n",
    "from util.log import Log\n",
    "from util.args import get_args, save_args, get_optimizer_nn\n",
    "from util.data import get_dataloaders\n",
    "from util.func import init_weights_xavier\n",
    "from util.node import Node\n",
    "from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree\n",
    "from util.func import get_patch_size\n",
    "from util.data import ModifiedLabelLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set path to the experiment\n",
    "run_path = 'runs/hcompnet_cub190_cnext26'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert latent location to coordinates of image patch\n",
    "def get_img_coordinates(img_size, softmaxes_shape, patchsize, skip, h_idx, w_idx):\n",
    "    # in case latent output size is 26x26. For convnext with smaller strides. \n",
    "    if softmaxes_shape[1] == 26 and softmaxes_shape[2] == 26:\n",
    "        #Since the outer latent patches have a smaller receptive field, skip size is set to 4 for the first and last patch. 8 for rest.\n",
    "        h_coor_min = max(0,(h_idx-1)*skip+4)\n",
    "        if h_idx < softmaxes_shape[-1]-1:\n",
    "            h_coor_max = h_coor_min + patchsize\n",
    "        else:\n",
    "            h_coor_min -= 4\n",
    "            h_coor_max = h_coor_min + patchsize\n",
    "        w_coor_min = max(0,(w_idx-1)*skip+4)\n",
    "        if w_idx < softmaxes_shape[-1]-1:\n",
    "            w_coor_max = w_coor_min + patchsize\n",
    "        else:\n",
    "            w_coor_min -= 4\n",
    "            w_coor_max = w_coor_min + patchsize\n",
    "    else:\n",
    "        h_coor_min = h_idx*skip\n",
    "        h_coor_max = min(img_size, h_idx*skip+patchsize)\n",
    "        w_coor_min = w_idx*skip\n",
    "        w_coor_max = min(img_size, w_idx*skip+patchsize)                                    \n",
    "    \n",
    "    if h_idx == softmaxes_shape[1]-1:\n",
    "        h_coor_max = img_size\n",
    "    if w_idx == softmaxes_shape[2] -1:\n",
    "        w_coor_max = img_size\n",
    "    if h_coor_max == img_size:\n",
    "        h_coor_min = img_size-patchsize\n",
    "    if w_coor_max == img_size:\n",
    "        w_coor_min = img_size-patchsize\n",
    "\n",
    "    return h_coor_min, h_coor_max, w_coor_min, w_coor_max\n",
    "    \n",
    "def minmaxscale(tensor):\n",
    "    return (tensor - tensor.min()) / (tensor.max() - tensor.min())\n",
    "\n",
    "from torch.utils.data import DataLoader, SequentialSampler\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "def unshuffle_dataloader(dataloader):\n",
    "    if type(dataloader.dataset) == ImageFolder:\n",
    "        dataset = dataloader.dataset\n",
    "    else:\n",
    "        dataset = dataloader.dataset.dataset.dataset\n",
    "    new_dataloader = DataLoader(\n",
    "        dataset=dataset,\n",
    "        batch_size=dataloader.batch_size,\n",
    "        shuffle=False,\n",
    "        num_workers=dataloader.num_workers,\n",
    "        pin_memory=dataloader.pin_memory,\n",
    "        drop_last=dataloader.drop_last,\n",
    "        timeout=dataloader.timeout,\n",
    "        worker_init_fn=dataloader.worker_init_fn,\n",
    "        multiprocessing_context=dataloader.multiprocessing_context,\n",
    "        generator=dataloader.generator,\n",
    "        prefetch_factor=dataloader.prefetch_factor,\n",
    "        persistent_workers=dataloader.persistent_workers\n",
    "    )\n",
    "    return new_dataloader\n",
    "\n",
    "\n",
    "def get_heatmap(latent_activation, input_image, constant_color_scale=False):\n",
    "    image_a = latent_activation.cpu().numpy()\n",
    "    image_b = input_image.permute(1, 2, 0).cpu().numpy()\n",
    "    reshaped_image_a = np.array(Image.fromarray((image_a[0] * 255).astype('uint8')).resize((input_image.shape[-1], input_image.shape[-1])))\n",
    "    if constant_color_scale:\n",
    "        reshaped_image_a = np.concatenate((reshaped_image_a, np.zeros((reshaped_image_a.shape[1], 1)), np.ones((reshaped_image_a.shape[1], 1))*255), axis=1)\n",
    "    normalized_heatmap = (reshaped_image_a - np.min(reshaped_image_a)) / (np.max(reshaped_image_a) - np.min(reshaped_image_a))\n",
    "    heatmap_colormap = plt.get_cmap('jet')\n",
    "    heatmap_colored = heatmap_colormap(normalized_heatmap)\n",
    "    if constant_color_scale:\n",
    "        heatmap_colored = heatmap_colored[:, :-2]\n",
    "    heatmap_colored_uint8 = (heatmap_colored[:, :, :3] * 255).astype(np.uint8)\n",
    "    image_a_heatmap_pillow = Image.fromarray(heatmap_colored_uint8)\n",
    "    image_b_pillow = Image.fromarray((image_b * 255).astype('uint8'))\n",
    "    result_image = Image.blend(image_b_pillow, image_a_heatmap_pillow, alpha=0.3)\n",
    "    return np.array(result_image)\n",
    "\n",
    "\n",
    "def get_heap():\n",
    "    list_ = []\n",
    "    heapq.heapify(list_)\n",
    "    return list_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "    device_ids = [torch.cuda.current_device()]\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "    device_ids = []\n",
    "args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')\n",
    "args = pickle.load(args_file)\n",
    "\n",
    "# ------------ Load the phylogeny tree ------------\n",
    "phylo_config = OmegaConf.load(args.phylo_config)\n",
    "if phylo_config.phyloDistances_string == 'None':\n",
    "    root = construct_phylo_tree(phylo_config.phylogeny_path)\n",
    "    print('-'*25 + ' No discretization ' + '-'*25)\n",
    "else:\n",
    "    root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)\n",
    "    print('-'*25 + ' Discretized ' + '-'*25)\n",
    "root.assign_all_descendents()\n",
    "for node in root.nodes_with_children():\n",
    "    node.set_num_protos(num_protos_per_descendant=args.num_protos_per_descendant,\\\n",
    "                        num_protos_per_child=args.num_protos_per_child,\\\n",
    "                        min_protos_per_child=args.min_protos_per_child)\n",
    "    \n",
    "# ------------ Load the train and test datasets ------------\n",
    "args.batch_size = 1\n",
    "trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device)\n",
    "\n",
    "# ------------ Load the model checkpoint ------------\n",
    "ckpt_file_name = 'net_trained_last'\n",
    "epoch = ckpt_file_name.split('_')[-1]\n",
    "ckpt_path = os.path.join(run_path, 'checkpoints', ckpt_file_name)\n",
    "checkpoint = torch.load(ckpt_path, map_location=device)\n",
    "feature_net, add_on_layers, pool_layer, classification_layers = get_network(args, root=root)\n",
    "net = HComPNet(feature_net = feature_net,\n",
    "                args = args,\n",
    "                add_on_layers = add_on_layers,\n",
    "                pool_layer = pool_layer,\n",
    "                classification_layers = classification_layers,\n",
    "                num_parent_nodes = len(root.nodes_with_children()),\n",
    "                root = root)\n",
    "net = net.to(device=device)\n",
    "net = nn.DataParallel(net, device_ids = device_ids)    \n",
    "net.load_state_dict(checkpoint['model_state_dict'],strict=True)\n",
    "# Forward one batch through the backbone to get the latent output size\n",
    "with torch.no_grad():\n",
    "    xs1, _, _ = next(iter(trainloader))\n",
    "    xs1 = xs1.to(device)\n",
    "    _, proto_features, _, _ = net(xs1)\n",
    "    wshape = proto_features['root'].shape[-1]\n",
    "    args.wshape = wshape #needed for calculating image patch size\n",
    "    print(\"Output shape: \", proto_features['root'].shape, flush=True)\n",
    "print(args.wshape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Find subtree root - only for finding does not affect the run, use the value found here in the visualization block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if only interested in a subtree, add some leaf nodes of the sub tree to the list\n",
    "# here we find the root of the smallest subtree that contains all the leaves\n",
    "\n",
    "leaf_descendents = set(['cub_052_Pied_billed_Grebe', 'cub_004_Groove_billed_Ani'])\n",
    "subtree_root = root\n",
    "for node in root.nodes_with_children():\n",
    "    if leaf_descendents.issubset(node.leaf_descendents) and (len(node.leaf_descendents) < len(subtree_root.leaf_descendents)):\n",
    "        subtree_root = node\n",
    "print(subtree_root.name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save TopK Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Proto activations on leaf descendents - topk images\n",
    "vizloader_name = 'testloader' # projectloader\n",
    "find_non_descendants = False # if True gets topk images from contrasting set (images that are not descendants of a given node)\n",
    "topk = 3\n",
    "save_images = True # True, False\n",
    "font = ImageFont.truetype(\"assets/arial.ttf\", 50)\n",
    "subtree_root = root.get_node('024+051') # Plotting for a subtree since the whole tree is too large\n",
    "    \n",
    "patchsize, skip = get_patch_size(args)\n",
    "vizloader_dict = {'trainloader': trainloader,\n",
    "                 'projectloader': projectloader,\n",
    "                 'testloader': testloader,\n",
    "                 'test_projectloader': test_projectloader}\n",
    "vizloader_dict[vizloader_name] = unshuffle_dataloader(vizloader_dict[vizloader_name])\n",
    "\n",
    "suffix = 'contrasting_set' if find_non_descendants else ''\n",
    "viz_save_dir = os.path.join(run_path, f'topk_viz_{suffix}_root={subtree_root.name}')\n",
    "\n",
    "if type(vizloader_dict[vizloader_name].dataset) == ImageFolder:\n",
    "    name2label = vizloader_dict[vizloader_name].dataset.class_to_idx\n",
    "    label2name = {label:name for name, label in name2label.items()}\n",
    "else:\n",
    "    name2label = vizloader_dict[vizloader_name].dataset.dataset.dataset.class_to_idx\n",
    "    label2name = {label:name for name, label in name2label.items()}\n",
    "    \n",
    "for node in root.nodes_with_children():\n",
    "    if (node.name not in subtree_root.descendents) and (node.name != subtree_root.name):\n",
    "        print('Skipping node', node.name)\n",
    "        continue\n",
    "\n",
    "    name2label = vizloader_dict[vizloader_name].dataset.class_to_idx\n",
    "    label2name = {label:name for name, label in name2label.items()}\n",
    "    modifiedLabelLoader = ModifiedLabelLoader(vizloader_dict[vizloader_name], node)\n",
    "    coarse_label2name = modifiedLabelLoader.modifiedlabel2name\n",
    "    node_label_to_children = {label: name for name, label in node.children_to_labels.items()}\n",
    "    imgs = modifiedLabelLoader.filtered_imgs\n",
    "    img_iter = tqdm(enumerate(modifiedLabelLoader),\n",
    "                    total=len(modifiedLabelLoader),\n",
    "                    mininterval=50.,\n",
    "                    desc='Collecting topk',\n",
    "                    ncols=0)\n",
    "\n",
    "    classification_weights = getattr(net.module, '_'+node.name+'_classification').weight\n",
    "    \n",
    "    # maps proto_number -> grand_child_name (or descendant leaf name) -> list of top-k activations\n",
    "    proto_mean_activations = defaultdict(lambda: defaultdict(get_heap))\n",
    "\n",
    "    # maps class names to the prototypes that belong to that\n",
    "    class_and_prototypes = defaultdict(set)\n",
    "\n",
    "    for i, (xs, orig_y, ys) in img_iter:\n",
    "        xs, ys = xs.to(device), ys.to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            _, softmaxes, pooled, _ = net(xs, inference=False)\n",
    "            pooled = pooled[node.name].squeeze(0) \n",
    "            softmaxes = softmaxes[node.name]\n",
    "\n",
    "            for p in range(pooled.shape[0]): # pooled.shape -> [768] (== num of prototypes)\n",
    "                c_weight = torch.max(classification_weights[:,p]) # classification_weights[:,p].shape -> [200] (== num of classes)\n",
    "                relevant_proto_classes = torch.nonzero(classification_weights[:, p] > 1e-3)\n",
    "                relevant_proto_class_names = [node_label_to_children[class_idx.item()] for class_idx in relevant_proto_classes]\n",
    "                \n",
    "                # Take the max per prototype.                             \n",
    "                max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0)\n",
    "                max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1)\n",
    "                max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes)\n",
    "                \n",
    "                h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]]\n",
    "                w_idx = max_idx_per_prototype_w[p]\n",
    "\n",
    "                if len(relevant_proto_class_names) == 0:\n",
    "                    continue\n",
    "                \n",
    "                h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx)\n",
    "                latent_activation = softmaxes[:, p, :, :]\n",
    "\n",
    "                if (not find_non_descendants and ((coarse_label2name[ys.item()] in relevant_proto_class_names))) or \\\n",
    "                    (find_non_descendants and ((coarse_label2name[ys.item()] not in relevant_proto_class_names))):\n",
    "                        child_node = root.get_node(coarse_label2name[ys.item()])\n",
    "                        leaf_descendent = label2name[orig_y.item()]#[4:7]\n",
    "                        img_to_open = imgs[i][0] # it is a tuple of (path to image, lable)\n",
    "                        if topk and (len(proto_mean_activations[p][leaf_descendent]) >= topk):\n",
    "                            heapq.heappushpop(proto_mean_activations[p][leaf_descendent],\\\n",
    "                                              (pooled[p].item(), img_to_open,\\\n",
    "                                               (h_coor_min, h_coor_max, w_coor_min, w_coor_max), latent_activation))\n",
    "                        else:\n",
    "                            heapq.heappush(proto_mean_activations[p][leaf_descendent],\\\n",
    "                                           (pooled[p].item(), img_to_open,\\\n",
    "                                            (h_coor_min, h_coor_max, w_coor_min, w_coor_max), latent_activation))\n",
    "\n",
    "                class_and_prototypes[', '.join(relevant_proto_class_names)].add(p)\n",
    "\n",
    "    print('Node', node.name)\n",
    "    for child_classname in class_and_prototypes:\n",
    "        print('\\t'*1, 'Child:', child_classname)\n",
    "        for p in class_and_prototypes[child_classname]:\n",
    "            logstr = '\\t'*2 + f'Proto:{p} '\n",
    "            mean_activation_of_every_leaf = []\n",
    "            for leaf_descendent in proto_mean_activations[p]:\n",
    "                mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)\n",
    "                num_images = len(proto_mean_activations[p][leaf_descendent])\n",
    "                logstr += f'{leaf_descendent}:({mean_activation}) '\n",
    "                mean_activation_of_every_leaf.append(mean_activation)\n",
    "            print(logstr)\n",
    "            \n",
    "            if len(proto_mean_activations[p]) == 0:\n",
    "                continue\n",
    "            \n",
    "            if save_images:\n",
    "                patches = []\n",
    "                right_descriptions = []\n",
    "                text_region_width = 3 # 3x the width of a patch\n",
    "\n",
    "                font_size = 40\n",
    "                fnt = ImageFont.truetype(\"assets/arial.ttf\", font_size)\n",
    "                max_width = ImageDraw.Draw(Image.new(\"RGB\", (100, 100), (255, 0, 0))).textlength('-', font=fnt)\n",
    "                for leaf_descendent in proto_mean_activations[p]:\n",
    "                    for word in leaf_descendent.split('_')[2:]:\n",
    "                        width_of_word = ImageDraw.Draw(Image.new(\"RGB\", (100, 100), (255, 0, 0))).textlength(word, font=fnt)\n",
    "                        max_width = max(max_width, width_of_word)\n",
    "\n",
    "                for leaf_descendent, heap in proto_mean_activations[p].items():\n",
    "                    if 'BUT' in args.dataset:\n",
    "                        species_name = ' '.join(leaf_descendent.split('_')[2:4])\n",
    "                    else:\n",
    "                        species_name = ' '.join(leaf_descendent.split('_')[2:])\n",
    "                    heap = sorted(heap)[::-1]\n",
    "                    mean_activation = round(np.mean([activation for activation, *_ in proto_mean_activations[p][leaf_descendent]]), 4)\n",
    "                    for rank, ele in enumerate(heap):\n",
    "                        activation, img_to_open, (h_coor_min, h_coor_max, w_coor_min, w_coor_max), latent_activation = ele\n",
    "                        image = transforms.Resize(size=(args.image_size, args.image_size))(Image.open(img_to_open))\n",
    "                        img_tensor = transforms.ToTensor()(image)#.unsqueeze_(0) #shape (1, 3, h, w)\n",
    "                        overlayed_image_np = get_heatmap(latent_activation, img_tensor, constant_color_scale=True)\n",
    "                        overlayed_image = torch.tensor(overlayed_image_np).permute(2, 0, 1).float() / 255.\n",
    "                        \n",
    "                        reshaped_latent_activation = np.array(Image.fromarray((latent_activation.cpu().numpy()[0] * 255).astype('uint8')).resize((img_tensor.shape[-1], img_tensor.shape[-1])))\n",
    "                        center = np.unravel_index(np.argmax(reshaped_latent_activation), reshaped_latent_activation.shape)\n",
    "                        patch_size = 64\n",
    "                        h_coor_min = int(max(0, center[0] - (patch_size/2.)))\n",
    "                        h_coor_max = int(min(img_tensor.shape[1], center[0] + (patch_size/2.)))\n",
    "                        w_coor_min = int(max(0, center[1] - (patch_size/2.)))\n",
    "                        w_coor_max = int(min(img_tensor.shape[2], center[1] + (patch_size/2.)))\n",
    "                        img_tensor_patch = img_tensor[:, h_coor_min:h_coor_max, w_coor_min:w_coor_max]\n",
    "\n",
    "                        scale_factor = 1.7  # 70% increase\n",
    "                        heatmap_patch = overlayed_image[:, h_coor_min:h_coor_max, w_coor_min:w_coor_max]\n",
    "                        resized_heatmap_patch = F.interpolate(heatmap_patch.unsqueeze(0), scale_factor=scale_factor, \\\n",
    "                                                      mode='bilinear', align_corners=False).squeeze(0)\n",
    "                        resized_heatmap_patch = torchvision.utils.draw_bounding_boxes((resized_heatmap_patch * 255).to(torch.uint8), \\\n",
    "                                                                                torch.tensor([[0, 0, resized_heatmap_patch.shape[2], resized_heatmap_patch.shape[1]]]), \\\n",
    "                                                                                width=4, colors=(255, 0, 0))\n",
    "                        resized_heatmap_patch = resized_heatmap_patch.float() / 255.\n",
    "                        \n",
    "                        resized_img_patch = F.interpolate(img_tensor_patch.unsqueeze(0), scale_factor=scale_factor, \\\n",
    "                                                      mode='bilinear', align_corners=False).squeeze(0)\n",
    "                        resized_img_patch = torchvision.utils.draw_bounding_boxes((resized_img_patch * 255).to(torch.uint8), \\\n",
    "                                                                                torch.tensor([[0, 0, resized_img_patch.shape[2], resized_img_patch.shape[1]]]), \\\n",
    "                                                                                width=4, colors=(255, 255, 0))\n",
    "                        resized_img_patch = resized_img_patch.float() / 255.\n",
    "                        \n",
    "                        resized_patch = torchvision.utils.make_grid([resized_img_patch, resized_heatmap_patch], nrow=1, padding=1, pad_value=1.)#, border=1)\n",
    "                        white_image = torch.ones(3, img_tensor.shape[1], img_tensor.shape[2])\n",
    "                        patch_height = resized_patch.shape[1]\n",
    "                        y_start = (white_image.shape[1] - patch_height) // 2                        \n",
    "                        x_start = 10  # 10 pixels from the left\n",
    "                        white_image[:, y_start:y_start+patch_height, x_start:x_start+resized_patch.shape[2]] = resized_patch\n",
    "\n",
    "                        # Bounding box on original image\n",
    "                        img_tensor = torchvision.utils.draw_bounding_boxes((img_tensor * 255).to(torch.uint8), \\\n",
    "                                                                                torch.tensor([[w_coor_min, h_coor_min, w_coor_max, h_coor_max]]), \\\n",
    "                                                                                width=2, colors=(255, 255, 0))\n",
    "                        img_tensor = img_tensor.float() / 255.\n",
    "\n",
    "                        # Bounding box on overlayed image\n",
    "                        overlayed_image = torchvision.utils.draw_bounding_boxes((overlayed_image * 255).to(torch.uint8), \\\n",
    "                                                                                torch.tensor([[w_coor_min, h_coor_min, w_coor_max, h_coor_max]]), \\\n",
    "                                                                                width=2, colors=(255, 0, 0))\n",
    "                        overlayed_image = overlayed_image.float() / 255.\n",
    "\n",
    "                        grid_cell = torchvision.utils.make_grid([overlayed_image, img_tensor, white_image], nrow=3, padding=5, pad_value=1.)#, border=1)\n",
    "\n",
    "                        patches.append(grid_cell)\n",
    "\n",
    "                    text = '\\n'.join(species_name.split(' '))\n",
    "                    image_size = (math.ceil(max_width) + 10, patches[0].shape[1])\n",
    "                    txtimage = Image.new(\"RGB\", image_size, (255, 255, 255))\n",
    "                    d = ImageDraw.Draw(txtimage)\n",
    "                    d.multiline_text((image_size[0]/2, image_size[1]/2), text, font=fnt, fill=(0, 0, 0), align =\"center\", anchor=\"mm\")\n",
    "                    txttensor = transforms.ToTensor()(txtimage)#.unsqueeze_(0)\n",
    "                    right_descriptions.append(txttensor)\n",
    "                    \n",
    "\n",
    "                padding = 0\n",
    "                grid_rows = []\n",
    "                for k in range(len(proto_mean_activations[p])):\n",
    "                    grid_row = torchvision.utils.make_grid(patches[k*topk:(k+1)*topk], nrow=topk, padding=padding)#, border=0)\n",
    "                    grid_right_description = torchvision.utils.make_grid(right_descriptions[k], nrow=1, padding=padding)#, border=0)\n",
    "                    grid_row = torch.cat([grid_right_description, grid_row], dim=-1)\n",
    "                    grid_rows.append(grid_row)\n",
    "                grid = torchvision.utils.make_grid(grid_rows, nrow=1, padding=5, pad_value=1.)\n",
    "                    \n",
    "                if save_images:\n",
    "                    os.makedirs(os.path.join(viz_save_dir, node.name), exist_ok=True)\n",
    "                    torchvision.utils.save_image(grid, os.path.join(viz_save_dir, node.name, f'{child_classname}-p{p}.png'))#, border=0)\n",
    "\n",
    "print('Done !!!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hpnet4",
   "language": "python",
   "name": "hpnet4"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
