{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6958a441",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from lm_polygraph.estimators import *\n",
    "from lm_polygraph.utils.model import WhiteboxModel\n",
    "from lm_polygraph.utils.dataset import Dataset\n",
    "from lm_polygraph.utils.processor import Logger\n",
    "from lm_polygraph.utils.manager import UEManager\n",
    "from lm_polygraph.ue_metrics import PredictionRejectionArea\n",
    "from lm_polygraph.generation_metrics import RougeMetric, BartScoreSeqMetric, ModelScoreSeqMetric, ModelScoreTokenwiseMetric, AggregatedMetric"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5025e26e-fd7f-44b6-88d7-5876439a5ab0",
   "metadata": {},
   "source": [
    "# Specify HyperParameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7111f938-bc8c-4b82-82a1-fce490bc8e4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"bigscience/bloomz-560m\"\n",
    "device = \"cpu\"\n",
    "dataset_name = (\"trivia_qa\", \"rc.nocontext\")\n",
    "batch_size = 4\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "757a3862-77d1-4bb4-8423-1f86f3a58b54",
   "metadata": {},
   "source": [
    "# Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e7a7afe",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_path,\n",
    "    device_map=device,\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "\n",
    "model = WhiteboxModel(base_model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe460bd5-35bb-4c36-a6b8-12b7a111b403",
   "metadata": {},
   "source": [
    "# Train and Eval Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0444bbb3-7b9d-4823-ad9b-2b2a217d1638",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use validation split, since test split of trivia_qa doesn't have reference answers\n",
    "dataset = Dataset.load(\n",
    "    dataset_name,\n",
    "    'question', 'answer',\n",
    "    batch_size=batch_size,\n",
    "    prompt=\"Question: {question}\\nAnswer:{answer}\",\n",
    "    split=\"validation\"\n",
    ")\n",
    "dataset.subsample(16, seed=seed)\n",
    "\n",
    "train_dataset = Dataset.load(\n",
    "    dataset_name,\n",
    "    'question', 'answer',\n",
    "    batch_size=batch_size,\n",
    "    prompt=\"Question: {question}\\nAnswer:{answer}\",\n",
    "    split=\"train\"\n",
    ")\n",
    "train_dataset.subsample(16, seed=seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd61ed46-8757-4d83-baae-bf854bd11d0e",
   "metadata": {},
   "source": [
    "# Metric, UE Metric, and UE Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5baa618b-d6dc-4292-a316-30f0e0f8db78",
   "metadata": {},
   "outputs": [],
   "source": [
    "ue_methods = [MaximumSequenceProbability(), \n",
    "              SemanticEntropy(),\n",
    "              MahalanobisDistanceSeq(\"decoder\"),]\n",
    "\n",
    "ue_metrics = [PredictionRejectionArea()]\n",
    "\n",
    "# Wrap generation metric in AggregatedMetric, since trivia_qa is a multi-reference dataset\n",
    "# (y is a list of possible correct answers)\n",
    "metrics = [AggregatedMetric(RougeMetric('rougeL'))]\n",
    "\n",
    "loggers = [Logger()] "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6d80993-5a81-4d63-b193-9976301e2ad4",
   "metadata": {},
   "source": [
    "# Initialize UE Manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93cda59",
   "metadata": {},
   "outputs": [],
   "source": [
    "man = UEManager(\n",
    "    dataset,\n",
    "    model,\n",
    "    ue_methods,\n",
    "    metrics,\n",
    "    ue_metrics,\n",
    "    loggers,\n",
    "    train_data=train_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2a92e70-3036-430d-a60a-4c2ecf768d9d",
   "metadata": {},
   "source": [
    "# Compute Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2da7a129-cc59-4b55-b71f-fb4ee230a416",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = man()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef6abce0-dba7-40c1-916f-1be546a78c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in results.keys():\n",
    "    print(f\"UE Score: {key[1]}, Metric: {key[2]}, UE Metric: {key[3]}, Score: {results[key]:.3f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
