{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "from torchvision import transforms, datasets\n",
    "from torch.utils.data import DataLoader\n",
    "import PIL \n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "# tsne and pca\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "from DeepTaxonNet import DeepTaxonNet\n",
    "import argparse\n",
    "import utils\n",
    "\n",
    "from sklearn.mixture import GaussianMixture\n",
    "import os\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader, train_set, test_set = utils.get_data_loader('fashion-mnist', 128, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_layers=10\n",
    "## for MNIST/Fashion-MNIST data\n",
    "model = DeepTaxonNet(\n",
    "    n_layers=n_layers,\n",
    "    enc_hidden_dim=32*3*3,\n",
    "    dec_hidden_dim=(32,3,3),\n",
    "    input_dim=1*28*28,\n",
    "    latent_dim=8,\n",
    "    encoder_name='mnist',\n",
    "    decoder_name='mnist',\n",
    "    kl1_weight=1\n",
    ").to(device)\n",
    "\n",
    "## For CIFAR data\n",
    "# model = DeepTaxonNet(\n",
    "#     n_layers=n_layers,\n",
    "#     enc_hidden_dim=256*4*4,\n",
    "#     dec_hidden_dim=(256,4,4),\n",
    "#     input_dim=3*32*32,\n",
    "#     latent_dim=64,\n",
    "#     encoder_name='resnet',\n",
    "#     decoder_name='resnet',\n",
    "#     kl1_weight=1\n",
    "# ).to(device)\n",
    "\n",
    "path = './models/'\n",
    "model_name = 'dtn-10-fmnist.pt'\n",
    "model.load_state_dict(torch.load(f'{path}{model_name}', map_location=device), strict=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "annotation = utils.label_annotation(model, train_loader, 10, device)\n",
    "acc = utils.basic_node_evaluation(model, annotation, test_loader, device)\n",
    "print('acc:', acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nmi = utils.compute_nmi(model, annotation, test_loader, device)\n",
    "print(f\"MNI: {nmi}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dendrogram_purity = utils.soft_dendrogram_purity(model, test_loader, device)\n",
    "print('dendrogram_purity:', dendrogram_purity)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "overall_leaf_purity, per_leaf_purities = utils.leaf_purity(model, test_loader, device)\n",
    "print('overall_leaf_purity:', overall_leaf_purity)\n",
    "print('per_leaf_purities:', per_leaf_purities)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
