{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61907dd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import FSMTForConditionalGeneration, FSMTTokenizer\n",
    "from lm_polygraph.estimators import *\n",
    "from lm_polygraph.utils.model import WhiteboxModel\n",
    "from lm_polygraph.utils.dataset import Dataset\n",
    "from lm_polygraph.utils.processor import Logger\n",
    "from lm_polygraph.utils.manager import UEManager\n",
    "from lm_polygraph.ue_metrics import PredictionRejectionArea\n",
    "from lm_polygraph.generation_metrics import RougeMetric, BartScoreSeqMetric, ModelScoreSeqMetric, ModelScoreTokenwiseMetric"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61ac0a1f-1e92-42c8-87d5-a2f5192ce54d",
   "metadata": {},
   "source": [
    "# Specify HyperParameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7de21658-5202-487c-8ac6-bff35e2254b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"facebook/wmt19-en-de\"\n",
    "device = \"cpu\"\n",
    "dataset_name = (\"wmt14\", \"de-en\")\n",
    "batch_size = 4\n",
    "seed = 42"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c5de359-bb7d-4b60-96c3-e88b845d4ea8",
   "metadata": {},
   "source": [
    "# Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02434a58-3847-4e0d-994a-bdf5f533608a",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = FSMTForConditionalGeneration.from_pretrained(\n",
    "    model_path,\n",
    "    device_map=device,\n",
    ")\n",
    "tokenizer = FSMTTokenizer.from_pretrained(model_path)\n",
    "\n",
    "model = WhiteboxModel(base_model, tokenizer, model_type=\"Seq2SeqLM\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa15a017-dc14-4003-9c30-876657110a87",
   "metadata": {},
   "source": [
    "# Train and Eval Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b8d78dd-ff6f-481e-bf5a-392c7893a27b",
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_dataset = load_dataset(*dataset_name)\n",
    "\n",
    "dataset = Dataset(\n",
    "    x=[txt[\"en\"] for txt in hf_dataset[\"test\"][\"translation\"]],\n",
    "    y=[txt[\"de\"] for txt in hf_dataset[\"test\"][\"translation\"]],\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "dataset.subsample(16, seed=seed)\n",
    "\n",
    "train_dataset = Dataset(\n",
    "    x=[txt[\"en\"] for txt in hf_dataset[\"train\"][\"translation\"][:1000]],\n",
    "    y=[txt[\"de\"] for txt in hf_dataset[\"train\"][\"translation\"][:1000]],\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "train_dataset.subsample(16, seed=seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20ec50be-252a-4ab9-913d-7048c93cd648",
   "metadata": {},
   "source": [
    "# Metric, UE Metric, and UE Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5baa618b-d6dc-4292-a316-30f0e0f8db78",
   "metadata": {},
   "outputs": [],
   "source": [
    "ue_methods = [MaximumSequenceProbability(), \n",
    "              SemanticEntropy(),\n",
    "              MahalanobisDistanceSeq(\"encoder\"),]\n",
    "\n",
    "ue_metrics = [PredictionRejectionArea()]\n",
    "\n",
    "metrics = [RougeMetric('rougeL')]\n",
    "\n",
    "loggers = [Logger()] "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a67b0923-7ebe-4931-bb1a-0d4dcf7a1590",
   "metadata": {},
   "source": [
    "# Initialize UE Manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93cda59",
   "metadata": {},
   "outputs": [],
   "source": [
    "man = UEManager(\n",
    "    dataset,\n",
    "    model,\n",
    "    ue_methods,\n",
    "    metrics,\n",
    "    ue_metrics,\n",
    "    loggers,\n",
    "    train_data=train_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbd12b48-57c0-467d-9943-0b35a674e374",
   "metadata": {},
   "source": [
    "# Compute Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2da7a129-cc59-4b55-b71f-fb4ee230a416",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "results = man()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef6abce0-dba7-40c1-916f-1be546a78c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in results.keys():\n",
    "    print(f\"UE Score: {key[1]}, Metric: {key[2]}, UE Metric: {key[3]}, Score: {results[key]:.3f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
