{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WQpNapZNWuXP"
   },
   "source": [
    "\n",
    "**Best-of-n sampling as an alternative to RLHF**\n",
    "\n",
    "This notebook compares reward-model scores of prompt based responses from \n",
    "1. a base model (`gpt2-imdb`)\n",
    "2. `RLHF` tuned model based on this base-model \n",
    "3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n",
    "\n",
    "Import dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vDA6qayz692w"
   },
   "outputs": [],
   "source": [
    "%pip install transformers trl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "M1s_iNm773hM"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "\n",
    "from transformers import pipeline, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "\n",
    "from trl import AutoModelForCausalLMWithValueHead\n",
    "from trl.core import LengthSampler\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y7hyrIrO8tcY"
   },
   "source": [
    "Various constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "MqS3OM6Q8x6g"
   },
   "outputs": [],
   "source": [
    "ref_model_name = \"lvwerra/gpt2-imdb\"\n",
    "model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n",
    "reward_model = \"lvwerra/distilbert-imdb\"\n",
    "\n",
    "N_BEST_OF = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c1YcXeElg6or"
   },
   "source": [
    "Models and tokenizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "b855NrL181Hh"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/kashif/Github/transformers/src/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AutoModelForCausalLMWithValueHead(\n",
       "  (pretrained_model): GPT2LMHeadModel(\n",
       "    (transformer): GPT2Model(\n",
       "      (wte): Embedding(50257, 768)\n",
       "      (wpe): Embedding(1024, 768)\n",
       "      (drop): Dropout(p=0.1, inplace=False)\n",
       "      (h): ModuleList(\n",
       "        (0-11): 12 x GPT2Block(\n",
       "          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): GPT2SdpaAttention(\n",
       "            (c_attn): Conv1D(nf=2304, nx=768)\n",
       "            (c_proj): Conv1D(nf=768, nx=768)\n",
       "            (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "            (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): GPT2MLP(\n",
       "            (c_fc): Conv1D(nf=3072, nx=768)\n",
       "            (c_proj): Conv1D(nf=768, nx=3072)\n",
       "            (act): NewGELUActivation()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       "  )\n",
       "  (v_head): ValueHead(\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "    (summary): Linear(in_features=768, out_features=1, bias=True)\n",
       "    (flatten): Flatten(start_dim=1, end_dim=-1)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n",
    "\n",
    "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n",
    "\n",
    "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# cuda-ize models\n",
    "model.to(device)\n",
    "ref_model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z1Cz0gCFhZYJ"
   },
   "source": [
    "Dataset building"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "LqLVEp5p_8XM"
   },
   "outputs": [],
   "source": [
    "def build_dataset(\n",
    "    tokenizer,\n",
    "    dataset_name=\"stanfordnlp/imdb\",\n",
    "    input_min_text_length=2,\n",
    "    input_max_text_length=8,\n",
    "):\n",
    "    # load imdb with datasets\n",
    "    ds = load_dataset(dataset_name, split=\"train\")\n",
    "    ds = ds.rename_columns({\"text\": \"review\"})\n",
    "    ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
    "\n",
    "    input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
    "\n",
    "    def tokenize(sample):\n",
    "        sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
    "        sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
    "        return sample\n",
    "\n",
    "    ds = ds.map(tokenize, batched=False)\n",
    "    ds.set_format(type=\"torch\")\n",
    "    return ds\n",
    "\n",
    "\n",
    "dataset = build_dataset(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "AqA2McjMAxNw"
   },
   "outputs": [],
   "source": [
    "gen_kwargs = {\n",
    "    \"min_length\": -1,\n",
    "    \"top_k\": 0.0,\n",
    "    \"top_p\": 1.0,\n",
    "    \"do_sample\": True,\n",
    "    \"pad_token_id\": tokenizer.eos_token_id,\n",
    "}\n",
    "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "L_q4qs35AxcR"
   },
   "outputs": [],
   "source": [
    "output_min_length = 4\n",
    "output_max_length = 16\n",
    "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
    "\n",
    "#### get a batch from the dataset\n",
    "bs = 16\n",
    "output_data = dict()\n",
    "dataset.set_format(\"pandas\")\n",
    "df_batch = dataset[:].sample(bs)\n",
    "output_data[\"query\"] = df_batch[\"query\"].tolist()\n",
    "query_tensors = df_batch[\"input_ids\"].tolist()\n",
    "\n",
    "# :: [Resp]\n",
    "response_tensors_ref, response_tensors = [], []\n",
    "# :: [[Resp]]\n",
    "response_tensors_best_of = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QVfpyHnZBLKY"
   },
   "source": [
    "\n",
    "Generation using various models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "-imZ7uEFBNbw"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
     ]
    }
   ],
   "source": [
    "for i in range(bs):\n",
    "    gen_len = output_length_sampler()\n",
    "\n",
    "    query = torch.tensor(query_tensors[i])\n",
    "\n",
    "    output = ref_model.generate(\n",
    "        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "    ).squeeze()\n",
    "    response_tensors_ref.append(tokenizer.decode(output))\n",
    "\n",
    "    output = model.generate(\n",
    "        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "    ).squeeze()\n",
    "    response_tensors.append(tokenizer.decode(output))\n",
    "\n",
    "    # generating copies of the same query for the Best-of-n sampling\n",
    "    queries = query.repeat((N_BEST_OF, 1))\n",
    "    output = ref_model.generate(\n",
    "        queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "    ).squeeze()\n",
    "    response_tensors_best_of.append(tokenizer.batch_decode(output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jp5FC0Y5h_Sf"
   },
   "source": [
    "Scoring"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "PyDbbAQ0F_h7"
   },
   "outputs": [],
   "source": [
    "scores_ref = [\n",
    "    output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n",
    "]\n",
    "scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n",
    "scores_best_of = []\n",
    "for i, response in enumerate(response_tensors_best_of):\n",
    "    # base_score = scores_ref[i]\n",
    "    scores_best_of.append(\n",
    "        torch.tensor(\n",
    "            [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 682
    },
    "id": "nA1GDNJEiGm-",
    "outputId": "1389c686-0751-4304-dea2-b71fd68748e1"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>query</th>\n",
       "      <th>response (ref)</th>\n",
       "      <th>scores (ref)</th>\n",
       "      <th>response (RLHF)</th>\n",
       "      <th>scores (RLHF)</th>\n",
       "      <th>response (best_of)</th>\n",
       "      <th>scores (best_of)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>This movie</td>\n",
       "      <td>This movie should have read some books, and</td>\n",
       "      <td>1.411889</td>\n",
       "      <td>This movie has plenty of extraordinary feature...</td>\n",
       "      <td>2.735337</td>\n",
       "      <td>This movie was unexpectedly funny and funny, you</td>\n",
       "      <td>2.405301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OK where do i begin?</td>\n",
       "      <td>OK where do i begin? *** Acting is decent (not...</td>\n",
       "      <td>1.555380</td>\n",
       "      <td>OK where do i begin? For all of you who are no...</td>\n",
       "      <td>0.019694</td>\n",
       "      <td>OK where do i begin? i just wanted to add some...</td>\n",
       "      <td>0.622912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>I watched</td>\n",
       "      <td>I watched one can compare themselves upon view...</td>\n",
       "      <td>1.380120</td>\n",
       "      <td>I watched it because of its excellent cast. Th...</td>\n",
       "      <td>2.498309</td>\n",
       "      <td>I watched the trial trial for teaches us a goo...</td>\n",
       "      <td>2.057187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>It's been 19 years since Gordon</td>\n",
       "      <td>It's been 19 years since Gordon finally left c...</td>\n",
       "      <td>1.554914</td>\n",
       "      <td>It's been 19 years since Gordon Tree has becom...</td>\n",
       "      <td>1.632266</td>\n",
       "      <td>It's been 19 years since Gordon Clarke put me ...</td>\n",
       "      <td>2.783458</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Just kidding</td>\n",
       "      <td>Just kidding; I know a lot</td>\n",
       "      <td>-0.069533</td>\n",
       "      <td>Just kidding \"Third World Snopes</td>\n",
       "      <td>0.944632</td>\n",
       "      <td>Just kidding, I didn't even</td>\n",
       "      <td>1.945202</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>shakespeare's plays have a way</td>\n",
       "      <td>shakespeare's plays have a way of weaving into...</td>\n",
       "      <td>1.656927</td>\n",
       "      <td>shakespeare's plays have a way. It's the look ...</td>\n",
       "      <td>1.444803</td>\n",
       "      <td>shakespeare's plays have a way of getting back...</td>\n",
       "      <td>1.834373</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>This movie is wonderful. What</td>\n",
       "      <td>This movie is wonderful. What could have been ...</td>\n",
       "      <td>2.749068</td>\n",
       "      <td>This movie is wonderful. What someone likes ab...</td>\n",
       "      <td>2.759510</td>\n",
       "      <td>This movie is wonderful. What a different look,</td>\n",
       "      <td>2.695312</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>I loved</td>\n",
       "      <td>I loved this film. &lt;br /&gt;&lt;</td>\n",
       "      <td>2.576181</td>\n",
       "      <td>I loved it, and I really loved Audrey</td>\n",
       "      <td>2.578412</td>\n",
       "      <td>I loved this film. Reading reviews of it</td>\n",
       "      <td>2.751773</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>A superb and</td>\n",
       "      <td>A superb and very cool drama. The novel is</td>\n",
       "      <td>2.910374</td>\n",
       "      <td>A superb and super fun movie that removes all the</td>\n",
       "      <td>2.783201</td>\n",
       "      <td>A superb and most finely acted role that I will</td>\n",
       "      <td>2.894923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>I remember</td>\n",
       "      <td>I remember.Very poor execution but good movies</td>\n",
       "      <td>0.923775</td>\n",
       "      <td>I remember when Shelter saw some girls on TV</td>\n",
       "      <td>0.825408</td>\n",
       "      <td>I remember thinking to myself how SOMEONE who</td>\n",
       "      <td>1.634163</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>This su*k</td>\n",
       "      <td>This su*k camel down your kidd</td>\n",
       "      <td>1.605957</td>\n",
       "      <td>This su*k Dress! I loved it</td>\n",
       "      <td>2.345865</td>\n",
       "      <td>This su*k like a roll of crap</td>\n",
       "      <td>2.422874</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>One Stink</td>\n",
       "      <td>One Stink Act...&lt;br /&gt;&lt;br</td>\n",
       "      <td>1.456476</td>\n",
       "      <td>One Stinkl was a great actor, particularly</td>\n",
       "      <td>1.782818</td>\n",
       "      <td>One Stink?: Invisible of Saint Barbara, poor</td>\n",
       "      <td>1.667756</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>I pulled down a VHS</td>\n",
       "      <td>I pulled down a VHS copy and watched it with m...</td>\n",
       "      <td>0.756151</td>\n",
       "      <td>I pulled down a VHS looking a good looking, and a</td>\n",
       "      <td>-0.008258</td>\n",
       "      <td>I pulled down a VHS copy the other day and all I</td>\n",
       "      <td>0.992919</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>For some</td>\n",
       "      <td>For some alone no more Buddy Trumbull would ha...</td>\n",
       "      <td>0.790762</td>\n",
       "      <td>For some enthraled time, the film will impress...</td>\n",
       "      <td>2.455694</td>\n",
       "      <td>For some reason, a bomb crashed on the rear of...</td>\n",
       "      <td>0.857423</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>This one features all</td>\n",
       "      <td>This one features all the good elements of spi...</td>\n",
       "      <td>1.452079</td>\n",
       "      <td>This one features all kinds of wit and humor r...</td>\n",
       "      <td>2.743043</td>\n",
       "      <td>This one features all the best Birdprogram sup...</td>\n",
       "      <td>2.343950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>Somehow a woman working with</td>\n",
       "      <td>Somehow a woman working with Jim Wynorski prof...</td>\n",
       "      <td>0.242172</td>\n",
       "      <td>Somehow a woman working with her daughter play...</td>\n",
       "      <td>0.092226</td>\n",
       "      <td>Somehow a woman working with an overweight ins...</td>\n",
       "      <td>1.415525</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                              query  \\\n",
       "0                        This movie   \n",
       "1              OK where do i begin?   \n",
       "2                         I watched   \n",
       "3   It's been 19 years since Gordon   \n",
       "4                      Just kidding   \n",
       "5    shakespeare's plays have a way   \n",
       "6     This movie is wonderful. What   \n",
       "7                           I loved   \n",
       "8                      A superb and   \n",
       "9                        I remember   \n",
       "10                        This su*k   \n",
       "11                        One Stink   \n",
       "12              I pulled down a VHS   \n",
       "13                         For some   \n",
       "14            This one features all   \n",
       "15     Somehow a woman working with   \n",
       "\n",
       "                                       response (ref)  scores (ref)  \\\n",
       "0         This movie should have read some books, and      1.411889   \n",
       "1   OK where do i begin? *** Acting is decent (not...      1.555380   \n",
       "2   I watched one can compare themselves upon view...      1.380120   \n",
       "3   It's been 19 years since Gordon finally left c...      1.554914   \n",
       "4                          Just kidding; I know a lot     -0.069533   \n",
       "5   shakespeare's plays have a way of weaving into...      1.656927   \n",
       "6   This movie is wonderful. What could have been ...      2.749068   \n",
       "7                          I loved this film. <br /><      2.576181   \n",
       "8          A superb and very cool drama. The novel is      2.910374   \n",
       "9      I remember.Very poor execution but good movies      0.923775   \n",
       "10                     This su*k camel down your kidd      1.605957   \n",
       "11                          One Stink Act...<br /><br      1.456476   \n",
       "12  I pulled down a VHS copy and watched it with m...      0.756151   \n",
       "13  For some alone no more Buddy Trumbull would ha...      0.790762   \n",
       "14  This one features all the good elements of spi...      1.452079   \n",
       "15  Somehow a woman working with Jim Wynorski prof...      0.242172   \n",
       "\n",
       "                                      response (RLHF)  scores (RLHF)  \\\n",
       "0   This movie has plenty of extraordinary feature...       2.735337   \n",
       "1   OK where do i begin? For all of you who are no...       0.019694   \n",
       "2   I watched it because of its excellent cast. Th...       2.498309   \n",
       "3   It's been 19 years since Gordon Tree has becom...       1.632266   \n",
       "4                    Just kidding \"Third World Snopes       0.944632   \n",
       "5   shakespeare's plays have a way. It's the look ...       1.444803   \n",
       "6   This movie is wonderful. What someone likes ab...       2.759510   \n",
       "7               I loved it, and I really loved Audrey       2.578412   \n",
       "8   A superb and super fun movie that removes all the       2.783201   \n",
       "9        I remember when Shelter saw some girls on TV       0.825408   \n",
       "10                        This su*k Dress! I loved it       2.345865   \n",
       "11         One Stinkl was a great actor, particularly       1.782818   \n",
       "12  I pulled down a VHS looking a good looking, and a      -0.008258   \n",
       "13  For some enthraled time, the film will impress...       2.455694   \n",
       "14  This one features all kinds of wit and humor r...       2.743043   \n",
       "15  Somehow a woman working with her daughter play...       0.092226   \n",
       "\n",
       "                                   response (best_of)  scores (best_of)  \n",
       "0    This movie was unexpectedly funny and funny, you          2.405301  \n",
       "1   OK where do i begin? i just wanted to add some...          0.622912  \n",
       "2   I watched the trial trial for teaches us a goo...          2.057187  \n",
       "3   It's been 19 years since Gordon Clarke put me ...          2.783458  \n",
       "4                         Just kidding, I didn't even          1.945202  \n",
       "5   shakespeare's plays have a way of getting back...          1.834373  \n",
       "6     This movie is wonderful. What a different look,          2.695312  \n",
       "7            I loved this film. Reading reviews of it          2.751773  \n",
       "8     A superb and most finely acted role that I will          2.894923  \n",
       "9       I remember thinking to myself how SOMEONE who          1.634163  \n",
       "10                      This su*k like a roll of crap          2.422874  \n",
       "11       One Stink?: Invisible of Saint Barbara, poor          1.667756  \n",
       "12   I pulled down a VHS copy the other day and all I          0.992919  \n",
       "13  For some reason, a bomb crashed on the rear of...          0.857423  \n",
       "14  This one features all the best Birdprogram sup...          2.343950  \n",
       "15  Somehow a woman working with an overweight ins...          1.415525  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_data[\"response (ref)\"] = response_tensors_ref\n",
    "output_data[\"scores (ref)\"] = scores_ref\n",
    "output_data[\"response (RLHF)\"] = response_tensors\n",
    "output_data[\"scores (RLHF)\"] = scores\n",
    "output_data[\"response (best_of)\"] = [\n",
    "    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)\n",
    "]\n",
    "output_data[\"scores (best_of)\"] = [a.max().item() for a in scores_best_of]\n",
    "\n",
    "\n",
    "# store results in a dataframe\n",
    "df_results = pd.DataFrame(output_data)\n",
    "df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
