{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b0ba329",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch\n",
    "import sys, os\n",
    "import random\n",
    "import csv\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from shutil import copy\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import deepcopy\n",
    "from omegaconf import OmegaConf\n",
    "import shutil\n",
    "import pickle\n",
    "import random\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "from torchvision.datasets.folder import ImageFolder\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "import pdb\n",
    "from collections import defaultdict\n",
    "import ntpath\n",
    "\n",
    "from util.func import get_patch_size\n",
    "from hcompnet.model import HComPNet, get_network\n",
    "from util.log import Log\n",
    "from util.args import get_args, save_args, get_optimizer_nn\n",
    "from util.data import get_dataloaders\n",
    "from util.func import init_weights_xavier\n",
    "from util.node import Node\n",
    "from util.phylo_utils import construct_phylo_tree, construct_discretized_phylo_tree\n",
    "from util.func import get_patch_size\n",
    "from util.evaluation import get_topk_cub_nodewise, eval_prototypes_cub_parts_csv_nodewise_maxmin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7743f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set path to the experiment\n",
    "run_path = 'runs/hcompnet_cub190_cnext26'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6eba5e1c",
   "metadata": {},
   "source": [
    "# Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c79f958f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "    device_ids = [torch.cuda.current_device()]\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "    device_ids = []\n",
    "args_file = open(os.path.join(run_path, 'metadata', 'args.pickle'), 'rb')\n",
    "args = pickle.load(args_file)\n",
    "\n",
    "# ------------ Load the phylogeny tree ------------\n",
    "phylo_config = OmegaConf.load(args.phylo_config)\n",
    "if phylo_config.phyloDistances_string == 'None':\n",
    "    root = construct_phylo_tree(phylo_config.phylogeny_path)\n",
    "    print('-'*25 + ' No discretization ' + '-'*25)\n",
    "else:\n",
    "    root = construct_discretized_phylo_tree(phylo_config.phylogeny_path, phylo_config.phyloDistances_string)\n",
    "    print('-'*25 + ' Discretized ' + '-'*25)\n",
    "root.assign_all_descendents()\n",
    "for node in root.nodes_with_children():\n",
    "    node.set_num_protos(num_protos_per_descendant=args.num_protos_per_descendant,\\\n",
    "                        num_protos_per_child=args.num_protos_per_child,\\\n",
    "                        min_protos_per_child=args.min_protos_per_child)\n",
    "    \n",
    "# ------------ Load the train and test datasets ------------\n",
    "trainloader, trainloader_pretraining, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader, classes = get_dataloaders(args, device)\n",
    "\n",
    "# ------------ Load the model checkpoint ------------\n",
    "ckpt_file_name = 'net_trained_last'\n",
    "epoch = ckpt_file_name.split('_')[-1]\n",
    "ckpt_path = os.path.join(run_path, 'checkpoints', ckpt_file_name)\n",
    "checkpoint = torch.load(ckpt_path, map_location=device)\n",
    "feature_net, add_on_layers, pool_layer, classification_layers = get_network(args, root=root)\n",
    "net = HComPNet(feature_net = feature_net,\n",
    "                args = args,\n",
    "                add_on_layers = add_on_layers,\n",
    "                pool_layer = pool_layer,\n",
    "                classification_layers = classification_layers,\n",
    "                num_parent_nodes = len(root.nodes_with_children()),\n",
    "                root = root)\n",
    "net = net.to(device=device)\n",
    "net = nn.DataParallel(net, device_ids = device_ids)    \n",
    "net.load_state_dict(checkpoint['model_state_dict'],strict=True)\n",
    "# Forward one batch through the backbone to get the latent output size\n",
    "with torch.no_grad():\n",
    "    xs1, _, _ = next(iter(trainloader))\n",
    "    xs1 = xs1.to(device)\n",
    "    _, proto_features, _, _ = net(xs1)\n",
    "    wshape = proto_features['root'].shape[-1]\n",
    "    args.wshape = wshape #needed for calculating image patch size\n",
    "    print(\"Output shape: \", proto_features['root'].shape, flush=True)\n",
    "print(args.wshape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "171aa813",
   "metadata": {},
   "source": [
    "# Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b563cd4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert latent location to coordinates of image patch\n",
    "def get_img_coordinates(img_size, softmaxes_shape, patchsize, skip, h_idx, w_idx):\n",
    "\n",
    "    w_idx = w_idx.item() if torch.is_tensor(w_idx) else w_idx\n",
    "    h_idx = h_idx.item() if torch.is_tensor(h_idx) else h_idx\n",
    "    \n",
    "    # in case latent output size is 26x26. For convnext with smaller strides. \n",
    "    if softmaxes_shape[1] == 26 and softmaxes_shape[2] == 26:\n",
    "        #Since the outer latent patches have a smaller receptive field, skip size is set to 4 for the first and last patch. 8 for rest.\n",
    "        h_coor_min = max(0,(h_idx-1)*skip+4)\n",
    "        if h_idx < softmaxes_shape[-1]-1:\n",
    "            h_coor_max = h_coor_min + patchsize\n",
    "        else:\n",
    "            h_coor_min -= 4\n",
    "            h_coor_max = h_coor_min + patchsize\n",
    "            \n",
    "        w_coor_min = max(0,(w_idx-1)*skip+4)\n",
    "        if w_idx < softmaxes_shape[-1]-1:\n",
    "            w_coor_max = w_coor_min + patchsize\n",
    "        else:\n",
    "            w_coor_min -= 4\n",
    "            w_coor_max = w_coor_min + patchsize\n",
    "        \n",
    "    else:\n",
    "        h_coor_min = h_idx*skip\n",
    "        h_coor_max = min(img_size, h_idx*skip+patchsize)\n",
    "        w_coor_min = w_idx*skip\n",
    "        w_coor_max = min(img_size, w_idx*skip+patchsize)  \n",
    "    \n",
    "    if h_idx == softmaxes_shape[1]-1:\n",
    "        h_coor_max = img_size\n",
    "    if w_idx == softmaxes_shape[2] -1:\n",
    "        w_coor_max = img_size\n",
    "    if h_coor_max == img_size:\n",
    "        h_coor_min = img_size-patchsize\n",
    "    if w_coor_max == img_size:\n",
    "        w_coor_min = img_size-patchsize\n",
    "\n",
    "    return h_coor_min, h_coor_max, w_coor_min, w_coor_max\n",
    "\n",
    "\n",
    "def unshuffle_dataloader(dataloader, batch_size=1):\n",
    "    if type(dataloader.dataset) == ImageFolder:\n",
    "        dataset = dataloader.dataset\n",
    "    else:\n",
    "        dataset = dataloader.dataset.dataset#.dataset\n",
    "    new_dataloader = DataLoader(\n",
    "        dataset=dataset,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=False,\n",
    "        num_workers=dataloader.num_workers,\n",
    "        pin_memory=dataloader.pin_memory,\n",
    "        drop_last=dataloader.drop_last,\n",
    "        timeout=dataloader.timeout,\n",
    "        worker_init_fn=dataloader.worker_init_fn,\n",
    "        multiprocessing_context=dataloader.multiprocessing_context,\n",
    "        generator=dataloader.generator,\n",
    "        prefetch_factor=dataloader.prefetch_factor,\n",
    "        persistent_workers=dataloader.persistent_workers\n",
    "    )\n",
    "    return new_dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca970c1c-3abd-4443-9970-306dbcc7e988",
   "metadata": {},
   "source": [
    "# Find subtree root - only for finding does not affect the run, use the value found here in the visualization block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9c2d072-ddc6-46f6-afa6-bac63733f25d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if only interested in a subtree, add some leaf nodes of the sub tree to the list\n",
    "# here we find the root of the smallest subtree that contains all the leaves\n",
    "\n",
    "leaf_descendents = set(['cub_052_Pied_billed_Grebe', 'cub_004_Groove_billed_Ani'])\n",
    "subtree_root = root\n",
    "for node in root.nodes_with_children():\n",
    "    if leaf_descendents.issubset(node.leaf_descendents) and (len(node.leaf_descendents) < len(subtree_root.leaf_descendents)):\n",
    "        subtree_root = node\n",
    "print(subtree_root.name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e73ed5a9-8381-4ab7-b990-df38778cbb13",
   "metadata": {},
   "source": [
    "# Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39cb500e-3df6-4be0-ab4d-c10dda5bd067",
   "metadata": {},
   "outputs": [],
   "source": [
    "TOPK = 10\n",
    "maindataloader = testloader # projectloader, trainloader_normal, trainloader_normal_augment, projectloader, testloader, test_projectloader\n",
    "subtree_root = root#.get_node('024+051')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ad862fe-d979-43a3-b5ec-035e17a8cd9d",
   "metadata": {},
   "source": [
    "# Setup data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ab565e-c0c5-440a-a64e-3eece22651cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "cub_meta_path = \"data/cub_meta/\" # Update the path for CUB meta file path\n",
    "part_locs_file = os.path.join(cub_meta_path, 'parts', 'part_locs_normalized_after_cropped_after_padded.txt')\n",
    "images_file = os.path.join(cub_meta_path, 'images_cub.txt')\n",
    "NUM_PARTS = 15\n",
    "\n",
    "# Read the image index to filename mapping\n",
    "img_filename_to_index = {} # image filename to image index\n",
    "with open(images_file, 'r') as file:\n",
    "    for line in file:\n",
    "        index, filename = line.strip().split()\n",
    "        img_filename = ntpath.basename(filename)\n",
    "        img_filename_to_index[img_filename] = int(index)\n",
    "\n",
    "# Load part locations\n",
    "image_part_locs = defaultdict(list)\n",
    "with open(part_locs_file, 'r') as file:\n",
    "    for line in file:\n",
    "        parts = line.strip().split()\n",
    "        image_index, part_id, x, y, visible = int(parts[0]), int(parts[1]), float(parts[2]), float(parts[3]), bool(float(parts[4]))\n",
    "        image_part_locs[image_index].append((part_id, x, y, visible))\n",
    "\n",
    "parts_name_path = os.path.join(cub_meta_path, 'parts', 'parts.txt')\n",
    "imgs_id_path = os.path.join(cub_meta_path, 'parts', 'images_cub.txt')\n",
    "maindataloader = unshuffle_dataloader(maindataloader, batch_size=1)\n",
    "print(maindataloader.batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39585984-4500-4303-bfcd-0fc75bf7f93d",
   "metadata": {},
   "source": [
    "# Calculate part purity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03591379-002b-4f60-b622-e599c67087b7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "list_csvfile_topk, list_node_wise_df, dict_node_wise_df = get_topk_cub_nodewise(net, root, maindataloader, \\\n",
    "                                                                                 TOPK, str(epoch), device, args)\n",
    "\n",
    "node_wise_purity = []\n",
    "node_wise_purity_of_unmasked = []\n",
    "node_wise_purity_of_masked = []\n",
    "for csvfile_topk, node in zip(list_csvfile_topk, root.nodes_with_children()):\n",
    "\n",
    "    if node.name not in subtree_root.descendents:\n",
    "        print('Skipping node', node.name)\n",
    "        continue\n",
    "        \n",
    "    node_purity, max_presence_purity = eval_prototypes_cub_parts_csv_nodewise_maxmin(node, csvfile_topk, part_locs_file, parts_name_path, \\\n",
    "                              imgs_id_path, 'projectloader_topk_'+str(epoch), args, desc_threshold=0.2)\n",
    "    node_wise_purity.append(node_purity)\n",
    "    proto_presence = getattr(net.module, '_'+node.name+'_proto_presence')\n",
    "    node_wise_purity_of_unmasked.append(np.mean([max_presence_purity[p] for p in max_presence_purity if (proto_presence[int(p), 0] < proto_presence[int(p), 1])]))\n",
    "    node_wise_purity_of_masked.append(np.mean([max_presence_purity[p] for p in max_presence_purity if (proto_presence[int(p), 0] > proto_presence[int(p), 1])]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8bc2613",
   "metadata": {},
   "source": [
    "### Part purity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "364390d0-62f7-4a6a-880e-1290302d98a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Part purity including over-specific prototypes')\n",
    "print('Mean:', np.nanmean(node_wise_purity), 'Std:', np.nanstd(node_wise_purity))\n",
    "\n",
    "print('Part purity excluding over-specific prototypes')\n",
    "print('Mean:', np.nanmean(node_wise_purity_of_unmasked), 'Std:', np.nanstd(node_wise_purity_of_unmasked))\n",
    "\n",
    "print('Part purity of excluded prototypes')\n",
    "print('Mean:', np.nanmean(node_wise_purity_of_masked), 'Std:', np.nanstd(node_wise_purity_of_masked))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e006c62c-793f-4250-a08f-73e6296e3264",
   "metadata": {},
   "source": [
    "# Ratio of good protos / Total protos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b42fb9b-b99c-4800-a94b-8b5d8893a01d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "total_relevant_protos = 0.\n",
    "total_good_protos = 0.\n",
    "\n",
    "with torch.no_grad():\n",
    "    for node in root.nodes_with_children():\n",
    "        label_to_children = {v: k for k, v in node.children_to_labels.items()}\n",
    "        classification_weights = getattr(net.module, '_'+node.name+'_classification').weight\n",
    "        proto_presence = getattr(net.module, '_'+node.name+'_proto_presence')\n",
    "        proto_presence = F.gumbel_softmax(proto_presence, tau=0.5, hard=True, dim=-1)\n",
    "        masked_classification_weights = proto_presence[:, 1].unsqueeze(0) * classification_weights\n",
    "        all_protos_masked = False\n",
    "        for class_idx in range(masked_classification_weights.shape[0]):\n",
    "            total_relevant_protos += (classification_weights[class_idx, :] > 1e-3).sum().item()\n",
    "            total_good_protos += (masked_classification_weights[class_idx, :] > 1e-3).sum().item()\n",
    "\n",
    "print('Total protos:', total_relevant_protos, 'Total good protos:', total_good_protos, 'Ratio:', total_good_protos/total_relevant_protos)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hpnet4",
   "language": "python",
   "name": "hpnet4"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
