{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4d879b68",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "434ada53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, json\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import pandas as pd\n",
    "import wandb\n",
    "from pathlib import Path\n",
    "import resnet50_lightning as rn\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ad4f74d",
   "metadata": {},
   "source": [
    "# Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f57939b",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cef8fdb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "HF_REPO = \"hamzamooraj99/AgriPath-LF16-30k\"\n",
    "ARTIFACT_PATH = \"hhm2000-heriot-watt-university/AgriPath-VLM/combined_cnn_paths:v0\"\n",
    "CLASS_NAMES_JSON = None\n",
    "TOPK = 5 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bbbd09f",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "766ae35e",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a41e6865",
   "metadata": {},
   "outputs": [],
   "source": [
    "datamodule = rn.AgriPathDataModule(HF_REPO, batch_size=1)\n",
    "datamodule.setup()\n",
    "\n",
    "label_idx, idx_label = datamodule.return_labels()\n",
    "transform = datamodule.transform\n",
    "\n",
    "num_classes = len(idx_label)\n",
    "\n",
    "print(\"Num classes:\", num_classes)\n",
    "print(\"Example labels:\", list(idx_label.items())[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3892ee76",
   "metadata": {},
   "source": [
    "# Load CNN Model from W&B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06eabd2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_cnn_model(artifact_path, exp_idx=0):\n",
    "    api = wandb.Api()\n",
    "    artifact = api.artifact(artifact_path, type=\"model\")\n",
    "    artifact_dir = Path(artifact.download())\n",
    "    \n",
    "    # Selecting one of the 9 experiments (e.g., experiment 0)\n",
    "    checkpoint_path = artifact_dir / f\"resnet50_agripath_exp_{exp_idx}.pth\"\n",
    "    learning_rate = 1e-4 \n",
    "    \n",
    "    model = rn.ResNet50TLModel(num_classes=65, learning_rate=learning_rate)\n",
    "    checkpoint = torch.load(checkpoint_path, map_location=torch.device(DEVICE), weights_only=True)\n",
    "    model.load_state_dict(checkpoint)\n",
    "    model.to(DEVICE)\n",
    "    model.eval()\n",
    "    \n",
    "    return model\n",
    "\n",
    "model = load_cnn_model(ARTIFACT_PATH, exp_idx=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "242c2fae",
   "metadata": {},
   "source": [
    "# Single Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcb4e583",
   "metadata": {},
   "source": [
    "## Load Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef1ff12a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = datamodule.test_set[1347]\n",
    "image = sample['image']\n",
    "crop = sample['crop']\n",
    "disease = sample['disease']\n",
    "crop_disease = sample['crop_disease_label']\n",
    "print(f\"CROP: {crop}\")\n",
    "print(f\"DISEASE: {disease}\")\n",
    "display(image)\n",
    "print(crop_disease)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e5a29c5",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec4a4a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_sample(sample, model, transform, idx_label, topk=5):\n",
    "    image = sample['image']\n",
    "\n",
    "    x = transform(image.convert(\"RGB\")).unsqueeze(0).to(DEVICE)\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        logits = model(x)\n",
    "        probs = nn.functional.softmax(logits, dim=1)[0]\n",
    "    \n",
    "    k = min(topk, probs.numel())\n",
    "    top_probs, top_idxs = torch.topk(probs, k=k)\n",
    "\n",
    "    top_labels = [idx_label[int(i)] for i in top_idxs]\n",
    "\n",
    "    return pd.DataFrame({\n",
    "        \"class\": top_labels,\n",
    "        \"prob\": top_probs.cpu().numpy()\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8868e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = predict_sample(sample, model, transform, idx_label, TOPK)\n",
    "\n",
    "df[\"prob\"] = df[\"prob\"].map(lambda x: f\"{x:.6f}\")\n",
    "\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e448a0af",
   "metadata": {},
   "source": [
    "# Plot Top-k Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccaaa6f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.bar(range(len(df['prob'])), df['prob'])\n",
    "plt.xticks(range(len(df['prob'])), df['class'], rotation=75, ha=\"right\")\n",
    "plt.ylabel(\"Probability\")\n",
    "plt.title(f\"Top-{len(df['prob'])} predictions\\n{0}\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unsloth_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
