{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "090823bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from lm_polygraph.estimators import *\n",
    "from lm_polygraph.utils.model import WhiteboxModel\n",
    "from lm_polygraph.utils.dataset import Dataset\n",
    "from lm_polygraph.utils.processor import Logger\n",
    "from lm_polygraph.utils.manager import UEManager\n",
    "from lm_polygraph.ue_metrics import PredictionRejectionArea\n",
    "from lm_polygraph.generation_metrics import RougeMetric, BartScoreSeqMetric, ModelScoreSeqMetric, ModelScoreTokenwiseMetric"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c576ee4d-81dd-4396-9f74-88ae982fa708",
   "metadata": {},
   "source": [
    "# Specify HyperParameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "846e0926-68ad-48e2-88d0-6c216dca148a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"facebook/bart-large-cnn\"\n",
    "device = \"cpu\"\n",
    "dataset_name = \"xsum\"\n",
    "batch_size = 4\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5678afd-6a6e-4a49-8393-be92c7c47c1e",
   "metadata": {},
   "source": [
    "# Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dac975b4-0438-4ce4-9f08-d731a09c0e23",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_path,\n",
    "    device_map=device,\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "\n",
    "model = WhiteboxModel(base_model, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a89b325-ec3d-435f-ad9f-b3a22cca584c",
   "metadata": {},
   "source": [
    "# Train and Eval Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "471fb28d-2b04-4ee1-9890-40f67a83e8bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = Dataset.load(\n",
    "    dataset_name,\n",
    "    'document', 'summary',\n",
    "    batch_size=batch_size,\n",
    "    split=\"test\"\n",
    ")\n",
    "dataset.subsample(16, seed=seed)\n",
    "\n",
    "train_dataset = Dataset.load(\n",
    "    dataset_name,\n",
    "    'document', 'summary',\n",
    "    batch_size=batch_size,\n",
    "    split=\"train\"\n",
    ")\n",
    "train_dataset.subsample(16, seed=seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f41a76d4-933a-4cab-9f99-e5971b90dbb4",
   "metadata": {},
   "source": [
    "# Metric, UE Metric, and UE Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3028ee9a-be27-4bd5-8401-a8c3b5697d19",
   "metadata": {},
   "outputs": [],
   "source": [
    "ue_methods = [MaximumSequenceProbability(), \n",
    "              SemanticEntropy(),\n",
    "              MahalanobisDistanceSeq(\"decoder\"),]\n",
    "\n",
    "ue_metrics = [PredictionRejectionArea()]\n",
    "\n",
    "metrics = [RougeMetric('rougeL')]\n",
    "\n",
    "loggers = [Logger()] "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f8007b3-b209-492a-b4ab-cd93ad1f324c",
   "metadata": {},
   "source": [
    "# Initialize UE Manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2815aeb6-7a66-43ec-bd99-d48af98b31cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "man = UEManager(\n",
    "    dataset,\n",
    "    model,\n",
    "    ue_methods,\n",
    "    metrics,\n",
    "    ue_metrics,\n",
    "    loggers,\n",
    "    train_data=train_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c550f5e-695e-43e0-a69f-20a6c0531691",
   "metadata": {},
   "source": [
    "# Compute Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4e927f0-9631-4b94-a2b0-9111ba803dc3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = man()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83fd23f3-3378-44d1-b761-e908760f6862",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in results.keys():\n",
    "    print(f\"UE Score: {key[1]}, Metric: {key[2]}, UE Metric: {key[3]}, Score: {results[key]:.3f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
