{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9aa1164c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad09c5bd7c9d4de69d54c74faa7c9864",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os, math, json, random, gc, pathlib\n",
    "from dataclasses import dataclass\n",
    "from typing import Dict, List, Tuple, Optional, Literal\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "SEED = 123\n",
    "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n",
    "# ---------------- Config ----------------\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "DTYPE  = torch.float32\n",
    "MODEL_NAME = \"google/gemma-3-4b-it\"\n",
    "HF_TOKEN   = os.environ.get(\"HF_TOKEN\")\n",
    "# ---------------- Load model ----------------\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, use_fast=True)\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_NAME, token=HF_TOKEN, torch_dtype=DTYPE, low_cpu_mem_usage=True\n",
    ").to(DEVICE).eval()\n",
    "def generate_from_model(prompt: str, max_new_tokens: int = 100):\n",
    "    # Encode the prompt\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(DEVICE)\n",
    "\n",
    "    # Run model.generate() for autoregressive decoding\n",
    "    with torch.no_grad():\n",
    "        output_ids = model.generate(\n",
    "            **inputs,\n",
    "            max_new_tokens=max_new_tokens,\n",
    "            do_sample=True,        # sampling adds randomness\n",
    "            top_p=0.95,            # nucleus sampling\n",
    "            temperature=0.7,       # lower = more deterministic\n",
    "            pad_token_id=tokenizer.eos_token_id\n",
    "        )\n",
    "\n",
    "    # Decode tokens back to string\n",
    "    return tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
    "\n",
    "# Example usage\n",
    "\n",
    "# torch.set_float32_matmul_precision('high')``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5d47442f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gemma3ForConditionalGeneration(\n",
      "  (model): Gemma3Model(\n",
      "    (vision_tower): SiglipVisionModel(\n",
      "      (vision_model): SiglipVisionTransformer(\n",
      "        (embeddings): SiglipVisionEmbeddings(\n",
      "          (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)\n",
      "          (position_embedding): Embedding(4096, 1152)\n",
      "        )\n",
      "        (encoder): SiglipEncoder(\n",
      "          (layers): ModuleList(\n",
      "            (0-26): 27 x SiglipEncoderLayer(\n",
      "              (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n",
      "              (self_attn): SiglipAttention(\n",
      "                (k_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
      "                (v_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
      "                (q_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
      "                (out_proj): Linear(in_features=1152, out_features=1152, bias=True)\n",
      "              )\n",
      "              (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n",
      "              (mlp): SiglipMLP(\n",
      "                (activation_fn): PytorchGELUTanh()\n",
      "                (fc1): Linear(in_features=1152, out_features=4304, bias=True)\n",
      "                (fc2): Linear(in_features=4304, out_features=1152, bias=True)\n",
      "              )\n",
      "            )\n",
      "          )\n",
      "        )\n",
      "        (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)\n",
      "      )\n",
      "    )\n",
      "    (multi_modal_projector): Gemma3MultiModalProjector(\n",
      "      (mm_soft_emb_norm): Gemma3RMSNorm((1152,), eps=1e-06)\n",
      "      (avg_pool): AvgPool2d(kernel_size=4, stride=4, padding=0)\n",
      "    )\n",
      "    (language_model): Gemma3TextModel(\n",
      "      (embed_tokens): Gemma3TextScaledWordEmbedding(262208, 2560, padding_idx=0)\n",
      "      (layers): ModuleList(\n",
      "        (0-33): 34 x Gemma3DecoderLayer(\n",
      "          (self_attn): Gemma3Attention(\n",
      "            (q_proj): Linear(in_features=2560, out_features=2048, bias=False)\n",
      "            (k_proj): Linear(in_features=2560, out_features=1024, bias=False)\n",
      "            (v_proj): Linear(in_features=2560, out_features=1024, bias=False)\n",
      "            (o_proj): Linear(in_features=2048, out_features=2560, bias=False)\n",
      "            (q_norm): Gemma3RMSNorm((256,), eps=1e-06)\n",
      "            (k_norm): Gemma3RMSNorm((256,), eps=1e-06)\n",
      "          )\n",
      "          (mlp): Gemma3MLP(\n",
      "            (gate_proj): Linear(in_features=2560, out_features=10240, bias=False)\n",
      "            (up_proj): Linear(in_features=2560, out_features=10240, bias=False)\n",
      "            (down_proj): Linear(in_features=10240, out_features=2560, bias=False)\n",
      "            (act_fn): PytorchGELUTanh()\n",
      "          )\n",
      "          (input_layernorm): Gemma3RMSNorm((2560,), eps=1e-06)\n",
      "          (post_attention_layernorm): Gemma3RMSNorm((2560,), eps=1e-06)\n",
      "          (pre_feedforward_layernorm): Gemma3RMSNorm((2560,), eps=1e-06)\n",
      "          (post_feedforward_layernorm): Gemma3RMSNorm((2560,), eps=1e-06)\n",
      "        )\n",
      "      )\n",
      "      (norm): Gemma3RMSNorm((2560,), eps=1e-06)\n",
      "      (rotary_emb): Gemma3RotaryEmbedding()\n",
      "      (rotary_emb_local): Gemma3RotaryEmbedding()\n",
      "    )\n",
      "  )\n",
      "  (lm_head): Linear(in_features=2560, out_features=262208, bias=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "151dc306",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prompt = \"Write a Python function to compute n choose k (binomial coefficient).\\nRespond with code only.\\n```python\\n\"\n",
    "# print(generate_from_model(prompt, max_new_tokens=50))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2300d470",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/psquare_quantum/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:282: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
      "  warnings.warn(\n",
      "W0924 20:13:57.744000 605468 torch/_inductor/utils.py:1436] [0/0] Not enough SMs to use max_autotune_gemm mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Write a short function in C++ that returns the n-th Fibonacci number.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "\n",
      "int fibonacci(int n) {\n",
      "  if (n <= 1) {\n",
      "    return n;\n",
      "  }\n",
      "  return fibonacci(n - 1) + fibonacci(n - \n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function that checks if a number is prime.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "#include <cmath>\n",
      "\n",
      "bool isPrime(int n) {\n",
      "  if (n <= 1) {\n",
      "    return false;\n",
      "  }\n",
      "  for (int i = 2; i <=\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function to compute the factorial of a number using recursion.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "\n",
      "long long factorial(int n) {\n",
      "  if (n == 0) {\n",
      "    return 1;\n",
      "  } else if (n < 0) {\n",
      "    return -1; // Or\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function that reverses a string.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "#include <string>\n",
      "#include <algorithm>\n",
      "\n",
      "std::string reverseString(const std::string& str) {\n",
      "  std::string reversedStr = str;\n",
      "  std::reverse(reversedStr\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function that finds the maximum element in an array.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "#include <algorithm>\n",
      "#include <vector>\n",
      "\n",
      "int findMax(const std::vector<int>& arr) {\n",
      "    if (arr.empty()) {\n",
      "        return -1; // Or throw\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function that sorts an array using bubble sort.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "#include <vector>\n",
      "\n",
      "void bubbleSort(std::vector<int>& arr) {\n",
      "    int n = arr.size();\n",
      "    for (int i = 0; i < n - 1\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function that counts vowels in a string.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "#include <string>\n",
      "#include <algorithm>\n",
      "\n",
      "int countVowels(const std::string& str) {\n",
      "  int count = 0;\n",
      "  for (char c : str) {\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function to compute the greatest common divisor (GCD) of two numbers.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "#include <algorithm>\n",
      "\n",
      "int gcd(int a, int b) {\n",
      "  // Ensure non-negative inputs\n",
      "  a = std::abs(a);\n",
      "  b = std::abs(b);\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function to check if a string is a palindrome.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "#include <string>\n",
      "#include <algorithm>\n",
      "\n",
      "bool isPalindrome(const std::string& str) {\n",
      "  std::string clean_str = str;\n",
      "  // Remove non-alphanumeric\n",
      "\n",
      "\n",
      "\n",
      "Write a C++ function that computes the sum of digits of an integer.\n",
      "Respond with code only.\n",
      "```cpp\n",
      "#include <iostream>\n",
      "\n",
      "int sumOfDigits(int n) {\n",
      "  int sum = 0;\n",
      "  if (n < 0) {\n",
      "    n = -n;\n",
      "  }\n",
      "  while (n > \n",
      "\n",
      "\n",
      "\n",
      "Write a short function in Python that returns the n-th Fibonacci number.\n",
      "Respond with code only.\n",
      "```python\n",
      "def fibonacci(n):\n",
      "    \"\"\"\n",
      "    This function returns the n-th Fibonacci number.\n",
      "\n",
      "    Args:\n",
      "        n: The index of the Fibonacci number to return.\n",
      "\n",
      "    Returns:\n",
      "        The n-th Fibonacci number\n",
      "\n",
      "\n",
      "\n",
      "Write a Python function that checks if a number is prime.\n",
      "Respond with code only.\n",
      "```python\n",
      "def is_prime(n):\n",
      "  \"\"\"\n",
      "  Checks if a number is prime.\n",
      "\n",
      "  Args:\n",
      "    n: The number to check.\n",
      "\n",
      "  Returns:\n",
      "    True if the number is prime, False otherwise.\n",
      "  \n",
      "\n",
      "\n",
      "\n",
      "Write a Python function to compute the factorial of a number using recursion.\n",
      "Respond with code only.\n",
      "```python\n",
      "def factorial_recursive(n):\n",
      "  \"\"\"\n",
      "  Compute the factorial of a number using recursion.\n",
      "\n",
      "  Args:\n",
      "    n: The number to compute the factorial of.\n",
      "\n",
      "  Returns:\n",
      "    The factorial of n.\n",
      "  \n",
      "\n",
      "\n",
      "\n",
      "Write a Python function that reverses a string.\n",
      "Respond with code only.\n",
      "```python\n",
      "def reverse_string(s):\n",
      "    \"\"\"\n",
      "    Reverses a string.\n",
      "\n",
      "    Args:\n",
      "        s: The string to reverse.\n",
      "\n",
      "    Returns:\n",
      "        The reversed string.\n",
      "    \"\"\"\n",
      "    return s[::-1]\n",
      "\n",
      "\n",
      "\n",
      "Write a Python function that finds the maximum element in a list.\n",
      "Respond with code only.\n",
      "```python\n",
      "def find_max(data):\n",
      "    \"\"\"\n",
      "    Finds the maximum element in a list.\n",
      "\n",
      "    Args:\n",
      "      data: A list of numbers.\n",
      "\n",
      "    Returns:\n",
      "      The maximum element in the list, or None if\n",
      "\n",
      "\n",
      "\n",
      "Write a Python function that sorts a list using bubble sort.\n",
      "Respond with code only.\n",
      "```python\n",
      "def bubble_sort(data):\n",
      "    \"\"\"Sorts a list using the bubble sort algorithm.\n",
      "\n",
      "    Args:\n",
      "        data: The list to be sorted.\n",
      "\n",
      "    Returns:\n",
      "        The sorted list.\n",
      "    \"\"\"\n",
      "    n\n",
      "\n",
      "\n",
      "\n",
      "Write a Python function that counts vowels in a string.\n",
      "Respond with code only.\n",
      "```python\n",
      "def count_vowels(string):\n",
      "    \"\"\"\n",
      "    Counts the number of vowels (a, e, i, o, u) in a string.\n",
      "    Case-insensitive.\n",
      "\n",
      "    Args:\n",
      "        string: The string to\n",
      "\n",
      "\n",
      "\n",
      "Write a Python function to compute the greatest common divisor (GCD) of two numbers.\n",
      "Respond with code only.\n",
      "```python\n",
      "def gcd(a, b):\n",
      "    \"\"\"\n",
      "    Compute the greatest common divisor (GCD) of two numbers.\n",
      "\n",
      "    Args:\n",
      "      a: The first number.\n",
      "      b: The second number.\n",
      "\n",
      "    Returns:\n",
      "      \n",
      "\n",
      "\n",
      "\n",
      "Write a Python function to check if a string is a palindrome.\n",
      "Respond with code only.\n",
      "```python\n",
      "def is_palindrome(text):\n",
      "    \"\"\"\n",
      "    Checks if a string is a palindrome (reads the same backward as forward).\n",
      "\n",
      "    Args:\n",
      "        text: The string to check.\n",
      "\n",
      "    Returns:\n",
      "        True if the string\n",
      "\n",
      "\n",
      "\n",
      "Write a Python function that computes the sum of digits of an integer.\n",
      "Respond with code only.\n",
      "```python\n",
      "def sum_digits(n):\n",
      "    \"\"\"\n",
      "    Compute the sum of digits of an integer.\n",
      "\n",
      "    Args:\n",
      "        n (int): The integer to compute the sum of digits.\n",
      "\n",
      "    Returns:\n",
      "        int: The sum\n"
     ]
    }
   ],
   "source": [
    "with open(\"prompt_set/cpp_python_100.json\", 'r') as f:\n",
    "    code = json.load(f)\n",
    "    \n",
    "for prompt_pair in code[:10]:\n",
    "    print(\"\\n\\n\")\n",
    "    print(generate_from_model(prompt_pair[\"cpp_top\"], max_new_tokens=50))\n",
    "for prompt_pair in code[:10]:\n",
    "    print(\"\\n\\n\")\n",
    "    print(generate_from_model(prompt_pair[\"python_top\"], max_new_tokens=50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3772f212",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "Explain the fundamental principles of the entire universe, and describe the significance of this knowledge.\n",
      "\n",
      "Okay, this is a *massive* undertaking, and frankly, a question that philosophers and scientists have grappled with for millennia.  There's no single, definitive answer, but I can lay\n",
      "\n",
      "\n",
      "Explain the fundamental principles of the entire solution.\n",
      "\n",
      "The solution is designed around the following fundamental principles:\n",
      "\n",
      "1.  **Modular Design:** The system is broken down into independent, reusable modules. Each module has a specific, well-defined purpose and interacts with other modules through clearly defined interfaces\n",
      "\n",
      "\n",
      "\n",
      "Describe the core concepts of a magnetic field.\n",
      "\n",
      "A magnetic field is a region of space around a magnet or a moving electric charge where a magnetic force can be detected. Here's a breakdown of the core concepts:\n",
      "\n",
      "**1. Sources of Magnetic Fields:**\n",
      "\n",
      "* **Moving Electric\n",
      "\n",
      "\n",
      "Describe the core concepts of an organic field-effect transistor (OFET).\n",
      "\n",
      "Here's a breakdown of the core concepts of an Organic Field-Effect Transistor (OFET):\n",
      "\n",
      "**1. What is an OFET?**\n",
      "\n",
      "*   **Analogous to a Silicon FET\n",
      "\n",
      "\n",
      "\n",
      "Detail the measurable effects of a gravitational force on a body.\n",
      "\n",
      "The gravitational force, as described by Newton's Law of Universal Gravitation, has several measurable effects on a body. Here’s a breakdown of those effects, categorized for clarity:\n",
      "\n",
      "**1. Change in Velocity (\n",
      "\n",
      "\n",
      "Detail the measurable effects of a concentrated force acting on a rigid body.\n",
      "\n",
      "When a concentrated force is applied to a rigid body, several measurable effects can be observed. These effects depend on the magnitude, direction, and point of application of the force, as well as the properties of the rigid\n",
      "\n",
      "\n",
      "\n",
      "What is the role of energy in this system?\n",
      "\n",
      "The system described is a plant using photosynthesis. Photosynthesis is the process by which plants convert light energy into chemical energy in the form of glucose.\n",
      "\n",
      "Here's a breakdown of the role of energy in the system:\n",
      "\n",
      "*   **Input\n",
      "\n",
      "\n",
      "What is the role of temperature in this system?\n",
      "\n",
      "The role of temperature in this system is to influence the rate of chemical reactions.\n",
      "\n",
      "**Explanation:**\n",
      "\n",
      "Temperature is a fundamental factor in chemical kinetics, which studies the rates of chemical reactions. Here's how it works:\n",
      "\n",
      "* **Increased\n",
      "\n",
      "\n",
      "\n",
      "Explain the behavior of subatomic particles.\n",
      "\n",
      "Subatomic particles are the fundamental building blocks of matter, smaller than atoms. They come in various types, each with unique properties and behaviors. Here’s a breakdown of their behavior:\n",
      "\n",
      "**1. Types of Subatomic Particles:**\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "Explain the behavior of an immiscible mixture of water and oil.\n",
      "\n",
      "An immiscible mixture of water and oil is a mixture where the two liquids do not dissolve in each other. Instead, they form separate layers due to differences in their physical properties, primarily density. Here's a breakdown\n",
      "\n",
      "\n",
      "\n",
      "Describe the known properties of dark matter and dark energy.\n",
      "\n",
      "Dark matter and dark energy are two of the most mysterious components of the universe. While we can't directly observe them with telescopes, their existence is inferred from their gravitational effects on visible matter and the expansion of the universe.\n",
      "\n",
      "\n",
      "Describe the known properties of a water reaction with a strong acid.\n",
      "\n",
      "The reaction of water with a strong acid is a vigorous exothermic reaction. Here's a breakdown of the known properties:\n",
      "\n",
      "**1. Reaction Chemistry:**\n",
      "\n",
      "*   **Formation of Hydronium Ions (H3\n",
      "\n",
      "\n",
      "\n",
      "What is the nature of high-energy cosmic rays?\n",
      "\n",
      "Cosmic rays are high-energy particles that originate from outside Earth's atmosphere. They are primarily protons and atomic nuclei, but also include electrons and positrons. The nature of these cosmic rays is one of the biggest mysteries in modern astrophysics, and\n",
      "\n",
      "\n",
      "What is the nature of high-energy ethyl rays?\n",
      "\n",
      "The question of the nature of high-energy ethyl rays, as proposed by Konstantin Tsiolkovsky in the late 19th and early 20th centuries, is a fascinating and ultimately discredited one. Tsiolkovsky\n",
      "\n",
      "\n",
      "\n",
      "The study of the very small is the study of quantum mechanics. Quantum mechanics is the theory that governs the behavior of matter and energy at the atomic and subatomic levels. It's a strange and counterintuitive theory, but it's also incredibly successful in describing the world around us.\n",
      "\n",
      "Here'\n",
      "\n",
      "\n",
      "The study of the very small is the study of acid-base chemistry.\n",
      "\n",
      "The study of the very small is the study of **quantum mechanics**.\n",
      "\n",
      "The statement is incorrect. The study of the very small is the study of quantum mechanics.\n",
      "\n",
      "The statement is false. The study of the very small\n",
      "\n",
      "\n",
      "\n",
      "Explain the theoretical framework of a black hole.\n",
      "\n",
      "The theoretical framework of a black hole is built upon Einstein's theory of General Relativity, combined with concepts from thermodynamics and quantum mechanics (though the latter is still an area of active research and debate). Here’s a breakdown of the key\n",
      "\n",
      "\n",
      "Explain the theoretical framework of a ml hole.\n",
      "\n",
      "**Theoretical Framework of an ML Hole**\n",
      "\n",
      "The concept of an “ML Hole” (often popularized by the work of Dr. Peter Norvig) describes a critical challenge in the development and deployment of increasingly complex machine learning models. It'\n",
      "\n",
      "\n",
      "\n",
      "The universe is composed primarily of energy, not matter.\n",
      "\n",
      "The idea that the universe is mostly energy is a recurring theme in many spiritual and philosophical traditions, and it’s increasingly supported by modern scientific discoveries. Here's a breakdown of the evidence and the implications:\n",
      "\n",
      "**1\n",
      "\n",
      "\n",
      "A reaction is composed primarily of energy. Energy is defined as the ability to do work. The question asks about the reaction being composed primarily of energy.\n",
      "\n",
      "The statement \"a reaction is composed primarily of energy\" is a metaphorical way of saying that the reaction involves a significant exchange of energy\n"
     ]
    }
   ],
   "source": [
    "with open(\"prompt_set/phy_chem_prompts.json\", 'r') as f:\n",
    "    code = json.load(f)\n",
    "    \n",
    "for prompt_pair in code[:10]:\n",
    "    print(\"\\n\\n\")\n",
    "    print(generate_from_model(prompt_pair[\"physics\"], max_new_tokens=50))\n",
    "    print(\"\\n\")\n",
    "    print(generate_from_model(prompt_pair[\"chemistry\"], max_new_tokens=50))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9ecb09fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "In finance, a key indicator of an asset's health is its price-to-earnings (P/E) ratio. It’s a simple yet powerful tool that compares a company’s stock price to its earnings per share. Let's break down what it is, how it's calculated, and what it means for investors.\n",
      "\n",
      "**What is the Price-to-Earnings (P/E) Ratio?**\n",
      "\n",
      "The P/E ratio is a valuation metric that shows how much investors are willing to pay for each dollar of a company'\n",
      "\n",
      "\n",
      "In medicine, a key indicator of a patient's health is its blood pressure. High blood pressure, or hypertension, is a serious condition that can lead to serious health problems, including heart disease, stroke, and kidney failure. However, there are ways to manage and lower blood pressure, and many people can live long, healthy lives with hypertension.\n",
      "\n",
      "Here's a breakdown of what you need to know about blood pressure, hypertension, and how to manage it:\n",
      "\n",
      "**1. What is Blood Pressure?**\n",
      "\n",
      "* **Blood pressure** is the force of\n",
      "\n",
      "\n",
      "\n",
      "In finance, the process of identifying a specific risk is called risk identification. Risk identification is an important step in the risk management process. It involves identifying potential events that could negatively affect an organization’s objectives.\n",
      "\n",
      "Here's a breakdown of the process, its methods, and why it's crucial:\n",
      "\n",
      "**1. What is Risk Identification?**\n",
      "\n",
      "* **Definition:** Risk identification is the systematic process of discovering and documenting potential hazards, uncertainties, or threats that could impact an organization’s goals, strategies, or operations.\n",
      "* **Focus\n",
      "\n",
      "\n",
      "In medicine, the process of identifying a specific disease is called diagnosis. It involves a combination of different methods, including:\n",
      "\n",
      "*   **Medical history:** Gathering information about the patient's symptoms, past illnesses, family history, and lifestyle.\n",
      "*   **Physical examination:** Assessing the patient's vital signs, looking for physical signs of illness, and performing specific tests.\n",
      "*   **Laboratory tests:** Analyzing blood, urine, or other bodily fluids to detect abnormalities.\n",
      "*   **Imaging tests:** Using X-rays, CT scans, MR\n",
      "\n",
      "\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 6\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m prompt_pair \u001b[38;5;129;01min\u001b[39;00m code[:\u001b[38;5;241m10\u001b[39m]:\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[43mgenerate_from_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt_pair\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfinance\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m)\u001b[49m)\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      8\u001b[0m     \u001b[38;5;28mprint\u001b[39m(generate_from_model(prompt_pair[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmedical\u001b[39m\u001b[38;5;124m\"\u001b[39m], max_new_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m))\n",
      "Cell \u001b[0;32mIn[1], line 36\u001b[0m, in \u001b[0;36mgenerate_from_model\u001b[0;34m(prompt, max_new_tokens)\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[38;5;66;03m# Run model.generate() for autoregressive decoding\u001b[39;00m\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 36\u001b[0m     output_ids \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     37\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     38\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_new_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     39\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdo_sample\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# sampling adds randomness\u001b[39;49;00m\n\u001b[1;32m     40\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtop_p\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.95\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m            \u001b[49m\u001b[38;5;66;43;03m# nucleus sampling\u001b[39;49;00m\n\u001b[1;32m     41\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.7\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m       \u001b[49m\u001b[38;5;66;43;03m# lower = more deterministic\u001b[39;49;00m\n\u001b[1;32m     42\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpad_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meos_token_id\u001b[49m\n\u001b[1;32m     43\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     45\u001b[0m \u001b[38;5;66;03m# Decode tokens back to string\u001b[39;00m\n\u001b[1;32m     46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tokenizer\u001b[38;5;241m.\u001b[39mdecode(output_ids[\u001b[38;5;241m0\u001b[39m], skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/utils/_contextlib.py:120\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    117\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    119\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 120\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/transformers/generation/utils.py:2617\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)\u001b[0m\n\u001b[1;32m   2609\u001b[0m     input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m   2610\u001b[0m         input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m   2611\u001b[0m         expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_return_sequences,\n\u001b[1;32m   2612\u001b[0m         is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m   2613\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[1;32m   2614\u001b[0m     )\n\u001b[1;32m   2616\u001b[0m     \u001b[38;5;66;03m# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[0;32m-> 2617\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2618\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2619\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2620\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2621\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2622\u001b[0m \u001b[43m        \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2623\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2624\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2625\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2627\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SAMPLE, GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SEARCH):\n\u001b[1;32m   2628\u001b[0m     \u001b[38;5;66;03m# 11. interleave input_ids with `num_beams` additional sequences per batch\u001b[39;00m\n\u001b[1;32m   2629\u001b[0m     input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m   2630\u001b[0m         input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m   2631\u001b[0m         expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_beams,\n\u001b[1;32m   2632\u001b[0m         is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m   2633\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[1;32m   2634\u001b[0m     )\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/transformers/generation/utils.py:3601\u001b[0m, in \u001b[0;36mGenerationMixin._sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[1;32m   3599\u001b[0m     is_prefill \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m   3600\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3601\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m   3603\u001b[0m \u001b[38;5;66;03m# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping\u001b[39;00m\n\u001b[1;32m   3604\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_update_model_kwargs_for_generation(\n\u001b[1;32m   3605\u001b[0m     outputs,\n\u001b[1;32m   3606\u001b[0m     model_kwargs,\n\u001b[1;32m   3607\u001b[0m     is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m   3608\u001b[0m )\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:736\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>.compile_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    733\u001b[0m _maybe_set_eval_frame(_callback_from_stance(callback))\n\u001b[1;32m    735\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    737\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m Unsupported \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    738\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m config\u001b[38;5;241m.\u001b[39mverbose:\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1771\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1782\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1783\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py:1006\u001b[0m, in \u001b[0;36mGemma3ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **lm_kwargs)\u001b[0m\n\u001b[1;32m   1002\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m   1003\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mmulti_modal_projector\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m   1004\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mmulti_modal_projector\n\u001b[0;32m-> 1006\u001b[0m \u001b[38;5;129m@auto_docstring\u001b[39m\n\u001b[1;32m   1007\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m   1008\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m   1009\u001b[0m     input_ids: torch\u001b[38;5;241m.\u001b[39mLongTensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1010\u001b[0m     pixel_values: torch\u001b[38;5;241m.\u001b[39mFloatTensor \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1011\u001b[0m     attention_mask: Optional[torch\u001b[38;5;241m.\u001b[39mTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1012\u001b[0m     position_ids: Optional[torch\u001b[38;5;241m.\u001b[39mLongTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1013\u001b[0m     past_key_values: Optional[Union[\u001b[38;5;28mlist\u001b[39m[torch\u001b[38;5;241m.\u001b[39mFloatTensor], Cache]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1014\u001b[0m     token_type_ids: Optional[torch\u001b[38;5;241m.\u001b[39mLongTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1015\u001b[0m     cache_position: Optional[torch\u001b[38;5;241m.\u001b[39mLongTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1016\u001b[0m     inputs_embeds: Optional[torch\u001b[38;5;241m.\u001b[39mFloatTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1017\u001b[0m     labels: Optional[torch\u001b[38;5;241m.\u001b[39mLongTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1018\u001b[0m     use_cache: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1019\u001b[0m     output_attentions: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1020\u001b[0m     output_hidden_states: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1021\u001b[0m     return_dict: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   1022\u001b[0m     logits_to_keep: Union[\u001b[38;5;28mint\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m   1023\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mlm_kwargs,\n\u001b[1;32m   1024\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[\u001b[38;5;28mtuple\u001b[39m, Gemma3CausalLMOutputWithPast]:\n\u001b[1;32m   1025\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   1026\u001b[0m \u001b[38;5;124;03m    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\u001b[39;00m\n\u001b[1;32m   1027\u001b[0m \u001b[38;5;124;03m        Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1067\u001b[0m \u001b[38;5;124;03m    ```\u001b[39;00m\n\u001b[1;32m   1068\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m   1070\u001b[0m     output_attentions \u001b[38;5;241m=\u001b[39m output_attentions \u001b[38;5;28;01mif\u001b[39;00m output_attentions \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39moutput_attentions\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929\u001b[0m, in \u001b[0;36mDisableContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    927\u001b[0m _maybe_set_eval_frame(_callback_from_stance(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback))\n\u001b[1;32m    928\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 929\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    930\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    931\u001b[0m     set_eval_frame(\u001b[38;5;28;01mNone\u001b[39;00m)\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1241\u001b[0m, in \u001b[0;36maot_module_simplified.<locals>.forward\u001b[0;34m(*runtime_args)\u001b[0m\n\u001b[1;32m   1239\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(params_flat)\n\u001b[1;32m   1240\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(runtime_args)\n\u001b[0;32m-> 1241\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfull_args\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:384\u001b[0m, in \u001b[0;36m_create_runtime_wrapper.<locals>.runtime_wrapper\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m    382\u001b[0m         torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_set_grad_enabled(\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m    383\u001b[0m     record_runtime_wrapper_prologue_exit(cm)\n\u001b[0;32m--> 384\u001b[0m     all_outs \u001b[38;5;241m=\u001b[39m \u001b[43mcall_func_at_runtime_with_args\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    385\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcompiled_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisable_amp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_amp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteal_args\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m    386\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    387\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    388\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m grad_enabled:\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:126\u001b[0m, in \u001b[0;36mcall_func_at_runtime_with_args\u001b[0;34m(f, args, steal_args, disable_amp)\u001b[0m\n\u001b[1;32m    124\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[1;32m    125\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(f, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_boxed_call\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[0;32m--> 126\u001b[0m         out \u001b[38;5;241m=\u001b[39m normalize_as_list(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    127\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    128\u001b[0m         \u001b[38;5;66;03m# TODO: Please remove soon\u001b[39;00m\n\u001b[1;32m    129\u001b[0m         \u001b[38;5;66;03m# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670\u001b[39;00m\n\u001b[1;32m    130\u001b[0m         warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m    131\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYour compiler for AOTAutograd is returning a function that doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt take boxed arguments. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    132\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    133\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSee https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    134\u001b[0m         )\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:556\u001b[0m, in \u001b[0;36mFunctionalizedRngRuntimeWrapper.post_compile.<locals>.wrapper\u001b[0;34m(runtime_args)\u001b[0m\n\u001b[1;32m    549\u001b[0m     out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_functionalized_rng_runtime_epilogue(\n\u001b[1;32m    550\u001b[0m         runtime_metadata,\n\u001b[1;32m    551\u001b[0m         out,\n\u001b[1;32m    552\u001b[0m         \u001b[38;5;66;03m# TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper\u001b[39;00m\n\u001b[1;32m    553\u001b[0m         runtime_metadata\u001b[38;5;241m.\u001b[39mnum_forward_returns,\n\u001b[1;32m    554\u001b[0m     )\n\u001b[1;32m    555\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[0;32m--> 556\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mruntime_args\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/output_code.py:584\u001b[0m, in \u001b[0;36mCompiledFxGraph.__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m    582\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcurrent_callable \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    583\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 584\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    585\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    586\u001b[0m     get_runtime_metrics_context()\u001b[38;5;241m.\u001b[39mfinish()\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:1655\u001b[0m, in \u001b[0;36mcudagraphify.<locals>.run\u001b[0;34m(new_inputs)\u001b[0m\n\u001b[1;32m   1653\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m dynamo_utils\u001b[38;5;241m.\u001b[39mpreserve_rng_state():\n\u001b[1;32m   1654\u001b[0m         compiled_fn \u001b[38;5;241m=\u001b[39m cudagraphify_fn(model, new_inputs, static_input_idxs)  \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[0;32m-> 1655\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_inputs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:387\u001b[0m, in \u001b[0;36mcudagraphify_impl.<locals>.deferred_cudagraphify\u001b[0;34m(inputs)\u001b[0m\n\u001b[1;32m    385\u001b[0m fn \u001b[38;5;241m=\u001b[39m fn_cache\u001b[38;5;241m.\u001b[39mget(int_key)\n\u001b[1;32m    386\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 387\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    389\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m int_key \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    390\u001b[0m     log\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrecording cudagraph tree for graph without symints\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/utils.py:2716\u001b[0m, in \u001b[0;36malign_inputs_from_check_idxs.<locals>.run\u001b[0;34m(new_inputs)\u001b[0m\n\u001b[1;32m   2712\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrun\u001b[39m(new_inputs: \u001b[38;5;28mlist\u001b[39m[InputType]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m   2713\u001b[0m     old_tensors, new_tensors \u001b[38;5;241m=\u001b[39m copy_misaligned_inputs(\n\u001b[1;32m   2714\u001b[0m         new_inputs, inputs_to_check, mutated_input_idxs\n\u001b[1;32m   2715\u001b[0m     )\n\u001b[0;32m-> 2716\u001b[0m     out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_inputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2718\u001b[0m     \u001b[38;5;66;03m# If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the\u001b[39;00m\n\u001b[1;32m   2719\u001b[0m     \u001b[38;5;66;03m# original tensor.\u001b[39;00m\n\u001b[1;32m   2720\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(old_tensors):\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2012\u001b[0m, in \u001b[0;36mCUDAGraphTreeManager.run\u001b[0;34m(self, new_inputs, function_id)\u001b[0m\n\u001b[1;32m   2010\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmode \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_to_mode[function_id]\n\u001b[1;32m   2011\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompile_id \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_to_compile_id[function_id]\n\u001b[0;32m-> 2012\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunction_id\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2014\u001b[0m \u001b[38;5;66;03m# The forwards are only pending following invocation, not before\u001b[39;00m\n\u001b[1;32m   2015\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmode \u001b[38;5;241m==\u001b[39m CompilationMode\u001b[38;5;241m.\u001b[39mFORWARD:\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2182\u001b[0m, in \u001b[0;36mCUDAGraphTreeManager._run\u001b[0;34m(self, new_inputs, function_id)\u001b[0m\n\u001b[1;32m   2179\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply_checkpoint_execution_state_in_allocator()\n\u001b[1;32m   2181\u001b[0m \u001b[38;5;66;03m# now, we are in a recording state !\u001b[39;00m\n\u001b[0;32m-> 2182\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecord_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunction_id\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2238\u001b[0m, in \u001b[0;36mCUDAGraphTreeManager.record_function\u001b[0;34m(self, new_inputs, function_id)\u001b[0m\n\u001b[1;32m   2236\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpath_state \u001b[38;5;241m=\u001b[39m ExecutionState\u001b[38;5;241m.\u001b[39mRECORDING\n\u001b[1;32m   2237\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate_generation()\n\u001b[0;32m-> 2238\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcuda\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msynchronize\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2239\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m node\u001b[38;5;241m.\u001b[39mrun_first_inputs(new_inputs)\n",
      "File \u001b[0;32m~/Desktop/clone/causal_intervention/causal/lib/python3.10/site-packages/torch/cuda/__init__.py:1085\u001b[0m, in \u001b[0;36msynchronize\u001b[0;34m(device)\u001b[0m\n\u001b[1;32m   1083\u001b[0m _lazy_init()\n\u001b[1;32m   1084\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mdevice(device):\n\u001b[0;32m-> 1085\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cuda_synchronize\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "with open(\"prompt_set/fin_med_prompts.json\", 'r') as f:\n",
    "    code = json.load(f)\n",
    "    \n",
    "for prompt_pair in code[:10]:\n",
    "    print(\"\\n\\n\")\n",
    "    print(generate_from_model(prompt_pair[\"finance\"], max_new_tokens=100))\n",
    "    print(\"\\n\")\n",
    "    print(generate_from_model(prompt_pair[\"medical\"], max_new_tokens=100))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "causal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
