{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch\n",
    "import matplotlib.cm\n",
    "import torch\n",
    "from torch import Tensor, no_grad, hstack, vstack, ones\n",
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "# Transformers\n",
    "import transformers\n",
    "from transformers import TextStreamer\n",
    "from transformers.modeling_outputs import SequenceClassifierOutput\n",
    "\n",
    "# Matplotlib\n",
    "import matplotlib\n",
    "cm = matplotlib.cm\n",
    "from matplotlib import rcParams\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.pyplot import subplots, show\n",
    "\n",
    "# Other Modules\n",
    "from importlib import reload\n",
    "from typing import List, Union\n",
    "import time\n",
    "from functools import cache\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Reloading CBF-LLM Modules\n",
    "import CBFLLM\n",
    "reload(CBFLLM)\n",
    "from CBFLLM import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execution Time\n",
    "_tic_time = time.time()\n",
    "def tic():\n",
    "    global _tic_time\n",
    "    _tic_time = time.time()\n",
    "def toc(print_time: bool = True) -> Union[None, float]:\n",
    "    t = time.time() - _tic_time\n",
    "    if print_time:\n",
    "        print(f\"{t:.04f} s\")\n",
    "    else:\n",
    "        return t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Launch the language-constraint function (L-CF) $h:\\mathcal X \\to \\mathbb R$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HuggingFace cardiffnlp/twitter-roberta-base-sentiment-latest\n",
    "# https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment-latest\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "\n",
    "name = input(\"Name for cardiffnlp/twitter-roberta-base-sentiment-latest?\")\n",
    "\n",
    "\n",
    "def _mapper(output: SequenceClassifierOutput) -> List[float]:\n",
    "    \"\"\"\n",
    "    Convert the output of the RoBERTa model to L-CF value\n",
    "    \"\"\"\n",
    "    scores = torch.softmax(output[0], dim=1)\n",
    "    negatives, neutrals, positives = scores.T\n",
    "    h_list = positives - torch.max(negatives, neutrals)\n",
    "    return tolist(h_list)\n",
    "\n",
    "\n",
    "lcf = LanguageCF(\n",
    "    model=AutoModelForSequenceClassification.from_pretrained(name),\n",
    "    tokenizer=AutoTokenizer.from_pretrained(name),\n",
    "    mapper=_mapper,\n",
    "    name=\"cardiffnlp/twitter-roberta-base-sentiment-latest\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Huggingface cardiffnlp/tweet-topic-21-multi\n",
    "# https://huggingface.co/cardiffnlp/tweet-topic-21-multi\n",
    "\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "\n",
    "name = input(\"Name for cardiffnlp/tweet-topic-21-multi?\")\n",
    "\n",
    "constrained_topic_id = 8\n",
    "mask = ones((19,), dtype=torch.bool)\n",
    "mask[constrained_topic_id] = False\n",
    "\n",
    "\n",
    "def _mapper(output: SequenceClassifierOutput) -> List[float]:\n",
    "    scores = torch.softmax(output[0], dim=1)\n",
    "    h_list = scores[:, constrained_topic_id] - torch.max(scores[:, mask])\n",
    "    return tolist(h_list)\n",
    "\n",
    "\n",
    "lcf = LanguageCF(\n",
    "    AutoModelForSequenceClassification.from_pretrained(name).to(device),\n",
    "    AutoTokenizer.from_pretrained(name),\n",
    "    _mapper,\n",
    "    name=\"cardiffnlp/cardiffnlp_tweet-topic-21-multi\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Entropy of meta-llama/Meta-Llama-3-8B-Instruct output\n",
    "from transformers.modeling_outputs import CausalLMOutputWithPast\n",
    "TEMPERATURE = 1\n",
    "MAX_ENTROPY = 2.5\n",
    "def _mapper(outputs: CausalLMOutputWithPast)->List[float]:\n",
    "    logits = outputs.logits[:, -1]\n",
    "    P_list = vstack([distributionify(logits[i], TEMPERATURE) for i in range(len(logits))])\n",
    "    entropy_list =  -(P_list * P_list.log()).sum(dim=1)\n",
    "    h_list = MAX_ENTROPY - entropy_list\n",
    "    return h_list\n",
    "Gt.pad_token = Gt.eos_token # https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/101\n",
    "lcf = LanguageCF(Gm, Gt, _mapper, name=\"meta-ai/llama3-8b-Instruct Entropy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Launch the baseline LLM (to be placed in the token generator $G$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HuggingFace meta-llama/Meta-Llama-3-8B\n",
    "# https://huggingface.co/meta-llama/Meta-Llama-3-8B\n",
    "from transformers import LlamaForCausalLM, AutoTokenizer\n",
    "name = input(\"Name for meta-llama/Meta-Llama-3-8B?\")\n",
    "\n",
    "Gm = LlamaForCausalLM.from_pretrained(name, torch_dtype=torch.float16).to(device)\n",
    "Gt = AutoTokenizer.from_pretrained(name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# HuggingFace meta-llama/Meta-Llama-3-8B-Instruct\n",
    "# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct\n",
    "from transformers import LlamaForCausalLM, AutoTokenizer\n",
    "name = input(\"Name for meta-llama/Meta-Llama-3-8B-Instruct?\")\n",
    "\n",
    "Gm = LlamaForCausalLM.from_pretrained(name, torch_dtype=torch.float16).to(device)\n",
    "Gt = AutoTokenizer.from_pretrained(name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Launch the baseline LLM (to be placed in the token generator $G$), execute the following"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "streamer = TextStreamer(Gt)\n",
    "vocab = Gt.get_vocab()\n",
    "ivocab = {v: Gt.decode([v]).replace(\"\\n\", \"\") for k, v in vocab.items()}\n",
    "VOCAB_SIZE = len(vocab)\n",
    "\n",
    "\n",
    "@cache\n",
    "def Gtokenize(xstr: str) -> Tensor:\n",
    "    Ginputs = Gt(xstr, return_tensors=\"pt\", add_special_tokens=False).to(device)\n",
    "    x = Ginputs.input_ids[0]\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Models Check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_texts = [\n",
    "    \"You have a good place\",\n",
    "    \"I must be ill!\"\n",
    "]\n",
    "# The `LanguageCF` has caching function for computational efficiency.\n",
    "print(\"#Cache:\", len(lcf.cache))\n",
    "tic()\n",
    "h_list = lcf.get_for_texts(test_texts)\n",
    "toc()\n",
    "print(\"#Cache:\", len(lcf.cache))\n",
    "print(h_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "@no_grad()\n",
    "def generate(\n",
    "        x0: Tensor,\n",
    "        max_new_tokens: int,\n",
    "        temperature: float,\n",
    "        filter: Filter,\n",
    "        normalizer: Normalizer,\n",
    "        stream: bool = True\n",
    "):\n",
    "    R = {\"disallowed_tokens_history\": [], \"lcf_mapping_history\": []}\n",
    "    x = x0.clone()\n",
    "\n",
    "    for _ in range(max_new_tokens):\n",
    "        output = Gm(x[None])\n",
    "        logit = output.logits[0][-1]\n",
    "        P = distributionify(logit, temperature=temperature)\n",
    "        filter_result = filter.scan(x, P)\n",
    "        R[\"disallowed_tokens_history\"].append(filter_result.disallowed)\n",
    "        R[\"lcf_mapping_history\"].append(filter_result.lcf_mapping)\n",
    "        Q = normalizer(P, filter_result.allowed)\n",
    "        iast = Q.multinomial(num_samples=1)\n",
    "        if iast == Gt.eos_token_id:\n",
    "            break\n",
    "        x = hstack((x, iast))\n",
    "        if stream:\n",
    "            streamer.put(iast)\n",
    "\n",
    "    if stream:\n",
    "        streamer.end()\n",
    "\n",
    "    R[\"xf\"] = x\n",
    "    return R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_lcf_history(x0: Tensor, xf: Tensor) -> List[float]:\n",
    "    clf_history = []\n",
    "    x0len = len(x0)\n",
    "    xflen = len(xf)\n",
    "    for i in range(x0len, xflen + 1):\n",
    "        x = xf[:i]\n",
    "        xstr = Gt.decode(x)\n",
    "        hx = lcf.get_for_text(xstr)\n",
    "        clf_history.append(hx)\n",
    "    return clf_history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non-Positive Text-Avoiding Control\n",
    "Prepare the following models in the `Models` section:\n",
    "* The L-CF : `HuggingFace cardiffnlp/twitter-roberta-base-sentiment-latest`\n",
    "* The baseline LLM : `HuggingFace meta-llama/Meta-Llama-3-8B`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x0str = \"Everyone says you will be a good researcher in the future, but\"\n",
    "x0 = Gtokenize(x0str)\n",
    "h0 = lcf.get_for_text(x0str)\n",
    "print(f\"{x0=}\")\n",
    "print(f\"{h0=}\")\n",
    "\n",
    "TOPK = 30\n",
    "TEMPERATURE = 1\n",
    "NORMALIZER = ElementwiseMultiplyNormalizer()\n",
    "MAX_NEW_TOKENS = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLE_SIZE = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "FILTERS = {\n",
    "    \"NC\": JustTopkFilter(TOPK, output_lcf_mapping=True, tokenizer=Gt, lcf=lcf),\n",
    "    \"CBF0.3\": CBFFilter(TOPK, 0.3, Gt, lcf),\n",
    "    \"CBF0.8\": CBFFilter(TOPK, 0.8, Gt, lcf),\n",
    "    \"BL\": BlacklistFilter(TOPK, Gt, lcf),\n",
    "}\n",
    "COLORS = {\n",
    "    \"CBF0.3\": \"tab:red\",\n",
    "    \"CBF0.8\": \"tab:orange\",\n",
    "    \"BL\": \"tab:blue\",\n",
    "    \"NC\": \"black\"\n",
    "}\n",
    "LABELS = {\n",
    "    \"CBF0.3\": \"CBF ($\\\\alpha =0.3$)\",\n",
    "    \"CBF0.8\": \"CBF ($\\\\alpha =0.8$)\",\n",
    "    \"BL\": \"Blacklist\",\n",
    "    \"NC\": \"NoControl\"\n",
    "}\n",
    "ALPHAS = {\n",
    "    \"CBF0.3\": 0.3,\n",
    "    \"CBF0.8\": 0.8,\n",
    "    \"BL\": 1.0,\n",
    "    \"NC\": None\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Try to generate with NoControl(Llama 3 Output) at once\n",
    "R = generate(\n",
    "    x0=x0,\n",
    "    max_new_tokens=MAX_NEW_TOKENS,\n",
    "    temperature=TEMPERATURE,\n",
    "    # The `JustTopkFilter` plays the NoControl case.\n",
    "    filter=JustTopkFilter(\n",
    "        top_k=TOPK\n",
    "    ),\n",
    "    normalizer=normalizer\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ⚠️IT TAKES MUCH TIME⚠️ The text-generation process\n",
    "R = {} # Result data\n",
    "for label, filter in FILTERS.items():\n",
    "    print(label)\n",
    "    R[label] = []\n",
    "    for i in tqdm(range(SAMPLE_SIZE)):\n",
    "        sample = generate(\n",
    "            x0=x0, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,\n",
    "            filter=filter, normalizer=NORMALIZER, stream=False\n",
    "        )\n",
    "        R[label].append(sample)\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save if you want.\n",
    "torch.save(R, \"E1Results.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show generated texts\n",
    "for label, samples in R.items():\n",
    "    print(label)\n",
    "    for i, sample in enumerate(samples):\n",
    "        xf = sample[\"xf\"]\n",
    "        xfstr = Gt.decode(xf)\n",
    "        print(i, \"▶\", xfstr)\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### L-CF Trajectory Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index_to_show = {\"CBF0.3\": 0, \"CBF0.8\": 0, \"BL\": 0, \"NC\": 0}  # 🦀 Set by yourself\n",
    "\n",
    "rcParams[\"font.size\"] = 16\n",
    "fig, ax = subplots(figsize=(9, 5))\n",
    "for label, index in index_to_show.items():\n",
    "    sample = R[label][index]\n",
    "    lcf_history = get_lcf_history(x0, sample[\"xf\"])\n",
    "    xg = tonumpy(sample[\"xf\"][len(x0)-1:]) # The part that is generated by the text-generation system\n",
    "\n",
    "    ax.plot(lcf_history, marker=\"o\", ls=\":\",\n",
    "            label=LABELS[label], color=COLORS[label])\n",
    "    for k, tk in enumerate(xg):\n",
    "        tstr = ivocab[tk]\n",
    "        ax.text(k, lcf_history[k]+0.01, tstr, fontsize=12)\n",
    "ax.set_xlim(0, MAX_NEW_TOKENS+1)\n",
    "ax.hlines(0, *ax.get_xlim(), color=\"gray\", ls=\":\")\n",
    "ax.set_xlabel(\"Time $k$\")\n",
    "ax.set_ylabel(\"L-CF value $h(x(k))$\")\n",
    "ax.legend(loc=\"upper center\", bbox_to_anchor=(0.5,1.2), fontsize=14, ncol=4)\n",
    "fig.tight_layout()\n",
    "show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### L-CF Predicted Trajectory Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_to_show = R[\"CBF0.3\"][0] # 🦀 Set by yourself\n",
    "print(Gt.decode(sample_to_show[\"xf\"]))\n",
    "\n",
    "lcf_history = get_lcf_history(x0, sample_to_show[\"xf\"])\n",
    "lcf_mapping_history = sample_to_show[\"lcf_mapping_history\"]\n",
    "disallowed_tokens_history = sample_to_show[\"disallowed_tokens_history\"]\n",
    "xg = tonumpy(sample_to_show[\"xf\"][len(x0)-1:])\n",
    "\n",
    "fig, ax = subplots(figsize=(9,6))\n",
    "for k, tk in enumerate(xg[:-1]):\n",
    "    hk = lcf_history[k]\n",
    "    lcf_mapping = lcf_mapping_history[k]\n",
    "    disallowed_tokens = disallowed_tokens_history[k]\n",
    "    for t_pred, h_pred in lcf_mapping.items():\n",
    "        color = \"tab:red\" if t_pred in disallowed_tokens else \"tab:green\"\n",
    "        ax.plot([k, k+1], [hk, h_pred], color=color, alpha=0.3)\n",
    "\n",
    "# Dummy graph only for legends\n",
    "ax.plot([], [],ls=\"-\", label=\"Allowed Token\", color=\"tab:green\")\n",
    "ax.plot([],[],ls=\"-\", label=\"Disallowed Token\", color=\"tab:red\")\n",
    "\n",
    "ax.hlines(0, *ax.get_xlim(), color=\"gray\", ls=\":\")\n",
    "ax.plot(lcf_history, marker=\"o\", label=\"Actual Trajectory\", color=\"k\")\n",
    "ax.set_xlim(0, MAX_NEW_TOKENS + 1)\n",
    "ax.set_xlabel(\"Time $k$\")\n",
    "ax.set_ylabel(\"L-CF value $h(x(k))$\")\n",
    "plt.legend(loc=\"best\")\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attractors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_to_show = \"BL\"  # 🦀 Set by yourself\n",
    "samples = R[label_to_show]\n",
    "alpha = ALPHAS[label_to_show]\n",
    "\n",
    "tic()\n",
    "xdata = []\n",
    "ydata = []\n",
    "for sample in samples:\n",
    "    lcf_history = get_lcf_history(x0, sample[\"xf\"])\n",
    "    for hk, lcf_mapping in zip(lcf_history, sample[\"lcf_mapping_history\"]):\n",
    "        xdata += [hk] * len(lcf_mapping)\n",
    "        for h_next in lcf_mapping.values():\n",
    "            ydata.append(h_next - hk)\n",
    "toc()\n",
    "\n",
    "fig, ax = subplots(figsize=(6, 6))\n",
    "ax.hist2d(xdata, ydata, cmap=cm.gist_heat_r, bins=(25, 35))\n",
    "if alpha:\n",
    "    slope = - alpha\n",
    "    xmax = 1\n",
    "    ax.plot([0, xmax], [0, xmax*slope], \"k:\", label=\"Constraint\")\n",
    "    fig.legend()\n",
    "\n",
    "ax.set_xlabel(\"Current L-CF value $h(x)$\")\n",
    "ax.set_ylabel(\"$\\\\Delta h(x^+)$\")\n",
    "ax.set_xlim(-1, 1)\n",
    "ax.set_ylim(-0.5, 0.5)\n",
    "ax.vlines(0, *ax.get_ylim(), colors=\"gray\", ls=\"-\")\n",
    "fig.tight_layout()\n",
    "show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Non-Positive Text-Avoiding Control - Task Quality vs Intervention Weakness\n",
    "Prepare the following models in the `Models` section:\n",
    "* The L-CF : `HuggingFace cardiffnlp/twitter-roberta-base-sentiment-latest`\n",
    "* The baseline LLM : `HuggingFace meta-llama/Meta-Llama-3-8B`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x0str = \"From yesterday my health was very good, but \"\n",
    "x0 = Gtokenize(x0str)\n",
    "h0 = lcf.get_for_text(x0str)\n",
    "print(f\"{x0=}\")\n",
    "print(f\"{h0=}\")\n",
    "\n",
    "TOPK = 30\n",
    "TEMPERATURE = 1\n",
    "NORMALIZER = ElementwiseMultiplyNormalizer()\n",
    "MAX_NEW_TOKENS = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "ALPHAS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]\n",
    "# Recall that CBF filter with alpha = 1.0 equals to the BlacklistFilter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLE_SIZE = 1 # 🦀 Set by yourself"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ⚠️IT TAKES MUCH TIME⚠️ The text-generation process\n",
    "R = {} # Result data\n",
    "for alpha in ALPHAS:\n",
    "    print(alpha)\n",
    "    R[alpha] = []\n",
    "    filter = CBFFilter(TOPK, alpha, Gt, lcf)\n",
    "    for i in tqdm(range(SAMPLE_SIZE)):\n",
    "        sample = generate(\n",
    "            x0=x0, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,\n",
    "            filter=filter, normalizer=NORMALIZER, stream=False\n",
    "        )\n",
    "        R[alpha].append(sample)\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculates the statistics of\n",
    "# average L-CF value and disallowed tokens\n",
    "mean_lcf_data = {}\n",
    "mean_disallowed_tokens_data = {}\n",
    "\n",
    "for alpha, samples in R.items():\n",
    "    lcf_list = []\n",
    "    disallowed_tokens_list = []\n",
    "    for sample in samples:\n",
    "        sample[\"lcf_history\"] = get_lcf_history(x0, sample[\"xf\"])\n",
    "        lcf_list += sample[\"lcf_history\"]\n",
    "        sample[\"disallowed_tokens\"] = sum(len(x) for x in sample[\"disallowed_tokens_history\"])\n",
    "        disallowed_tokens_list.append(sample[\"disallowed_tokens\"])\n",
    "    mean_lcf_data[alpha] = numpy.mean(lcf_list)\n",
    "    mean_disallowed_tokens_data[alpha] = numpy.mean(disallowed_tokens_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save if you want.\n",
    "torch.save(R, \"E2Results.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show generated texts\n",
    "for alpha, samples in R.items():\n",
    "    print(alpha)\n",
    "    for i, sample in enumerate(samples):\n",
    "        xf = sample[\"xf\"]\n",
    "        xfstr = Gt.decode(xf)\n",
    "        print(i, \"▶\", xfstr)\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Alpha vs Two Indicators Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_A, ax = subplots(facecolor=\"white\", figsize=(8, 5))\n",
    "\n",
    "rax = ax.twinx()\n",
    "\n",
    "mean_disallowed_tokens_list = [\n",
    "    mean_disallowed_tokens_data[alpha] for alpha in ALPHAS]\n",
    "mean_lcf_list = [mean_lcf_data[alpha] for alpha in ALPHAS]\n",
    "ax.plot(ALPHAS, mean_disallowed_tokens_list,\n",
    "        \"k:o\", label=\"# of Disallowed Tokens\")\n",
    "rax.bar(ALPHAS, mean_lcf_list, width=0.01, color=\"tab:red\",\n",
    "        alpha=0.5, label=\"Average L-CF Value\")\n",
    "\n",
    "rax.set_ylabel(\"Average L-CF Value\", color=\"tab:red\")\n",
    "ax.set_ylabel(\"# of Disallowed Tokens per Generation\")\n",
    "ax.set_xlabel(\"Hyperparameter $\\\\alpha$\")\n",
    "ax.set_xticks(ALPHAS, ALPHAS)\n",
    "ax.set_xticklabels([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, \"BL\"])\n",
    "\n",
    "fig_A.tight_layout()\n",
    "fig_A.legend(loc=\"upper center\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Maintaining Specific Topic\n",
    "Prepare the following models in the `Models` section:\n",
    "* The L-CF : `HuggingFace cardiffnlp/tweet-topic-21-multi`\n",
    "* The baseline LLM : `HuggingFace meta-llama/Meta-Llama-3-8B`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x0str = \"It's time for lunch, but \"\n",
    "x0 = Gtokenize(x0str)\n",
    "h0 = lcf.get_for_text(x0str)\n",
    "print(f\"{x0=}\")\n",
    "print(f\"{h0=}\")\n",
    "\n",
    "TOPK = 30\n",
    "TEMPERATURE = 1\n",
    "NORMALIZER = ElementwiseMultiplyNormalizer()\n",
    "MAX_NEW_TOKENS = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "FILTERS = {\n",
    "    \"NC\": JustTopkFilter(TOPK, output_lcf_mapping=True, tokenizer=Gt, lcf=lcf),\n",
    "    \"CBF0.3\": CBFFilter(TOPK, 0.3, Gt, lcf),\n",
    "    \"CBF0.8\": CBFFilter(TOPK, 0.8, Gt, lcf),\n",
    "    \"BL\": BlacklistFilter(TOPK, Gt, lcf),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLE_SIZE = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ⚠️IT TAKES MUCH TIME⚠️ The text-generation process\n",
    "R = {} # Result data\n",
    "for label, filter in FILTERS.items():\n",
    "    print(label)\n",
    "    R[label] = []\n",
    "    for i in tqdm(range(SAMPLE_SIZE)):\n",
    "        sample = generate(\n",
    "            x0=x0, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,\n",
    "            filter=filter, normalizer=NORMALIZER, stream=False\n",
    "        )\n",
    "        R[label].append(sample)\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save if you want.\n",
    "torch.save(R, \"E3Results.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show generated texts\n",
    "for label, samples in R.items():\n",
    "    print(label)\n",
    "    for i, sample in enumerate(samples):\n",
    "        xf = sample[\"xf\"]\n",
    "        xfstr = Gt.decode(xf)\n",
    "        print(i, \"▶\", xfstr)\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Further Applications and Limitations of CBF-LLM\n",
    "Prepare the following models in the `Models` section:\n",
    "* The L-CF : `Entropy of meta-llama/Meta-Llama-3-8B-Instruct output`\n",
    "* The baseline LLM : `HuggingFace meta-llama/Meta-Llama-3-8B-Instruct`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_prompting_input(prompt:str)->Tensor:\n",
    "    conversation = [\n",
    "        {\"role\": \"user\", \"content\": prompt},\n",
    "    ]\n",
    "    inp = Gt.apply_chat_template(\n",
    "        conversation,\n",
    "        add_generation_prompt=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)[0]\n",
    "    return inp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TOPK = 50\n",
    "TEMPERATURE = 1.\n",
    "NORMALIZER = ElementwiseMultiplyNormalizer()\n",
    "MAX_NEW_TOKENS = 200\n",
    "x0 = create_prompting_input(\"Write a unique fizz buzz python in a single row. Just output the code.\")\n",
    "print(x0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_input_ids = create_prompting_input(\"What is CBF-LLM stands for?\")\n",
    "print(test_input_ids)\n",
    "print(Gt.decode(test_input_ids))\n",
    "sample = generate(test_input_ids, 20, TEMPERATURE, JustTopkFilter(top_k=TOPK, output_lcf_mapping=False), NORMALIZER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "FILTERS = {\n",
    "    \"NC\": JustTopkFilter(TOPK, output_lcf_mapping=True, tokenizer=Gt, lcf=lcf),\n",
    "    \"CBF0.8\": CBFFilter(TOPK, 0.8, Gt, lcf),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLE_SIZE = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ⚠️IT TAKES MUCH TIME⚠️ The text-generation process\n",
    "R = {} # Result data\n",
    "for label, filter in FILTERS.items():\n",
    "    print(label)\n",
    "    R[label] = []\n",
    "    for i in tqdm(range(SAMPLE_SIZE)):\n",
    "        sample = generate(\n",
    "            x0=x0, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,\n",
    "            filter=filter, normalizer=NORMALIZER, stream=False\n",
    "        )\n",
    "        R[label].append(sample)\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save if you want.\n",
    "torch.save(R, \"E4Results.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show generated texts\n",
    "x0len = len(x0)\n",
    "for label, samples in R.items():\n",
    "    print(label)\n",
    "    for i, sample in enumerate(samples):\n",
    "        xg = sample[\"xf\"][x0len:]\n",
    "        xgstr = Gt.decode(xg)\n",
    "        print(i, \"▶\", xgstr)\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PythonVenv_base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
