{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..') # add bayesvlm to path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "from tabulate import tabulate\n",
    "import torch\n",
    "import torch.distributions as dists\n",
    "from torchmetrics.classification import MulticlassCalibrationError\n",
    "\n",
    "from bayesvlm.utils import get_model_type_and_size, get_image_size, get_transform, load_model\n",
    "from bayesvlm.data.factory import DataModuleFactory\n",
    "from bayesvlm.hessians import load_hessians, optimize_prior_precision, compute_covariances\n",
    "from bayesvlm.precompute import precompute_text_features, precompute_image_features, make_predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_prediction(prediction: torch.Tensor, label: torch.Tensor, num_classes: int) -> Tuple[float, float, float]:\n",
    "    ece_metric = MulticlassCalibrationError(num_classes=num_classes, n_bins=20, norm='l1')\n",
    "    one_hot_pred = prediction.argmax(1)\n",
    "    acc = (one_hot_pred == label).float().cpu().numpy()\n",
    "    nlpd = -dists.Categorical(prediction).log_prob(label).cpu().numpy()\n",
    "    ece = ece_metric(prediction, label).item()\n",
    "    return acc, nlpd, ece\n",
    "\n",
    "def print_results(\n",
    "    acc_bayesvlm: float,\n",
    "    nlpd_bayesvlm: float,\n",
    "    ece_bayesvlm: float,\n",
    "    acc_map: float,\n",
    "    nlpd_map: float,\n",
    "    ece_map: float,\n",
    "):\n",
    "    # Data table\n",
    "    data = [\n",
    "        [\"Accuracy (↑)\", f\"{acc_bayesvlm:.5f}\", f\"{acc_map:.5f}\"],\n",
    "        [\"NLPD (↓)\", f\"{nlpd_bayesvlm:.5f}\", f\"{nlpd_map:.5f}\"],\n",
    "        [\"ECE (↓)\", f\"{ece_bayesvlm:.5f}\", f\"{ece_map:.5f}\"]\n",
    "    ]\n",
    "\n",
    "    # Display table\n",
    "    print(tabulate(data, headers=[\"Metric\", \"BayesVLM (ours)\", \"MAP\"], tablefmt=\"simple\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the model and dataset\n",
    "model_str = 'clip-base'\n",
    "dataset = 'food101'\n",
    "hessian_dir = '../hessians/hessian_CLIP-ViT-B-32-laion2B-s34B-b79K'\n",
    "pseudo_data_count = 4\n",
    "batch_size = 32\n",
    "num_workers = 4\n",
    "device = 'mps'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading Model and Transforms  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model and transforms based on `model_str`\n",
    "model_type, model_size = get_model_type_and_size(model_str)\n",
    "transform_image_size = get_image_size(model_str)\n",
    "transform = get_transform(model_type, transform_image_size)\n",
    "image_encoder, text_encoder, vlm = load_model(model_str, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimizing Prior Precision and Covariances  \n",
    "\n",
    "This cell loads Hessians for image and text modalities to optimize prior precision (`λ`) via marginal log-likelihood maximization. Finally, the computed covariance matrices are passed to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load hessians\n",
    "info = {'n_img': pseudo_data_count, 'n_txt': pseudo_data_count}\n",
    "A_img, B_img = load_hessians(hessian_dir, tag='img', return_info=False)\n",
    "A_txt, B_txt = load_hessians(hessian_dir, tag='txt', return_info=False)\n",
    "\n",
    "# optimize prior precision based on marginal log-likelihood\n",
    "info['lambda_img'] = optimize_prior_precision(\n",
    "    image_encoder.vision_projection,\n",
    "    A=A_img,\n",
    "    B=B_img,\n",
    "    lmbda_init=1500,\n",
    "    n=info['n_img'],\n",
    "    lr=1e-2,\n",
    "    num_steps=300,\n",
    "    device=device,\n",
    "    verbose=True,\n",
    ").item()\n",
    "\n",
    "info['lambda_txt'] = optimize_prior_precision(\n",
    "    text_encoder.text_projection,\n",
    "    A=A_txt,\n",
    "    B=B_txt,\n",
    "    lmbda_init=1500,\n",
    "    n=info['n_txt'],\n",
    "    lr=1e-2,\n",
    "    num_steps=300,\n",
    "    device=device,\n",
    "    verbose=True,\n",
    ").item()\n",
    "\n",
    "print(\"n_img:\", info['n_img'])\n",
    "print(\"n_txt:\", info['n_txt'])\n",
    "print(\"lambda_img:\", info['lambda_img'])\n",
    "print(\"lambda_txt:\", info['lambda_txt'])\n",
    "\n",
    "# pass the covatiances to the model\n",
    "cov_img, cov_txt = compute_covariances(A_img, B_img, A_txt, B_txt, info)\n",
    "vlm.set_covariances(cov_img, cov_txt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initializing the Data Module  \n",
    "\n",
    "This cell creates a `DataModule` with the specified batch size, workers, and transforms. We will only use the test set for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create the data module\n",
    "f = DataModuleFactory(\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers,\n",
    "    train_transform=transform,\n",
    "    test_transform=transform,\n",
    "    shuffle_train=True,\n",
    ")\n",
    "dm = f.create(dataset)\n",
    "dm.setup()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Precomputing Embeddings  \n",
    "\n",
    "This cell precomputes image and text embeddings using the image and text encoders. Image features are extracted from the test dataset, while text features are computed for class prompts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# precompute embeddings\n",
    "with torch.no_grad():\n",
    "    image_outputs_test, image_class_ids_test, image_ids_test = precompute_image_features(\n",
    "        image_encoder=image_encoder,\n",
    "        loader=dm.test_dataloader(),\n",
    "    )\n",
    "\n",
    "    label_outputs = precompute_text_features(\n",
    "        text_encoder=text_encoder,\n",
    "        class_prompts=dm.class_prompts,\n",
    "        batch_size=batch_size,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Making Predictions  \n",
    "\n",
    "This cell generates predictions using both the Bayesian VLM (`BayesVLM`) and the standard CLIP model (`MAP estimate`). The Bayesian variant accounts for uncertainty, while the MAP estimate represents the deterministic prediction. Both use the precomputed image and text embeddings for inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make predictions for vanilla BayesVLM and vanilla CLIP (MAP estimate)\n",
    "logits_bayesvlm = make_predictions(\n",
    "    clip=vlm,\n",
    "    image_outputs=image_outputs_test,\n",
    "    text_outputs=label_outputs,\n",
    "    batch_size=batch_size,\n",
    "    device=device,\n",
    "    map_estimate=False,\n",
    ")\n",
    "\n",
    "logits_map = make_predictions(\n",
    "    clip=vlm,\n",
    "    image_outputs=image_outputs_test,\n",
    "    text_outputs=label_outputs,\n",
    "    batch_size=batch_size,\n",
    "    device=device,\n",
    "    map_estimate=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Converting Logits to Probabilities  \n",
    "\n",
    "This cell converts model logits into probabilities. For `BayesVLM`, the probit approximation (MacKay, 1992) is used to adjust for uncertainty before applying softmax. For the MAP estimate, probabilities are computed directly from the mean logits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# convert probabilistic logits to probabilities using the probit approximation\n",
    "# Reference: David JC MacKay. Bayesian interpolation. Neural computation, 4(3):415–447, 1992b.\n",
    "kappa = 1 / torch.sqrt(1. + torch.pi / 8 * logits_bayesvlm.var)\n",
    "probas_bayesvlm = torch.softmax(kappa * logits_bayesvlm.mean, dim=-1)\n",
    "\n",
    "# convert MAP logits to probabilities\n",
    "probas_map = torch.softmax(logits_map.mean, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate the predictions\n",
    "acc_bayesvlm, nlpd_bayesvlm, ece_bayesvlm = evaluate_prediction(\n",
    "    prediction=probas_bayesvlm, \n",
    "    label=image_class_ids_test, \n",
    "    num_classes=len(dm.class_prompts),\n",
    ")\n",
    "\n",
    "acc_map, nlpd_map, ece_map = evaluate_prediction(\n",
    "    prediction=probas_map,\n",
    "    label=image_class_ids_test,\n",
    "    num_classes=len(dm.class_prompts),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We report the zero-shot results on the `food-101` dataset in terms of accuracy (higher is better), negative log predictive density (NLPD, lower is better), and expected calibration error (ECE, lower is better). We compare the performance of the proposed method with the state-of-the-art method (CLIP)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print_results(\n",
    "    acc_bayesvlm=acc_bayesvlm.mean(),\n",
    "    nlpd_bayesvlm=nlpd_bayesvlm.mean(),\n",
    "    ece_bayesvlm=ece_bayesvlm,\n",
    "    acc_map=acc_map.mean(),\n",
    "    nlpd_map=nlpd_map.mean(),\n",
    "    ece_map=ece_map,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
