{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3.1+cu118\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import ast\n",
    "import torch\n",
    "print(torch.__version__)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of invalid code snippets: 126\n",
      "61\n"
     ]
    }
   ],
   "source": [
    "# load the data\n",
    "def is_valid_code(code):\n",
    "    try:\n",
    "        ast.parse(code)\n",
    "        return True\n",
    "    except SyntaxError:\n",
    "        return False\n",
    "\n",
    "def preprocess_data(data, test_data=False):\n",
    "    # Apply the is_valid_code function to each code entry\n",
    "    data['is_valid'] = data['code'].apply(is_valid_code)\n",
    "    \n",
    "    # Count and print the number of invalid code snippets\n",
    "    invalid_count = data['is_valid'].value_counts().get(False, 0)\n",
    "    print(f\"Number of invalid code snippets: {invalid_count}\")\n",
    "    \n",
    "    # Filter out invalid code entries and reset the index\n",
    "    data = data[data['is_valid']].drop(columns=['is_valid']).reset_index(drop=True)\n",
    "\n",
    "    if test_data:\n",
    "        return data\n",
    "    \n",
    "    data['code_list_embedding'] = data['code_list_embedding'].apply(ast.literal_eval)\n",
    "    data['previous_code_embedding'] = data['previous_code_embedding'].apply(ast.literal_eval)\n",
    "    data['code_list'] = data['code_list'].apply(ast.literal_eval)\n",
    "    data['previous_code_list'] = data['previous_code_list'].apply(ast.literal_eval)\n",
    "    \n",
    "    return data\n",
    "\n",
    "data = pd.read_csv(r\"D:\\Python_project\\leetcode_data\\data\\leetcode_Median_of_Two_Sorted_Arrays_embedding.csv\")\n",
    "data = preprocess_data(data)\n",
    "print(len(data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of invalid code snippets: 1\n",
      "119\n",
      "labels\n",
      "0    60\n",
      "1    59\n",
      "Name: count, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "# load nagative data\n",
    "test_data = pd.read_csv(r\"D:\\Python_project\\leetcode_data\\data\\leetcode_diff_task.csv\")\n",
    "test_data = preprocess_data(test_data, test_data=True)\n",
    "# label all rows that has \"Median of Two Sorted Arrays\" in title column as 1 and the rest as 0\n",
    "test_data['labels'] = test_data['title'].apply(lambda x: 1 if 'Median of Two Sorted Arrays' in x else 0)\n",
    "# only keeps the columns that are needed\n",
    "test_data = test_data[['code', 'labels']]\n",
    "test_data.reset_index(drop=True, inplace=True)\n",
    "print(len(test_data))\n",
    "print(test_data['labels'].value_counts())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### state define step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "\n",
    "# We will assume these 5 categories for demonstration:\n",
    "CATEGORIES = [\"variable\", \"function\", \"class\", \"import\", \"argument\"]\n",
    "\n",
    "class CodeState:\n",
    "    \"\"\"\n",
    "    Tracks memory-affecting information: variables, functions, classes, imports, arguments.\n",
    "    We use a max_unseen limit so if we define more than 'max_unseen' distinct names,\n",
    "    we store them in the 'overflow' slot for that category.\n",
    "    \"\"\"\n",
    "    def __init__(self, max_unseen=10):\n",
    "        self.max_unseen = max_unseen\n",
    "        \n",
    "        # We'll maintain a separate list for each category\n",
    "        # Mapping: category -> list of known names\n",
    "        self.known = {\n",
    "            \"variable\": [],\n",
    "            \"function\": [],\n",
    "            \"class\": [],\n",
    "            \"import\": [],\n",
    "            \"argument\": [],  # if you'd like to track function arguments specifically\n",
    "        }\n",
    "\n",
    "    def add_item(self, category, name):\n",
    "        \"\"\"\n",
    "        Generic helper. If name is not in self.known[category], add it if not exceeding max_unseen.\n",
    "        \"\"\"\n",
    "        if name not in self.known[category]:\n",
    "            if len(self.known[category]) < self.max_unseen:\n",
    "                self.known[category].append(name)\n",
    "            # else do nothing (any new name goes into 'overflow')\n",
    "\n",
    "    def get_index(self, category, name):\n",
    "        \"\"\"\n",
    "        Return an index for 'name' in the given category.\n",
    "        If we have not stored it (either brand new or exceeding max_unseen), \n",
    "        we treat it as 'overflow'.\n",
    "        \"\"\"\n",
    "        if name in self.known[category]:\n",
    "            idx = self.known[category].index(name)\n",
    "            if idx < self.max_unseen:\n",
    "                return idx\n",
    "        # If unknown or out of range, we treat it as overflow.\n",
    "        return self.max_unseen  # i.e. the \"others\" slot\n",
    "\n",
    "    def get_state_representation(self):\n",
    "        \"\"\"\n",
    "        Returns a dictionary describing how many names we have for each category \n",
    "        (up to max_unseen), plus a boolean if we have overflow.\n",
    "        In practice, we'll convert it to a one-hot or multi-hot vector later.\n",
    "        \"\"\"\n",
    "        result = {}\n",
    "        for cat in CATEGORIES:\n",
    "            count = len(self.known[cat])\n",
    "            has_overflow = (count > self.max_unseen)\n",
    "            # We record the actual count (clamped) + overflow info\n",
    "            result[f\"{cat}_count\"] = min(count, self.max_unseen)\n",
    "            result[f\"{cat}_overflow\"] = has_overflow\n",
    "        return result\n",
    "\n",
    "    def __str__(self):\n",
    "        return \"CodeState(\" + \", \".join(\n",
    "            f\"{cat}={self.known[cat]}\" for cat in CATEGORIES\n",
    "        ) + \")\"\n",
    "\n",
    "class CodeToTrajectory:\n",
    "    \"\"\"\n",
    "    Converts code into a sequence of (state, action) pairs, line-by-line.\n",
    "    We keep it simple: each line may produce one or more actions, \n",
    "    combined into a single \"action\" item.\n",
    "    \"\"\"\n",
    "    def __init__(self, max_unseen=10):\n",
    "        self.max_unseen = max_unseen\n",
    "\n",
    "    def convert_to_trajectory(self, code_str):\n",
    "        \"\"\"\n",
    "        Return a list of (state_dict, action_dict) pairs for each line in code_str.\n",
    "        \"\"\"\n",
    "        tree = ast.parse(code_str)\n",
    "        \n",
    "        # Collect AST nodes by line\n",
    "        nodes = list(ast.walk(tree))\n",
    "        line_to_nodes = {}\n",
    "        for node in nodes:\n",
    "            if hasattr(node, 'lineno'):\n",
    "                line_no = node.lineno\n",
    "                line_to_nodes.setdefault(line_no, []).append(node)\n",
    "        \n",
    "        code_lines = code_str.split(\"\\n\")\n",
    "        code_state = CodeState(max_unseen=self.max_unseen)\n",
    "        \n",
    "        trajectory = []\n",
    "        for lineno, line_text in enumerate(code_lines, start=1):\n",
    "            if lineno not in line_to_nodes:\n",
    "                continue  # skip empty/comment lines or lines with no AST node\n",
    "            \n",
    "            # The current state (BEFORE this line executes)\n",
    "            current_state = code_state.get_state_representation()\n",
    "            \n",
    "            # Gather the actions for this line\n",
    "            actions_for_line = []\n",
    "            for node in line_to_nodes[lineno]:\n",
    "                # interpret the node and possibly update state\n",
    "                action_list = self._interpret_node(node, code_state)\n",
    "                if isinstance(action_list, list):\n",
    "                    actions_for_line.extend(action_list)\n",
    "                else:\n",
    "                    actions_for_line.append(action_list)\n",
    "\n",
    "            combined_action = {\n",
    "                \"lineno\": lineno,\n",
    "                \"line_text\": line_text.strip(),\n",
    "                \"actions\": actions_for_line\n",
    "            }\n",
    "            trajectory.append((current_state, combined_action))\n",
    "\n",
    "        return trajectory\n",
    "\n",
    "    def _interpret_node(self, node, code_state):\n",
    "        \"\"\"\n",
    "        Interpret a node. Return the action(s) as a textual or structured label.\n",
    "        Also update code_state if definitions occur.\n",
    "        \"\"\"\n",
    "        if isinstance(node, ast.Assign):\n",
    "            # Typically: x = ...\n",
    "            actions = []\n",
    "            for target in node.targets:\n",
    "                if isinstance(target, ast.Name):\n",
    "                    var_name = target.id\n",
    "                    # \"Define variable\" or \"Assign variable\"\n",
    "                    code_state.add_item(\"variable\", var_name)\n",
    "                    idx = code_state.get_index(\"variable\", var_name)\n",
    "                    actions.append(f\"define variable#{idx}\")\n",
    "                else:\n",
    "                    actions.append(\"define variable@complex_target\")\n",
    "            return actions\n",
    "        \n",
    "        elif isinstance(node, ast.Name):\n",
    "            # Use of a variable (if inside an expression)\n",
    "            var_name = node.id\n",
    "            idx = code_state.get_index(\"variable\", var_name)\n",
    "            if idx == code_state.max_unseen:\n",
    "                return \"call/use variable#overflow_or_undefined\"\n",
    "            else:\n",
    "                return f\"call/use variable#{idx}\"\n",
    "\n",
    "        elif isinstance(node, ast.FunctionDef):\n",
    "            # define a function\n",
    "            func_name = node.name\n",
    "            code_state.add_item(\"function\", func_name)\n",
    "            idx = code_state.get_index(\"function\", func_name)\n",
    "            \n",
    "            # Also track arguments as \"argument\" if desired\n",
    "            arg_actions = []\n",
    "            for arg in node.args.args:\n",
    "                arg_name = arg.arg\n",
    "                code_state.add_item(\"argument\", arg_name)\n",
    "                arg_idx = code_state.get_index(\"argument\", arg_name)\n",
    "                arg_actions.append(f\"define argument#{arg_idx}\")\n",
    "            \n",
    "            return [f\"define function#{idx}\"] + arg_actions\n",
    "\n",
    "        elif isinstance(node, ast.ClassDef):\n",
    "            cls_name = node.name\n",
    "            code_state.add_item(\"class\", cls_name)\n",
    "            idx = code_state.get_index(\"class\", cls_name)\n",
    "            return f\"define class#{idx}\"\n",
    "\n",
    "        elif isinstance(node, ast.Call):\n",
    "            # calling a function or method\n",
    "            return self._interpret_call(node, code_state)\n",
    "\n",
    "        elif isinstance(node, ast.Import):\n",
    "            # import statements\n",
    "            actions = []\n",
    "            for alias in node.names:\n",
    "                import_name = alias.name\n",
    "                code_state.add_item(\"import\", import_name)\n",
    "                idx = code_state.get_index(\"import\", import_name)\n",
    "                actions.append(f\"define import#{idx}\")\n",
    "            return actions\n",
    "\n",
    "        elif isinstance(node, ast.ImportFrom):\n",
    "            # from X import Y\n",
    "            actions = []\n",
    "            module_name = node.module if node.module else \"?\"\n",
    "            code_state.add_item(\"import\", module_name)\n",
    "            module_idx = code_state.get_index(\"import\", module_name)\n",
    "            actions.append(f\"define import#{module_idx}\")\n",
    "            # We won’t handle the imported names in detail, but you could if desired\n",
    "            return actions\n",
    "\n",
    "        else:\n",
    "            # Other operations: e.g. If, For, While, expressions, etc.\n",
    "            # We can either break them down further or keep them as a single \"operation\" label.\n",
    "            # For vocabulary-building, we'll just store something like:\n",
    "            return f\"other operation: {type(node).__name__}\"\n",
    "\n",
    "    def _interpret_call(self, node, code_state):\n",
    "        \"\"\"\n",
    "        Handle calls like foo(...) or obj.foo(...).\n",
    "        \"\"\"\n",
    "        # If node.func is ast.Name => direct function call\n",
    "        if isinstance(node.func, ast.Name):\n",
    "            func_name = node.func.id\n",
    "            idx = code_state.get_index(\"function\", func_name)\n",
    "            if idx == code_state.max_unseen:\n",
    "                return \"call/use function#overflow_or_undefined\"\n",
    "            else:\n",
    "                return f\"call/use function#{idx}\"\n",
    "        elif isinstance(node.func, ast.Attribute):\n",
    "            # e.g. obj.method(...)\n",
    "            # We might do more logic to see if obj is a known variable\n",
    "            return \"call/use method_on_object\"\n",
    "        else:\n",
    "            return \"call/use unknown_function\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VocabularyBuilder:\n",
    "    \"\"\"\n",
    "    Scans all code in the dataset to build a fixed state-space dimension\n",
    "    and a fixed action-space dimension (including any 'other operation' strings).\n",
    "    \"\"\"\n",
    "    def __init__(self, max_unseen=10, categories=None):\n",
    "        self.max_unseen = max_unseen\n",
    "        self.categories = categories if categories else CATEGORIES\n",
    "\n",
    "        # We know state space dimension = (max_unseen + 1) * len(categories).\n",
    "        self.state_dim = (self.max_unseen + 1) * len(self.categories)\n",
    "\n",
    "        # For the action space, let's build:\n",
    "        # 1) define_X#i for i in [0..max_unseen], for each category X\n",
    "        # 2) call/use_X#i for i in [0..max_unseen], for each category X\n",
    "        # 3) additional \"operations\" discovered from the dataset\n",
    "        # 4) plus 1 for 'unknown operation'.\n",
    "        \n",
    "        self.define_action_list = []  # e.g. define variable#0, define function#0, ...\n",
    "        self.call_action_list = []    # e.g. call/use variable#0, call/use function#0, ...\n",
    "        \n",
    "        # We will gather the \"other operations\" from scanning the code\n",
    "        self.other_operations = set()\n",
    "        \n",
    "        # We'll finalize them after scanning the dataset\n",
    "        self.action2index = {}\n",
    "        self.index2action = []\n",
    "\n",
    "    def build_vocabulary(self, trajectories):\n",
    "        \"\"\"\n",
    "        Given a list of trajectories from the entire dataset, \n",
    "        collect all distinct actions that appear, so we can finalize the action2index map.\n",
    "        \"\"\"\n",
    "        # 1) Build define_action_list\n",
    "        for cat in self.categories:\n",
    "            for i in range(self.max_unseen + 1):\n",
    "                # i goes from 0..max_unseen\n",
    "                self.define_action_list.append(f\"define {cat}#{i}\")\n",
    "        # 2) Build call_action_list\n",
    "        for cat in self.categories:\n",
    "            for i in range(self.max_unseen + 1):\n",
    "                self.call_action_list.append(f\"call/use {cat}#{i}\")\n",
    "\n",
    "        # Now read from the trajectories to find all distinct \"other\" actions\n",
    "        for traj in trajectories:\n",
    "            for (state, action_dict) in traj:\n",
    "                for act in action_dict[\"actions\"]:\n",
    "                    # Check if it's in define or call lists\n",
    "                    if (act not in self.define_action_list) and (act not in self.call_action_list):\n",
    "                        self.other_operations.add(act)\n",
    "\n",
    "        # We'll remove from 'other_operations' any define/call duplicates if they exist \n",
    "        # (should not if everything is consistent, but just in case):\n",
    "        self.other_operations -= set(self.define_action_list)\n",
    "        self.other_operations -= set(self.call_action_list)\n",
    "\n",
    "        # Build final index\n",
    "        # Start with all define actions:\n",
    "        idx = 0\n",
    "        for d in self.define_action_list:\n",
    "            self.action2index[d] = idx\n",
    "            idx += 1\n",
    "        # Then call actions:\n",
    "        for c in self.call_action_list:\n",
    "            self.action2index[c] = idx\n",
    "            idx += 1\n",
    "        # Then the \"other\" operations\n",
    "        # Sort them for reproducibility\n",
    "        self.other_operations = sorted(list(self.other_operations))\n",
    "        for op in self.other_operations:\n",
    "            self.action2index[op] = idx\n",
    "            idx += 1\n",
    "\n",
    "        # Finally, we add ONE slot for unknown\n",
    "        self.unknown_action_idx = idx\n",
    "        self.action2index[\"<UNKNOWN>\"] = idx\n",
    "        idx += 1\n",
    "\n",
    "        # Build index2action list for quick debugging\n",
    "        self.index2action = [None] * idx\n",
    "        for k, v in self.action2index.items():\n",
    "            self.index2action[v] = k\n",
    "\n",
    "        self.action_dim = idx\n",
    "\n",
    "    def encode_state(self, state_dict):\n",
    "        \"\"\"\n",
    "        Convert a state_dict (like from CodeState.get_state_representation()) \n",
    "        into a one-hot or multi-hot vector of length self.state_dim.\n",
    "        \n",
    "        We'll do a multi-hot approach: for each category, if we have X items, \n",
    "        we set indices for #0..#(X-1). If overflow is True, we also set #max_unseen for that category.\n",
    "        \"\"\"\n",
    "        vec = [0]*self.state_dim\n",
    "        offset = 0\n",
    "        for cat in self.categories:\n",
    "            count = state_dict[f\"{cat}_count\"]\n",
    "            overflow = state_dict[f\"{cat}_overflow\"]\n",
    "            # We set 1 at positions offset..offset+count-1\n",
    "            for i in range(count):\n",
    "                vec[offset + i] = 1\n",
    "            if overflow:\n",
    "                # set the overflow slot\n",
    "                vec[offset + self.max_unseen] = 1\n",
    "            offset += (self.max_unseen + 1)\n",
    "        return vec\n",
    "\n",
    "    def encode_action_list(self, action_list):\n",
    "        \"\"\"\n",
    "        For a line that might have multiple actions, we can either \n",
    "        - combine them somehow (e.g. sum or multi-hot)\n",
    "        - or produce multiple steps\n",
    "        For demonstration, let's do a multi-hot representation for the line’s combined action.\n",
    "        \"\"\"\n",
    "        vec = [0]*self.action_dim\n",
    "        for act in action_list:\n",
    "            if act in self.action2index:\n",
    "                vec[self.action2index[act]] = 1\n",
    "            else:\n",
    "                # unknown\n",
    "                vec[self.unknown_action_idx] = 1\n",
    "        return vec\n",
    "\n",
    "\n",
    "def build_and_encode_dataset(\n",
    "    data_df, \n",
    "    max_unseen=10,\n",
    "    code_col=\"code\"\n",
    "):\n",
    "    \"\"\"\n",
    "    1. Parse each code snippet into a trajectory\n",
    "    2. Collect all trajectories\n",
    "    3. Build the vocabulary (state & action space)\n",
    "    4. Encode everything into one-hot arrays\n",
    "    \"\"\"\n",
    "    # Step 1: parse\n",
    "    parser = CodeToTrajectory(max_unseen=max_unseen)\n",
    "    all_trajectories = []\n",
    "    for i, row in data_df.iterrows():\n",
    "        code_str = row[code_col]\n",
    "        traj = parser.convert_to_trajectory(code_str)\n",
    "        all_trajectories.append(traj)\n",
    "\n",
    "    # Step 2: build vocabulary from all trajectories\n",
    "    vocab_builder = VocabularyBuilder(max_unseen=max_unseen, categories=CATEGORIES)\n",
    "    vocab_builder.build_vocabulary(all_trajectories)\n",
    "\n",
    "    # Step 3: encode each trajectory\n",
    "    encoded_dataset = []\n",
    "    for traj in all_trajectories:\n",
    "        encoded_steps = []\n",
    "        for (state_dict, action_dict) in traj:\n",
    "            s_vec = vocab_builder.encode_state(state_dict)\n",
    "            # We combine all sub-actions on the line:\n",
    "            a_vec = vocab_builder.encode_action_list(action_dict[\"actions\"])\n",
    "            encoded_steps.append((s_vec, a_vec))\n",
    "        encoded_dataset.append(encoded_steps)\n",
    "\n",
    "    return all_trajectories, encoded_dataset, vocab_builder\n",
    "\n",
    "def encode_new_code(\n",
    "    code_str, \n",
    "    parser: CodeToTrajectory, \n",
    "    vocab_builder: VocabularyBuilder\n",
    "):\n",
    "    \"\"\"\n",
    "    Given a new code snippet, parse it into a trajectory, \n",
    "    then encode using the pre-built vocab_builder.\n",
    "    \"\"\"\n",
    "    traj = parser.convert_to_trajectory(code_str)\n",
    "    encoded_steps = []\n",
    "    for (state_dict, action_dict) in traj:\n",
    "        s_vec = vocab_builder.encode_state(state_dict)\n",
    "        a_vec = vocab_builder.encode_action_list(action_dict[\"actions\"])\n",
    "        encoded_steps.append((s_vec, a_vec))\n",
    "    return traj, encoded_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Built Vocabulary ===\n",
      "State space dimension: 55\n",
      "Action space dimension: 134\n",
      "\n",
      "=== Action Space Vocabulary (index -> action) ===\n",
      "0: define variable#0\n",
      "1: define variable#1\n",
      "2: define variable#2\n",
      "3: define variable#3\n",
      "4: define variable#4\n",
      "5: define variable#5\n",
      "6: define variable#6\n",
      "7: define variable#7\n",
      "8: define variable#8\n",
      "9: define variable#9\n",
      "10: define variable#10\n",
      "11: define function#0\n",
      "12: define function#1\n",
      "13: define function#2\n",
      "14: define function#3\n",
      "15: define function#4\n",
      "16: define function#5\n",
      "17: define function#6\n",
      "18: define function#7\n",
      "19: define function#8\n",
      "20: define function#9\n",
      "21: define function#10\n",
      "22: define class#0\n",
      "23: define class#1\n",
      "24: define class#2\n",
      "25: define class#3\n",
      "26: define class#4\n",
      "27: define class#5\n",
      "28: define class#6\n",
      "29: define class#7\n",
      "30: define class#8\n",
      "31: define class#9\n",
      "32: define class#10\n",
      "33: define import#0\n",
      "34: define import#1\n",
      "35: define import#2\n",
      "36: define import#3\n",
      "37: define import#4\n",
      "38: define import#5\n",
      "39: define import#6\n",
      "40: define import#7\n",
      "41: define import#8\n",
      "42: define import#9\n",
      "43: define import#10\n",
      "44: define argument#0\n",
      "45: define argument#1\n",
      "46: define argument#2\n",
      "47: define argument#3\n",
      "48: define argument#4\n",
      "49: define argument#5\n",
      "50: define argument#6\n",
      "51: define argument#7\n",
      "52: define argument#8\n",
      "53: define argument#9\n",
      "54: define argument#10\n",
      "55: call/use variable#0\n",
      "56: call/use variable#1\n",
      "57: call/use variable#2\n",
      "58: call/use variable#3\n",
      "59: call/use variable#4\n",
      "60: call/use variable#5\n",
      "61: call/use variable#6\n",
      "62: call/use variable#7\n",
      "63: call/use variable#8\n",
      "64: call/use variable#9\n",
      "65: call/use variable#10\n",
      "66: call/use function#0\n",
      "67: call/use function#1\n",
      "68: call/use function#2\n",
      "69: call/use function#3\n",
      "70: call/use function#4\n",
      "71: call/use function#5\n",
      "72: call/use function#6\n",
      "73: call/use function#7\n",
      "74: call/use function#8\n",
      "75: call/use function#9\n",
      "76: call/use function#10\n",
      "77: call/use class#0\n",
      "78: call/use class#1\n",
      "79: call/use class#2\n",
      "80: call/use class#3\n",
      "81: call/use class#4\n",
      "82: call/use class#5\n",
      "83: call/use class#6\n",
      "84: call/use class#7\n",
      "85: call/use class#8\n",
      "86: call/use class#9\n",
      "87: call/use class#10\n",
      "88: call/use import#0\n",
      "89: call/use import#1\n",
      "90: call/use import#2\n",
      "91: call/use import#3\n",
      "92: call/use import#4\n",
      "93: call/use import#5\n",
      "94: call/use import#6\n",
      "95: call/use import#7\n",
      "96: call/use import#8\n",
      "97: call/use import#9\n",
      "98: call/use import#10\n",
      "99: call/use argument#0\n",
      "100: call/use argument#1\n",
      "101: call/use argument#2\n",
      "102: call/use argument#3\n",
      "103: call/use argument#4\n",
      "104: call/use argument#5\n",
      "105: call/use argument#6\n",
      "106: call/use argument#7\n",
      "107: call/use argument#8\n",
      "108: call/use argument#9\n",
      "109: call/use argument#10\n",
      "110: call/use function#overflow_or_undefined\n",
      "111: call/use method_on_object\n",
      "112: call/use variable#overflow_or_undefined\n",
      "113: define variable@complex_target\n",
      "114: other operation: Attribute\n",
      "115: other operation: AugAssign\n",
      "116: other operation: BinOp\n",
      "117: other operation: BoolOp\n",
      "118: other operation: Compare\n",
      "119: other operation: Constant\n",
      "120: other operation: Expr\n",
      "121: other operation: For\n",
      "122: other operation: If\n",
      "123: other operation: IfExp\n",
      "124: other operation: List\n",
      "125: other operation: Raise\n",
      "126: other operation: Return\n",
      "127: other operation: Slice\n",
      "128: other operation: Subscript\n",
      "129: other operation: Tuple\n",
      "130: other operation: UnaryOp\n",
      "131: other operation: While\n",
      "132: other operation: arg\n",
      "133: <UNKNOWN>\n",
      "\n",
      "=== Trajectory of the first code snippet ===\n",
      "Step 1:\n",
      "  State dict: {'variable_count': 0, 'variable_overflow': False, 'function_count': 0, 'function_overflow': False, 'class_count': 0, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 0, 'argument_overflow': False}\n",
      "  Action: {'lineno': 1, 'line_text': 'class Solution(object):', 'actions': ['define class#0', 'call/use variable#overflow_or_undefined']}\n",
      "  Encoded State Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Step 2:\n",
      "  State dict: {'variable_count': 0, 'variable_overflow': False, 'function_count': 0, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 0, 'argument_overflow': False}\n",
      "  Action: {'lineno': 2, 'line_text': 'def findMedianSortedArrays(self, nums1, nums2):', 'actions': ['define function#0', 'define argument#0', 'define argument#1', 'define argument#2', 'other operation: arg', 'other operation: arg', 'other operation: arg']}\n",
      "  Encoded State Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]\n",
      "\n",
      "Step 3:\n",
      "  State dict: {'variable_count': 0, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}\n",
      "  Action: {'lineno': 3, 'line_text': 'merged = sorted(nums1+nums2)', 'actions': ['define variable#0', 'call/use variable#0', 'call/use function#overflow_or_undefined', 'call/use variable#overflow_or_undefined', 'other operation: BinOp', 'call/use variable#overflow_or_undefined', 'call/use variable#overflow_or_undefined']}\n",
      "  Encoded State Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Step 4:\n",
      "  State dict: {'variable_count': 1, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}\n",
      "  Action: {'lineno': 4, 'line_text': 'n = len(merged)', 'actions': ['define variable#1', 'call/use variable#1', 'call/use function#overflow_or_undefined', 'call/use variable#overflow_or_undefined', 'call/use variable#0']}\n",
      "  Encoded State Vec: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Step 5:\n",
      "  State dict: {'variable_count': 2, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}\n",
      "  Action: {'lineno': 5, 'line_text': 'if n%2==1:', 'actions': ['other operation: If', 'other operation: Compare', 'other operation: BinOp', 'other operation: Constant', 'call/use variable#1', 'other operation: Constant']}\n",
      "  Encoded State Vec: [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Step 6:\n",
      "  State dict: {'variable_count': 2, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}\n",
      "  Action: {'lineno': 6, 'line_text': 'return merged[n//2]', 'actions': ['other operation: Return', 'other operation: Subscript', 'call/use variable#0', 'other operation: BinOp', 'call/use variable#1', 'other operation: Constant']}\n",
      "  Encoded State Vec: [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0]\n",
      "\n",
      "Step 7:\n",
      "  State dict: {'variable_count': 2, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}\n",
      "  Action: {'lineno': 8, 'line_text': 'return (merged[n//2]+merged[n//2-1])/2.0', 'actions': ['other operation: Return', 'other operation: BinOp', 'other operation: BinOp', 'other operation: Constant', 'other operation: Subscript', 'other operation: Subscript', 'call/use variable#0', 'other operation: BinOp', 'call/use variable#0', 'other operation: BinOp', 'call/use variable#1', 'other operation: Constant', 'other operation: BinOp', 'other operation: Constant', 'call/use variable#1', 'other operation: Constant']}\n",
      "  Encoded State Vec: [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0]\n",
      "\n",
      "=== Trajectory for NEW Code ===\n",
      "Line 2, Code: import sys\n",
      "  State dict: {'variable_count': 0, 'variable_overflow': False, 'function_count': 0, 'function_overflow': False, 'class_count': 0, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 0, 'argument_overflow': False}\n",
      "  Actions: ['define import#0']\n",
      "  Encoded State Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Line 3, Code: def bar(x, y):\n",
      "  State dict: {'variable_count': 0, 'variable_overflow': False, 'function_count': 0, 'function_overflow': False, 'class_count': 0, 'class_overflow': False, 'import_count': 1, 'import_overflow': False, 'argument_count': 0, 'argument_overflow': False}\n",
      "  Actions: ['define function#0', 'define argument#0', 'define argument#1', 'other operation: arg', 'other operation: arg']\n",
      "  Encoded State Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]\n",
      "\n",
      "Line 4, Code: return x * y\n",
      "  State dict: {'variable_count': 0, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 0, 'class_overflow': False, 'import_count': 1, 'import_overflow': False, 'argument_count': 2, 'argument_overflow': False}\n",
      "  Actions: ['other operation: Return', 'other operation: BinOp', 'call/use variable#overflow_or_undefined', 'call/use variable#overflow_or_undefined']\n",
      "  Encoded State Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Line 5, Code: z = bar(3,5)\n",
      "  State dict: {'variable_count': 0, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 0, 'class_overflow': False, 'import_count': 1, 'import_overflow': False, 'argument_count': 2, 'argument_overflow': False}\n",
      "  Actions: ['define variable#0', 'call/use variable#0', 'call/use function#0', 'call/use variable#overflow_or_undefined', 'other operation: Constant', 'other operation: Constant']\n",
      "  Encoded State Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n",
      "Line 6, Code: unknown_function_call(123)\n",
      "  State dict: {'variable_count': 1, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 0, 'class_overflow': False, 'import_count': 1, 'import_overflow': False, 'argument_count': 2, 'argument_overflow': False}\n",
      "  Actions: ['other operation: Expr', 'call/use function#overflow_or_undefined', 'call/use variable#overflow_or_undefined', 'other operation: Constant']\n",
      "  Encoded State Vec: [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "  Encoded Action Vec: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data_test_01 = data[[\"code\"]].copy(deep=True)\n",
    "\n",
    "max_unseen = 10  # user-defined hyperparameter\n",
    "all_trajectories, encoded_dataset, vocab_builder = build_and_encode_dataset(\n",
    "    data_df=data_test_01,\n",
    "    max_unseen=max_unseen,\n",
    "    code_col=\"code\"\n",
    ")\n",
    "\n",
    "print(\"=== Built Vocabulary ===\")\n",
    "print(f\"State space dimension: {vocab_builder.state_dim}\")\n",
    "print(f\"Action space dimension: {vocab_builder.action_dim}\\n\")\n",
    "\n",
    "print(\"=== Action Space Vocabulary (index -> action) ===\")\n",
    "for idx, action_str in enumerate(vocab_builder.index2action):\n",
    "    print(f\"{idx}: {action_str}\")\n",
    "print()\n",
    "\n",
    "# Let's print the trajectory & encodings of the first code snippet\n",
    "print(\"=== Trajectory of the first code snippet ===\")\n",
    "first_traj = all_trajectories[0]\n",
    "first_encoded = encoded_dataset[0]\n",
    "\n",
    "for step_idx, ((state_dict), action_dict) in enumerate(first_traj):\n",
    "    print(f\"Step {step_idx+1}:\")\n",
    "    print(\"  State dict:\", state_dict)\n",
    "    print(\"  Action:\", action_dict)\n",
    "    print(\"  Encoded State Vec:\", first_encoded[step_idx][0])\n",
    "    print(\"  Encoded Action Vec:\", first_encoded[step_idx][1])\n",
    "    print()\n",
    "\n",
    "# Now let's encode a NEW code snippet that wasn't in the dataset\n",
    "new_code = \"\"\"\n",
    "import sys\n",
    "def bar(x, y):\n",
    "    return x * y\n",
    "z = bar(3,5)\n",
    "unknown_function_call(123)\n",
    "\"\"\"\n",
    "parser = CodeToTrajectory(max_unseen=max_unseen)  # same parser settings\n",
    "new_traj, new_encoded = encode_new_code(new_code, parser, vocab_builder)\n",
    "\n",
    "print(\"=== Trajectory for NEW Code ===\")\n",
    "for step_idx, (state_action) in enumerate(new_traj):\n",
    "    state_dict, action_dict = state_action\n",
    "    print(f\"Line {action_dict['lineno']}, Code: {action_dict['line_text']}\")\n",
    "    print(\"  State dict:\", state_dict)\n",
    "    print(\"  Actions:\", action_dict[\"actions\"])\n",
    "    print(\"  Encoded State Vec:\", new_encoded[step_idx][0])\n",
    "    print(\"  Encoded Action Vec:\", new_encoded[step_idx][1])\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IRL test 01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                          trajectory  \\\n",
      "0  [({'variable_count': 0, 'variable_overflow': F...   \n",
      "1  [({'variable_count': 0, 'variable_overflow': F...   \n",
      "2  [({'variable_count': 0, 'variable_overflow': F...   \n",
      "3  [({'variable_count': 0, 'variable_overflow': F...   \n",
      "4  [({'variable_count': 0, 'variable_overflow': F...   \n",
      "\n",
      "                                       one_hot_state  \\\n",
      "0  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   \n",
      "1  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   \n",
      "2  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   \n",
      "3  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   \n",
      "4  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   \n",
      "\n",
      "                                      one_hot_action  \n",
      "0  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...  \n",
      "1  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...  \n",
      "2  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...  \n",
      "3  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...  \n",
      "4  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...  \n",
      "56\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data_temp = data[[\"code\"]].copy(deep=True)\n",
    "\n",
    "max_unseen = 10  # user-defined hyperparameter\n",
    "all_trajectories, encoded_dataset, vocab_builder = build_and_encode_dataset(\n",
    "    data_df=data_temp,\n",
    "    max_unseen=max_unseen,\n",
    "    code_col=\"code\"\n",
    ")\n",
    "\n",
    "def convert_to_dataframe(all_trajectories, encoded_dataset):\n",
    "    data = []\n",
    "\n",
    "    for traj, encoded_traj in zip(all_trajectories, encoded_dataset):\n",
    "        trajectory = traj\n",
    "        one_hot_state = [step[0] for step in encoded_traj]\n",
    "        one_hot_action = [step[1] for step in encoded_traj]\n",
    "\n",
    "        if len(trajectory) == 0:\n",
    "            continue\n",
    "        data.append({\n",
    "            \"trajectory\": trajectory,\n",
    "            \"one_hot_state\": one_hot_state,\n",
    "            \"one_hot_action\": one_hot_action\n",
    "        })\n",
    "\n",
    "    return pd.DataFrame(data)\n",
    "\n",
    "# Example usage\n",
    "new_data = convert_to_dataframe(all_trajectories, encoded_dataset)\n",
    "print(new_data.head())\n",
    "print(len(new_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20, Loss: 0.999450609087944\n",
      "Epoch 2/20, Loss: 0.994601845741272\n",
      "Epoch 3/20, Loss: 0.9872276782989502\n",
      "Epoch 4/20, Loss: 0.9410593509674072\n",
      "Epoch 5/20, Loss: 0.9086482226848602\n",
      "Epoch 6/20, Loss: 0.7923699021339417\n",
      "Epoch 7/20, Loss: 0.6169751510024071\n",
      "Epoch 8/20, Loss: 0.3312264746055007\n",
      "Epoch 9/20, Loss: 0.2363501493819058\n",
      "Epoch 10/20, Loss: 0.18268271419219673\n",
      "Epoch 11/20, Loss: 0.08537940867245197\n",
      "Epoch 12/20, Loss: 0.006267095799557865\n",
      "Epoch 13/20, Loss: 0.012405365181621164\n",
      "Epoch 14/20, Loss: 1.2970799338072538\n",
      "Epoch 15/20, Loss: 0.05587441846728325\n",
      "Epoch 16/20, Loss: 0.12009201943874359\n",
      "Epoch 17/20, Loss: 0.0005212992546148598\n",
      "Epoch 18/20, Loss: 0.01028733840212226\n",
      "Epoch 19/20, Loss: 0.004989218898117542\n",
      "Epoch 20/20, Loss: 0.004389132664073259\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Hyperparameters\n",
    "STATE_SIZE = len(new_data['one_hot_state'][0][0])  # Size of the state space\n",
    "ACTION_SIZE = len(new_data['one_hot_action'][0][0])  # Size of the action space\n",
    "EMBEDDING_SIZE = 128\n",
    "HIDDEN_SIZE = 256\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 20\n",
    "LEARNING_RATE = 1e-3\n",
    "MARGIN = 1.0\n",
    "NUM_NEGATIVE_SAMPLES = 10  # Number of negative samples per positive action\n",
    "OUTPUT_BOUND = None  # Set to (min_value, max_value) or None (-1.0, 1.0)\n",
    "REDUCED_DIM = None  # Dimension after reducing the one-hot vectors\n",
    "\n",
    "# Dataset Preparation\n",
    "class TrajectoryDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.state_sequences = []\n",
    "        self.action_sequences = []\n",
    "        self.action_space = self.get_action_space(data)\n",
    "        self.action_space_size = len(data['one_hot_action'][0][0])\n",
    "        for idx, row in data.iterrows():\n",
    "            states = torch.tensor(row['one_hot_state'], dtype=torch.float)\n",
    "            actions = torch.tensor(row['one_hot_action'], dtype=torch.float)\n",
    "            self.state_sequences.append(states)\n",
    "            self.action_sequences.append(actions)\n",
    "            if len(actions) == 0:\n",
    "                print(idx)\n",
    "                print(\"debug\")\n",
    "        # Precompute the projection matrix for dimension reduction\n",
    "        if REDUCED_DIM == -1:\n",
    "            self.projection_matrix_state = None\n",
    "            self.projection_matrix_action = None\n",
    "            self.reduce_method = 'binary'\n",
    "            self.state_reduced_dimention = int(np.ceil(np.log2(STATE_SIZE)))\n",
    "            self.action_reduced_dimention = int(np.ceil(np.log2(ACTION_SIZE)))\n",
    "        elif REDUCED_DIM is None:\n",
    "            self.projection_matrix_state = None\n",
    "            self.projection_matrix_action = None\n",
    "            self.reduce_method = 'original'\n",
    "            self.state_reduced_dimention = STATE_SIZE\n",
    "            self.action_reduced_dimention = ACTION_SIZE\n",
    "        else:\n",
    "            self.projection_matrix_state = torch.randn(STATE_SIZE, REDUCED_DIM)\n",
    "            self.projection_matrix_action = torch.randn(ACTION_SIZE, REDUCED_DIM)\n",
    "            self.reduce_method = 'linear'\n",
    "            self.state_reduced_dimention = REDUCED_DIM\n",
    "            self.action_reduced_dimention = REDUCED_DIM\n",
    "\n",
    "    def get_action_space(self, data):\n",
    "        # Extract unique actions from the dataset\n",
    "        all_actions = []\n",
    "        for actions in data['one_hot_action']:\n",
    "            all_actions.extend(actions)\n",
    "        unique_actions = np.unique(np.array(all_actions), axis=0)\n",
    "        return torch.tensor(unique_actions, dtype=torch.float)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state_sequences)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        states = self.state_sequences[idx]\n",
    "        actions = self.action_sequences[idx]\n",
    "        negative_actions = self.sample_negative_actions(actions)\n",
    "        # print(negative_actions)\n",
    "        # raise Exception(\"Stop here\")\n",
    "        # Reduce dimension of states and actions\n",
    "        states = self.reduce_dimension(states, self.projection_matrix_state)\n",
    "        actions = self.reduce_dimension(actions, self.projection_matrix_action)\n",
    "        negative_actions = self.reduce_dimension(negative_actions.view(-1, ACTION_SIZE), self.projection_matrix_action)\n",
    "        negative_actions = negative_actions.view(-1, NUM_NEGATIVE_SAMPLES, self.action_reduced_dimention)\n",
    "        return states, actions, negative_actions\n",
    "\n",
    "    def sample_negative_actions(self, actions):\n",
    "        negative_actions_list = []\n",
    "        for action in actions:\n",
    "            neg_actions = self.generate_similar_negative_actions(action, NUM_NEGATIVE_SAMPLES)\n",
    "            # random_neg_actions = self.generate_negative_actions(action, NUM_NEGATIVE_SAMPLES)\n",
    "            negative_actions_list.append(neg_actions)\n",
    "            # negative_actions_list.append(random_neg_actions)\n",
    "        if len(negative_actions_list) == 0:\n",
    "            print(actions)\n",
    "        negative_actions = torch.stack(negative_actions_list)  # (seq_len, NUM_NEGATIVE_SAMPLES, action_size)\n",
    "        return negative_actions\n",
    "\n",
    "    def generate_similar_negative_actions(self, action, num_samples):\n",
    "        num_ones = action.sum().int().item()\n",
    "        candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() == num_ones) & (~torch.all(self.action_space == action, dim=1))]\n",
    "        if len(candidate_actions) == 0:\n",
    "            # If no candidate found, relax the condition slightly\n",
    "            candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() >= num_ones - 5) & (self.action_space.sum(dim=1).int() <= num_ones + 5)]\n",
    "            candidate_actions = candidate_actions[~torch.all(candidate_actions == action, dim=1)]\n",
    "        if len(candidate_actions) == 0:\n",
    "            # If still no candidate, use all actions except the current one\n",
    "            candidate_actions = self.action_space[~torch.all(self.action_space == action, dim=1)]\n",
    "        # Sample num_samples negative actions\n",
    "        indices = torch.randint(0, len(candidate_actions), (num_samples,))\n",
    "        neg_actions = candidate_actions[indices]\n",
    "        # print(neg_actions)\n",
    "        # raise Exception(\"Stop here\")\n",
    "        return neg_actions  # (num_samples, action_size)\n",
    "    \n",
    "    def generate_negative_actions(self, action, negative_sample_size=1):\n",
    "        negative_actions = []\n",
    "        num_ones = int(action.sum().item())\n",
    "        action_size = self.action_space_size  # Original action size\n",
    "        for _ in range(negative_sample_size):\n",
    "            # Generate a random action with the same number of ones\n",
    "            neg_action = self.generate_random_action_with_num_ones(action_size, num_ones)\n",
    "            # Ensure it's not the same as the positive action\n",
    "            while torch.equal(neg_action, action):\n",
    "                neg_action = self.generate_random_action_with_num_ones(action_size, num_ones)\n",
    "            negative_actions.append(neg_action)\n",
    "        return torch.stack(negative_actions)\n",
    "    \n",
    "    def generate_random_action_with_num_ones(self, action_size, num_ones):\n",
    "        indices = torch.randperm(action_size)[:num_ones]\n",
    "        neg_action = torch.zeros(action_size)\n",
    "        neg_action[indices] = 1.0\n",
    "        return neg_action\n",
    "    \n",
    "    def reduce_dimension(self, one_hot_vectors, projection_matrix):\n",
    "        method = self.reduce_method\n",
    "        if method == 'binary':\n",
    "            # print(one_hot_vectors)\n",
    "            reduced_vectors = self.binary_encode(one_hot_vectors)\n",
    "        elif method == 'linear':\n",
    "            reduced_vectors = torch.matmul(one_hot_vectors, projection_matrix)\n",
    "        elif method == 'original':\n",
    "            reduced_vectors = one_hot_vectors\n",
    "        return reduced_vectors\n",
    "    \n",
    "    def binary_encode(self, array):\n",
    "        # Convert the one-hot array to indices\n",
    "        indices = np.argmax(array, axis=1)\n",
    "        # Calculate the number of bits needed\n",
    "        num_bits = int(np.ceil(np.log2(np.max(indices) + 1)))\n",
    "        # Convert indices to binary representation\n",
    "        binary_array = ((indices[:, None] & (1 << np.arange(num_bits))) > 0).astype(int)\n",
    "        binary_array = torch.tensor(binary_array, dtype=torch.float)\n",
    "        return binary_array\n",
    "\n",
    "dataset = TrajectoryDataset(new_data)\n",
    "reduced_dimention = (dataset.action_reduced_dimention, dataset.state_reduced_dimention)\n",
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: x)\n",
    "\n",
    "# Reward Model Definition with LSTM\n",
    "class RewardModel(nn.Module):\n",
    "    def __init__(self, reduced_dim, embedding_size, hidden_size, output_bound=None):\n",
    "        super(RewardModel, self).__init__()\n",
    "        self.state_embedding = nn.Linear(reduced_dim[1], embedding_size)\n",
    "        self.action_embedding = nn.Linear(reduced_dim[0], embedding_size)\n",
    "        self.lstm = nn.LSTM(input_size=embedding_size * 2, hidden_size=hidden_size, batch_first=True)\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(hidden_size, embedding_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(embedding_size, 1)\n",
    "        )\n",
    "        self.output_bound = output_bound\n",
    "\n",
    "    def forward(self, states, actions):\n",
    "        # states and actions are sequences\n",
    "        state_emb = self.state_embedding(states)  # (batch_size, seq_len, embedding_size)\n",
    "        action_emb = self.action_embedding(actions)  # (batch_size, seq_len, embedding_size)\n",
    "        x = torch.cat([state_emb, action_emb], dim=2)  # (batch_size, seq_len, embedding_size * 2)\n",
    "        lstm_out, _ = self.lstm(x)  # (batch_size, seq_len, hidden_size)\n",
    "        # Use the last output of LSTM for reward prediction\n",
    "        reward = self.fc(lstm_out[:, -1, :])  # (batch_size, 1)\n",
    "        if self.output_bound is not None:\n",
    "            min_val, max_val = self.output_bound\n",
    "            reward = torch.clamp(reward, min=min_val, max=max_val)\n",
    "        return reward.squeeze()\n",
    "\n",
    "reward_model = RewardModel(reduced_dimention, EMBEDDING_SIZE, HIDDEN_SIZE, output_bound=OUTPUT_BOUND).to(device)\n",
    "\n",
    "# Loss Function\n",
    "criterion = nn.MarginRankingLoss(margin=MARGIN)\n",
    "\n",
    "# Optimizer\n",
    "optimizer = optim.Adam(reward_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(EPOCHS):\n",
    "    total_loss = 0\n",
    "    for batch in dataloader:\n",
    "        optimizer.zero_grad()\n",
    "        states_batch, actions_batch, negative_actions_batch = zip(*batch)\n",
    "        # states_batch: list of tensors, each of shape (seq_len, reduced_dim)\n",
    "        # actions_batch: list of tensors, each of shape (seq_len, reduced_dim)\n",
    "        # negative_actions_batch: list of tensors, each of shape (seq_len, NUM_NEGATIVE_SAMPLES, reduced_dim)\n",
    "\n",
    "        # Pad sequences\n",
    "        states_padded = nn.utils.rnn.pad_sequence(states_batch, batch_first=True).to(device)\n",
    "        actions_padded = nn.utils.rnn.pad_sequence(actions_batch, batch_first=True).to(device)\n",
    "\n",
    "        # For negative actions, we need to pad the sequences and handle the extra dimensions\n",
    "        max_seq_len = states_padded.size(1)\n",
    "        negative_actions_padded = []\n",
    "        for neg_actions in negative_actions_batch:\n",
    "            # neg_actions: (seq_len, NUM_NEGATIVE_SAMPLES, reduced_dim)\n",
    "            pad_size = max_seq_len - neg_actions.size(0)\n",
    "            if pad_size > 0:\n",
    "                pad_tensor = torch.zeros(pad_size, NUM_NEGATIVE_SAMPLES, reduced_dimention[0])\n",
    "                neg_actions_padded_seq = torch.cat([neg_actions, pad_tensor], dim=0)\n",
    "            else:\n",
    "                neg_actions_padded_seq = neg_actions\n",
    "            negative_actions_padded.append(neg_actions_padded_seq)\n",
    "        negative_actions_padded = torch.stack(negative_actions_padded).to(device)  # (batch_size, max_seq_len, NUM_NEGATIVE_SAMPLES, reduced_dim)\n",
    "\n",
    "        # Compute rewards for positive examples\n",
    "        # print(states_padded.shape)\n",
    "        positive_rewards = reward_model(states_padded, actions_padded)  # (batch_size,)\n",
    "\n",
    "        # Compute rewards for negative examples\n",
    "        # We need to compute the reward for each negative action\n",
    "        negative_rewards_list = []\n",
    "        for i in range(NUM_NEGATIVE_SAMPLES):\n",
    "            negative_actions_i = negative_actions_padded[:, :, i, :]  # (batch_size, max_seq_len, reduced_dim)\n",
    "            negative_reward = reward_model(states_padded, negative_actions_i)  # (batch_size,)\n",
    "            negative_rewards_list.append(negative_reward)\n",
    "        # Stack negative rewards: (NUM_NEGATIVE_SAMPLES, batch_size)\n",
    "        negative_rewards = torch.stack(negative_rewards_list)  # (NUM_NEGATIVE_SAMPLES, batch_size)\n",
    "\n",
    "        # Compute the loss for each negative sample\n",
    "        # Labels for MarginRankingLoss: 1 indicates positive_rewards should be larger than negative_rewards\n",
    "        target = torch.ones(positive_rewards.size()).unsqueeze(0).repeat(NUM_NEGATIVE_SAMPLES, 1).to(device)  # (NUM_NEGATIVE_SAMPLES, batch_size)\n",
    "        # Compute the margin ranking loss for each negative sample\n",
    "        loss = criterion(positive_rewards.unsqueeze(0).repeat(NUM_NEGATIVE_SAMPLES, 1), negative_rewards, target)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader)}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[10.669058799743652, 11.363011360168457, 11.085075378417969, 10.876649856567383, 11.196233749389648, 11.000472068786621, 11.041563034057617, 11.085075378417969, 10.980281829833984, 11.056007385253906, 11.179656982421875, 11.246696472167969, 11.214632034301758, 11.26550579071045, 11.338752746582031, 11.141716003417969, 9.901939392089844, 8.277589797973633, 11.097898483276367, 9.901939392089844, 10.760971069335938, 11.002238273620605, 11.035372734069824, 11.124422073364258, 11.185617446899414, 8.53988265991211, 11.002238273620605, 11.002238273620605, 11.002238273620605, 7.284907817840576, 11.144830703735352, 9.638358116149902, 10.360899925231934, 11.116124153137207, 10.825331687927246, 10.88405990600586, 11.214656829833984, 11.214656829833984, 11.120922088623047, 11.214656829833984, 11.122477531433105, 10.652515411376953, 10.569616317749023, 11.214656829833984, 11.214656829833984, 10.466533660888672, 10.977919578552246, 11.11998176574707, 10.982305526733398, 10.979074478149414, 11.10554313659668, 11.148614883422852, 10.99344539642334, 11.049837112426758, 11.022422790527344, 11.148614883422852, 11.041313171386719, 10.965509414672852, 10.999311447143555, 11.148614883422852, 10.310661315917969, 11.377128601074219, 10.418590545654297, 10.97688102722168, 9.080301284790039, 11.303327560424805, 11.295997619628906, 10.414755821228027, 9.630558967590332, 11.433084487915039, 10.998490333557129, 11.329764366149902, 11.127067565917969, 10.370113372802734, 11.303327560424805, 11.517419815063477, 11.495654106140137, 11.092903137207031, 11.297577857971191, 10.901498794555664, 10.878829956054688, 10.680139541625977, 10.853461265563965, 10.760965347290039, 11.408719062805176, 11.209044456481934, 10.486696243286133, 11.232939720153809, 10.895585060119629, 11.495136260986328, 10.446857452392578, 11.189756393432617, 10.645304679870605, 11.303327560424805, 11.303327560424805, 10.645304679870605, 11.375886917114258, 11.329801559448242, 10.692787170410156, 11.363574981689453, 11.494730949401855, 11.297577857971191, 10.04403305053711, 10.907233238220215, 0.6161327362060547, -0.6354856491088867, 11.217903137207031, 10.370113372802734, 10.749000549316406, 11.329801559448242, 10.82927131652832, 11.507745742797852, 11.214905738830566, 10.534235954284668, 10.5677490234375, 10.733465194702148, 10.446857452392578, 11.188777923583984, 10.956761360168457]\n",
      "Accuracy: 0.5126, ROC AUC: 0.5274, Precision: 0.5046, Recall: 0.9322, F1-score: 0.6548\n",
      "Average Positive Score: 10.5634\n",
      "Average Negative Score: 10.8253\n",
      "Absolute Difference: 0.2618\n"
     ]
    }
   ],
   "source": [
    "def evaluate_model_with_avg_scores(model, test_data, vocab_builder, max_unseen, threshold=0, new_version=False):\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "    true_labels = []\n",
    "    predicted_rewards = []\n",
    "    positive_scores = []\n",
    "    negative_scores = []\n",
    "    parser = CodeToTrajectory(max_unseen=max_unseen)\n",
    "    \n",
    "    for idx, row in test_data.iterrows():\n",
    "        code = row['code']\n",
    "        label = row['labels']\n",
    "        # Process the code to get states and actions\n",
    "        traj_i, encoded_i = encode_new_code(code, parser, vocab_builder)\n",
    "        states = torch.tensor([[step_i[0] for step_i in encoded_i]], dtype=torch.float)\n",
    "        actions = torch.tensor([[step_i[1] for step_i in encoded_i]], dtype=torch.float)\n",
    "        # Reduce dimension\n",
    "        if new_version:\n",
    "            try:\n",
    "                states = dataset.reduce_dimension(states, dataset.projection_matrix_state)\n",
    "            except Exception as e:\n",
    "                print(code)\n",
    "                print(label)\n",
    "                print(e)\n",
    "                raise Exception(\"Stop here\")\n",
    "            actions = dataset.reduce_dimension(actions, dataset.projection_matrix_action)\n",
    "        states = states.to(device)\n",
    "        actions = actions.to(device)\n",
    "        with torch.no_grad():\n",
    "            rewards = model(states, actions)\n",
    "        true_labels.append(label)\n",
    "        predicted_rewards.append(rewards.item())\n",
    "\n",
    "        # Collect scores for positive and negative samples\n",
    "        if label == 1:\n",
    "            positive_scores.append(rewards.item())\n",
    "        elif label == 0:\n",
    "            negative_scores.append(rewards.item())\n",
    "\n",
    "    # Compute evaluation metrics\n",
    "    from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_fscore_support\n",
    "    predicted_labels = [1 if r >= threshold else 0 for r in predicted_rewards]\n",
    "    print(predicted_rewards)\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    roc_auc = roc_auc_score(true_labels, predicted_rewards)\n",
    "    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='binary')\n",
    "\n",
    "    # Calculate average scores\n",
    "    avg_positive_score = sum(positive_scores) / len(positive_scores) if positive_scores else 0\n",
    "    avg_negative_score = sum(negative_scores) / len(negative_scores) if negative_scores else 0\n",
    "    abs_difference = abs(avg_positive_score - avg_negative_score)\n",
    "\n",
    "    # Print results\n",
    "    print(f\"Accuracy: {accuracy:.4f}, ROC AUC: {roc_auc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}\")\n",
    "    print(f\"Average Positive Score: {avg_positive_score:.4f}\")\n",
    "    print(f\"Average Negative Score: {avg_negative_score:.4f}\")\n",
    "    print(f\"Absolute Difference: {abs_difference:.4f}\")\n",
    "\n",
    "evaluate_model_with_avg_scores(reward_model, test_data, vocab_builder, max_unseen, threshold=10, new_version = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IRL test 02"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = data[[\"code\"]].copy(deep=True)\n",
    "eval_df = test_data.copy(deep=True)\n",
    "\n",
    "max_unseen = 10  # user-defined hyperparameter\n",
    "all_trajectories, encoded_dataset, vocab_builder = build_and_encode_dataset(\n",
    "    data_df=train_df,\n",
    "    max_unseen=max_unseen,\n",
    "    code_col=\"code\"\n",
    ")\n",
    "# eval_trajectories, eval_encoded_dataset = build_and_encode_dataset(\n",
    "#     data_df=eval_df,\n",
    "#     max_unseen=max_unseen,\n",
    "#     code_col=\"code\"\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleRewardModel(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim, hidden_dim=128):\n",
    "        super(SimpleRewardModel, self).__init__()\n",
    "        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(hidden_dim, 1)  # Output a single reward score\n",
    "\n",
    "    def forward(self, state, action):\n",
    "        x = torch.cat([state, action], dim=1)  # Concatenate state and action\n",
    "        x = self.relu(self.fc1(x))\n",
    "        reward = self.fc2(x)\n",
    "        return reward\n",
    "\n",
    "class TransformerRewardModel(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim, embed_dim=256, num_heads=8, num_layers=4, dim_feedforward=512):\n",
    "        super(TransformerRewardModel, self).__init__()\n",
    "        self.input_dim = state_dim + action_dim\n",
    "        self.embed = nn.Linear(self.input_dim, embed_dim)\n",
    "        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=dim_feedforward)\n",
    "        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n",
    "        self.fc_out = nn.Linear(embed_dim, 1)  # Output a single reward score\n",
    "\n",
    "    def forward(self, state_seq, action_seq):\n",
    "        # state_seq and action_seq: (batch_size, seq_length, state_dim/action_dim)\n",
    "        x = torch.cat([state_seq, action_seq], dim=-1)  # (batch, seq, input_dim)\n",
    "        x = self.embed(x)  # (batch, seq, embed_dim)\n",
    "        x = x.permute(1, 0, 2)  # Transformer expects (seq, batch, embed)\n",
    "        x = self.transformer(x)  # (seq, batch, embed)\n",
    "        x = x.permute(1, 0, 2)  # (batch, seq, embed)\n",
    "        # Aggregate rewards over the sequence, e.g., sum or mean\n",
    "        x = x.mean(dim=1)  # (batch, embed)\n",
    "        reward = self.fc_out(x)  # (batch, 1)\n",
    "        return reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TrajectoryDataset(Dataset):\n",
    "    def __init__(self, encoded_dataset):\n",
    "        \"\"\"\n",
    "        encoded_dataset: List of trajectories, each trajectory is a list of (state_vec, action_vec)\n",
    "        \"\"\"\n",
    "        self.trajectories = encoded_dataset\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.trajectories)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        traj = self.trajectories[idx]\n",
    "        states = torch.tensor([step[0] for step in traj], dtype=torch.float32)  # (seq_length, state_dim)\n",
    "        actions = torch.tensor([step[1] for step in traj], dtype=torch.float32)  # (seq_length, action_dim)\n",
    "        return states, actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaxEntIRL:\n",
    "    def __init__(self, reward_model, state_dim, action_dim, trajectories, lr=1e-3):\n",
    "        self.reward_model = reward_model\n",
    "        self.state_dim = state_dim\n",
    "        self.action_dim = action_dim\n",
    "        self.trajectories = trajectories\n",
    "        self.optimizer = optim.Adam(self.reward_model.parameters(), lr=lr)\n",
    "\n",
    "    def compute_feature_expectations(self):\n",
    "        feat_exp = torch.zeros(self.state_dim + self.action_dim)\n",
    "        for traj in self.trajectories:\n",
    "            for step in traj:\n",
    "                state, action = step\n",
    "                feat_exp += torch.cat([state, action])\n",
    "        feat_exp /= len(self.trajectories)\n",
    "        return feat_exp\n",
    "\n",
    "    def train(self, num_epochs=100):\n",
    "        feat_exp = self.compute_feature_expectations()\n",
    "        for epoch in tqdm(range(num_epochs), desc=\"Training IRL\"):\n",
    "            self.optimizer.zero_grad()\n",
    "            # Compute expected feature counts under current reward model\n",
    "            expected_feat = torch.zeros(self.state_dim + self.action_dim)\n",
    "            for traj in self.trajectories:\n",
    "                for step in traj:\n",
    "                    state, action = step\n",
    "                    expected_feat += torch.cat([state, action])\n",
    "            expected_feat /= len(self.trajectories)\n",
    "            # Loss: difference between expert feat and model feat\n",
    "            loss = torch.norm(feat_exp - expected_feat, p=2)\n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "            if epoch % 10 == 0:\n",
    "                print(f\"Epoch {epoch}: Loss = {loss.item():.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\transformer.py:306: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
      "  warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Simple Reward Model...\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "expected Tensor as element 0 in argument 0, but got dict",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[16], line 18\u001b[0m\n\u001b[0;32m     16\u001b[0m \u001b[38;5;66;03m# Train the simple reward model\u001b[39;00m\n\u001b[0;32m     17\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining Simple Reward Model...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m---> 18\u001b[0m \u001b[43mirl_simple\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m     20\u001b[0m \u001b[38;5;66;03m# Train the transformer reward model\u001b[39;00m\n\u001b[0;32m     21\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining Transformer Reward Model...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "Cell \u001b[1;32mIn[15], line 19\u001b[0m, in \u001b[0;36mMaxEntIRL.train\u001b[1;34m(self, num_epochs)\u001b[0m\n\u001b[0;32m     18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, num_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m):\n\u001b[1;32m---> 19\u001b[0m     feat_exp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_feature_expectations\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     20\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(num_epochs), desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining IRL\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m     21\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
      "Cell \u001b[1;32mIn[15], line 14\u001b[0m, in \u001b[0;36mMaxEntIRL.compute_feature_expectations\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     12\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m traj:\n\u001b[0;32m     13\u001b[0m         state, action \u001b[38;5;241m=\u001b[39m step\n\u001b[1;32m---> 14\u001b[0m         feat_exp \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     15\u001b[0m feat_exp \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrajectories)\n\u001b[0;32m     16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m feat_exp\n",
      "\u001b[1;31mTypeError\u001b[0m: expected Tensor as element 0 in argument 0, but got dict"
     ]
    }
   ],
   "source": [
    "# Initialize datasets and dataloaders\n",
    "train_dataset = TrajectoryDataset(encoded_dataset)\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: x)\n",
    "\n",
    "# Initialize the simple and transformer reward models\n",
    "state_dim = vocab_builder.state_dim  # 55\n",
    "action_dim = vocab_builder.action_dim  # 134\n",
    "\n",
    "simple_reward_model = SimpleRewardModel(state_dim, action_dim)\n",
    "transformer_reward_model = TransformerRewardModel(state_dim, action_dim)\n",
    "\n",
    "# Initialize MaxEnt IRL for both models\n",
    "irl_simple = MaxEntIRL(simple_reward_model, state_dim, action_dim, all_trajectories, lr=1e-3)\n",
    "irl_transformer = MaxEntIRL(transformer_reward_model, state_dim, action_dim, all_trajectories, lr=1e-3)\n",
    "\n",
    "# Train the simple reward model\n",
    "print(\"Training Simple Reward Model...\")\n",
    "irl_simple.train(num_epochs=100)\n",
    "\n",
    "# Train the transformer reward model\n",
    "print(\"Training Transformer Reward Model...\")\n",
    "irl_transformer.train(num_epochs=100)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[({'variable_count': 0, 'variable_overflow': False, 'function_count': 0, 'function_overflow': False, 'class_count': 0, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 0, 'argument_overflow': False}, {'lineno': 1, 'line_text': 'class Solution(object):', 'actions': ['define class#0', 'call/use variable#overflow_or_undefined']}), ({'variable_count': 0, 'variable_overflow': False, 'function_count': 0, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 0, 'argument_overflow': False}, {'lineno': 2, 'line_text': 'def findMedianSortedArrays(self, nums1, nums2):', 'actions': ['define function#0', 'define argument#0', 'define argument#1', 'define argument#2', 'other operation: arg', 'other operation: arg', 'other operation: arg']}), ({'variable_count': 0, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}, {'lineno': 3, 'line_text': 'merged = sorted(nums1+nums2)', 'actions': ['define variable#0', 'call/use variable#0', 'call/use function#overflow_or_undefined', 'call/use variable#overflow_or_undefined', 'other operation: BinOp', 'call/use variable#overflow_or_undefined', 'call/use variable#overflow_or_undefined']}), ({'variable_count': 1, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}, {'lineno': 4, 'line_text': 'n = len(merged)', 'actions': ['define variable#1', 'call/use variable#1', 'call/use function#overflow_or_undefined', 'call/use variable#overflow_or_undefined', 'call/use variable#0']}), ({'variable_count': 2, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}, {'lineno': 5, 'line_text': 'if n%2==1:', 'actions': ['other operation: If', 'other operation: Compare', 'other operation: BinOp', 'other operation: Constant', 'call/use variable#1', 'other operation: Constant']}), ({'variable_count': 2, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}, {'lineno': 6, 'line_text': 'return merged[n//2]', 'actions': ['other operation: Return', 'other operation: Subscript', 'call/use variable#0', 'other operation: BinOp', 'call/use variable#1', 'other operation: Constant']}), ({'variable_count': 2, 'variable_overflow': False, 'function_count': 1, 'function_overflow': False, 'class_count': 1, 'class_overflow': False, 'import_count': 0, 'import_overflow': False, 'argument_count': 3, 'argument_overflow': False}, {'lineno': 8, 'line_text': 'return (merged[n//2]+merged[n//2-1])/2.0', 'actions': ['other operation: Return', 'other operation: BinOp', 'other operation: BinOp', 'other operation: Constant', 'other operation: Subscript', 'other operation: Subscript', 'call/use variable#0', 'other operation: BinOp', 'call/use variable#0', 'other operation: BinOp', 'call/use variable#1', 'other operation: Constant', 'other operation: BinOp', 'other operation: Constant', 'call/use variable#1', 'other operation: Constant']})]\n"
     ]
    }
   ],
   "source": [
    "print(all_trajectories[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_reward_model(reward_model, eval_encoded_dataset, labels, batch_size=32):\n",
    "    reward_model.eval()\n",
    "    rewards = []\n",
    "    with torch.no_grad():\n",
    "        for traj, label in zip(eval_encoded_dataset, labels):\n",
    "            states = torch.tensor([step[0] for step in traj], dtype=torch.float32).unsqueeze(0)  # (1, seq, state_dim)\n",
    "            actions = torch.tensor([step[1] for step in traj], dtype=torch.float32).unsqueeze(0)  # (1, seq, action_dim)\n",
    "            if isinstance(reward_model, SimpleRewardModel):\n",
    "                # For simple model, aggregate rewards over steps\n",
    "                step_rewards = reward_model(states.view(-1, state_dim), actions.view(-1, action_dim))\n",
    "                traj_reward = step_rewards.mean().item()\n",
    "            else:\n",
    "                # For transformer model\n",
    "                traj_reward = reward_model(states, actions).item()\n",
    "            rewards.append((traj_reward, label))\n",
    "    # Separate positive and negative\n",
    "    pos_rewards = [r for r, l in rewards if l == 1]\n",
    "    neg_rewards = [r for r, l in rewards if l == 0]\n",
    "    avg_pos = np.mean(pos_rewards) if pos_rewards else 0\n",
    "    avg_neg = np.mean(neg_rewards) if neg_rewards else 0\n",
    "    return avg_pos, avg_neg\n",
    "\n",
    "# Prepare evaluation data\n",
    "eval_labels = eval_df['labels'].tolist()\n",
    "\n",
    "# Evaluate simple reward model\n",
    "avg_pos_simple, avg_neg_simple = evaluate_reward_model(simple_reward_model, eval_encoded_dataset, eval_labels)\n",
    "print(f\"Simple Reward Model - Average Positive: {avg_pos_simple:.4f}, Average Negative: {avg_neg_simple:.4f}\")\n",
    "\n",
    "# Evaluate transformer reward model\n",
    "avg_pos_trans, avg_neg_trans = evaluate_reward_model(transformer_reward_model, eval_encoded_dataset, eval_labels)\n",
    "print(f\"Transformer Reward Model - Average Positive: {avg_pos_trans:.4f}, Average Negative: {avg_neg_trans:.4f}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
