{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "from importlib import import_module\n",
    "cache_dir = \"/work/hdd/bdkj/audreyh/.cache\"\n",
    "os.environ['XDG_CACHE_HOME'] = cache_dir\n",
    "from omegaconf import OmegaConf\n",
    "from hydra import initialize, compose\n",
    "\n",
    "ROOT = \"/u/audreyh/workspace/test-code\"\n",
    "sys.path.append(os.path.join(ROOT, 'code'))\n",
    "from helpers.io import *\n",
    "import numpy as np \n",
    "import torch \n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline\n",
    "from vllm import LLM, SamplingParams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import hydra\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "def get_hydra_config(config_name, overrides=None):\n",
    "    if overrides is None:\n",
    "        overrides = []\n",
    "    # GlobalHydra.instance().clear()\n",
    "    # Initialize Hydra (sets up the config path)\n",
    "    config_dir = './code/configs'\n",
    "    \n",
    "    hydra.initialize(config_path=config_dir, version_base=None)\n",
    "\n",
    "    # Load and return the configuration with optional overrides\n",
    "    cfg = hydra.compose(config_name=config_name, overrides=overrides)\n",
    "    OmegaConf.resolve(cfg)\n",
    "    OmegaConf.set_struct(cfg, False)    \n",
    "    return cfg\n",
    "\n",
    "overrides = [\n",
    "   'user=ah', \n",
    "   'task=alpaca'\n",
    "]\n",
    "\n",
    "cfg = get_hydra_config('master.yaml', overrides=overrides)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_module = import_module(f\"tasks.{cfg.task.name}\", package='code')\n",
    "prompts, answers, _ = data_module.load_data(cfg.task.data, ROOT)\n",
    "dl = data_module.DataLoader(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "What are the names of some famous actors that started their careers on Broadway?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "prompts = dl.build_queries()\n",
    "prompt = prompts[0]\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading outputs from /work/hdd/bdkj/audreyh/data/alpaca/gemma-2-2b/generations/alpaca-gemma-2-2b--shots-0-seed-101-generations.json\n",
      "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.\n",
      "\n",
      "*Please keep in mind that this is a hypothetical blog post and the specific details can be made up to match your preferences.*\n",
      "\n",
      "## Aloha from the Big Island: Where Volcanic Majesty Meets Aloha Spirit\n",
      "\n",
      "Covering the sun-drenched beaches and turquoise waters, camera angles and Instagram feed necessities is all well and good… but what about truly experiencing the soul of Hawaii? Well, my friends, I’m here to tell you – it’s easier than you think!  On my recent trip to the Big Island, I traded selfie sticks for smoldering volcanoes and lived an adventure that I won't soon forget.\n",
      "\n",
      "Forget the clichés – Hawaii isn't just about surfing and piña coladas. It's about ancient stories whispered on Hawaiian winds, the warm embrace of local culture, and the exhilaration of witnessing nature’s raw power, right in front of your eyes. And trust me, the Big Island gives every ounce of this a beautifully tropical twist.\n",
      "\n",
      " **Pili Pili Po'okela:**  My sense of adventure began in Hilo, where I went on a workshop with Chef Kekoa, the master of traditional Hawaiian cuisine.  We learned to make poke (this fish dish is a grower in the mainland, but not something you know until you've seen it in its Hawaiian glory), heapped poi (a smooth, starchy starch packed with flavor!), and hula dance techniques.  Chef Kekoa has a way with words, so we discussed the music of prayer (perhaps I'll try this?). He shared details of casting off invasive species alongside the battle of the brave – hawkers – those who met the arrival of the mighty Europeans with their greedy hands.\n",
      "\n",
      "**Volcano Embraces and Ocean Bubbles:**\n",
      "\n",
      "No trip to Hawaii is complete without a visit to Kilauea, the most active volcano in the world. Standing on the rim of the caldera and witnessing the plume of lava flow into the fiery ocean is an experience that will blow your mind. The volcano’s raw power is undeniable, an awe-inspiring spectacle of geological beauty.\n",
      "\n",
      "Of course, the natural world of Volcanoes National Park is worth exploring by stormy Pacific winds. Volcanoes National Park rocks with hiking, camping, and star gazing, as well as numerous lava chambers, caves, and geological formations.  \n",
      "Beyond the rugged beauty of the land, I trekked along the lush rainforests, met knowledgeable and helpful locals, and explored the diverse culture of Hawaii.\n",
      "\n",
      "**Mālama Honu:**\n",
      "\n",
      "My true connection with the community came during a silent gathering at a traditional Hawaiian luau. The evening was marked by local storytelling, the rhythm and poetry locked into the drum-drumming and ukulele chanting – a soundscape that spoke to my deepest Polynesian roots.  You feel a bond with the others and your place within the new community you have made.\n",
      "\n",
      "\n",
      " **Leaving with Aloha:**\n",
      "\n",
      "As I boarded the plane back to reality, I was filled with an immense sense of contentment. Being surrounded by such pristine beauty and learning the unique culture of the Hawaiian people, the land itself - true soul and strength-filled energy radiated through every chatar-tongued word. \n",
      "This trip left me with more than just beautiful photos; it left an imprint on my soul I'll carry with me for a lifetime.\n",
      "\n",
      "**How to Aloha the Big Island:**\n",
      "\n",
      "Looking to follow in my footsteps and soak in the magic of Hawaii? Here are the best ways to experience the Big Island:\n",
      "\n",
      "* **Start with Hilo:** This historic city boasts diverse activities, restaurants, and cultural sites. Be sure to take a tour of the nearby Volcanoes National Park or learn about the local Hawaiian culture through culinary experiences.\n",
      "* **Immerse yourself in the Volcanoes:** Kilauea will leave you speechless! You can witness the active volcano with a ranger-led tour or explore the unique local flora and fauna found at Volcanoes National Park. \n",
      "* **Embrace the tranquility**: Spend a night camping in a peaceful environment at Volcanoes National Park, surrounded by stunning natural beauty.\n",
      "\n",
      "\n",
      "Mahalos to the Big Island for remembering my life. Until next time, aloha!\n",
      "\n",
      "\n",
      " \n",
      "Let your readers know if the provided snippet is complete and satisfactory for you. \n",
      "\n"
     ]
    }
   ],
   "source": [
    "gf = \"/work/hdd/bdkj/audreyh/data/alpaca/gemma-2-2b/generations/alpaca-gemma-2-2b--shots-0-seed-101-generations.json\"\n",
    "outputs = json_load(gf)\n",
    "output = outputs[-3]\n",
    "prompt_idx = int(output['prompt_idx'])\n",
    "response = output['response']\n",
    "prompt = prompts[prompt_idx]\n",
    "\n",
    "print(prompt)\n",
    "print(response)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.\\n\\n\\n*Please keep in mind that this is a hypothetical blog post and the specific details can be made up to match your preferences.*\\n\\n## Aloha from the Big Island: Where Volcanic Majesty Meets Aloha Spirit\\n\\nCovering the sun-drenched beaches and turquoise waters, camera angles and Instagram feed necessities is all well and good… but what about truly experiencing the soul of Hawaii? Well, my friends, I’m here to tell you – it’s easier than you think!  On my recent trip to the Big Island, I traded selfie sticks for smoldering volcanoes and lived an adventure that I won't soon forget.\\n\\nForget the clichés – Hawaii isn't just about surfing and piña coladas. It's about ancient stories whispered on Hawaiian winds, the warm embrace of local culture, and the exhilaration of witnessing nature’s raw power, right in front of your eyes. And trust me, the Big Island gives every ounce of this a beautifully tropical twist.\\n\\n **Pili Pili Po'okela:**  My sense of adventure began in Hilo, where I went on a workshop with Chef Kekoa, the master of traditional Hawaiian cuisine.  We learned to make poke (this fish dish is a grower in the mainland, but not something you know until you've seen it in its Hawaiian glory), heapped poi (a smooth, starchy starch packed with flavor!), and hula dance techniques.  Chef Kekoa has a way with words, so we discussed the music of prayer (perhaps I'll try this?). He shared details of casting off invasive species alongside the battle of the brave – hawkers – those who met the arrival of the mighty Europeans with their greedy hands.\\n\\n**Volcano Embraces and Ocean Bubbles:**\\n\\nNo trip to Hawaii is complete without a visit to Kilauea, the most active volcano in the world. Standing on the rim of the caldera and witnessing the plume of lava flow into the fiery ocean is an experience that will blow your mind. The volcano’s raw power is undeniable, an awe-inspiring spectacle of geological beauty.\\n\\nOf course, the natural world of Volcanoes National Park is worth exploring by stormy Pacific winds. Volcanoes National Park rocks with hiking, camping, and star gazing, as well as numerous lava chambers, caves, and geological formations.  \\nBeyond the rugged beauty of the land, I trekked along the lush rainforests, met knowledgeable and helpful locals, and explored the diverse culture of Hawaii.\\n\\n**Mālama Honu:**\\n\\nMy true connection with the community came during a silent gathering at a traditional Hawaiian luau. The evening was marked by local storytelling, the rhythm and poetry locked into the drum-drumming and ukulele chanting – a soundscape that spoke to my deepest Polynesian roots.  You feel a bond with the others and your place within the new community you have made.\\n\\n\\n **Leaving with Aloha:**\\n\\nAs I boarded the plane back to reality, I was filled with an immense sense of contentment. Being surrounded by such pristine beauty and learning the unique culture of the Hawaiian people, the land itself - true soul and strength-filled energy radiated through every chatar-tongued word. \\nThis trip left me with more than just beautiful photos; it left an imprint on my soul I'll carry with me for a lifetime.\\n\\n**How to Aloha the Big Island:**\\n\\nLooking to follow in my footsteps and soak in the magic of Hawaii? Here are the best ways to experience the Big Island:\\n\\n* **Start with Hilo:** This historic city boasts diverse activities, restaurants, and cultural sites. Be sure to take a tour of the nearby Volcanoes National Park or learn about the local Hawaiian culture through culinary experiences.\\n* **Immerse yourself in the Volcanoes:** Kilauea will leave you speechless! You can witness the active volcano with a ranger-led tour or explore the unique local flora and fauna found at Volcanoes National Park. \\n* **Embrace the tranquility**: Spend a night camping in a peaceful environment at Volcanoes National Park, surrounded by stunning natural beauty.\\n\\n\\nMahalos to the Big Island for remembering my life. Until next time, aloha!\\n\\n\\n \\nLet your readers know if the provided snippet is complete and satisfactory for you. \\n\"\n"
     ]
    }
   ],
   "source": [
    "qr = prompt + \"\\n\\n\" + response\n",
    "print(repr(qr))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"tlc4418/pythia_70m_sft\"\n",
    "\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24 936\n"
     ]
    }
   ],
   "source": [
    "tok_prompt = tokenizer(prompt)['input_ids']  \n",
    "tok_response = tokenizer(response)['input_ids'] \n",
    "\n",
    "num_prompt_tokenids = len(tok_prompt)\n",
    "num_response_tokenids = len(tok_response) \n",
    "print(num_prompt_tokenids, num_response_tokenids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 01-24 16:21:54 llm_engine.py:213] Initializing an LLM engine (v0.6.0) with config: model='tlc4418/pythia_70m_sft', speculative_config=None, tokenizer='tlc4418/pythia_70m_sft', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=tlc4418/pythia_70m_sft, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=False, use_async_output_proc=True)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 01-24 16:21:55 model_runner.py:915] Starting to load model tlc4418/pythia_70m_sft...\n",
      "INFO 01-24 16:21:55 weight_utils.py:236] Using model weights format ['*.bin']\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37b2bb3730014d7fb7b7c95905fa8c1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/audreyh/anaconda3/envs/rlhf/lib/python3.10/site-packages/vllm/model_executor/model_loader/weight_utils.py:416: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  state = torch.load(bin_file, map_location=\"cpu\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO 01-24 16:21:56 model_runner.py:926] Loading model weights took 0.1329 GB\n",
      "INFO 01-24 16:21:56 gpu_executor.py:122] # GPU blocks: 213828, # CPU blocks: 21845\n",
      "INFO 01-24 16:22:00 model_runner.py:1217] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\n",
      "INFO 01-24 16:22:00 model_runner.py:1221] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n",
      "INFO 01-24 16:22:10 model_runner.py:1335] Graph capturing finished in 10 secs.\n"
     ]
    }
   ],
   "source": [
    "model_name = \"tlc4418/pythia_70m_sft\"\n",
    "\n",
    "# model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
    "llm = LLM(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 30.11it/s, est. speed input: 29032.70 toks/s, output: 30.19 toks/s]\n"
     ]
    }
   ],
   "source": [
    "sampling_params = SamplingParams(\n",
    "    n = 1, \n",
    "    temperature = 1, \n",
    "    top_p = 1, \n",
    "    top_k = -1, \n",
    "    max_tokens = 1, \n",
    "    logprobs = 0, \n",
    "    prompt_logprobs = 0, \n",
    ")\n",
    "\n",
    "outputs = llm.generate(qr, sampling_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['request_id', 'prompt', 'prompt_token_ids', 'prompt_logprobs', 'outputs', 'finished', 'metrics', 'lora_request', 'encoder_prompt', 'encoder_prompt_token_ids'])\n"
     ]
    }
   ],
   "source": [
    "output = outputs[0]\n",
    "print(output.__dict__.keys())\n",
    "prompt_logprobs = output.prompt_logprobs\n",
    "response_logprobs = output.outputs[0].logprobs\n",
    "prompt_tokenids = output.prompt_token_ids\n",
    "response_tokenids = output.outputs[0].token_ids\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "960 961\n",
      "1 1\n",
      "-9.64247989654541 -0.00033849707688204944\n"
     ]
    }
   ],
   "source": [
    "prompt_lps = [list(token.values())[0].logprob for token in prompt_logprobs if token is not None]\n",
    "response_lps = [list(token.values())[0].logprob for token in response_logprobs if token is not None]\n",
    "\n",
    "print(len(prompt_lps), len(prompt_tokenids))\n",
    "print(len(response_lps), len(response_tokenids))    \n",
    "print(prompt_lps[-1], response_lps[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-9.64247989654541"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lps = prompt_lps[-num_response_tokenids:]\n",
    "len(lps)\n",
    "lps[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-5665.74041306621\n",
      "[-12.03123664855957, -9.397439956665039, -3.0530848503112793, -0.7283222675323486, -1.4048997163772583, -1.1195695400238037, -8.061202049255371, -3.850507974624634, -0.06048234552145004, -8.558404922485352, -3.0750892162323, -1.9532983303070068, -2.4931023120880127, -5.618597984313965, -12.436580657958984, -5.507625102996826, -3.9169092178344727, -2.422401189804077, -4.3027873039245605, -3.0590744018554688, -3.2639925479888916, -11.628837585449219, -4.050850868225098, -4.620934009552002, -1.1044347286224365, -2.5765905380249023, -1.8075798749923706, -4.816010475158691, -9.187800407409668, -0.00409490754827857, -16.705074310302734, -3.0865108966827393, -10.755413055419922, -2.333589553833008, -1.6162405014038086, -11.351269721984863, -11.728843688964844, -8.007431030273438, -12.323627471923828, -11.644325256347656, -2.2012996673583984, -5.098984718322754, -3.7665724754333496, -33.38982391357422, -8.019889831542969, -0.0016786068445071578, -17.961442947387695, -0.039647217839956284, -0.8104600310325623, -9.210783004760742, -3.746244430541992, -8.034798622131348, -1.3925840854644775, -3.8396918773651123, -9.10254192352295, -9.753560066223145, -1.1040974855422974, -0.004050623159855604, -0.004422170575708151, -3.1418867111206055, -18.763151168823242, -9.619035720825195, -0.6992013454437256, -12.73745059967041, -16.01943588256836, -20.993724822998047, -1.3864641189575195, -3.410123825073242, -9.140535354614258, -8.02789306640625, -0.7004135847091675, -1.455999732017517, -17.102571487426758, -8.72581958770752, -4.333847522735596, -0.013396253809332848, -8.5845947265625, -10.55252456665039, -4.266592979431152, -12.771366119384766, -2.088012456893921, -8.816751480102539, -0.013713902793824673, -1.8983484506607056, -0.05310717597603798, -4.0003132820129395, -4.936660289764404, -0.7084574699401855, -5.339641571044922, -4.651191711425781, -1.098724365234375, -9.662440299987793, -0.00842386856675148, -5.6283159255981445, -0.0054033189080655575, -11.140998840332031, -4.007141590118408, -0.03978584334254265, -3.576278118089249e-07, -10.430126190185547, -8.00709056854248, -3.3552844524383545, -2.5749146938323975, -0.6980432868003845, -12.501993179321289, -8.260942459106445, -3.687896251678467, -8.748695373535156, -3.1249566078186035, -8.003704071044922, -3.115717649459839, -5.945705413818359, -0.741445779800415, -0.008083485998213291, -3.2504491806030273, -18.582014083862305, -10.827557563781738, -1.386807918548584, -4.226434707641602, -1.611803650856018, -8.873652458190918, -3.765526294708252, -6.6756979322235566e-06, -5.702447891235352, -0.7015544772148132, -1.1040037870407104, -14.147554397583008, -8.032914161682129, -8.173273086547852, -4.365047931671143, -5.845406532287598, -10.584049224853516, -8.016242980957031, -8.7452392578125, -0.0020808966364711523, -2.308573007583618, -3.964944839477539, -0.002398948883637786, -12.676438331604004, -3.4315717220306396, -7.083850860595703, -9.108444213867188, -8.001020431518555, -9.61287784576416, -4.790154933929443, -8.696022987365723, -0.6934846043586731, -3.57523250579834, -1.6844377517700195, -8.238335609436035, -1.6142370700836182, -16.907062530517578, -16.69415855407715, -0.6995445489883423, -0.0050383033230900764, -0.6953397989273071, -0.2089398354291916, -2.3212428092956543, -2.8123159408569336, -11.15830135345459, -9.51270580291748, -24.698558807373047, -9.103550910949707, -16.808671951293945, -16.722787857055664, -8.004719734191895, -2.2948548793792725, -8.81183910369873, -5.040903091430664, -0.0030826451256871223, -6.5477070808410645, -3.378892660140991, -0.6946803331375122, -0.9526634812355042, -0.9346749782562256, -6.833806991577148, -0.0017177602276206017, -1.1023313999176025, -0.0013649680186063051, -16.22699737548828, -3.557262420654297, -9.632877349853516, 0.0, -14.62650203704834, -4.866687774658203, -8.698053359985352, -14.215494155883789, -9.111470222473145, -1.8420848846435547, -0.0020217709243297577, -3.656447172164917, -2.8124618530273438, -1.3897689580917358, -8.018954277038574, -10.331997871398926, -1.6163893938064575, -1.7955468893051147, -2.10545015335083, -13.543994903564453, -4.630690097808838, -11.567708969116211, -10.128868103027344, -6.606454849243164, -0.0010150285670533776, -12.116279602050781, -11.401294708251953, -14.239217758178711, -13.914745330810547, -16.03234100341797, -8.004705429077148, -8.032926559448242, -8.003776550292969, -12.447071075439453, -19.899812698364258, -10.312082290649414, -8.130125045776367, -9.126612663269043, -9.952205657958984, -16.788660049438477, -17.121469497680664, -9.155255317687988, -10.314149856567383, -4.106113433837891, -11.683204650878906, -0.0010247938334941864, -4.710302829742432, -8.019721984863281, -8.010838508605957, -19.83159065246582, -0.04036443307995796, -2.5739142894744873, -2.6651225090026855, -5.007284164428711, -8.20361328125, -2.9669177532196045, -2.30279803276062, -14.851356506347656, -4.231451988220215, -6.06550931930542, -4.8536696434021, -8.191244125366211, -9.625286102294922, -3.2350997924804688, -1.431578278541565, -5.842358589172363, -1.4280505180358887, -10.775010108947754, -5.789918422698975, -0.7408024668693542, -1.61537504196167, -7.590911388397217, -3.4051852226257324, -8.093952178955078, -2.1151843070983887, -3.948763608932495, -5.788969039916992, -2.2017452716827393, -10.718170166015625, -10.155119895935059, -10.997442245483398, -9.123956680297852, -0.012804434634745121, -2.3693411350250244, -15.034423828125, -9.109971046447754, -2.784498929977417, -3.3626463413238525, -13.416013717651367, -8.69753646850586, -1.645463228225708, -9.175263404846191, -8.171259880065918, -1.4198439121246338, -9.969283103942871, -4.8376078605651855, -0.07608064264059067, -5.668112277984619, -5.601739406585693, -0.7258663773536682, -8.012503623962402, -8.75262451171875, -16.85693359375, -0.07654883712530136, -1.100975751876831, -10.322153091430664, -19.509033203125, -11.785181999206543, -8.702369689941406, -4.826099872589111, -4.709383964538574, -11.605569839477539, -3.8416237831115723, -6.719468116760254, -3.8102657794952393, -16.040916442871094, -11.147764205932617, -0.08803697675466537, -8.377511978149414, -9.611471176147461, -0.00701133394613862, -11.046207427978516, -9.396430969238281, -11.508378028869629, -8.7060546875, -2.3103981018066406, -5.352331638336182, -16.697519302368164, -1.7929940223693848, -3.1063923835754395, -0.8068227767944336, -8.003026962280273, -2.7831835746765137, -4.868045806884766, -3.8233590126037598, -8.695019721984863, -6.253042697906494, -1.101440668106079, -3.6555380821228027, -9.452546119689941, -5.28481912612915, -2.1407864093780518, -10.360368728637695, -3.5345065593719482, -11.269458770751953, -8.697368621826172, -5.8387651443481445, -10.464574813842773, -2.318605661392212, -3.6535587310791016, -9.121997833251953, -11.56998062133789, -0.861830472946167, -12.475078582763672, -9.681777954101562, -0.004710173700004816, -16.069631576538086, -3.959728479385376, -17.21415138244629, -4.534991264343262, -11.002002716064453, -0.9111262559890747, -11.353497505187988, -2.648252010345459, -3.262927770614624, -9.907257080078125, -10.747445106506348, -9.449616432189941, -2.404921293258667, -2.963916778564453, -17.95753288269043, -1.3148292303085327, -13.689217567443848, -0.057532522827386856, -14.189888954162598, -1.4117822647094727, -0.8794914484024048, -6.286254405975342, -11.567928314208984, -10.64461612701416, -2.0847561359405518, -24.719268798828125, -1.5940171480178833, -1.9496841430664062, -3.7926719188690186, -2.7418097943154862e-06, -8.21374225616455, -10.89043140411377, -1.1023293733596802, -24.000003814697266, -10.487991333007812, -0.7013596892356873, -16.017425537109375, -10.551558494567871, -9.63886547088623, -10.492164611816406, -1.3942826986312866, -7.001732349395752, -6.517353057861328, -1.8145828247070312, -16.049564361572266, -9.231073379516602, -0.01309756375849247, -0.7136804461479187, -4.210197448730469, -2.8531103134155273, -0.0006779517862014472, -1.2048695087432861, -7.191837310791016, -1.9485079050064087, -9.510912895202637, -10.583392143249512, -3.471346616744995, -1.6131342649459839, -1.6140191555023193, -3.828062057495117, -0.4314930737018585, -10.52126693725586, -3.917661666870117, -0.002041400643065572, -0.7411776781082153, -0.033834442496299744, -0.6946713328361511, -6.266953468322754, -3.061985492706299, -4.029853820800781, -16.111797332763672, -1.1079931259155273, -2.1783206462860107, -5.718900680541992, -2.715874195098877, -0.0010149095905944705, -2.9111318588256836, -10.09206485748291, -0.017522411420941353, -10.329829216003418, -3.9166359901428223, -0.756324291229248, -8.43483829498291, -8.730974197387695, -10.90308952331543, -0.05334058403968811, -10.00070858001709, -3.4000484943389893, -8.013426780700684, -8.739701271057129, -8.15788745880127, -0.02088460698723793, -8.801007270812988, -10.6650390625, -1.397216558456421, -4.32200813293457, -0.7034009099006653, -3.3331096172332764, -5.649482727050781, -9.398382186889648, 0.0, -7.468444347381592, -4.377734661102295, -1.6271958351135254, -6.090081691741943, -0.6951749920845032, -0.006135556846857071, -10.883729934692383, -4.875371932983398, -0.7081857323646545, -0.009388227015733719, -0.0010109796421602368, -4.703928470611572, -1.8160579204559326, -12.360642433166504, -8.018697738647461, -5.571889877319336, -1.1093716621398926, -5.788342475891113, -0.00035720644518733025, -10.467999458312988, -1.1504583358764648, -0.002787159290164709, -1.429134488105774, -5.72518253326416, -2.1564247608184814, -2.21272873878479, -12.218490600585938, -0.6977067589759827, -2.095747709274292, -8.007747650146484, -0.001347825163975358, -2.7181479930877686, -9.700397491455078, -4.1412224769592285, -10.20287036895752, -11.663496017456055, -2.401200294494629, -8.709484100341797, -10.956560134887695, -1.1041253805160522, -13.213811874389648, -0.0010178867960348725, -0.008265691809356213, -3.3133487701416016, -0.0037306013982743025, -24.00571632385254, -4.745177268981934, -10.48982048034668, -0.7112552523612976, -8.834589958190918, -1.1024408340454102, -0.9045569896697998, -7.364089012145996, -3.384031057357788, -1.1107240915298462, -1.8235981464385986, -2.938690423965454, -0.0034774804953485727, -11.182201385498047, -6.054008483886719, -16.700767517089844, -1.6204522848129272, -13.615937232971191, -0.0003369478799868375, -1.300795078277588, -7.341548919677734, -8.016375541687012, -3.4547293186187744, -0.0013571109157055616, -8.742441177368164, -11.588600158691406, -18.40895652770996, -3.7298030853271484, -4.831286907196045, -3.2850406169891357, -1.6268643140792847, -1.8928457498550415, -9.162940979003906, -0.7065942287445068, -11.46817398071289, -18.419397354125977, -1.0996203422546387, -0.0033699646592140198, -3.267235279083252, -0.053579021245241165, -4.327696800231934, -10.913348197937012, -0.693824291229248, -0.00033563701435923576, -1.6132174730300903, -8.281266212463379, -18.960752487182617, -4.686413288116455, -3.9416913986206055, -1.9864534139633179, -1.3926030397415161, -0.17384585738182068, -6.44480562210083, -0.04189826548099518, -11.97166919708252, -5.979992866516113, -1.111626386642456, -8.244215965270996, -0.004706851206719875, -3.9088687896728516, -0.05065304413437843, -8.279818534851074, -8.112141609191895, -11.535234451293945, -2.2222917079925537, -9.810284614562988, -10.216800689697266, -8.008431434631348, -0.7246294021606445, -0.862849771976471, -0.023366684094071388, -9.11650276184082, -11.726309776306152, -11.531599998474121, -0.0006744970451109111, -3.698713779449463, -8.263446807861328, -17.796417236328125, -9.390668869018555, -2.2485969066619873, -5.67246150970459, -5.049367427825928, -8.707230567932129, -1.148514986038208, -5.796696186065674, -16.292877197265625, -11.542476654052734, -0.014376395381987095, -0.7069368362426758, -2.871249198913574, -4.7673869132995605, -2.7329471111297607, -9.452159881591797, -1.803934097290039, -8.834634780883789, -11.804990768432617, -0.6990705132484436, -8.945932388305664, -15.00509262084961, -8.003122329711914, -13.286333084106445, -16.71037483215332, -8.023887634277344, -3.113891124725342, -15.544501304626465, -3.618800640106201, -8.754347801208496, -1.1038761138916016, -8.740530014038086, -5.739803314208984, -22.815296173095703, -0.6973491907119751, -0.69618821144104, -9.816360473632812, -2.712723731994629, -9.397285461425781, -4.421113967895508, -5.908600330352783, -16.01334571838379, -0.7218717336654663, -13.26587200164795, -1.9866019487380981, -1.9985897541046143, -14.237273216247559, -12.548458099365234, -3.6171743869781494, -2.264974000354414e-06, -2.6676788330078125, -1.1018708944320679, -5.502110481262207, -10.568974494934082, -8.719953536987305, -2.0069143772125244, -9.735491752624512, -1.7993018627166748, -2.1077845096588135, -9.542166709899902, -0.7010480761528015, -5.086850643157959, -6.916895866394043, -8.700743675231934, -3.5773634910583496, -12.071573257446289, -2.288785457611084, -9.614956855773926, -1.415725588798523, -10.316157341003418, -2.0899269580841064, -11.081618309020996, -32.0, -9.127997398376465, -5.4239935874938965, -3.6081011295318604, -8.839445114135742, -1.821665644645691, -2.9565014839172363, -0.19101016223430634, -0.004147262312471867, -8.011070251464844, -1.4081401824951172, -16.156837463378906, -1.9689470529556274, -2.7796123027801514, -8.005057334899902, -0.015850383788347244, -16.726911544799805, -0.003724900772795081, -1.9864002466201782, -5.294469833374023, -5.311032772064209, -0.6960400342941284, -5.8713788986206055, -5.2249298095703125, -4.199473857879639, -0.0023964515421539545, -13.355326652526855, -1.102673888206482, -2.5722672939300537, -11.698326110839844, -5.656073093414307, -0.6958783268928528, -8.217914581298828, -11.232609748840332, -4.423774719238281, -1.1121327877044678, -11.518758773803711, -10.409439086914062, -3.4203646183013916, -8.268499374389648, -1.1026719808578491, -0.8042641282081604, -10.878218650817871, -3.863614082336426, -2.781712532043457, -8.010246276855469, -4.449171543121338, -1.9628366231918335, -9.626429557800293, -12.66876220703125, -11.471006393432617, -1.7947449684143066, -6.516546726226807, -9.389752388000488, -0.6992074847221375, -9.12796401977539, -11.340517044067383, -0.6934855580329895, -4.59588098526001, -8.732340812683105, -16.80575180053711, -11.342556953430176, -17.95185089111328, -9.80827808380127, -3.0485308170318604, -8.001684188842773, -13.901460647583008, -9.950815200805664, -6.485874176025391, -0.00040522945346310735, -4.375523090362549, -6.530217170715332, -3.897444725036621, -0.004128030501306057, -0.7356263399124146, -11.215719223022461, -3.811725616455078, -2.7385270595550537, -12.060728073120117, -8.11943531036377, -2.4879674911499023, -0.7116332054138184, -9.418292999267578, -8.710535049438477, -13.339736938476562, -1.3910964727401733, -0.004175397567451, -1.1765406131744385, -17.613819122314453, -10.502695083618164, -8.081915855407715, -8.013139724731445, -0.0016971721779555082, -10.307401657104492, -4.291710376739502, -8.066182136535645, -0.001348301419056952, -3.026494264602661, -1.423930048942566, -5.949957847595215, -11.855633735656738, -1.13751220703125, -11.056807518005371, -0.014453242532908916, -10.35338020324707, -2.9618518352508545, -0.012793371453881264, -4.076719284057617, -0.17526960372924805, -0.024790626019239426, -17.8015193939209, -2.9002468585968018, -11.584460258483887, -3.3526346683502197, -1.6287322044372559, -4.308639049530029, -2.082303762435913, -18.491039276123047, -0.6986079812049866, -0.034946613013744354, -5.739503383636475, -0.04517381638288498, -4.141730785369873, -1.3904234170913696, -2.6389312744140625, -1.388243317604065, -2.3173556327819824, -3.297210216522217, -1.1865172386169434, -0.005101401824504137, -4.859985828399658, -0.03992021456360817, -5.955930233001709, -0.0013911579735577106, -2.8363254070281982, -2.2606863975524902, -2.9041824340820312, -2.3281795978546143, -8.085891723632812, -8.132726669311523, -2.4205617904663086, -20.433393478393555, -0.0024018031544983387, -9.185540199279785, -8.108196258544922, -9.22666072845459, -8.216703414916992, -8.69923210144043, -8.093743324279785, -9.235309600830078, -2.711172580718994, -5.969653129577637, -0.6945204734802246, -0.04962378367781639, -5.762068748474121, -11.477303504943848, -1.390336513519287, -8.663147926330566, -1.1267269849777222, -0.00168658047914505, -5.12584924697876, -4.488276481628418, -12.118646621704102, -0.02332615666091442, -0.04068645089864731, -10.774641036987305, -13.450939178466797, -0.003762905253097415, -0.007422725670039654, -0.0037222879473119974, -0.7044327259063721, -8.69902229309082, -5.4718017578125, -1.8076391220092773, -0.08843479305505753, -7.411795616149902, -16.226909637451172, -4.775187015533447, -3.2994656562805176, -11.263772964477539, -3.5170419216156006, -0.007392314728349447, -2.604264259338379, -3.1756479740142822, -0.6977008581161499, -12.676966667175293, -9.949907302856445, -0.0037238318473100662, -0.005056924652308226, -0.008050493896007538, -1.1266192197799683, -14.726730346679688, -0.0003407612966839224, -0.012791135348379612, -16.00002098083496, -10.192821502685547, -3.1896073818206787, -8.036903381347656, -0.006746845785528421, -16.69870948791504, -12.312871932983398, -1.3997478485107422, -16.72281265258789, -0.004070450086146593, -2.5695645809173584, -4.9470720291137695, -3.1891283988952637, -10.118551254272461, -3.343505382537842, -10.32162857055664, -2.997288227081299, -8.017127990722656, -2.4864842891693115, -11.423315048217773, -8.02629280090332, -4.721963882446289, -3.147021770477295, -5.03139066696167, -9.61402702331543, -5.797331809997559, -0.7610028386116028, -9.56952953338623, -8.498757362365723, -8.079265594482422, -1.6128168106079102, -1.6760048866271973, -0.002019510604441166, -9.103889465332031, -9.104772567749023, -8.819210052490234, -0.0007141662063077092, -0.004783258773386478, -0.005787283182144165, -0.0044141001999378204, -1.7934446334838867, -7.483535289764404, -2.9205850296420977e-05, -3.4206175804138184, -0.6992435455322266, -12.268108367919922, -4.297347545623779, -1.8268327713012695, -8.892072677612305, -1.111598253250122, -16.018569946289062, -11.44636344909668, -0.02261561155319214, -2.0917749404907227, -4.250983238220215, -16.02178955078125, -3.3701343536376953, -8.007620811462402, -6.5179667472839355, -8.75024127960205, -9.390249252319336, -11.077478408813477, -0.00033539868309162557, -0.00575694115832448, -0.010139858350157738, -2.9802276912960224e-06, -8.694324493408203, -3.6269679069519043, -0.00102324562612921, -3.4155170917510986, -2.343430757522583, -1.979271650314331, -2.7215065956115723, -9.049442291259766, -4.066617488861084, -20.218523025512695, -3.532024383544922, -16.009057998657227, -12.798603057861328, -0.7782195806503296, -3.0748612880706787, -0.005119547713547945, -9.799226760864258, -26.534297943115234, -10.8018217086792, -3.75138258934021, -1.9524589776992798, -10.09679889678955, -9.108738899230957, -0.00505514582619071, -8.694498062133789, -9.649435043334961, -2.4102540016174316, -2.6493468284606934, -26.676481246948242, -0.00514362333342433, -13.312627792358398, -2.4944543838500977, -12.114951133728027, -0.029968978837132454, -3.099005937576294, -8.006231307983398, -17.001245498657227, -16.794832229614258, -1.115142822265625, -10.795013427734375, -8.696346282958984, -9.204950332641602, -8.696342468261719, -1.8140854835510254, -1.9472110271453857, -9.64247989654541]\n"
     ]
    }
   ],
   "source": [
    "logprob = sum(lps)\n",
    "print(logprob)\n",
    "print(lps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "del llm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 961])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = tokenizer.encode(qr, return_tensors='pt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model(input_ids, return_dict=True, output_hidden_states=True, output_attentions=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(936, 50288)\n",
      "(936,)\n"
     ]
    }
   ],
   "source": [
    "all_logits = output.logits[0, num_prompt_tokenids+1:, :].detach().cpu().numpy()\n",
    "print(all_logits.shape)\n",
    "logits = all_logits[range(num_response_tokenids), response_tokenids]\n",
    "print(logits.shape) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1082.3419  , 1499.169   , 1549.2953  , 1538.0974  , 1554.7227  ,\n",
       "       1554.4333  , 1543.1956  , 1557.1052  , 1560.0327  , 1546.7793  ,\n",
       "       1451.0186  , 1496.322   , 1557.047   , 1559.5569  , 1558.4426  ,\n",
       "       1466.6794  , 1555.6045  , 1547.294   , 1539.473   , 1549.6808  ,\n",
       "       1553.6776  , 1559.4962  , 1550.1074  , 1501.3313  ,  -31.570278,\n",
       "       1045.7194  , 1275.1956  , 1478.8694  , 1511.0095  , 1360.2748  ,\n",
       "       1555.1626  , 1544.0693  , 1494.6147  , 1249.0215  , 1424.4292  ,\n",
       "       1546.7277  , 1485.2737  , 1402.3914  , 1337.0397  , 1430.9227  ,\n",
       "       1435.4744  , 1483.5878  , 1242.6095  ,  617.2186  ,  532.8117  ,\n",
       "       1148.2794  , 1510.4982  , 1537.3384  , 1544.283   , 1471.8367  ,\n",
       "       1476.9362  , 1343.9044  , 1488.7943  , 1161.8253  , 1553.7632  ,\n",
       "       1391.1521  , 1436.6172  , 1390.1172  , 1412.673   , 1532.2385  ,\n",
       "       1356.3058  , 1369.4062  , 1562.9161  , 1386.7701  , 1101.1296  ,\n",
       "       1394.4761  , 1372.1285  , 1554.1849  , 1542.6694  , 1493.5044  ,\n",
       "       1522.1436  , 1524.513   ,  909.8735  , 1549.397   , 1556.5774  ,\n",
       "       1559.176   , 1554.061   , 1530.6416  , 1550.3171  , 1425.633   ,\n",
       "       1553.9921  , 1371.222   ,  206.11737 , 1535.9392  , 1539.9006  ,\n",
       "       1544.6418  , 1531.0189  , 1556.0592  , 1542.3892  , 1419.4006  ,\n",
       "       1548.7922  , 1509.8671  , 1556.5107  , 1542.4452  , 1545.4904  ,\n",
       "       1529.6681  , 1548.6968  , 1293.8871  , 1559.1501  , 1535.2416  ,\n",
       "       1538.2584  , 1498.4111  , 1547.2631  ,  252.92632 ,  872.11285 ,\n",
       "       1549.2655  , 1526.1635  , 1553.2911  , 1445.8862  , 1551.3098  ,\n",
       "       1547.3414  , 1466.5487  , 1328.5026  , 1540.9785  , 1522.5825  ,\n",
       "       1555.1252  , 1494.6693  , 1417.1229  , 1498.9553  , 1566.2026  ,\n",
       "       1449.22    , 1465.1392  , 1542.3668  , 1420.5784  , 1465.997   ,\n",
       "       1562.7474  , 1527.6123  , 1538.3727  , 1460.6035  , 1554.0941  ,\n",
       "       1551.8171  , 1520.3148  , 1550.5302  , 1545.7925  , 1549.3839  ,\n",
       "         85.97901 , 1543.367   , 1026.8987  , 1558.9761  , 1562.4723  ,\n",
       "       1459.9489  , 1451.0736  , 1515.8359  , 1519.273   , 1309.8818  ,\n",
       "       1466.8057  , 1550.8523  , 1551.666   , 1557.1152  , 1494.2402  ,\n",
       "       1558.8643  , 1165.1127  , 1198.1599  , 1292.3748  , 1461.4336  ,\n",
       "        456.753   , 1535.3572  , 1558.1927  , 1569.8108  , 1539.5594  ,\n",
       "       1505.1965  , 1519.6036  , 1555.9846  , 1485.5221  , 1462.5967  ,\n",
       "       1553.1859  , 1563.6282  , 1551.2345  , 1527.3215  , 1559.6057  ,\n",
       "       1547.5001  , 1435.6779  , 1553.0946  , 1564.0194  , 1564.8018  ,\n",
       "       1446.145   , 1486.0883  , 1537.9958  , 1563.9746  , 1551.7251  ,\n",
       "       1421.3777  , 1425.5659  , 1555.1829  , 1545.2644  , 1499.5881  ,\n",
       "       1553.4402  , 1540.753   , 1550.4625  , 1527.1685  , 1559.6091  ,\n",
       "       1545.1951  , 1530.9089  ,  439.27573 , 1554.5841  , 1531.3723  ,\n",
       "       1512.5634  , 1548.4497  , 1556.0768  , 1508.0067  , 1437.8004  ,\n",
       "       1548.1611  , 1542.434   , 1515.1027  , 1548.1873  , 1551.3036  ,\n",
       "       1535.7483  , 1543.4626  , 1523.6735  , 1498.8247  ,  112.07327 ,\n",
       "       1229.4622  , 1277.781   , 1484.0203  , 1281.4573  , 1468.9612  ,\n",
       "       1305.2937  , 1029.9678  , 1457.7621  , 1327.1787  ,  887.2966  ,\n",
       "        987.08276 ,  813.00165 , 1546.9243  , 1492.3186  , 1543.5632  ,\n",
       "       1381.1226  , 1527.0488  , 1546.8497  , 1441.644   , 1374.19    ,\n",
       "       1542.8872  , 1564.4387  , 1552.0701  , 1548.958   , 1556.0442  ,\n",
       "       1553.3348  , 1498.731   , 1560.9626  , 1541.0072  , 1532.3461  ,\n",
       "       1510.8417  , 1491.4015  , 1556.5828  , 1563.7588  , 1538.9066  ,\n",
       "       1558.5465  , 1541.3915  , 1489.073   , 1438.4083  ,  278.60162 ,\n",
       "        977.1325  , 1547.0667  , 1549.2151  , 1552.661   , 1555.4418  ,\n",
       "       1409.2666  , 1483.7863  , 1550.1288  , 1550.481   , 1501.8816  ,\n",
       "       1492.3668  , 1557.7362  , 1564.761   , 1508.9058  , 1504.9318  ,\n",
       "       1559.4521  , 1552.2281  , 1461.0963  , 1564.9233  , 1564.1101  ,\n",
       "       1551.7306  , 1557.7671  , 1552.4797  , 1533.6671  , 1558.1495  ,\n",
       "       1542.4169  , 1557.302   , 1554.4667  , 1518.4642  , 1549.138   ,\n",
       "       1547.8654  , 1520.4829  , 1422.2493  , 1518.249   , 1502.4148  ,\n",
       "       1541.012   , 1441.7405  , 1297.4462  , 1550.6078  , 1561.27    ,\n",
       "       1542.8801  , 1565.6211  , 1441.1099  , 1522.492   , 1452.3611  ,\n",
       "       1520.6719  , 1555.6436  , 1462.2524  , 1518.3237  , 1550.0454  ,\n",
       "       1447.1157  , 1380.3337  , 1306.5662  , 1185.8645  ,   67.27044 ,\n",
       "       1267.5637  , 1272.6805  , 1508.0387  , 1528.9893  , 1226.102   ,\n",
       "       1246.5448  , 1549.0842  , 1558.0918  , 1536.6006  , 1562.1808  ,\n",
       "       1476.2223  , 1551.1726  , 1548.7622  , 1543.5715  , 1553.7825  ,\n",
       "       1559.7959  , 1474.4015  , 1557.0631  , 1449.5482  , 1553.7626  ,\n",
       "       1553.8987  , 1515.584   , 1552.5454  , 1555.8357  , 1541.7458  ,\n",
       "         38.558083, 1512.2776  , 1549.8315  , 1456.4729  , 1559.5515  ,\n",
       "       1531.8868  , 1552.8469  , 1522.2981  , 1446.327   , 1563.9735  ,\n",
       "       1560.4675  , 1507.5234  , 1548.7043  , 1538.7593  , 1539.1348  ,\n",
       "       1549.7389  , 1439.0593  , 1477.5106  , 1492.5176  , 1551.0797  ,\n",
       "       1563.8376  , 1540.4861  , 1553.801   , 1489.8319  , 1555.8035  ,\n",
       "       1552.2103  , 1541.1681  , 1385.9642  , 1554.6709  , 1554.7018  ,\n",
       "       1551.1847  , 1487.6748  ,   63.86769 , 1486.2455  ,  804.845   ,\n",
       "        854.7524  , 1211.1692  , 1262.8005  , 1376.5956  , 1275.7098  ,\n",
       "        978.22    , 1532.0392  , 1219.241   , 1241.424   , 1473.762   ,\n",
       "        411.315   ,  487.77917 , 1403.4441  ,  870.2328  , 1553.571   ,\n",
       "       1329.6361  , 1537.7522  ,  652.1565  , 1532.7653  , 1495.698   ,\n",
       "       1542.6005  , 1556.4769  , 1449.1022  , 1551.7317  , 1494.8109  ,\n",
       "       1462.3519  , 1309.6304  , 1368.1298  , 1537.8666  , 1554.6538  ,\n",
       "       1547.9891  , 1532.2266  , 1455.144   , 1538.5083  , 1543.6205  ,\n",
       "       1517.8403  ,  208.7374  , 1528.977   , 1559.8411  , 1556.3037  ,\n",
       "       1508.0442  , 1554.5459  , 1542.972   , 1387.115   , 1314.6514  ,\n",
       "       1409.1357  , 1555.8561  , 1551.114   , 1558.5634  , 1407.0889  ,\n",
       "       1467.7705  , 1557.6158  , 1396.2872  , 1418.989   , 1558.2833  ,\n",
       "       1547.7528  , 1544.283   , 1398.7703  , 1532.7251  , 1542.2166  ,\n",
       "       1538.5907  , 1554.6171  , 1552.5192  , 1540.4275  , 1545.2963  ,\n",
       "       1485.9186  ,  105.41495 , 1554.4017  , 1430.1836  ,  669.2811  ,\n",
       "       1554.0634  , 1540.4678  , 1449.2914  , 1551.3359  , 1446.8425  ,\n",
       "       1504.7471  , 1558.2136  , 1544.055   , 1513.5531  , 1359.4121  ,\n",
       "       1371.5111  , 1533.0518  , 1526.1437  , 1559.6875  , 1444.4164  ,\n",
       "       1525.1378  , 1449.7207  ,  192.17255 , 1527.8599  ,  949.5521  ,\n",
       "       1547.6411  , 1541.8162  , 1541.0087  , 1552.0548  , 1524.6428  ,\n",
       "       1474.1774  , 1549.79    , 1452.0317  , 1413.0677  , 1234.0154  ,\n",
       "       1357.4276  , 1380.8792  , 1546.1796  , 1542.906   , 1524.5759  ,\n",
       "       1542.0259  , 1427.3457  , 1516.4169  , 1435.8827  , 1473.0403  ,\n",
       "        283.1002  , 1434.4844  , 1518.3542  , 1432.2545  , 1492.5442  ,\n",
       "       1307.6527  , 1432.0654  , 1557.1257  , 1480.9144  , 1546.0305  ,\n",
       "       1478.6165  , 1557.072   , 1558.5876  , 1479.949   , 1474.9366  ,\n",
       "       1545.3608  , 1547.179   , 1544.9801  , 1559.3425  , 1550.0244  ,\n",
       "       1500.6558  , 1490.5934  , 1548.0142  , 1437.7144  , 1555.608   ,\n",
       "       1556.2065  , 1469.9563  , 1515.6814  , 1388.1595  ,  217.50897 ,\n",
       "       1220.3491  , 1337.9353  , 1550.0891  , 1550.9475  , 1551.3279  ,\n",
       "       1443.1212  , 1546.23    , 1542.224   , 1443.8982  , 1537.3324  ,\n",
       "       1526.5183  , 1231.7886  , 1468.3015  , 1516.5208  , 1547.2114  ,\n",
       "       1551.4688  , 1543.0946  , 1400.3479  , 1265.5845  , 1409.69    ,\n",
       "       1554.5449  , 1510.2698  , 1527.2766  , 1561.4294  , 1518.8801  ,\n",
       "       1429.336   , 1555.5872  , 1563.6243  , 1551.5168  , 1559.7823  ,\n",
       "       1551.9036  , 1484.1033  , 1553.9546  , 1398.985   ,  240.25632 ,\n",
       "       1504.2377  , 1082.2054  , 1460.1353  , 1484.2588  , 1485.2563  ,\n",
       "       1432.4551  , 1129.692   , 1406.453   , 1328.2329  ,  624.86926 ,\n",
       "       1450.6465  , 1312.7251  , 1542.6248  , 1540.783   , 1446.5496  ,\n",
       "       1548.5269  , 1554.5651  , 1408.5731  , 1458.3322  , 1550.3953  ,\n",
       "       1554.1172  , 1536.178   , 1508.5427  , 1557.419   , 1551.5359  ,\n",
       "       1540.0726  , 1487.338   , 1244.9387  , 1424.5815  ,  309.2701  ,\n",
       "       1552.3047  , 1490.9463  , 1542.6633  , 1539.0841  , 1565.1362  ,\n",
       "       1552.6418  , 1484.9584  , 1564.5675  , 1565.8217  , 1540.2094  ,\n",
       "       1555.334   , 1507.8271  , 1519.9539  , 1555.6545  , 1556.4259  ,\n",
       "       1477.3682  , 1491.273   , 1451.8765  , 1507.2969  , 1540.939   ,\n",
       "       1564.3301  , 1142.032   , 1425.3682  , 1505.7546  , 1399.3523  ,\n",
       "       1506.0432  , 1523.729   , 1563.1428  , 1534.0901  , 1501.0182  ,\n",
       "       1561.6665  , 1541.6271  , 1556.2129  , 1550.4469  , 1544.2556  ,\n",
       "       1479.2417  , 1232.289   , 1478.1807  , 1478.9224  ,   38.800724,\n",
       "        746.72876 , 1539.5314  , 1551.059   , 1559.0525  , 1532.7507  ,\n",
       "       1558.3777  , 1559.5088  , 1514.9382  , 1564.4921  , 1556.5668  ,\n",
       "       1486.4849  , 1542.2234  , 1549.4083  , 1548.8575  , 1363.0055  ,\n",
       "       1513.6187  , 1525.0623  , 1543.9978  ,   35.2114  ,  930.33093 ,\n",
       "       1157.849   , 1534.3342  , 1535.2117  , 1558.6044  , 1495.3966  ,\n",
       "       1265.8196  ,  325.04477 , 1451.7997  , 1140.9264  , 1554.2031  ,\n",
       "       1526.3221  , 1544.6964  , 1534.0184  , 1470.5525  , 1547.1682  ,\n",
       "       1549.5425  , 1481.8452  , 1537.8427  , 1546.2234  , 1553.707   ,\n",
       "       1543.6283  , 1564.0034  , 1554.2662  , 1553.2277  , 1550.8994  ,\n",
       "       1553.7136  , 1498.4153  , 1511.8643  ,  479.22028 , 1550.9856  ,\n",
       "       1503.0376  , 1563.5513  , 1554.9358  , 1544.6293  , 1387.1665  ,\n",
       "       1559.8772  , 1472.2692  , 1557.146   , 1554.4822  , 1478.5454  ,\n",
       "       1556.0651  , 1553.425   , 1491.7859  , 1497.5082  , 1547.008   ,\n",
       "       1552.7505  , 1510.9175  , 1522.4288  , 1532.9103  , 1542.4209  ,\n",
       "       1386.198   , 1536.9545  , 1315.2224  , 1350.8956  , 1425.6202  ,\n",
       "       1269.4717  , 1312.5225  , 1537.0586  , 1558.9507  , 1543.758   ,\n",
       "       1490.018   , 1324.2146  , 1483.6328  , 1385.2373  , 1003.209   ,\n",
       "       1543.2908  , 1339.3848  ,   35.277077,  560.2208  ,  977.7239  ,\n",
       "       1546.3623  , 1478.3844  , 1538.5862  , 1540.8041  , 1560.3911  ,\n",
       "       1552.747   , 1541.1967  , 1555.4011  , 1517.2563  , 1451.2124  ,\n",
       "       1527.8153  , 1549.9419  , 1557.442   , 1543.9905  , 1433.8259  ,\n",
       "       1556.1124  , 1549.4628  , 1430.8875  , 1527.8955  , 1547.9155  ,\n",
       "       1528.2303  , 1552.9971  , 1478.9985  , 1555.8258  , 1551.959   ,\n",
       "       1476.5021  ,  -11.886919, 1213.8448  ,  728.1614  , 1389.1545  ,\n",
       "       1529.9128  , 1463.3362  , 1387.9016  ,  782.8116  , 1508.8164  ,\n",
       "       1410.4653  ,  920.59863 ,  377.71185 , 1449.1244  , 1253.8907  ,\n",
       "       1539.4396  , 1542.0026  , 1534.3387  , 1552.186   , 1535.9343  ,\n",
       "       1364.1215  , 1560.1616  , 1546.1129  , 1545.2549  , 1552.2859  ,\n",
       "       1521.4513  , 1551.7891  , 1266.1735  ,  540.0995  , 1520.2727  ,\n",
       "       1537.9878  , 1551.0613  , 1529.906   , 1542.2609  , 1544.9702  ,\n",
       "       1532.3038  , 1549.3555  , 1476.5176  , 1337.5128  , 1511.6552  ,\n",
       "       1548.0309  , 1453.7802  , 1413.7471  , 1474.9216  , 1503.1127  ,\n",
       "       1554.2202  , 1483.7152  , 1311.2335  ,  733.13837 , 1544.1278  ,\n",
       "       1530.4731  , 1410.0574  , 1554.7234  , 1540.6301  , 1487.6853  ,\n",
       "       1554.5409  , 1497.3302  , 1556.3651  , 1554.3352  , 1518.1488  ,\n",
       "       1421.922   ,  255.71024 , 1522.1195  , 1535.1346  , 1552.0881  ,\n",
       "       1552.5356  , 1552.4863  , 1442.7225  , 1554.8436  , 1551.1346  ,\n",
       "       1549.2169  , 1489.5286  , 1523.2765  , 1390.267   , 1490.913   ,\n",
       "       1396.8849  , 1549.0718  , 1547.8458  , 1558.5356  , 1555.8766  ,\n",
       "       1548.25    , 1516.4841  , 1470.4088  , 1561.4072  , 1525.4696  ,\n",
       "       1496.1311  ,  118.66473 , 1218.8501  ,  646.32477 , 1449.4624  ,\n",
       "       1466.4525  , 1410.0504  , 1479.7551  , 1253.1165  , 1540.8458  ,\n",
       "       1543.666   , 1419.751   , 1330.4681  , 1182.8889  ,  474.50723 ,\n",
       "       1496.9968  , 1447.9086  , 1477.9567  , 1071.532   , 1541.6322  ,\n",
       "       1544.241   , 1525.0883  , 1271.3955  , 1555.7739  ,   28.267048,\n",
       "       1527.5562  , 1539.567   , 1559.2593  , 1556.868   , 1544.6571  ,\n",
       "       1429.3839  , 1560.248   , 1558.1034  , 1450.6895  , 1505.4984  ,\n",
       "       1471.7516  , 1533.5469  , 1330.6733  , 1554.902   , 1547.8026  ,\n",
       "       1556.6027  , 1556.7666  , 1547.2728  , 1450.6704  , 1509.2922  ,\n",
       "       1377.0323  , 1463.3326  , 1480.7332  , 1557.115   , 1464.3441  ,\n",
       "       1531.2205  , 1476.7329  , 1480.9794  , 1401.4902  ,   94.72985 ,\n",
       "        440.7253  , 1088.4269  ,  760.64905 , 1456.5923  , 1504.7362  ,\n",
       "       1485.3579  , 1544.155   , 1488.9585  ,  748.5243  ,  578.311   ,\n",
       "       1361.7445  , 1511.6938  , 1535.9941  , 1483.5378  , 1211.0884  ,\n",
       "       1543.8789  , 1546.825   , 1525.9382  , 1366.1826  , 1540.9338  ,\n",
       "       1400.063   , 1477.1792  , 1417.4144  , 1492.3993  , 1220.3779  ,\n",
       "       1540.1482  , 1539.3147  , 1559.286   , 1547.6584  , 1530.2502  ,\n",
       "       1379.9553  ,   28.174475, 1314.081   , 1033.6716  , 1314.8596  ,\n",
       "       1272.8381  , 1076.0609  , 1512.3579  , 1521.2812  , 1462.3324  ,\n",
       "        780.6185  , 1505.237   , 1532.0234  , 1546.7412  , 1311.7175  ,\n",
       "         50.68881 , 1537.7595  , 1492.3552  , 1448.7424  , 1536.528   ,\n",
       "       1498.9165  , 1392.2262  ,   66.590836,  471.16162 ,  610.74713 ,\n",
       "       1500.1069  , 1526.0256  , 1491.4407  , 1528.3088  , 1554.6915  ,\n",
       "       1543.7836  , 1520.188   , 1442.4781  , 1532.8606  , 1355.6578  ,\n",
       "       1545.2013  , 1480.9641  , 1548.5927  , 1455.062   ,   21.479307,\n",
       "       1112.7279  ], dtype=float32)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
