{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14346300",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from pathlib import Path\n",
    "\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image\n",
    "\n",
    "from pathfmtools.analysis.zeroshot_classification import ZeroShotPatchClassifier\n",
    "from pathfmtools.embedding_models.registry import get_embedding_model\n",
    "from pathfmtools.image.slide import Slide\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f83b308b",
   "metadata": {},
   "outputs": [],
   "source": [
    "slide_path = Path()\n",
    "store_root = Path()  # Directory where processed data will be saved\n",
    "slide = Slide(slide_path=slide_path, store_root=store_root)\n",
    "\n",
    "# 448 for 40x, 224 for 20x\n",
    "slide.preprocess(patch_size=448)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dd8f02d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with h5py.File(slide.store.get_slide_h5_path(slide.id_), \"r\") as f:\n",
    "    print(f.keys())\n",
    "    print(json.loads(f[\"slide_metadata\"][()]))\n",
    "    tile_meta_grp = f[\"tile_metadata\"]\n",
    "    print(tile_meta_grp.keys())\n",
    "    tile_cols = tile_meta_grp[\"col\"][()]\n",
    "    tile_rows = tile_meta_grp[\"row\"][()]\n",
    "    tile_widths = tile_meta_grp[\"width\"][()]\n",
    "    tile_heights = tile_meta_grp[\"height\"][()]\n",
    "    tile_top_left_x = tile_meta_grp[\"top_left_x\"][()]\n",
    "    tile_top_left_y = tile_meta_grp[\"top_left_y\"][()]\n",
    "    seg_mask = f[\"tile_segmentation_mask\"][()]\n",
    "    tiles = f[\"tiles\"][()]\n",
    "\n",
    "print(seg_mask.shape)\n",
    "print(tiles.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cb2841f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(seg_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41539ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Image.fromarray(slide.slide_reader.get_thumbnail((1024, 1024))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd5534f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample and display 20 preprocessed tiles\n",
    "sample_idxs = np.random.randint(0, tiles.shape[0], size=20)\n",
    "fig, axes = plt.subplots(5, 4, figsize=(12, 15))\n",
    "for ax, idx in zip(axes.flat, sample_idxs):\n",
    "    ax.imshow(tiles[idx])\n",
    "    ax.set_title(f\"Tile {idx}\")\n",
    "    ax.axis(\"off\")\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ec54b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_embedding_model(\"conch\")(device=torch.device(\"cuda:0\"))\n",
    "slide.embed_tiles(\n",
    "    model=model,\n",
    "    batch_size=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6091e2a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "zeroshot_classifier = ZeroShotPatchClassifier()\n",
    "slide.run_zeroshot_classification(\n",
    "    zero_shot_classifier=zeroshot_classifier,\n",
    "    model_name=\"conch\",\n",
    "    text_list=[\"prompt 1\", \"prompt 2\", \"prompt 3\"],\n",
    "    device=torch.device(\"cuda:0\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0bbb738",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".cpathtools",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
