{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4d879b68",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "434ada53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, json\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from transformers import AutoProcessor, AutoModel\n",
    "from datasets import load_dataset\n",
    "import wandb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ad4f74d",
   "metadata": {},
   "source": [
    "# Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f57939b",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cef8fdb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "CHECKPOINT = \"openai/clip-vit-large-patch14\"\n",
    "HF_REPO = \"hamzamooraj99/AgriPath-LF16-30k\"\n",
    "HEAD_ARTIFACT = \"hhm2000-heriot-watt-university/AgriPath-VLM/CLIP_openai_large14_LR0.01:v0\"\n",
    "CLASS_NAMES_JSON = None\n",
    "TOPK = 5 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bbbd09f",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.login()\n",
    "api = wandb.Api()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "766ae35e",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a41e6865",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_set = load_dataset(HF_REPO, split='test')\n",
    "class_labels = sorted(set(test_set[\"crop_disease_label\"]))\n",
    "num_classes = len(set(test_set[\"numeric_label\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67f17f72",
   "metadata": {},
   "outputs": [],
   "source": [
    "label2id = {label: i for i, label in enumerate(class_labels)}\n",
    "id2label = {i: label for label, i in label2id.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d63dbd44",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"num_classes:\", num_classes)\n",
    "print(\"Example label:\", id2label[31])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3892ee76",
   "metadata": {},
   "source": [
    "# Load Model & Head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06eabd2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = AutoProcessor.from_pretrained(CHECKPOINT)\n",
    "backbone = AutoModel.from_pretrained(CHECKPOINT).to(DEVICE)\n",
    "backbone.eval()\n",
    "for p in backbone.parameters():\n",
    "    p.requires_grad = False\n",
    "\n",
    "assert hasattr(backbone, \"get_image_features\") or hasattr(backbone, \"vision_model\"), \\\n",
    "    \"Backbone must expose get_image_features or vision_model (CLIP/SigLIP-like checkpoint).\"\n",
    "\n",
    "print(type(backbone))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aca62d83",
   "metadata": {},
   "outputs": [],
   "source": [
    "artifact = api.artifact(HEAD_ARTIFACT, type=\"linear_probe_head\")\n",
    "artifact_dir = artifact.download()\n",
    "\n",
    "head_path = os.path.join(artifact_dir, \"classifier_head.pt\")\n",
    "meta_path = os.path.join(artifact_dir, \"metadata.json\")\n",
    "\n",
    "print(\"artifact_dir:\", artifact_dir)\n",
    "print(\"head_path exists:\", os.path.exists(head_path))\n",
    "print(\"meta_path exists:\", os.path.exists(meta_path))\n",
    "\n",
    "head_state = torch.load(head_path, map_location=\"cpu\")\n",
    "with open(meta_path, \"r\") as f:\n",
    "    head_meta = json.load(f)\n",
    "\n",
    "feat_dim = int(head_meta[\"feature_dim\"])\n",
    "head_num_classes = int(head_meta[\"num_classes\"])\n",
    "print(\"feat_dim:\", feat_dim)\n",
    "print(\"head_num_classes:\", head_num_classes)\n",
    "print(\"head backbone_name:\", head_meta.get(\"backbone_name\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981a62e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if head_num_classes != num_classes:\n",
    "    raise ValueError(f\"Head expects {head_num_classes} classes, but dataset has {num_classes}.\")\n",
    "\n",
    "classifier = nn.Linear(feat_dim, num_classes).to(DEVICE)\n",
    "classifier.load_state_dict(head_state)\n",
    "classifier.eval()\n",
    "\n",
    "print(\"classifier:\", classifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "242c2fae",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcb4e583",
   "metadata": {},
   "source": [
    "## Load Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef1ff12a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = test_set[1966]\n",
    "image = sample['image']\n",
    "crop = sample['crop']\n",
    "disease = sample['disease']\n",
    "print(f\"CROP: {crop}\")\n",
    "print(f\"DISEASE: {disease}\")\n",
    "display(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf5fc438",
   "metadata": {},
   "source": [
    "## Load Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ec9bd0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = [test_set[i] for i in range(1324,1370)]\n",
    "print(samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15a03b25",
   "metadata": {},
   "source": [
    "## Single Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51bd48bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "inputs = processor(images=[image], return_tensors=\"pt\")\n",
    "pixel_values = inputs['pixel_values'].to(DEVICE)\n",
    "\n",
    "with torch.no_grad():\n",
    "    # Feature extraction (mirrors your eval script)\n",
    "    if hasattr(backbone, \"get_image_features\"):\n",
    "        feats = backbone.get_image_features(pixel_values=pixel_values)\n",
    "    else:\n",
    "        out = backbone.vision_model(pixel_values=pixel_values)\n",
    "        if hasattr(out, \"pooler_output\") and out.pooler_output is not None:\n",
    "            feats = out.pooler_output\n",
    "        else:\n",
    "            feats = out.last_hidden_state[:, 0, :]\n",
    "\n",
    "    feats = feats / feats.norm(dim=1, keepdim=True).clamp(min=1e-12)\n",
    "\n",
    "    logits = classifier(feats)                 # [1, 65]\n",
    "    probs = torch.softmax(logits, dim=-1)[0]   # [65]\n",
    "\n",
    "top_probs, top_idx = torch.topk(probs, k=min(TOPK, num_classes))\n",
    "top_probs = top_probs.detach().cpu().numpy()\n",
    "top_idx = top_idx.detach().cpu().numpy()\n",
    "\n",
    "top_labels = [id2label[int(i)] for i in top_idx]\n",
    "pd.DataFrame({\"class\": top_labels, \"prob\": top_probs})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e5a29c5",
   "metadata": {},
   "source": [
    "## Batch Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2578e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec4a4a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 1324\n",
    "final = []\n",
    "for sample in samples:\n",
    "    image = sample['image']\n",
    "    inputs = processor(images=[image], return_tensors=\"pt\")\n",
    "    pixel_values = inputs['pixel_values'].to(DEVICE)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # Feature extraction (mirrors your eval script)\n",
    "        if hasattr(backbone, \"get_image_features\"):\n",
    "            feats = backbone.get_image_features(pixel_values=pixel_values)\n",
    "        else:\n",
    "            out = backbone.vision_model(pixel_values=pixel_values)\n",
    "            if hasattr(out, \"pooler_output\") and out.pooler_output is not None:\n",
    "                feats = out.pooler_output\n",
    "            else:\n",
    "                feats = out.last_hidden_state[:, 0, :]\n",
    "\n",
    "        feats = feats / feats.norm(dim=1, keepdim=True).clamp(min=1e-12)\n",
    "\n",
    "        logits = classifier(feats)                 # [1, 65]\n",
    "        probs = torch.softmax(logits, dim=-1)[0]   # [65]\n",
    "\n",
    "    top_probs, top_idx = torch.topk(probs, k=min(1, num_classes))\n",
    "    top_probs = top_probs.detach().cpu().numpy()\n",
    "    top_idx = top_idx.detach().cpu().numpy()\n",
    "\n",
    "    top_labels = [id2label[int(i)] for i in top_idx]\n",
    "    final.append(top_labels[0])\n",
    "    df = pd.DataFrame({\"class\": top_labels, \"prob\": top_probs})\n",
    "    print(f\"IDX: {idx} | {df}\")\n",
    "    print(\"------------------------------------------------------------------------------\")\n",
    "    idx+=1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c72b407",
   "metadata": {},
   "outputs": [],
   "source": [
    "total = len(final)\n",
    "early = 0\n",
    "late = 0\n",
    "other = 0\n",
    "for f in final:\n",
    "    if f == \"potato_early_blight\":\n",
    "        early += 1\n",
    "    elif f == \"potato_late_blight\":\n",
    "        late += 1\n",
    "    else:\n",
    "        other += 1\n",
    "\n",
    "\n",
    "print(f\"EARLY: {early}\\nLATE: {late}\\nOTHER: {other}\\nTOTAL: {total}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e448a0af",
   "metadata": {},
   "source": [
    "# Plot Top-k Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccaaa6f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.bar(range(len(top_probs)), top_probs)\n",
    "plt.xticks(range(len(top_probs)), top_labels, rotation=75, ha=\"right\")\n",
    "plt.ylabel(\"Probability\")\n",
    "plt.title(f\"Top-{len(top_probs)} predictions\\n{CHECKPOINT}\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unsloth_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
