{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4c2302d3-10a1-4427-84e1-3f39913142f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b5465947",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append('../src')\n",
    "\n",
    "from lm_polygraph.utils.model import WhiteboxModel\n",
    "from lm_polygraph.estimators import *\n",
    "from lm_polygraph.stat_calculators import *\n",
    "from lm_polygraph.utils.openai_chat import OpenAIChat\n",
    "from lm_polygraph.utils.deberta import MultilingualDeberta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3b7ae864",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "bloomz=\"bigscience/bloomz-560m\"\n",
    "yi6bchat=\"01-ai/Yi-6B-Chat\"\n",
    "yi6b=\"01-ai/Yi-6B\"\n",
    "t5base='google-t5/t5-base'\n",
    "bloomz560m='bigscience/bloomz-560m' # poor quality but fast calculations\n",
    "model = WhiteboxModel.from_pretrained(yi6b, add_bos_token=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "802f4378",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    }
   ],
   "source": [
    "texts = [\"请介绍一下曹操。\"]\n",
    "stat = {}\n",
    "\n",
    "# MBZ Budget\n",
    "os.environ[\"OPENAI_KEY\"] = \"YOUR_API_KEY\"\n",
    "\n",
    "for calculator in [\n",
    "    GreedyProbsCalculator(),\n",
    "    EntropyCalculator(),\n",
    "    GreedyLMProbsCalculator(),\n",
    "    # ClaimsExtractorZH(OpenAIChat(\"gpt-4o\")),\n",
    "]:\n",
    "    stat.update(calculator(stat, texts, model))\n",
    "\n",
    "claim_extractor=ClaimsExtractor(OpenAIChat(\"gpt-4o\"),language=\"zh\")\n",
    "stat.update(claim_extractor(stat, texts, model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fade0e70-9077-4ccd-beb2-6b9b36cc954a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_texts', 'input_tokens', 'greedy_log_probs', 'greedy_tokens', 'greedy_tokens_alternatives', 'greedy_texts', 'greedy_log_likelihoods', 'embeddings_decoder', 'entropy', 'greedy_lm_log_probs', 'greedy_lm_log_likelihoods', 'claims', 'claim_texts_concatenated', 'claim_input_texts_concatenated'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stat.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "505008d3-3e93-4612-91f3-f90e1eab9706",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[Claim(claim_text='曹操是三国时期的政治家。', sentence=' 曹操是三国时期最伟大的政治家', aligned_tokens=[1, 2, 3, 4, 7]),\n",
       "  Claim(claim_text='曹操是三国时期最伟大的政治家。', sentence=' 曹操是三国时期最伟大的政治家', aligned_tokens=[1, 2, 3, 4, 5, 6, 7]),\n",
       "  Claim(claim_text='他一生中为官。', sentence='他一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵', aligned_tokens=[9, 10, 11, 12, 13]),\n",
       "  Claim(claim_text='他多次被封为侯。', sentence='他一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵', aligned_tokens=[9, 20, 21, 22, 23]),\n",
       "  Claim(claim_text='他多次被封为侯爵。', sentence='他一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵', aligned_tokens=[9, 20, 21, 22, 23, 30]),\n",
       "  Claim(claim_text='曹操一生中为官。', sentence='曹操一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵', aligned_tokens=[32, 33, 34, 35, 36]),\n",
       "  Claim(claim_text='为官期间，曹操多次被封为侯。', sentence='曹操一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵', aligned_tokens=[32, 35, 36, 38, 39, 40, 43, 44, 45, 46]),\n",
       "  Claim(claim_text='为官期间，曹操多次被封为侯爵。', sentence='曹操一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵', aligned_tokens=[32, 35, 36, 40, 43, 44, 45, 46, 53])]]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stat['claims']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cde52623-aac2-4315-9e32-a1fc93e0de59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output:  曹操是三国时期最伟大的政治家。他一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵。曹操一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵。\n"
     ]
    }
   ],
   "source": [
    "print(\"Output:\", stat[\"greedy_texts\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a99e12fa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output:  曹操是三国时期最伟大的政治家。他一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵。曹操一生中为官，为官期间，他多次被封为侯，并多次被封为侯爵。\n",
      "\n",
      "claim: 曹操是三国时期的政治家。\n",
      "aligned tokens: [1, 2, 3, 4, 7]\n",
      "\n",
      "claim: 曹操是三国时期最伟大的政治家。\n",
      "aligned tokens: [1, 2, 3, 4, 5, 6, 7]\n",
      "\n",
      "claim: 他一生中为官。\n",
      "aligned tokens: [9, 10, 11, 12, 13]\n",
      "\n",
      "claim: 他多次被封为侯。\n",
      "aligned tokens: [9, 20, 21, 22, 23]\n",
      "\n",
      "claim: 他多次被封为侯爵。\n",
      "aligned tokens: [9, 20, 21, 22, 23, 30]\n",
      "\n",
      "claim: 曹操一生中为官。\n",
      "aligned tokens: [32, 33, 34, 35, 36]\n",
      "\n",
      "claim: 为官期间，曹操多次被封为侯。\n",
      "aligned tokens: [32, 35, 36, 38, 39, 40, 43, 44, 45, 46]\n",
      "\n",
      "claim: 为官期间，曹操多次被封为侯爵。\n",
      "aligned tokens: [32, 35, 36, 40, 43, 44, 45, 46, 53]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"Output:\", stat[\"greedy_texts\"][0])\n",
    "print()\n",
    "for claim in stat[\"claims\"][0]:\n",
    "    print(\"claim:\", claim.claim_text)\n",
    "    print(\"aligned tokens:\", claim.aligned_tokens)\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6750525d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[5.517868,\n",
       "  8.203161,\n",
       "  11.508308,\n",
       "  10.462716,\n",
       "  11.252705,\n",
       "  7.277602,\n",
       "  4.004912,\n",
       "  3.6238604]]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Maximum Claim Probability\n",
    "max_prob = MaximumClaimProbability()\n",
    "max_prob(stat)  # Uncertainty for each claim, the higher, the less certain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "85be9f76-3f50-4b32-8867-45849fab8eb8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1.1035736,\n",
       "  1.1718801,\n",
       "  2.3016617,\n",
       "  2.0925431,\n",
       "  1.8754507,\n",
       "  1.4555204,\n",
       "  0.40049118,\n",
       "  0.40265116]]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Perlexity\n",
    "perlexity_claim = PerplexityClaim()\n",
    "perlexity_claim (stat)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a7b4a46a-37e9-4a8d-b583-c43cb0f94692",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[2.0959154e-05,\n",
       "  2.0959154e-05,\n",
       "  2.4157525e-05,\n",
       "  2.5509149e-05,\n",
       "  2.5509149e-05,\n",
       "  2.243754e-05,\n",
       "  2.126221e-05,\n",
       "  2.126221e-05]]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Maximum Token Entropy\n",
    "max_token_ent = MaxTokenEntropyClaim()\n",
    "max_token_ent(stat)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2f118841-c40b-4d04-b06f-da7a6b1c287c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[-31.81642296910286,\n",
       "  -45.93048223853111,\n",
       "  -30.54413414001465,\n",
       "  -29.65266525745392,\n",
       "  -36.55968350172043,\n",
       "  -34.77716279029846,\n",
       "  -109.19080419000238,\n",
       "  -97.22904723981628]]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Pointwise Mutual Information\n",
    "pmi = PointwiseMutualInformationClaim()\n",
    "pmi (stat)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ef7ffc38-e905-42eb-aa48-5dee774afed1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[9.566464,\n",
       "  9.849717,\n",
       "  10.163636,\n",
       "  11.212116,\n",
       "  11.415592,\n",
       "  9.818885,\n",
       "  10.598446,\n",
       "  10.781264]]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_true_calculator_stat = PromptCalculator(\n",
    "    \"Question: {q}\\n Possible answer:{a}\\n \"\n",
    "    \"Is the possible answer True or False? The possible answer is: \",\n",
    "    \"True\",\n",
    "    \"p_true_claim\",\n",
    "    input_text_dependency=\"claim_input_texts_concatenated\",\n",
    "    sample_text_dependency=None,\n",
    "    generation_text_dependency=\"claim_texts_concatenated\",\n",
    ")\n",
    "\n",
    "stat.update(p_true_calculator_stat(stat, texts, model))\n",
    "\n",
    "ptrue_claim = PTrueClaim()\n",
    "ptrue_claim (stat)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "42d9982a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[-0.3098740519439675,\n",
       "  -0.25099232306072716,\n",
       "  -0.6779380193045293,\n",
       "  -0.45294784600333193,\n",
       "  -0.4083827155215617,\n",
       "  -0.8225784658523201,\n",
       "  -0.8718909106823165,\n",
       "  -0.8765211222672762]]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Claim Conditional Probability\n",
    "for calculator in [\n",
    "    GreedyAlternativesNLICalculator(MultilingualDeberta())\n",
    "]:\n",
    "    stat.update(calculator(stat, texts, model))\n",
    "\n",
    "ccp = ClaimConditionedProbabilityClaim()\n",
    "ccp (stat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "11d1a0d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[-0.5257379541381006,\n",
       "  -0.5118784852870085,\n",
       "  -0.6496051811740936,\n",
       "  -0.801711036128341,\n",
       "  -0.801711036128341,\n",
       "  -0.8218088030097428,\n",
       "  -0.8774868957560673,\n",
       "  -0.8775706039132196]]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for calculator in [\n",
    "    GreedyAlternativesFactPrefNLICalculator(MultilingualDeberta())\n",
    "]:\n",
    "    stat.update(calculator(stat, texts, model))\n",
    "\n",
    "ccp_no_cxt=ClaimConditionedProbabilityClaim(nli_context=\"fact_pref\")\n",
    "ccp_no_cxt (stat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e852a2cc-b319-490e-8ae7-e258dd4b41cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lm_polygraph.generation_metrics.openai_fact_check import OpenAIFactCheck\n",
    "chinese_checker = OpenAIFactCheck('gpt-4o', language=\"zh\")\n",
    "chatgpt_response = chinese_checker(stat, None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f0381fff-e2d2-43b1-b830-edec434d2f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, nan, 1, 1, 0, 0, 1, 0]]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chatgpt_response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dac0849-cb89-4a66-a63f-1dbbe6e58b3d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.mlspace-calibration]",
   "language": "python",
   "name": "conda-env-.mlspace-calibration-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
