{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### state define for leet code data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import ast\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3.1+cu118\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of invalid code snippets: 126\n",
      "61\n"
     ]
    }
   ],
   "source": [
    "# load the data\n",
    "def is_valid_code(code):\n",
    "    try:\n",
    "        ast.parse(code)\n",
    "        return True\n",
    "    except SyntaxError:\n",
    "        return False\n",
    "\n",
    "def preprocess_data(data, test_data=False):\n",
    "    # Apply the is_valid_code function to each code entry\n",
    "    data['is_valid'] = data['code'].apply(is_valid_code)\n",
    "    \n",
    "    # Count and print the number of invalid code snippets\n",
    "    invalid_count = data['is_valid'].value_counts().get(False, 0)\n",
    "    print(f\"Number of invalid code snippets: {invalid_count}\")\n",
    "    \n",
    "    # Filter out invalid code entries and reset the index\n",
    "    data = data[data['is_valid']].drop(columns=['is_valid']).reset_index(drop=True)\n",
    "\n",
    "    if test_data:\n",
    "        return data\n",
    "    \n",
    "    data['code_list_embedding'] = data['code_list_embedding'].apply(ast.literal_eval)\n",
    "    data['previous_code_embedding'] = data['previous_code_embedding'].apply(ast.literal_eval)\n",
    "    data['code_list'] = data['code_list'].apply(ast.literal_eval)\n",
    "    data['previous_code_list'] = data['previous_code_list'].apply(ast.literal_eval)\n",
    "    \n",
    "    return data\n",
    "\n",
    "data = pd.read_csv(r\"D:\\Python_project\\leetcode_data\\data\\leetcode_Median_of_Two_Sorted_Arrays_embedding.csv\")\n",
    "data = preprocess_data(data)\n",
    "print(len(data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of invalid code snippets: 1\n",
      "119\n",
      "labels\n",
      "0    60\n",
      "1    59\n",
      "Name: count, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "# load nagative data\n",
    "test_data = pd.read_csv(r\"D:\\Python_project\\leetcode_data\\data\\leetcode_diff_task.csv\")\n",
    "test_data = preprocess_data(test_data, test_data=True)\n",
    "# label all rows that has \"Median of Two Sorted Arrays\" in title column as 1 and the rest as 0\n",
    "test_data['labels'] = test_data['title'].apply(lambda x: 1 if 'Median of Two Sorted Arrays' in x else 0)\n",
    "# only keeps the columns that are needed\n",
    "test_data = test_data[['code', 'labels']]\n",
    "test_data.reset_index(drop=True, inplace=True)\n",
    "print(len(test_data))\n",
    "print(test_data['labels'].value_counts())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data load and auto generation test 01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Code:\n",
      " \n",
      "class Example:\n",
      "    def __init__(self, value):\n",
      "        self.value = value\n",
      "    \n",
      "    def increment(self, amount):\n",
      "        self.value += amount\n",
      "        print(self.value)\n",
      "\n",
      "def main():\n",
      "    ex = Example(10)\n",
      "    ex.increment(5)\n",
      "    ex.increment(3)\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "\n",
      "Transformed Code:\n",
      " class UXJGZPHk:\n",
      "\n",
      "    def LZTVWvjw(KyhTaYYi, IcBdzOZj):\n",
      "        KyhTaYYi.value = IcBdzOZj\n",
      "\n",
      "    def hdsUVfzj(KyhTaYYi, BkCtoSDS):\n",
      "        KyhTaYYi.value += BkCtoSDS\n",
      "        skDNFXCX(KyhTaYYi.value)\n",
      "\n",
      "\n",
      "def XMzZlfEL():\n",
      "    ZPGinPTT = UXJGZPHk(10)\n",
      "    ZPGinPTT.increment(5)\n",
      "    ZPGinPTT.increment(3)\n",
      "\n",
      "\n",
      "if HAOPQDWv == '__main__':\n",
      "    XMzZlfEL()\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import astor\n",
    "import random\n",
    "import string\n",
    "\n",
    "class RenameIdentifiers(ast.NodeTransformer):\n",
    "    def __init__(self):\n",
    "        # Mapping from original names to new random names\n",
    "        self.name_mapping = {}\n",
    "\n",
    "    def random_name(self, length=8):\n",
    "        # Generate a random name consisting of letters\n",
    "        letters = string.ascii_letters\n",
    "        return ''.join(random.choice(letters) for _ in range(length))\n",
    "\n",
    "    def visit_FunctionDef(self, node):\n",
    "        # Rename the function name\n",
    "        if node.name not in self.name_mapping:\n",
    "            self.name_mapping[node.name] = self.random_name()\n",
    "        node.name = self.name_mapping[node.name]\n",
    "        # Rename arguments\n",
    "        node.args = self.visit(node.args)\n",
    "        # Process the function body\n",
    "        node.body = [self.visit(stmt) for stmt in node.body]\n",
    "        return node\n",
    "\n",
    "    def visit_ClassDef(self, node):\n",
    "        # Rename the class name\n",
    "        if node.name not in self.name_mapping:\n",
    "            self.name_mapping[node.name] = self.random_name()\n",
    "        node.name = self.name_mapping[node.name]\n",
    "        # Process the class body\n",
    "        node.body = [self.visit(stmt) for stmt in node.body]\n",
    "        return node\n",
    "\n",
    "    def visit_Name(self, node):\n",
    "        # Rename variables used in the code\n",
    "        if isinstance(node.ctx, (ast.Store, ast.Load, ast.Del)):\n",
    "            if node.id not in self.name_mapping:\n",
    "                self.name_mapping[node.id] = self.random_name()\n",
    "            node.id = self.name_mapping[node.id]\n",
    "        return node\n",
    "\n",
    "    def visit_arg(self, node):\n",
    "        # Rename function arguments\n",
    "        if node.arg not in self.name_mapping:\n",
    "            self.name_mapping[node.arg] = self.random_name()\n",
    "        node.arg = self.name_mapping[node.arg]\n",
    "        return node\n",
    "\n",
    "    def visit_Attribute(self, node):\n",
    "        # Avoid renaming attributes (like object methods or properties)\n",
    "        node.value = self.visit(node.value)\n",
    "        return node\n",
    "\n",
    "def transform_code(code):\n",
    "    # Parse the code into an AST\n",
    "    tree = ast.parse(code)\n",
    "    # Create a transformer and apply it to the AST\n",
    "    transformer = RenameIdentifiers()\n",
    "    transformer.visit(tree)\n",
    "    # Convert the AST back to source code\n",
    "    transformed_code = astor.to_source(tree)\n",
    "    return transformed_code\n",
    "\n",
    "# Example usage\n",
    "if __name__ == \"__main__\":\n",
    "    original_code = '''\n",
    "class Example:\n",
    "    def __init__(self, value):\n",
    "        self.value = value\n",
    "    \n",
    "    def increment(self, amount):\n",
    "        self.value += amount\n",
    "        print(self.value)\n",
    "\n",
    "def main():\n",
    "    ex = Example(10)\n",
    "    ex.increment(5)\n",
    "    ex.increment(3)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n",
    "'''\n",
    "\n",
    "    transformed_code = transform_code(original_code)\n",
    "    print(\"Original Code:\\n\", original_code)\n",
    "    print(\"Transformed Code:\\n\", transformed_code)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data load and auto generation test 02"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Code:\n",
      "\n",
      "class Example:\n",
      "    def __init__(self, value):\n",
      "        self.value = value\n",
      "    \n",
      "    def increment(self, amount):\n",
      "        self.value += amount\n",
      "        print(self.value)\n",
      "\n",
      "def main():\n",
      "    ex = Example(10)\n",
      "    ex.increment(5)\n",
      "    ex.increment(3)\n",
      "\n",
      "if __name__ == \"__main__\":\n",
      "    main()\n",
      "\n",
      "\n",
      "Transformed Samples:\n",
      "\n",
      "Sample 1:\n",
      "class rPbZKuY:\n",
      "\n",
      "    def SFBVWrQQ(fTXl, ZOsdK):\n",
      "        fTXl.value = ZOsdK\n",
      "\n",
      "    def CmyGMurUT(fTXl, ptTLpZ):\n",
      "        fTXl.value += ptTLpZ\n",
      "        xnRBv(fTXl.value)\n",
      "\n",
      "\n",
      "def moXk():\n",
      "    tPo = rPbZKuY(10)\n",
      "    tPo.increment(5)\n",
      "    tPo.increment(3)\n",
      "\n",
      "\n",
      "if tbUDXcYd == '__main__':\n",
      "    moXk()\n",
      "\n",
      "\n",
      "Sample 2:\n",
      "class DxILLeO:\n",
      "\n",
      "    def UBIxSsTS(nixB, EWoCG):\n",
      "        nixB.value = EWoCG\n",
      "\n",
      "    def NlFaIwvma(nixB, VRcUxE):\n",
      "        nixB.value += VRcUxE\n",
      "        YMiBj(nixB.value)\n",
      "\n",
      "\n",
      "def PjyP():\n",
      "    Kpv = DxILLeO(10)\n",
      "    Kpv.increment(5)\n",
      "    Kpv.increment(3)\n",
      "\n",
      "\n",
      "if zShtFNmr == '__main__':\n",
      "    PjyP()\n",
      "\n",
      "\n",
      "Sample 3:\n",
      "class GTxqYsQ:\n",
      "\n",
      "    def iwsdrmSp(bjjl, PfvtA):\n",
      "        bjjl.value = PfvtA\n",
      "\n",
      "    def HGEkPfRbl(bjjl, OjWJiH):\n",
      "        bjjl.value += OjWJiH\n",
      "        nKMyl(bjjl.value)\n",
      "\n",
      "\n",
      "def vEJQ():\n",
      "    TgP = GTxqYsQ(10)\n",
      "    TgP.increment(5)\n",
      "    TgP.increment(3)\n",
      "\n",
      "\n",
      "if IhWaAeRI == '__main__':\n",
      "    vEJQ()\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import astor\n",
    "import random\n",
    "import string\n",
    "from copy import deepcopy\n",
    "\n",
    "class IdentifierRenamer(ast.NodeTransformer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.name_mapping = {}\n",
    "    \n",
    "    def generate_new_name(self, original_name):\n",
    "        return ''.join(random.choices(string.ascii_letters, k=max(3, len(original_name))))\n",
    "    \n",
    "    def rename(self, original_name):\n",
    "        if original_name not in self.name_mapping:\n",
    "            self.name_mapping[original_name] = self.generate_new_name(original_name)\n",
    "        return self.name_mapping[original_name]\n",
    "    \n",
    "    def visit_FunctionDef(self, node):\n",
    "        node.name = self.rename(node.name)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "    \n",
    "    def visit_ClassDef(self, node):\n",
    "        node.name = self.rename(node.name)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "    \n",
    "    def visit_Name(self, node):\n",
    "        if isinstance(node.ctx, (ast.Store, ast.Load, ast.Del)):\n",
    "            node.id = self.rename(node.id)\n",
    "        return node\n",
    "    \n",
    "    def visit_arg(self, node):\n",
    "        node.arg = self.rename(node.arg)\n",
    "        return node\n",
    "\n",
    "class StatementReorderer(ast.NodeTransformer):\n",
    "    def reorder_statements(self, body):\n",
    "        assign_stmts = [stmt for stmt in body if isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))]\n",
    "        other_stmts = [stmt for stmt in body if not isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))]\n",
    "        \n",
    "        random.shuffle(assign_stmts)\n",
    "        return assign_stmts + other_stmts\n",
    "\n",
    "    def visit_FunctionDef(self, node):\n",
    "        node.body = self.reorder_statements(node.body)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "    \n",
    "    def visit_Module(self, node):\n",
    "        node.body = self.reorder_statements(node.body)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "\n",
    "class CodeTransformer:\n",
    "    def __init__(self, code):\n",
    "        self.original_code = code\n",
    "        self.ast_tree = ast.parse(code)\n",
    "    \n",
    "    def transform(self):\n",
    "        # Deepcopy the AST to avoid modifying the original\n",
    "        tree = deepcopy(self.ast_tree)\n",
    "        \n",
    "        # Apply identifier renaming\n",
    "        renamer = IdentifierRenamer()\n",
    "        tree = renamer.visit(tree)\n",
    "        ast.fix_missing_locations(tree)\n",
    "        \n",
    "        # Apply statement reordering\n",
    "        reorderer = StatementReorderer()\n",
    "        tree = reorderer.visit(tree)\n",
    "        ast.fix_missing_locations(tree)\n",
    "        \n",
    "        transformed_code = astor.to_source(tree)\n",
    "        return transformed_code\n",
    "\n",
    "def generate_transformed_samples(code, num_samples=5):\n",
    "    samples = []\n",
    "    for _ in range(num_samples):\n",
    "        transformer = CodeTransformer(code)\n",
    "        transformed_code = transformer.transform()\n",
    "        samples.append(transformed_code)\n",
    "    return samples\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Example usage\n",
    "    sample_code = \"\"\"\n",
    "class Example:\n",
    "    def __init__(self, value):\n",
    "        self.value = value\n",
    "    \n",
    "    def increment(self, amount):\n",
    "        self.value += amount\n",
    "        print(self.value)\n",
    "\n",
    "def main():\n",
    "    ex = Example(10)\n",
    "    ex.increment(5)\n",
    "    ex.increment(3)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n",
    "\"\"\"\n",
    "\n",
    "    print(\"Original Code:\")\n",
    "    print(sample_code)\n",
    "    \n",
    "    print(\"\\nTransformed Samples:\")\n",
    "    transformed_samples = generate_transformed_samples(sample_code, num_samples=3)\n",
    "    for idx, sample in enumerate(transformed_samples, 1):\n",
    "        print(f\"\\nSample {idx}:\\n{sample}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Code:\n",
      "\n",
      "def test01(arg1):\n",
      "    result = arg1 * 2\n",
      "    return result\n",
      "\n",
      "def test02(arg1, arg2):\n",
      "    print(\"Starting data processing...\")\n",
      "    data = arg1 + arg2\n",
      "    processed_data = data * 3\n",
      "\n",
      "test01(5)\n",
      "\n",
      "def test03(arg1, arg2, arg3):\n",
      "    return [1, 2, 3, 4, 5]\n",
      "\n",
      "test02(10, 20)\n",
      "test03(1, 2, 3)\n",
      "\n",
      "\n",
      "Transformed Samples:\n",
      "\n",
      "Sample 1:\n",
      "def nIDsbd(QMmK):\n",
      "    YeLIlL = QMmK * 2\n",
      "    return YeLIlL\n",
      "\n",
      "\n",
      "def UhYnNd(QMmK, OJkn):\n",
      "    mckL = QMmK + OJkn\n",
      "    tbmQoRQKIUTdcK = mckL * 3\n",
      "    print('Starting data processing...')\n",
      "\n",
      "\n",
      "nIDsbd(5)\n",
      "\n",
      "\n",
      "def ONfOcg(QMmK, OJkn, OSHW):\n",
      "    return [1, 2, 3, 4, 5]\n",
      "\n",
      "\n",
      "UhYnNd(10, 20)\n",
      "ONfOcg(1, 2, 3)\n",
      "\n",
      "\n",
      "Sample 2:\n",
      "def UMuqtj(JckN):\n",
      "    KpCBWm = JckN * 2\n",
      "    return KpCBWm\n",
      "\n",
      "\n",
      "def wtLvkp(JckN, BvGo):\n",
      "    Hlhr = JckN + BvGo\n",
      "    fOuCJXVilOWpHd = Hlhr * 3\n",
      "    print('Starting data processing...')\n",
      "\n",
      "\n",
      "UMuqtj(5)\n",
      "\n",
      "\n",
      "def YHiwYx(JckN, BvGo, utGs):\n",
      "    return [1, 2, 3, 4, 5]\n",
      "\n",
      "\n",
      "wtLvkp(10, 20)\n",
      "YHiwYx(1, 2, 3)\n",
      "\n",
      "\n",
      "Sample 3:\n",
      "def AcAyQu(nhYf):\n",
      "    hVXYeA = nhYf * 2\n",
      "    return hVXYeA\n",
      "\n",
      "\n",
      "def OXrUlw(nhYf, QAsj):\n",
      "    buHX = nhYf + QAsj\n",
      "    iNxtOwvlRhQFUl = buHX * 3\n",
      "    print('Starting data processing...')\n",
      "\n",
      "\n",
      "AcAyQu(5)\n",
      "\n",
      "\n",
      "def GUrdZF(nhYf, QAsj, sCmg):\n",
      "    return [1, 2, 3, 4, 5]\n",
      "\n",
      "\n",
      "OXrUlw(10, 20)\n",
      "GUrdZF(1, 2, 3)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import astor\n",
    "import random\n",
    "import string\n",
    "from copy import deepcopy\n",
    "\n",
    "class IdentifierRenamer(ast.NodeTransformer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.name_mapping = {}\n",
    "        # Set of identifiers to exclude from renaming\n",
    "        self.exclude = {\n",
    "            'self', 'cls', '__init__', '__str__', '__repr__', '__main__',\n",
    "            # Built-in functions and exceptions\n",
    "            'print', 'len', 'range', 'int', 'str', 'float', 'list', 'dict',\n",
    "            'set', 'tuple', 'Exception', 'ValueError', 'TypeError',\n",
    "            # Special variables\n",
    "            '__name__', '__doc__', '__file__', '__package__', '__loader__',\n",
    "            '__spec__', '__annotations__', '__builtins__'\n",
    "        }\n",
    "    \n",
    "    def generate_new_name(self, original_name):\n",
    "        # Generate a random name with the same length as the original\n",
    "        # Ensure the new name does not conflict with excluded names or existing mappings\n",
    "        while True:\n",
    "            new_name = ''.join(random.choices(string.ascii_letters, k=max(3, len(original_name))))\n",
    "            if new_name not in self.exclude and new_name not in self.name_mapping.values():\n",
    "                return new_name\n",
    "    \n",
    "    def rename(self, original_name):\n",
    "        if original_name in self.exclude:\n",
    "            return original_name\n",
    "        if original_name not in self.name_mapping:\n",
    "            self.name_mapping[original_name] = self.generate_new_name(original_name)\n",
    "        return self.name_mapping[original_name]\n",
    "    \n",
    "    def visit_FunctionDef(self, node):\n",
    "        original_name = node.name\n",
    "        node.name = self.rename(node.name)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "    \n",
    "    def visit_ClassDef(self, node):\n",
    "        original_name = node.name\n",
    "        node.name = self.rename(node.name)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "    \n",
    "    def visit_Name(self, node):\n",
    "        if isinstance(node.ctx, (ast.Store, ast.Load, ast.Del)):\n",
    "            original_id = node.id\n",
    "            node.id = self.rename(original_id)\n",
    "        return node\n",
    "    \n",
    "    def visit_Attribute(self, node):\n",
    "        # Handle attribute accesses (e.g., self.value)\n",
    "        self.generic_visit(node)\n",
    "        if node.attr in self.name_mapping:\n",
    "            node.attr = self.name_mapping[node.attr]\n",
    "        return node\n",
    "    \n",
    "    def visit_arg(self, node):\n",
    "        # Rename argument names, excluding 'self' and 'cls'\n",
    "        original_arg = node.arg\n",
    "        node.arg = self.rename(original_arg)\n",
    "        return node\n",
    "\n",
    "class StatementReorderer(ast.NodeTransformer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.assigned_vars = set()\n",
    "    \n",
    "    def get_assigned_vars(self, node):\n",
    "        \"\"\"Recursively collect all variable names that are assigned in the node.\"\"\"\n",
    "        vars = set()\n",
    "        if isinstance(node, ast.Assign):\n",
    "            for target in node.targets:\n",
    "                if isinstance(target, ast.Name):\n",
    "                    vars.add(target.id)\n",
    "                elif isinstance(target, (ast.Tuple, ast.List)):\n",
    "                    for elt in target.elts:\n",
    "                        if isinstance(elt, ast.Name):\n",
    "                            vars.add(elt.id)\n",
    "        elif isinstance(node, ast.AugAssign):\n",
    "            if isinstance(node.target, ast.Name):\n",
    "                vars.add(node.target.id)\n",
    "        elif isinstance(node, ast.AnnAssign):\n",
    "            if isinstance(node.target, ast.Name):\n",
    "                vars.add(node.target.id)\n",
    "        return vars\n",
    "\n",
    "    def get_used_vars(self, node):\n",
    "        \"\"\"Recursively collect all variable names that are used in the node.\"\"\"\n",
    "        vars = set()\n",
    "        for child in ast.walk(node):\n",
    "            if isinstance(child, ast.Name):\n",
    "                if isinstance(child.ctx, ast.Load):\n",
    "                    vars.add(child.id)\n",
    "        return vars\n",
    "\n",
    "    def reorder_statements(self, body):\n",
    "        # Collect all assignments\n",
    "        assignments = [stmt for stmt in body if isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))]\n",
    "        non_assignments = [stmt for stmt in body if not isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))]\n",
    "        \n",
    "        # Determine dependencies\n",
    "        assigned_vars = set()\n",
    "        dependencies = {}\n",
    "        for i, stmt in enumerate(assignments):\n",
    "            vars_assigned = self.get_assigned_vars(stmt)\n",
    "            vars_used = self.get_used_vars(stmt)\n",
    "            dependencies[i] = vars_used.intersection(assigned_vars)\n",
    "            assigned_vars.update(vars_assigned)\n",
    "        \n",
    "        # Identify independent assignments (no dependencies)\n",
    "        independent_indices = [i for i, deps in dependencies.items() if not deps]\n",
    "        independent_assignments = [assignments[i] for i in independent_indices]\n",
    "        dependent_assignments = [stmt for i, stmt in enumerate(assignments) if i not in independent_indices]\n",
    "        \n",
    "        # Shuffle independent assignments if more than one\n",
    "        if len(independent_assignments) > 1:\n",
    "            random.shuffle(independent_assignments)\n",
    "        \n",
    "        # Combine shuffled independent assignments with dependent assignments\n",
    "        new_assignments = independent_assignments + dependent_assignments\n",
    "        \n",
    "        return new_assignments + non_assignments\n",
    "\n",
    "    def visit_FunctionDef(self, node):\n",
    "        node.body = self.reorder_statements(node.body)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "    \n",
    "    def visit_Module(self, node):\n",
    "        node.body = self.reorder_statements(node.body)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "\n",
    "class CodeTransformer:\n",
    "    def __init__(self, code):\n",
    "        self.original_code = code\n",
    "        try:\n",
    "            self.ast_tree = ast.parse(code)\n",
    "        except SyntaxError as e:\n",
    "            raise ValueError(f\"Invalid Python code provided: {e}\")\n",
    "    \n",
    "    def transform(self):\n",
    "        # Deepcopy the AST to avoid modifying the original\n",
    "        tree = deepcopy(self.ast_tree)\n",
    "        \n",
    "        # Apply identifier renaming\n",
    "        renamer = IdentifierRenamer()\n",
    "        tree = renamer.visit(tree)\n",
    "        ast.fix_missing_locations(tree)\n",
    "        \n",
    "        # Apply statement reordering\n",
    "        reorderer = StatementReorderer()\n",
    "        tree = reorderer.visit(tree)\n",
    "        ast.fix_missing_locations(tree)\n",
    "        \n",
    "        # Convert the modified AST back to source code\n",
    "        transformed_code = astor.to_source(tree)\n",
    "        return transformed_code\n",
    "\n",
    "def generate_transformed_samples(code, num_samples=5):\n",
    "    samples = []\n",
    "    for _ in range(num_samples):\n",
    "        transformer = CodeTransformer(code)\n",
    "        try:\n",
    "            transformed_code = transformer.transform()\n",
    "            samples.append(transformed_code)\n",
    "        except ValueError as e:\n",
    "            print(f\"Error during transformation: {e}\")\n",
    "    return samples\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Example usage\n",
    "    sample_code = \"\"\"\n",
    "def test01(arg1):\n",
    "    result = arg1 * 2\n",
    "    return result\n",
    "\n",
    "def test02(arg1, arg2):\n",
    "    print(\"Starting data processing...\")\n",
    "    data = arg1 + arg2\n",
    "    processed_data = data * 3\n",
    "\n",
    "test01(5)\n",
    "\n",
    "def test03(arg1, arg2, arg3):\n",
    "    return [1, 2, 3, 4, 5]\n",
    "\n",
    "test02(10, 20)\n",
    "test03(1, 2, 3)\n",
    "\"\"\"\n",
    "\n",
    "    print(\"Original Code:\")\n",
    "    print(sample_code)\n",
    "    \n",
    "    print(\"\\nTransformed Samples:\")\n",
    "    transformed_samples = generate_transformed_samples(sample_code, num_samples=3)\n",
    "    for idx, sample in enumerate(transformed_samples, 1):\n",
    "        print(f\"\\nSample {idx}:\\n{sample}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data load and auto generation test 03"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Code:\n",
      "\n",
      "def test01(arg1):\n",
      "    result = arg1 * 2\n",
      "    return result\n",
      "\n",
      "def test02(arg1, arg2):\n",
      "    print(\"Starting data processing...\")\n",
      "    data = arg1 + arg2\n",
      "    processed_data = data * 3\n",
      "\n",
      "test01(5)\n",
      "print(\"debug\")\n",
      "\n",
      "def test03(arg1, arg2, arg3):\n",
      "    return [1, 2, 3, 4, 5]\n",
      "\n",
      "test02(10, 20)\n",
      "test03(1, 2, 3)\n",
      "\n",
      "\n",
      "Transformed Samples:\n",
      "\n",
      "Sample 1:\n",
      "def XFjpgj(arg1):\n",
      "    ZdFBDf = arg1 * 2\n",
      "    return ZdFBDf\n",
      "\n",
      "\n",
      "def ycaKXd(arg1, arg2):\n",
      "    GCbx = arg1 + arg2\n",
      "    CHTqTNSxiLXaIf = GCbx * 3\n",
      "    print('Starting data processing...')\n",
      "\n",
      "\n",
      "XFjpgj(5)\n",
      "print('debug')\n",
      "ycaKXd(10, 20)\n",
      "XjRKcX(1, 2, 3)\n",
      "\n",
      "\n",
      "Sample 2:\n",
      "def KllGkp(arg1):\n",
      "    qvgcAZ = arg1 * 2\n",
      "    return qvgcAZ\n",
      "\n",
      "\n",
      "KllGkp(5)\n",
      "print('debug')\n",
      "oxlPiK(10, 20)\n",
      "bXlfWK(1, 2, 3)\n",
      "\n",
      "\n",
      "Sample 3:\n",
      "def EkVRor(arg1):\n",
      "    xYqjbr = arg1 * 2\n",
      "    return xYqjbr\n",
      "\n",
      "\n",
      "def sdYCqX(arg1, arg2, arg3):\n",
      "    return [1, 2, 3, 4, 5]\n",
      "\n",
      "\n",
      "EkVRor(5)\n",
      "print('debug')\n",
      "TYWSZI(10, 20)\n",
      "sdYCqX(1, 2, 3)\n",
      "\n",
      "\n",
      "Sample 4:\n",
      "def yizPOW(arg1):\n",
      "    bZphyn = arg1 * 2\n",
      "    return bZphyn\n",
      "\n",
      "\n",
      "yizPOW(5)\n",
      "print('debug')\n",
      "OzkpLD(10, 20)\n",
      "DLLvyP(1, 2, 3)\n",
      "\n",
      "\n",
      "Sample 5:\n",
      "def qUesXm(arg1):\n",
      "    GfXlZI = arg1 * 2\n",
      "    return GfXlZI\n",
      "\n",
      "\n",
      "qUesXm(5)\n",
      "print('debug')\n",
      "tUWQgf(10, 20)\n",
      "IDOCQQ(1, 2, 3)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import astor\n",
    "import random\n",
    "import string\n",
    "from copy import deepcopy\n",
    "from collections import defaultdict, deque\n",
    "\n",
    "class DefinitionCollector(ast.NodeVisitor):\n",
    "    def __init__(self):\n",
    "        self.definitions = {}  # name -> Definition object\n",
    "        self.usages = defaultdict(set)  # name -> set of positions where it's used\n",
    "        self.current_definition = None\n",
    "        self.current_position = -1  # Position in module body\n",
    "        self.module_body = []  # List of top-level statements\n",
    "        # Set of built-in function names and standard library names to exclude\n",
    "        self.exclude_names = set([\n",
    "            'print', 'len', 'range', 'int', 'str', 'float', 'list', 'dict',\n",
    "            'set', 'tuple', 'Exception', 'ValueError', 'TypeError', 'abs', 'max',\n",
    "            'min', 'sum', 'open', 'input', 'map', 'filter', 'zip', 'enumerate',\n",
    "            'any', 'all', 'sorted', 'reversed', 'super', 'isinstance', 'issubclass',\n",
    "            'dir', 'globals', 'locals', 'vars', 'help', 'type', 'object', 'staticmethod',\n",
    "            'classmethod', 'property', 'format', 'eval', 'exec', 'compile', 'delattr',\n",
    "            'getattr', 'setattr', 'hasattr', 'id', 'hash', 'repr', 'memoryview',\n",
    "            'next', 'iter', 'bytes', 'bytearray', 'callable', 'chr', 'ord',\n",
    "            'complex', 'divmod', 'frozenset', 'bin', 'oct', 'hex', 'round',\n",
    "            # Add more built-in names if necessary\n",
    "        ])\n",
    "\n",
    "    def visit_Module(self, node):\n",
    "        self.module_body = node.body\n",
    "        for idx, stmt in enumerate(node.body):\n",
    "            self.current_position = idx\n",
    "            self.visit(stmt)\n",
    "\n",
    "    def visit_FunctionDef(self, node):\n",
    "        def_obj = Definition(node.name, node, self.current_position)\n",
    "        self.definitions[node.name] = def_obj\n",
    "        self.current_definition = def_obj\n",
    "        self.generic_visit(node)\n",
    "        self.current_definition = None\n",
    "\n",
    "    def visit_ClassDef(self, node):\n",
    "        def_obj = Definition(node.name, node, self.current_position)\n",
    "        self.definitions[node.name] = def_obj\n",
    "        self.current_definition = def_obj\n",
    "        self.generic_visit(node)\n",
    "        self.current_definition = None\n",
    "\n",
    "    def visit_Call(self, node):\n",
    "        func_name = self.get_function_name(node.func)\n",
    "        if func_name and func_name not in self.exclude_names:\n",
    "            if self.current_definition:\n",
    "                if func_name in self.definitions:\n",
    "                    self.current_definition.dependencies.add(func_name)\n",
    "            else:\n",
    "                if func_name in self.definitions:\n",
    "                    self.usages[func_name].add(self.current_position)\n",
    "        self.generic_visit(node)\n",
    "\n",
    "    def get_function_name(self, node):\n",
    "        if isinstance(node, ast.Name):\n",
    "            return node.id\n",
    "        elif isinstance(node, ast.Attribute):\n",
    "            # Handle cases like module.func or self.method\n",
    "            return node.attr\n",
    "        return None\n",
    "\n",
    "    def visit_Name(self, node):\n",
    "        if isinstance(node.ctx, ast.Load):\n",
    "            name = node.id\n",
    "            if name in self.exclude_names:\n",
    "                return\n",
    "            if self.current_definition:\n",
    "                # Possible dependency if the name is a function or class\n",
    "                if name in self.definitions:\n",
    "                    self.current_definition.dependencies.add(name)\n",
    "            else:\n",
    "                if name in self.definitions:\n",
    "                    self.usages[name].add(self.current_position)\n",
    "        self.generic_visit(node)\n",
    "\n",
    "class Definition:\n",
    "    def __init__(self, name, node, position):\n",
    "        self.name = name\n",
    "        self.node = node\n",
    "        self.position = position\n",
    "        self.dependencies = set()\n",
    "\n",
    "class DefinitionReorderer(ast.NodeTransformer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.collector = DefinitionCollector()\n",
    "\n",
    "    def transform(self, tree):\n",
    "        # Step 1: Collect definitions, dependencies, and usages\n",
    "        self.collector.visit(tree)\n",
    "        definitions = self.collector.definitions\n",
    "        usages = self.collector.usages\n",
    "        module_body = self.collector.module_body\n",
    "\n",
    "        # Step 2: Build dependency graph and perform topological sort\n",
    "        dependency_graph = {name: def_obj.dependencies for name, def_obj in definitions.items()}\n",
    "        sorted_definitions = self.topological_sort(definitions, dependency_graph)\n",
    "\n",
    "        # Step 3: Determine allowed positions for each definition\n",
    "        allowed_positions = {}\n",
    "        for def_obj in sorted_definitions:\n",
    "            earliest = 0\n",
    "            if def_obj.dependencies:\n",
    "                dep_positions = [definitions[dep].position for dep in def_obj.dependencies if dep in definitions]\n",
    "                if dep_positions:\n",
    "                    earliest = max(dep_positions) + 1\n",
    "            latest_usages = usages.get(def_obj.name, set())\n",
    "            latest = len(module_body)\n",
    "            if latest_usages:\n",
    "                latest = min(latest_usages)\n",
    "            allowed_positions[def_obj.name] = (earliest, latest)\n",
    "\n",
    "        # Step 4: Assign positions to definitions within allowed intervals\n",
    "        assigned_positions = self.assign_positions(allowed_positions, definitions)\n",
    "\n",
    "        # Step 5: Rebuild the module body\n",
    "        new_body = []\n",
    "        # Collect all statements and their original positions\n",
    "        statements = []\n",
    "        for idx, stmt in enumerate(module_body):\n",
    "            if isinstance(stmt, (ast.FunctionDef, ast.ClassDef)):\n",
    "                # Skip original definitions; they'll be re-inserted\n",
    "                continue\n",
    "            statements.append((idx, stmt))\n",
    "\n",
    "        # Insert definitions and other statements into new_body\n",
    "        # Build a mapping from old to new positions to handle shifts\n",
    "        position_mapping = {}\n",
    "        # First, collect all positions that need to be occupied\n",
    "        occupied_positions = set(assigned_positions.values())\n",
    "        statement_positions = set(idx for idx, _ in statements)\n",
    "        all_positions = sorted(occupied_positions.union(statement_positions))\n",
    "\n",
    "        # Build new_body by inserting definitions and statements in order\n",
    "        pos_iter = iter(all_positions)\n",
    "        pos_to_stmt = {}\n",
    "        for def_name, pos in assigned_positions.items():\n",
    "            pos_to_stmt[pos] = definitions[def_name].node\n",
    "        for idx, stmt in statements:\n",
    "            pos_to_stmt[idx] = stmt\n",
    "\n",
    "        for pos in sorted(pos_to_stmt.keys()):\n",
    "            new_body.append(pos_to_stmt[pos])\n",
    "\n",
    "        tree.body = new_body\n",
    "        return tree\n",
    "\n",
    "    def topological_sort(self, definitions, dependency_graph):\n",
    "        sorted_defs = []\n",
    "        visited = {}\n",
    "        temp_mark = set()\n",
    "\n",
    "        def visit(node_name):\n",
    "            if node_name in temp_mark:\n",
    "                # Detected a cycle; handle appropriately (could raise an error or ignore)\n",
    "                return\n",
    "            if visited.get(node_name, False):\n",
    "                return\n",
    "            temp_mark.add(node_name)\n",
    "            for dep_name in dependency_graph.get(node_name, []):\n",
    "                if dep_name in definitions:\n",
    "                    visit(dep_name)\n",
    "            temp_mark.remove(node_name)\n",
    "            visited[node_name] = True\n",
    "            sorted_defs.append(definitions[node_name])\n",
    "\n",
    "        for def_name in definitions:\n",
    "            visit(def_name)\n",
    "\n",
    "        return sorted_defs[::-1]  # Reverse to get correct order\n",
    "\n",
    "    def assign_positions(self, allowed_positions, definitions):\n",
    "        # Build intervals and shuffle definitions within overlapping intervals\n",
    "        intervals = defaultdict(list)\n",
    "        for def_name, (start, end) in allowed_positions.items():\n",
    "            intervals[(start, end)].append(def_name)\n",
    "        assigned_positions = {}\n",
    "        for interval, defs in intervals.items():\n",
    "            start, end = interval\n",
    "            possible_positions = list(range(start, end))\n",
    "            random.shuffle(possible_positions)\n",
    "            defs_positions = {}\n",
    "            for def_name in defs:\n",
    "                if possible_positions:\n",
    "                    pos = possible_positions.pop(0)\n",
    "                else:\n",
    "                    # If no positions are available, place it at the latest allowed position minus one\n",
    "                    pos = end - 1\n",
    "                assigned_positions[def_name] = pos\n",
    "                # Remove this position from possible_positions of other defs to avoid conflicts\n",
    "                possible_positions = [p for p in possible_positions if p != pos]\n",
    "        return assigned_positions\n",
    "\n",
    "class CodeTransformer:\n",
    "    def __init__(self, code):\n",
    "        self.original_code = code\n",
    "        try:\n",
    "            self.ast_tree = ast.parse(code)\n",
    "        except SyntaxError as e:\n",
    "            raise ValueError(f\"Invalid Python code provided: {e}\")\n",
    "\n",
    "    def transform(self):\n",
    "        # Deepcopy the AST to avoid modifying the original\n",
    "        tree = deepcopy(self.ast_tree)\n",
    "\n",
    "        # Apply identifier renaming\n",
    "        renamer = IdentifierRenamer()\n",
    "        tree = renamer.visit(tree)\n",
    "        ast.fix_missing_locations(tree)\n",
    "\n",
    "        # Apply statement reordering\n",
    "        reorderer = StatementReorderer()\n",
    "        tree = reorderer.visit(tree)\n",
    "        ast.fix_missing_locations(tree)\n",
    "\n",
    "        # Apply definition reordering\n",
    "        def_reorderer = DefinitionReorderer()\n",
    "        tree = def_reorderer.transform(tree)\n",
    "        ast.fix_missing_locations(tree)\n",
    "\n",
    "        # Convert the modified AST back to source code\n",
    "        transformed_code = astor.to_source(tree)\n",
    "        return transformed_code\n",
    "\n",
    "class IdentifierRenamer(ast.NodeTransformer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.name_mapping = {}\n",
    "        # Set of identifiers to exclude from renaming\n",
    "        self.exclude = set([\n",
    "            'self', 'cls', '__init__', '__str__', '__repr__', '__main__',\n",
    "            # Built-in functions and exceptions\n",
    "            'print', 'len', 'range', 'int', 'str', 'float', 'list', 'dict',\n",
    "            'set', 'tuple', 'Exception', 'ValueError', 'TypeError',\n",
    "            # Special variables\n",
    "            '__name__', '__doc__', '__file__', '__package__', '__loader__',\n",
    "            '__spec__', '__annotations__', '__builtins__',\n",
    "            # Commonly used names\n",
    "            'arg1', 'arg2', 'arg3'\n",
    "        ])\n",
    "\n",
    "    def generate_new_name(self, original_name):\n",
    "        # Generate a random name with the same length as the original\n",
    "        # Ensure the new name does not conflict with excluded names or existing mappings\n",
    "        while True:\n",
    "            new_name = ''.join(random.choices(string.ascii_letters, k=max(3, len(original_name))))\n",
    "            if new_name not in self.exclude and new_name not in self.name_mapping.values():\n",
    "                return new_name\n",
    "\n",
    "    def rename(self, original_name):\n",
    "        if original_name in self.exclude:\n",
    "            return original_name\n",
    "        if original_name not in self.name_mapping:\n",
    "            self.name_mapping[original_name] = self.generate_new_name(original_name)\n",
    "        return self.name_mapping[original_name]\n",
    "\n",
    "    def visit_FunctionDef(self, node):\n",
    "        original_name = node.name\n",
    "        node.name = self.rename(node.name)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "\n",
    "    def visit_ClassDef(self, node):\n",
    "        original_name = node.name\n",
    "        node.name = self.rename(node.name)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "\n",
    "    def visit_Name(self, node):\n",
    "        if isinstance(node.ctx, (ast.Store, ast.Load, ast.Del)):\n",
    "            original_id = node.id\n",
    "            node.id = self.rename(original_id)\n",
    "        return node\n",
    "\n",
    "    def visit_Attribute(self, node):\n",
    "        # Handle attribute accesses (e.g., self.value)\n",
    "        self.generic_visit(node)\n",
    "        # Attributes typically should not be renamed unless they are methods or variables defined within the class\n",
    "        # Here, we check if the attribute name is in the mapping and rename accordingly\n",
    "        if node.attr in self.name_mapping:\n",
    "            node.attr = self.name_mapping[node.attr]\n",
    "        return node\n",
    "\n",
    "    def visit_arg(self, node):\n",
    "        # Rename argument names, excluding 'self' and 'cls'\n",
    "        original_arg = node.arg\n",
    "        node.arg = self.rename(original_arg)\n",
    "        return node\n",
    "\n",
    "class StatementReorderer(ast.NodeTransformer):\n",
    "    def reorder_statements(self, body):\n",
    "        # Collect all assignments\n",
    "        assignments = [stmt for stmt in body if isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))]\n",
    "        non_assignments = [stmt for stmt in body if not isinstance(stmt, (ast.Assign, ast.AugAssign, ast.AnnAssign))]\n",
    "\n",
    "        # Determine dependencies\n",
    "        assigned_vars = set()\n",
    "        dependencies = {}\n",
    "        for i, stmt in enumerate(assignments):\n",
    "            vars_assigned = self.get_assigned_vars(stmt)\n",
    "            vars_used = self.get_used_vars(stmt)\n",
    "            dependencies[i] = vars_used.intersection(assigned_vars)\n",
    "            assigned_vars.update(vars_assigned)\n",
    "\n",
    "        # Identify independent assignments (no dependencies)\n",
    "        independent_indices = [i for i, deps in dependencies.items() if not deps]\n",
    "        independent_assignments = [assignments[i] for i in independent_indices]\n",
    "        dependent_assignments = [stmt for i, stmt in enumerate(assignments) if i not in independent_indices]\n",
    "\n",
    "        # Shuffle independent assignments if more than one\n",
    "        if len(independent_assignments) > 1:\n",
    "            random.shuffle(independent_assignments)\n",
    "\n",
    "        # Combine shuffled independent assignments with dependent assignments\n",
    "        new_assignments = independent_assignments + dependent_assignments\n",
    "\n",
    "        return new_assignments + non_assignments\n",
    "\n",
    "    def get_assigned_vars(self, node):\n",
    "        \"\"\"Recursively collect all variable names that are assigned in the node.\"\"\"\n",
    "        vars = set()\n",
    "        if isinstance(node, ast.Assign):\n",
    "            for target in node.targets:\n",
    "                vars.update(self.extract_names(target))\n",
    "        elif isinstance(node, ast.AugAssign):\n",
    "            if isinstance(node.target, ast.Name):\n",
    "                vars.add(node.target.id)\n",
    "            else:\n",
    "                vars.update(self.extract_names(node.target))\n",
    "        elif isinstance(node, ast.AnnAssign):\n",
    "            if isinstance(node.target, ast.Name):\n",
    "                vars.add(node.target.id)\n",
    "            else:\n",
    "                vars.update(self.extract_names(node.target))\n",
    "        return vars\n",
    "\n",
    "    def get_used_vars(self, node):\n",
    "        \"\"\"Recursively collect all variable names that are used in the node.\"\"\"\n",
    "        vars = set()\n",
    "        for child in ast.walk(node):\n",
    "            if isinstance(child, ast.Name):\n",
    "                if isinstance(child.ctx, ast.Load):\n",
    "                    vars.add(child.id)\n",
    "        return vars\n",
    "\n",
    "    def extract_names(self, node):\n",
    "        \"\"\"Extract variable names from assignment targets.\"\"\"\n",
    "        names = set()\n",
    "        if isinstance(node, ast.Name):\n",
    "            names.add(node.id)\n",
    "        elif isinstance(node, (ast.Tuple, ast.List)):\n",
    "            for elt in node.elts:\n",
    "                names.update(self.extract_names(elt))\n",
    "        elif isinstance(node, ast.Attribute):\n",
    "            # Attributes like self.value are not added as variables\n",
    "            pass\n",
    "        return names\n",
    "\n",
    "    def visit_FunctionDef(self, node):\n",
    "        node.body = self.reorder_statements(node.body)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "\n",
    "    def visit_Module(self, node):\n",
    "        node.body = self.reorder_statements(node.body)\n",
    "        self.generic_visit(node)\n",
    "        return node\n",
    "\n",
    "def generate_transformed_samples(code, num_samples=5):\n",
    "    samples = []\n",
    "    for _ in range(num_samples):\n",
    "        transformer = CodeTransformer(code)\n",
    "        try:\n",
    "            transformed_code = transformer.transform()\n",
    "            samples.append(transformed_code)\n",
    "        except ValueError as e:\n",
    "            print(f\"Error during transformation: {e}\")\n",
    "    return samples\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Example usage\n",
    "    sample_code = \"\"\"\n",
    "def test01(arg1):\n",
    "    result = arg1 * 2\n",
    "    return result\n",
    "\n",
    "def test02(arg1, arg2):\n",
    "    print(\"Starting data processing...\")\n",
    "    data = arg1 + arg2\n",
    "    processed_data = data * 3\n",
    "\n",
    "test01(5)\n",
    "print(\"debug\")\n",
    "\n",
    "def test03(arg1, arg2, arg3):\n",
    "    return [1, 2, 3, 4, 5]\n",
    "\n",
    "test02(10, 20)\n",
    "test03(1, 2, 3)\n",
    "\"\"\"\n",
    "\n",
    "    print(\"Original Code:\")\n",
    "    print(sample_code)\n",
    "\n",
    "    print(\"\\nTransformed Samples:\")\n",
    "    transformed_samples = generate_transformed_samples(sample_code, num_samples=5)\n",
    "    for idx, sample in enumerate(transformed_samples, 1):\n",
    "        print(f\"\\nSample {idx}:\\n{sample}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_code_dataset(data, num_samples=5):\n",
    "    if 'code' not in data.columns:\n",
    "        raise ValueError(\"Input DataFrame must contain a 'code' column.\")\n",
    "\n",
    "    # Store all code in a list\n",
    "    all_code = []\n",
    "\n",
    "    # Iterate over each row in the DataFrame\n",
    "    for index, row in data.iterrows():\n",
    "        original_code = row['code']\n",
    "        all_code.append(original_code)  # Append the original code\n",
    "\n",
    "        # Generate transformed samples using the CodeTransformer\n",
    "        transformer = CodeTransformer(original_code)\n",
    "        for _ in range(num_samples):\n",
    "            try:\n",
    "                transformed_code = transformer.transform()\n",
    "                all_code.append(transformed_code)  # Append transformed code\n",
    "            except ValueError as e:\n",
    "                print(f\"Error transforming code at index {index}: {e}\")\n",
    "\n",
    "    # Create a new DataFrame with all code stored in the 'code' column\n",
    "    transformed_data = pd.DataFrame({'code': all_code})\n",
    "\n",
    "    return transformed_data\n",
    "\n",
    "expanded_data = generate_code_dataset(data, num_samples=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "366\n"
     ]
    }
   ],
   "source": [
    "print(len(expanded_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### state define test 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "\n",
    "def get_node_types(node, node_types):\n",
    "    \"\"\"Recursively collect node types from the AST.\"\"\"\n",
    "    node_types.add(type(node).__name__)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        get_node_types(child, node_types)\n",
    "\n",
    "def build_vocab(code_list):\n",
    "    \"\"\"Build a vocabulary of node types from a list of code snippets.\"\"\"\n",
    "    node_types = set()\n",
    "    for code in code_list:\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "            get_node_types(tree, node_types)\n",
    "        except SyntaxError:\n",
    "            continue  # Skip code snippets with syntax errors\n",
    "    vocab = {node_type: idx for idx, node_type in enumerate(sorted(node_types))}\n",
    "    return vocab\n",
    "\n",
    "def collect_nodes_per_line(node, line_nodes):\n",
    "    \"\"\"Collect node types per line number.\"\"\"\n",
    "    lineno = getattr(node, 'lineno', None)\n",
    "    if lineno is not None:\n",
    "        lineno -= 1  # Adjust for zero-based index\n",
    "        if 0 <= lineno < len(line_nodes):\n",
    "            line_nodes[lineno].add(type(node).__name__)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        collect_nodes_per_line(child, line_nodes)\n",
    "\n",
    "def code_to_states(code, vocab):\n",
    "    \"\"\"Convert a code snippet into a list of states.\"\"\"\n",
    "    lines = code.split('\\n')\n",
    "    N = len(lines)\n",
    "    line_nodes = [set() for _ in range(N)]\n",
    "    try:\n",
    "        tree = ast.parse(code)\n",
    "        collect_nodes_per_line(tree, line_nodes)\n",
    "        cumulative_nodes = set()\n",
    "        states = []\n",
    "        for i in range(N):\n",
    "            cumulative_nodes.update(line_nodes[i])\n",
    "            indices = [vocab[node_type] for node_type in sorted(cumulative_nodes)]\n",
    "            states.append(indices)\n",
    "            # print(f\"Line {i+1}: Node types: {sorted(cumulative_nodes)}\")\n",
    "    except SyntaxError:\n",
    "        pass  # Handle the syntax error appropriately\n",
    "    return states\n",
    "\n",
    "def process_dataframe(df):\n",
    "    \"\"\"Process the dataframe to convert code snippets into lists of states.\"\"\"\n",
    "    code_list = df['code'].tolist()\n",
    "    vocab = build_vocab(code_list)\n",
    "    df['states'] = df['code'].apply(lambda code: code_to_states(code, vocab))\n",
    "    return df, vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data, vocab = process_dataframe(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[9, 31], [9, 17, 31, 46], [2, 5, 8, 9, 17, 31, 46], [2, 5, 8, 9, 17, 31, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 36, 40, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 36, 40, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 36, 40, 46]]\n"
     ]
    }
   ],
   "source": [
    "print(new_data[\"states\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[9, 31], [9, 17, 31, 46], [2, 5, 8, 9, 17, 31, 46], [2, 5, 8, 9, 17, 31, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 36, 40, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 36, 40, 46], [2, 5, 8, 9, 10, 11, 17, 20, 31, 36, 40, 46]]\n"
     ]
    }
   ],
   "source": [
    "code_temp = new_data[\"code\"][0]\n",
    "state_test = code_to_states(code_temp, vocab)\n",
    "print(state_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class Solution(object):\n",
      "    def findMedianSortedArrays(self, nums1, nums2):\n",
      "        merged = sorted(nums1+nums2)\n",
      "        n = len(merged)\n",
      "        if n%2==1:\n",
      "            return merged[n//2]\n",
      "        else:\n",
      "            return (merged[n//2]+merged[n//2-1])/2.0\n",
      "{'Add': 0, 'And': 1, 'Assign': 2, 'Attribute': 3, 'AugAssign': 4, 'BinOp': 5, 'BitAnd': 6, 'BoolOp': 7, 'Call': 8, 'ClassDef': 9, 'Compare': 10, 'Constant': 11, 'Div': 12, 'Eq': 13, 'Expr': 14, 'FloorDiv': 15, 'For': 16, 'FunctionDef': 17, 'Gt': 18, 'GtE': 19, 'If': 20, 'IfExp': 21, 'Import': 22, 'Is': 23, 'List': 24, 'Load': 25, 'Lt': 26, 'LtE': 27, 'Mod': 28, 'Module': 29, 'Mult': 30, 'Name': 31, 'Not': 32, 'NotEq': 33, 'Or': 34, 'Raise': 35, 'Return': 36, 'Slice': 37, 'Store': 38, 'Sub': 39, 'Subscript': 40, 'Tuple': 41, 'USub': 42, 'UnaryOp': 43, 'While': 44, 'alias': 45, 'arg': 46, 'arguments': 47}\n"
     ]
    }
   ],
   "source": [
    "print(code_temp)\n",
    "print(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### state define test 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "\n",
    "def get_identity(node):\n",
    "    \n",
    "    identity_info = None\n",
    "    is_state_node = True\n",
    "\n",
    "    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.ClassDef):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.Name):\n",
    "        identity_info = node.id\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.arg):\n",
    "        identity_info = node.arg\n",
    "    elif isinstance(node, ast.Attribute):\n",
    "        identity_info = node.attr\n",
    "    elif isinstance(node, ast.Call):\n",
    "        if isinstance(node.func, ast.Name):\n",
    "            identity_info = node.func.id\n",
    "        elif isinstance(node.func, ast.Attribute):\n",
    "            identity_info = node.func.attr\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.Assign):\n",
    "        targets = []\n",
    "        for t in node.targets:\n",
    "            if isinstance(t, ast.Name):\n",
    "                targets.append(t.id)\n",
    "            elif isinstance(t, ast.Tuple):\n",
    "                targets.extend([elt.id for elt in t.elts if isinstance(elt, ast.Name)])\n",
    "        identity_info = ','.join(targets)\n",
    "    elif isinstance(node, ast.Import):\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = ','.join(names)\n",
    "    elif isinstance(node, ast.ImportFrom):\n",
    "        module = node.module or ''\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = f\"{module}::{','.join(names)}\"\n",
    "    else:\n",
    "        identity_info = None\n",
    "        is_state_node = False\n",
    "\n",
    "    return identity_info, is_state_node\n",
    "\n",
    "def get_node_reprs(node, node_reprs):\n",
    "    node_type = type(node).__name__\n",
    "    identity, is_state_node = get_identity(node)\n",
    "    if identity:\n",
    "        node_repr = f\"{node_type}({identity})\"\n",
    "    else:\n",
    "        node_repr = node_type\n",
    "    if is_state_node:\n",
    "        node_repr = \"<state>\" + node_repr\n",
    "    node_reprs.add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        get_node_reprs(child, node_reprs)\n",
    "\n",
    "def build_vocab(code_list):\n",
    "    node_reprs = set()\n",
    "    for code in code_list:\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "            get_node_reprs(tree, node_reprs)\n",
    "        except SyntaxError:\n",
    "            continue  # Skip code snippets with syntax errors\n",
    "    vocab = {node_repr: idx for idx, node_repr in enumerate(sorted(node_reprs))}\n",
    "    return vocab\n",
    "\n",
    "def collect_nodes_per_line(node, line_nodes):\n",
    "    lineno = getattr(node, 'lineno', None)\n",
    "    if lineno is not None:\n",
    "        lineno -= 1  # Adjust for zero-based index\n",
    "        if 0 <= lineno < len(line_nodes):\n",
    "            node_type = type(node).__name__\n",
    "            identity, is_state_node = get_identity(node)\n",
    "            if identity:\n",
    "                node_repr = f\"{node_type}({identity})\"\n",
    "            else:\n",
    "                node_repr = node_type\n",
    "            if is_state_node:\n",
    "                node_repr = \"<state>\" + node_repr\n",
    "            line_nodes[lineno].add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        collect_nodes_per_line(child, line_nodes)\n",
    "\n",
    "def code_to_states(code, vocab):\n",
    "    lines = code.split('\\n')\n",
    "    N = len(lines)\n",
    "    line_nodes = [set() for _ in range(N)]\n",
    "    try:\n",
    "        tree = ast.parse(code)\n",
    "        collect_nodes_per_line(tree, line_nodes)\n",
    "        cumulative_nodes = set()\n",
    "        states = []\n",
    "        actions = []\n",
    "        for i in range(N):\n",
    "            # State: cumulative node representations up to current line\n",
    "            cumulative_nodes.update(line_nodes[i])\n",
    "            state_indices = [vocab[node_repr] for node_repr in sorted(cumulative_nodes) if \"<state>\" in node_repr]\n",
    "            states.append(state_indices)\n",
    "            # Action: node representations that start at current line\n",
    "            action_nodes = line_nodes[i]\n",
    "            action_indices = [vocab[node_repr] for node_repr in sorted(line_nodes[i])]\n",
    "            actions.append(action_indices)\n",
    "        return states, actions\n",
    "    except SyntaxError:\n",
    "        return [], []\n",
    "\n",
    "def process_dataframe(df):\n",
    "    code_list = df['code'].tolist()\n",
    "    vocab = build_vocab(code_list)\n",
    "    state_vocab_indices = [idx for node_repr, idx in vocab.items() if \"<state>\" in node_repr]\n",
    "    results = df['code'].apply(lambda code: code_to_states(code, vocab))\n",
    "    df['states'] = [res[0] for res in results]\n",
    "    df['actions'] = [res[1] for res in results]\n",
    "    df['one_hot_state'] = df['states'].apply(lambda state: one_hot_from_indices(state, vocab, state_vocab_indices, True))\n",
    "    df['one_hot_action'] = df['actions'].apply(lambda action: one_hot_from_indices(action, vocab))\n",
    "    return df, vocab\n",
    "\n",
    "def one_hot_from_indices(indices_list, vocab, state_vocab_indices = [], is_state = False):\n",
    "    encode_list = []\n",
    "    N = len(vocab)\n",
    "    for indices in indices_list:\n",
    "        one_hot_seq = np.zeros(N)\n",
    "        for index in indices:\n",
    "            one_hot_seq[index] = 1\n",
    "        if is_state:\n",
    "            one_hot_seq = one_hot_seq[state_vocab_indices]\n",
    "        encode_list.append(one_hot_seq)\n",
    "    return encode_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "def get_identity(node):\n",
    "    identity_info = None\n",
    "    is_state_node = True\n",
    "\n",
    "    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.ClassDef):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.Name):\n",
    "        identity_info = node.id\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.arg):\n",
    "        identity_info = node.arg\n",
    "    elif isinstance(node, ast.Attribute):\n",
    "        identity_info = node.attr\n",
    "    elif isinstance(node, ast.Call):\n",
    "        if isinstance(node.func, ast.Name):\n",
    "            identity_info = node.func.id\n",
    "        elif isinstance(node.func, ast.Attribute):\n",
    "            identity_info = node.func.attr\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.Assign):\n",
    "        targets = []\n",
    "        for t in node.targets:\n",
    "            if isinstance(t, ast.Name):\n",
    "                targets.append(t.id)\n",
    "            elif isinstance(t, ast.Tuple):\n",
    "                targets.extend([elt.id for elt in t.elts if isinstance(elt, ast.Name)])\n",
    "        identity_info = ','.join(targets)\n",
    "    elif isinstance(node, ast.Import):\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = ','.join(names)\n",
    "    elif isinstance(node, ast.ImportFrom):\n",
    "        module = node.module or ''\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = f\"{module}::{','.join(names)}\"\n",
    "    else:\n",
    "        identity_info = None\n",
    "        is_state_node = False\n",
    "\n",
    "    return identity_info, is_state_node\n",
    "\n",
    "def get_node_reprs(node, node_reprs):\n",
    "    node_type = type(node).__name__\n",
    "    identity, is_state_node = get_identity(node)\n",
    "    if identity:\n",
    "        node_repr = f\"{node_type}({identity})\"\n",
    "    else:\n",
    "        node_repr = node_type\n",
    "    if is_state_node:\n",
    "        node_repr = \"<state>\" + node_repr\n",
    "    node_reprs.add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        get_node_reprs(child, node_reprs)\n",
    "\n",
    "def build_vocab(code_list):\n",
    "    node_reprs = set()\n",
    "    for code in code_list:\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "            get_node_reprs(tree, node_reprs)\n",
    "        except SyntaxError:\n",
    "            continue  # Skip code snippets with syntax errors\n",
    "    # Initialize vocabulary with <UNK> token\n",
    "    vocab = {\"<state><UNK>\": 0}\n",
    "    # Start indexing from 1\n",
    "    for idx, node_repr in enumerate(sorted(node_reprs), start=1):\n",
    "        vocab[node_repr] = idx\n",
    "    return vocab\n",
    "\n",
    "def collect_nodes_per_line(node, line_nodes):\n",
    "    lineno = getattr(node, 'lineno', None)\n",
    "    if lineno is not None:\n",
    "        lineno -= 1  # Adjust for zero-based index\n",
    "        if 0 <= lineno < len(line_nodes):\n",
    "            node_type = type(node).__name__\n",
    "            identity, is_state_node = get_identity(node)\n",
    "            if identity:\n",
    "                node_repr = f\"{node_type}({identity})\"\n",
    "            else:\n",
    "                node_repr = node_type\n",
    "            if is_state_node:\n",
    "                node_repr = \"<state>\" + node_repr\n",
    "            line_nodes[lineno].add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        collect_nodes_per_line(child, line_nodes)\n",
    "\n",
    "def code_to_states(code, vocab, unk_token=\"<state><UNK>\"):\n",
    "    lines = code.split('\\n')\n",
    "    N = len(lines)\n",
    "    line_nodes = [set() for _ in range(N)]\n",
    "    try:\n",
    "        tree = ast.parse(code)\n",
    "        collect_nodes_per_line(tree, line_nodes)\n",
    "        cumulative_nodes = set()\n",
    "        states = []\n",
    "        actions = []\n",
    "        for i in range(N):\n",
    "            # State: cumulative node representations up to current line\n",
    "            cumulative_nodes.update(line_nodes[i])\n",
    "            # Handle state nodes, map to vocab or <UNK>\n",
    "            state_indices = [\n",
    "                vocab.get(node_repr, vocab[unk_token])\n",
    "                for node_repr in sorted(cumulative_nodes)\n",
    "                if \"<state>\" in node_repr\n",
    "            ]\n",
    "            states.append(state_indices)\n",
    "            # Action: node representations that start at current line\n",
    "            action_nodes = line_nodes[i]\n",
    "            action_indices = [\n",
    "                vocab.get(node_repr, vocab[unk_token])\n",
    "                for node_repr in sorted(action_nodes)\n",
    "            ]\n",
    "            actions.append(action_indices)\n",
    "        return states, actions\n",
    "    except SyntaxError:\n",
    "        return [], []\n",
    "\n",
    "def one_hot_from_indices(indices_list, vocab_size, indices_map):\n",
    "    encode_list = []\n",
    "    for indices in indices_list:\n",
    "        one_hot_seq = np.zeros(vocab_size, dtype=np.float32)\n",
    "        for index in indices:\n",
    "            one_hot_seq[index] = 1.0\n",
    "        encode_list.append(one_hot_seq)\n",
    "    return encode_list\n",
    "\n",
    "def process_dataframe(df):\n",
    "    code_list = df['code'].tolist()\n",
    "    vocab = build_vocab(code_list)\n",
    "    results = df['code'].apply(lambda code: code_to_states(code, vocab))\n",
    "    df['states'] = [res[0] for res in results]\n",
    "    df['actions'] = [res[1] for res in results]\n",
    "    \n",
    "    # Separate state and action vocab indices\n",
    "    state_vocab_indices = {k: v for k, v in vocab.items() if \"<state>\" in k}\n",
    "    action_vocab_indices = {k: v for k, v in vocab.items() if \"<state>\" not in k}\n",
    "    \n",
    "    vocab_size = len(vocab)\n",
    "    state_vocab_size = len(state_vocab_indices)\n",
    "    \n",
    "    # Create mapping for states\n",
    "    state_map = {k: v for k, v in vocab.items() if \"<state>\" in k}\n",
    "    \n",
    "    # One-hot encode states and actions\n",
    "    df['one_hot_state'] = df['states'].apply(\n",
    "        lambda state: one_hot_from_indices(\n",
    "            state, vocab_size, state_map\n",
    "        )\n",
    "    )\n",
    "    df['one_hot_action'] = df['actions'].apply(\n",
    "        lambda action: one_hot_from_indices(\n",
    "            action, vocab_size, vocab\n",
    "        )\n",
    "    )\n",
    "    return df, vocab\n",
    "\n",
    "def process_new_code(code, vocab, unk_token=\"<state><UNK>\"):\n",
    "    states, actions = code_to_states(code, vocab, unk_token=unk_token)\n",
    "    vocab_size = len(vocab)\n",
    "    \n",
    "    states_one_hot = one_hot_from_indices(states, vocab_size, vocab)\n",
    "    actions_one_hot = one_hot_from_indices(actions, vocab_size, vocab)\n",
    "    \n",
    "    return states_one_hot, actions_one_hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data, vocab = process_dataframe(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[136], [136, 137, 145, 146, 149], [52, 136, 137, 145, 146, 149], [52, 70, 136, 137, 145, 146, 149], [52, 70, 136, 137, 145, 146, 149], [52, 70, 136, 137, 145, 146, 149], [52, 70, 136, 137, 145, 146, 149], [52, 70, 136, 137, 145, 146, 149]]\n",
      "[[136, 298], [137, 145, 146, 149], [52, 153, 174, 249, 283, 290, 323], [70, 165, 226, 249, 268], [153, 175, 176, 184, 268], [153, 176, 249, 268, 339, 343], [], [153, 176, 249, 268, 339, 343]]\n"
     ]
    }
   ],
   "source": [
    "print(new_data[\"states\"][0])\n",
    "print(new_data[\"actions\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class Solution(object):\n",
      "    def findMedianSortedArrays(self, nums1, nums2):\n",
      "        merged = sorted(nums1+nums2)\n",
      "        n = len(merged)\n",
      "        if n%2==1:\n",
      "            return merged[n//2]\n",
      "        else:\n",
      "            return (merged[n//2]+merged[n//2-1])/2.0\n",
      "Vocabulary:\n",
      "0: <state>Assign\n",
      "1: <state>Assign(L)\n",
      "2: <state>Assign(M)\n",
      "3: <state>Assign(N)\n",
      "4: <state>Assign(a)\n",
      "5: <state>Assign(ans)\n",
      "6: <state>Assign(array)\n",
      "7: <state>Assign(b)\n",
      "8: <state>Assign(com_arr)\n",
      "9: <state>Assign(f_pointer)\n",
      "10: <state>Assign(half)\n",
      "11: <state>Assign(half_len)\n",
      "12: <state>Assign(high)\n",
      "13: <state>Assign(i)\n",
      "14: <state>Assign(i,j,k)\n",
      "15: <state>Assign(imax)\n",
      "16: <state>Assign(imin)\n",
      "17: <state>Assign(imin,imax,half_len)\n",
      "18: <state>Assign(ind)\n",
      "19: <state>Assign(index)\n",
      "20: <state>Assign(int_upper_median)\n",
      "21: <state>Assign(j)\n",
      "22: <state>Assign(k)\n",
      "23: <state>Assign(l)\n",
      "24: <state>Assign(l,r)\n",
      "25: <state>Assign(l1)\n",
      "26: <state>Assign(l1,l2,r1,r2)\n",
      "27: <state>Assign(l2)\n",
      "28: <state>Assign(left)\n",
      "29: <state>Assign(left,right)\n",
      "30: <state>Assign(len1)\n",
      "31: <state>Assign(len2)\n",
      "32: <state>Assign(len_list)\n",
      "33: <state>Assign(length)\n",
      "34: <state>Assign(low)\n",
      "35: <state>Assign(low,high)\n",
      "36: <state>Assign(ls)\n",
      "37: <state>Assign(m)\n",
      "38: <state>Assign(m,n)\n",
      "39: <state>Assign(m1)\n",
      "40: <state>Assign(m2)\n",
      "41: <state>Assign(m_length)\n",
      "42: <state>Assign(maxLeftX)\n",
      "43: <state>Assign(maxLeftY)\n",
      "44: <state>Assign(max_left1)\n",
      "45: <state>Assign(max_left2)\n",
      "46: <state>Assign(max_of_left)\n",
      "47: <state>Assign(max_size)\n",
      "48: <state>Assign(med)\n",
      "49: <state>Assign(med1)\n",
      "50: <state>Assign(med2)\n",
      "51: <state>Assign(median)\n",
      "52: <state>Assign(merged)\n",
      "53: <state>Assign(merged_arr)\n",
      "54: <state>Assign(merged_array)\n",
      "55: <state>Assign(mid)\n",
      "56: <state>Assign(mid1)\n",
      "57: <state>Assign(mid2)\n",
      "58: <state>Assign(midIndex)\n",
      "59: <state>Assign(midNums1)\n",
      "60: <state>Assign(midNums2)\n",
      "61: <state>Assign(mid_val)\n",
      "62: <state>Assign(middle)\n",
      "63: <state>Assign(middle1)\n",
      "64: <state>Assign(middle2)\n",
      "65: <state>Assign(minRightX)\n",
      "66: <state>Assign(minRightY)\n",
      "67: <state>Assign(min_of_right)\n",
      "68: <state>Assign(min_right1)\n",
      "69: <state>Assign(min_right2)\n",
      "70: <state>Assign(n)\n",
      "71: <state>Assign(n1,n2)\n",
      "72: <state>Assign(new)\n",
      "73: <state>Assign(new2)\n",
      "74: <state>Assign(newArray)\n",
      "75: <state>Assign(new_array,N)\n",
      "76: <state>Assign(new_len)\n",
      "77: <state>Assign(new_list)\n",
      "78: <state>Assign(new_nums)\n",
      "79: <state>Assign(new_splt)\n",
      "80: <state>Assign(newlist)\n",
      "81: <state>Assign(num3)\n",
      "82: <state>Assign(nums)\n",
      "83: <state>Assign(nums1)\n",
      "84: <state>Assign(nums1,nums2)\n",
      "85: <state>Assign(nums1,nums2,m,n)\n",
      "86: <state>Assign(nums1Left)\n",
      "87: <state>Assign(nums1Right)\n",
      "88: <state>Assign(nums1_left)\n",
      "89: <state>Assign(nums1_left_max)\n",
      "90: <state>Assign(nums1_right)\n",
      "91: <state>Assign(nums1_right_min)\n",
      "92: <state>Assign(nums2)\n",
      "93: <state>Assign(nums2Left)\n",
      "94: <state>Assign(nums2Right)\n",
      "95: <state>Assign(nums2_left)\n",
      "96: <state>Assign(nums2_left_max)\n",
      "97: <state>Assign(nums2_right)\n",
      "98: <state>Assign(nums2_right_min)\n",
      "99: <state>Assign(nums3)\n",
      "100: <state>Assign(odd_median)\n",
      "101: <state>Assign(p1,p2,np)\n",
      "102: <state>Assign(partition1)\n",
      "103: <state>Assign(partition2)\n",
      "104: <state>Assign(partitionX)\n",
      "105: <state>Assign(partitionY)\n",
      "106: <state>Assign(r)\n",
      "107: <state>Assign(r1)\n",
      "108: <state>Assign(r2)\n",
      "109: <state>Assign(res)\n",
      "110: <state>Assign(result)\n",
      "111: <state>Assign(right)\n",
      "112: <state>Assign(s)\n",
      "113: <state>Assign(s_pointer)\n",
      "114: <state>Assign(self)\n",
      "115: <state>Assign(size)\n",
      "116: <state>Assign(solution)\n",
      "117: <state>Assign(sort)\n",
      "118: <state>Assign(sorted_l)\n",
      "119: <state>Assign(sorted_merged_arr)\n",
      "120: <state>Assign(splt)\n",
      "121: <state>Assign(targetIndex)\n",
      "122: <state>Assign(total)\n",
      "123: <state>Assign(total_elements)\n",
      "124: <state>Assign(ttk)\n",
      "125: <state>Assign(upper_median)\n",
      "126: <state>Assign(val)\n",
      "127: <state>Assign(x)\n",
      "128: <state>Assign(x,y)\n",
      "129: <state>Attribute(append)\n",
      "130: <state>Attribute(ceil)\n",
      "131: <state>Attribute(extend)\n",
      "132: <state>Attribute(findMedianSortedArrays)\n",
      "133: <state>Attribute(floor)\n",
      "134: <state>Attribute(merge)\n",
      "135: <state>Attribute(sort)\n",
      "136: <state>ClassDef(Solution)\n",
      "137: <state>FunctionDef(findMedianSortedArrays)\n",
      "138: <state>FunctionDef(median)\n",
      "139: <state>FunctionDef(merge)\n",
      "140: <state>Import(math)\n",
      "141: <state>arg(a)\n",
      "142: <state>arg(array1)\n",
      "143: <state>arg(array2)\n",
      "144: <state>arg(b)\n",
      "145: <state>arg(nums1)\n",
      "146: <state>arg(nums2)\n",
      "147: <state>arg(p)\n",
      "148: <state>arg(q)\n",
      "149: <state>arg(self)\n",
      "150: Add\n",
      "151: And\n",
      "152: AugAssign\n",
      "153: BinOp\n",
      "154: BitAnd\n",
      "155: BoolOp\n",
      "156: Call(Solution)\n",
      "157: Call(ValueError)\n",
      "158: Call(append)\n",
      "159: Call(ceil)\n",
      "160: Call(extend)\n",
      "161: Call(findMedianSortedArrays)\n",
      "162: Call(float)\n",
      "163: Call(floor)\n",
      "164: Call(int)\n",
      "165: Call(len)\n",
      "166: Call(max)\n",
      "167: Call(median)\n",
      "168: Call(merge)\n",
      "169: Call(min)\n",
      "170: Call(print)\n",
      "171: Call(range)\n",
      "172: Call(round)\n",
      "173: Call(sort)\n",
      "174: Call(sorted)\n",
      "175: Compare\n",
      "176: Constant\n",
      "177: Div\n",
      "178: Eq\n",
      "179: Expr\n",
      "180: FloorDiv\n",
      "181: For\n",
      "182: Gt\n",
      "183: GtE\n",
      "184: If\n",
      "185: IfExp\n",
      "186: Is\n",
      "187: List\n",
      "188: Load\n",
      "189: Lt\n",
      "190: LtE\n",
      "191: Mod\n",
      "192: Module\n",
      "193: Mult\n",
      "194: Name(L)\n",
      "195: Name(M)\n",
      "196: Name(N)\n",
      "197: Name(Solution)\n",
      "198: Name(ValueError)\n",
      "199: Name(__name__)\n",
      "200: Name(a)\n",
      "201: Name(ans)\n",
      "202: Name(array)\n",
      "203: Name(array1)\n",
      "204: Name(array2)\n",
      "205: Name(b)\n",
      "206: Name(com_arr)\n",
      "207: Name(count)\n",
      "208: Name(f_pointer)\n",
      "209: Name(float)\n",
      "210: Name(half)\n",
      "211: Name(half_len)\n",
      "212: Name(high)\n",
      "213: Name(i)\n",
      "214: Name(imax)\n",
      "215: Name(imin)\n",
      "216: Name(ind)\n",
      "217: Name(index)\n",
      "218: Name(int)\n",
      "219: Name(int_upper_median)\n",
      "220: Name(j)\n",
      "221: Name(k)\n",
      "222: Name(l)\n",
      "223: Name(l1)\n",
      "224: Name(l2)\n",
      "225: Name(left)\n",
      "226: Name(len)\n",
      "227: Name(len1)\n",
      "228: Name(len2)\n",
      "229: Name(len_list)\n",
      "230: Name(length)\n",
      "231: Name(low)\n",
      "232: Name(ls)\n",
      "233: Name(m)\n",
      "234: Name(m1)\n",
      "235: Name(m2)\n",
      "236: Name(m_length)\n",
      "237: Name(math)\n",
      "238: Name(max)\n",
      "239: Name(maxLeftX)\n",
      "240: Name(maxLeftY)\n",
      "241: Name(max_left1)\n",
      "242: Name(max_left2)\n",
      "243: Name(max_of_left)\n",
      "244: Name(max_size)\n",
      "245: Name(med)\n",
      "246: Name(med1)\n",
      "247: Name(med2)\n",
      "248: Name(median)\n",
      "249: Name(merged)\n",
      "250: Name(merged_arr)\n",
      "251: Name(merged_array)\n",
      "252: Name(mid)\n",
      "253: Name(mid1)\n",
      "254: Name(mid2)\n",
      "255: Name(midIndex)\n",
      "256: Name(midNums1)\n",
      "257: Name(midNums2)\n",
      "258: Name(mid_val)\n",
      "259: Name(middle)\n",
      "260: Name(middle1)\n",
      "261: Name(middle2)\n",
      "262: Name(min)\n",
      "263: Name(minRightX)\n",
      "264: Name(minRightY)\n",
      "265: Name(min_of_right)\n",
      "266: Name(min_right1)\n",
      "267: Name(min_right2)\n",
      "268: Name(n)\n",
      "269: Name(n1)\n",
      "270: Name(n2)\n",
      "271: Name(new)\n",
      "272: Name(new2)\n",
      "273: Name(newArray)\n",
      "274: Name(new_array)\n",
      "275: Name(new_len)\n",
      "276: Name(new_list)\n",
      "277: Name(new_nums)\n",
      "278: Name(new_splt)\n",
      "279: Name(newlist)\n",
      "280: Name(np)\n",
      "281: Name(num3)\n",
      "282: Name(nums)\n",
      "283: Name(nums1)\n",
      "284: Name(nums1Left)\n",
      "285: Name(nums1Right)\n",
      "286: Name(nums1_left)\n",
      "287: Name(nums1_left_max)\n",
      "288: Name(nums1_right)\n",
      "289: Name(nums1_right_min)\n",
      "290: Name(nums2)\n",
      "291: Name(nums2Left)\n",
      "292: Name(nums2Right)\n",
      "293: Name(nums2_left)\n",
      "294: Name(nums2_left_max)\n",
      "295: Name(nums2_right)\n",
      "296: Name(nums2_right_min)\n",
      "297: Name(nums3)\n",
      "298: Name(object)\n",
      "299: Name(odd_median)\n",
      "300: Name(p)\n",
      "301: Name(p1)\n",
      "302: Name(p2)\n",
      "303: Name(partition1)\n",
      "304: Name(partition2)\n",
      "305: Name(partitionX)\n",
      "306: Name(partitionY)\n",
      "307: Name(print)\n",
      "308: Name(q)\n",
      "309: Name(r)\n",
      "310: Name(r1)\n",
      "311: Name(r2)\n",
      "312: Name(range)\n",
      "313: Name(res)\n",
      "314: Name(result)\n",
      "315: Name(right)\n",
      "316: Name(round)\n",
      "317: Name(s)\n",
      "318: Name(s_pointer)\n",
      "319: Name(self)\n",
      "320: Name(size)\n",
      "321: Name(solution)\n",
      "322: Name(sort)\n",
      "323: Name(sorted)\n",
      "324: Name(sorted_l)\n",
      "325: Name(sorted_merged_arr)\n",
      "326: Name(splt)\n",
      "327: Name(targetIndex)\n",
      "328: Name(total)\n",
      "329: Name(total_elements)\n",
      "330: Name(ttk)\n",
      "331: Name(upper_median)\n",
      "332: Name(val)\n",
      "333: Name(x)\n",
      "334: Name(y)\n",
      "335: Not\n",
      "336: NotEq\n",
      "337: Or\n",
      "338: Raise\n",
      "339: Return\n",
      "340: Slice\n",
      "341: Store\n",
      "342: Sub\n",
      "343: Subscript\n",
      "344: Tuple\n",
      "345: USub\n",
      "346: UnaryOp\n",
      "347: While\n",
      "348: alias\n",
      "349: arguments\n",
      "\n",
      "Line 1:\n",
      "State node representations: ['<state>ClassDef(Solution)']\n",
      "Action node representations: ['<state>ClassDef(Solution)', 'Name(object)']\n",
      "\n",
      "Line 2:\n",
      "State node representations: ['<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "\n",
      "Line 3:\n",
      "State node representations: ['<state>Assign(merged)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['<state>Assign(merged)', 'BinOp', 'Call(sorted)', 'Name(merged)', 'Name(nums1)', 'Name(nums2)', 'Name(sorted)']\n",
      "\n",
      "Line 4:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['<state>Assign(n)', 'Call(len)', 'Name(len)', 'Name(merged)', 'Name(n)']\n",
      "\n",
      "Line 5:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['BinOp', 'Compare', 'Constant', 'If', 'Name(n)']\n",
      "\n",
      "Line 6:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['BinOp', 'Constant', 'Name(merged)', 'Name(n)', 'Return', 'Subscript']\n",
      "\n",
      "Line 7:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: []\n",
      "\n",
      "Line 8:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['BinOp', 'Constant', 'Name(merged)', 'Name(n)', 'Return', 'Subscript']\n"
     ]
    }
   ],
   "source": [
    "code_temp = new_data[\"code\"][0]\n",
    "\n",
    "print(code_temp)\n",
    "\n",
    "states, actions = code_to_states(code_temp, vocab)\n",
    "\n",
    "print(\"Vocabulary:\")\n",
    "for node_repr, idx in vocab.items():\n",
    "    print(f\"{idx}: {node_repr}\")\n",
    "\n",
    "for i, (state_indices, action_indices) in enumerate(zip(states, actions)):\n",
    "    state_nodes = [list(vocab.keys())[list(vocab.values()).index(idx)] for idx in state_indices]\n",
    "    action_nodes = [list(vocab.keys())[list(vocab.values()).index(idx)] for idx in action_indices]\n",
    "    print(f\"\\nLine {i+1}:\")\n",
    "    print(f\"State node representations: {state_nodes}\")\n",
    "    print(f\"Action node representations: {action_nodes}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "150\n",
      "[136]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0.]\n",
      "350\n",
      "[136, 298]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "# one-hot check\n",
    "state_temp = new_data[\"states\"][0][0]\n",
    "action_temp = new_data[\"actions\"][0][0]\n",
    "one_hot_state_temp = new_data[\"one_hot_state\"][0][0]\n",
    "one_hot_action_temp = new_data[\"one_hot_action\"][0][0]\n",
    "\n",
    "print(len(one_hot_state_temp))\n",
    "print(state_temp)\n",
    "print(one_hot_state_temp)\n",
    "print(len(one_hot_action_temp))\n",
    "print(action_temp)\n",
    "print(one_hot_action_temp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### state define test 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "def get_identity(node):\n",
    "    identity_info = None\n",
    "    is_state_node = True\n",
    "\n",
    "    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.ClassDef):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.Name):\n",
    "        identity_info = node.id\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.arg):\n",
    "        identity_info = node.arg\n",
    "    elif isinstance(node, ast.Attribute):\n",
    "        identity_info = node.attr\n",
    "    elif isinstance(node, ast.Call):\n",
    "        if isinstance(node.func, ast.Name):\n",
    "            identity_info = node.func.id\n",
    "        elif isinstance(node.func, ast.Attribute):\n",
    "            identity_info = node.func.attr\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.Assign):\n",
    "        targets = []\n",
    "        for t in node.targets:\n",
    "            if isinstance(t, ast.Name):\n",
    "                targets.append(t.id)\n",
    "            elif isinstance(t, ast.Tuple):\n",
    "                targets.extend([elt.id for elt in t.elts if isinstance(elt, ast.Name)])\n",
    "        identity_info = ','.join(targets)\n",
    "    elif isinstance(node, ast.Import):\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = ','.join(names)\n",
    "    elif isinstance(node, ast.ImportFrom):\n",
    "        module = node.module or ''\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = f\"{module}::{','.join(names)}\"\n",
    "    else:\n",
    "        identity_info = None\n",
    "        is_state_node = False\n",
    "\n",
    "    return identity_info, is_state_node\n",
    "\n",
    "def get_node_reprs(node, node_reprs):\n",
    "    node_type = type(node).__name__\n",
    "    identity, is_state_node = get_identity(node)\n",
    "    if identity:\n",
    "        node_repr = f\"{node_type}({identity})\"\n",
    "    else:\n",
    "        node_repr = node_type\n",
    "    if is_state_node:\n",
    "        node_repr = \"<state>\" + node_repr\n",
    "    node_reprs.add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        get_node_reprs(child, node_reprs)\n",
    "\n",
    "def build_vocab(code_list):\n",
    "    node_reprs = set()\n",
    "    for code in code_list:\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "            get_node_reprs(tree, node_reprs)\n",
    "        except SyntaxError:\n",
    "            continue  # Skip code snippets with syntax errors\n",
    "    # Initialize vocabulary with <UNK> token\n",
    "    vocab = {\"<state><UNK>\": 0}\n",
    "    # Start indexing from 1\n",
    "    for idx, node_repr in enumerate(sorted(node_reprs), start=1):\n",
    "        vocab[node_repr] = idx\n",
    "    return vocab\n",
    "\n",
    "def collect_nodes_per_line(node, line_nodes):\n",
    "    lineno = getattr(node, 'lineno', None)\n",
    "    if lineno is not None:\n",
    "        lineno -= 1  # Adjust for zero-based index\n",
    "        if 0 <= lineno < len(line_nodes):\n",
    "            node_type = type(node).__name__\n",
    "            identity, is_state_node = get_identity(node)\n",
    "            if identity:\n",
    "                node_repr = f\"{node_type}({identity})\"\n",
    "            else:\n",
    "                node_repr = node_type\n",
    "            if is_state_node:\n",
    "                node_repr = \"<state>\" + node_repr\n",
    "            line_nodes[lineno].add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        collect_nodes_per_line(child, line_nodes)\n",
    "\n",
    "def code_to_states(code, vocab, unk_token=\"<state><UNK>\"):\n",
    "    lines = code.split('\\n')\n",
    "    N = len(lines)\n",
    "    line_nodes = [set() for _ in range(N)]\n",
    "    try:\n",
    "        tree = ast.parse(code)\n",
    "        collect_nodes_per_line(tree, line_nodes)\n",
    "        cumulative_nodes = set()\n",
    "        states = []\n",
    "        actions = []\n",
    "        for i in range(N):\n",
    "            # State: cumulative node representations up to current line\n",
    "            cumulative_nodes.update(line_nodes[i])\n",
    "            # Handle state nodes, map to vocab or <UNK>\n",
    "            state_indices = [\n",
    "                vocab.get(node_repr, vocab[unk_token])\n",
    "                for node_repr in sorted(cumulative_nodes)\n",
    "                if \"<state>\" in node_repr\n",
    "            ]\n",
    "            states.append(state_indices)\n",
    "            # Action: node representations that start at current line\n",
    "            action_nodes = line_nodes[i]\n",
    "            action_indices = [\n",
    "                vocab.get(node_repr, vocab[unk_token])\n",
    "                for node_repr in sorted(action_nodes)\n",
    "            ]\n",
    "            actions.append(action_indices)\n",
    "        return states, actions\n",
    "    except SyntaxError:\n",
    "        return [], []\n",
    "\n",
    "def one_hot_from_indices(indices_list, vocab_size, state_vocab_indices = [], is_state = False):\n",
    "    encode_list = []\n",
    "    for indices in indices_list:\n",
    "        one_hot_seq = np.zeros(vocab_size, dtype=np.float32)\n",
    "        for index in indices:\n",
    "            one_hot_seq[index] = 1.0\n",
    "        if is_state:\n",
    "            one_hot_seq = one_hot_seq[state_vocab_indices]\n",
    "        encode_list.append(one_hot_seq)\n",
    "    return encode_list\n",
    "\n",
    "def process_dataframe(df):\n",
    "    code_list = df['code'].tolist()\n",
    "    vocab = build_vocab(code_list)\n",
    "    results = df['code'].apply(lambda code: code_to_states(code, vocab))\n",
    "    df['states'] = [res[0] for res in results]\n",
    "    df['actions'] = [res[1] for res in results]\n",
    "    \n",
    "    # Separate state and action vocab indices\n",
    "    state_vocab_indices = [v for k, v in vocab.items() if \"<state>\" in k]\n",
    "    action_vocab_indices = [v for k, v in vocab.items() if \"<action>\" not in k]\n",
    "    \n",
    "    vocab_size = len(vocab)\n",
    "    state_vocab_size = len(state_vocab_indices)\n",
    "    \n",
    "    # Create mapping for states\n",
    "    state_map = {k: v for k, v in vocab.items() if \"<state>\" in k}\n",
    "    \n",
    "    # One-hot encode states and actions\n",
    "    df['one_hot_state'] = df['states'].apply(\n",
    "        lambda state: one_hot_from_indices(\n",
    "            state, vocab_size, state_vocab_indices, True\n",
    "        )\n",
    "    )\n",
    "    df['one_hot_action'] = df['actions'].apply(\n",
    "        lambda action: one_hot_from_indices(\n",
    "            action, vocab_size\n",
    "        )\n",
    "    )\n",
    "    return df, vocab\n",
    "\n",
    "def process_new_code(code, vocab, unk_token=\"<state><UNK>\"):\n",
    "    states, actions = code_to_states(code, vocab, unk_token=unk_token)\n",
    "    vocab_size = len(vocab)\n",
    "    state_vocab_indices = [v for k, v in vocab.items() if \"<state>\" in k]\n",
    "    \n",
    "    states_one_hot = one_hot_from_indices(states, vocab_size, state_vocab_indices, True)\n",
    "    actions_one_hot = one_hot_from_indices(actions, vocab_size)\n",
    "    \n",
    "    return states_one_hot, actions_one_hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data, vocab = process_dataframe(expanded_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data, vocab = process_dataframe(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[137], [137, 138, 146, 147, 150], [53, 137, 138, 146, 147, 150], [53, 71, 137, 138, 146, 147, 150], [53, 71, 137, 138, 146, 147, 150], [53, 71, 137, 138, 146, 147, 150], [53, 71, 137, 138, 146, 147, 150], [53, 71, 137, 138, 146, 147, 150]]\n",
      "[[137, 299], [138, 146, 147, 150], [53, 154, 175, 250, 284, 291, 324], [71, 166, 227, 250, 269], [154, 176, 177, 185, 269], [154, 177, 250, 269, 340, 344], [], [154, 177, 250, 269, 340, 344]]\n"
     ]
    }
   ],
   "source": [
    "print(new_data[\"states\"][0])\n",
    "print(new_data[\"actions\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class Solution(object):\n",
      "    def findMedianSortedArrays(self, nums1, nums2):\n",
      "        merged = sorted(nums1+nums2)\n",
      "        n = len(merged)\n",
      "        if n%2==1:\n",
      "            return merged[n//2]\n",
      "        else:\n",
      "            return (merged[n//2]+merged[n//2-1])/2.0\n",
      "Vocabulary:\n",
      "0: <state><UNK>\n",
      "1: <state>Assign\n",
      "2: <state>Assign(L)\n",
      "3: <state>Assign(M)\n",
      "4: <state>Assign(N)\n",
      "5: <state>Assign(a)\n",
      "6: <state>Assign(ans)\n",
      "7: <state>Assign(array)\n",
      "8: <state>Assign(b)\n",
      "9: <state>Assign(com_arr)\n",
      "10: <state>Assign(f_pointer)\n",
      "11: <state>Assign(half)\n",
      "12: <state>Assign(half_len)\n",
      "13: <state>Assign(high)\n",
      "14: <state>Assign(i)\n",
      "15: <state>Assign(i,j,k)\n",
      "16: <state>Assign(imax)\n",
      "17: <state>Assign(imin)\n",
      "18: <state>Assign(imin,imax,half_len)\n",
      "19: <state>Assign(ind)\n",
      "20: <state>Assign(index)\n",
      "21: <state>Assign(int_upper_median)\n",
      "22: <state>Assign(j)\n",
      "23: <state>Assign(k)\n",
      "24: <state>Assign(l)\n",
      "25: <state>Assign(l,r)\n",
      "26: <state>Assign(l1)\n",
      "27: <state>Assign(l1,l2,r1,r2)\n",
      "28: <state>Assign(l2)\n",
      "29: <state>Assign(left)\n",
      "30: <state>Assign(left,right)\n",
      "31: <state>Assign(len1)\n",
      "32: <state>Assign(len2)\n",
      "33: <state>Assign(len_list)\n",
      "34: <state>Assign(length)\n",
      "35: <state>Assign(low)\n",
      "36: <state>Assign(low,high)\n",
      "37: <state>Assign(ls)\n",
      "38: <state>Assign(m)\n",
      "39: <state>Assign(m,n)\n",
      "40: <state>Assign(m1)\n",
      "41: <state>Assign(m2)\n",
      "42: <state>Assign(m_length)\n",
      "43: <state>Assign(maxLeftX)\n",
      "44: <state>Assign(maxLeftY)\n",
      "45: <state>Assign(max_left1)\n",
      "46: <state>Assign(max_left2)\n",
      "47: <state>Assign(max_of_left)\n",
      "48: <state>Assign(max_size)\n",
      "49: <state>Assign(med)\n",
      "50: <state>Assign(med1)\n",
      "51: <state>Assign(med2)\n",
      "52: <state>Assign(median)\n",
      "53: <state>Assign(merged)\n",
      "54: <state>Assign(merged_arr)\n",
      "55: <state>Assign(merged_array)\n",
      "56: <state>Assign(mid)\n",
      "57: <state>Assign(mid1)\n",
      "58: <state>Assign(mid2)\n",
      "59: <state>Assign(midIndex)\n",
      "60: <state>Assign(midNums1)\n",
      "61: <state>Assign(midNums2)\n",
      "62: <state>Assign(mid_val)\n",
      "63: <state>Assign(middle)\n",
      "64: <state>Assign(middle1)\n",
      "65: <state>Assign(middle2)\n",
      "66: <state>Assign(minRightX)\n",
      "67: <state>Assign(minRightY)\n",
      "68: <state>Assign(min_of_right)\n",
      "69: <state>Assign(min_right1)\n",
      "70: <state>Assign(min_right2)\n",
      "71: <state>Assign(n)\n",
      "72: <state>Assign(n1,n2)\n",
      "73: <state>Assign(new)\n",
      "74: <state>Assign(new2)\n",
      "75: <state>Assign(newArray)\n",
      "76: <state>Assign(new_array,N)\n",
      "77: <state>Assign(new_len)\n",
      "78: <state>Assign(new_list)\n",
      "79: <state>Assign(new_nums)\n",
      "80: <state>Assign(new_splt)\n",
      "81: <state>Assign(newlist)\n",
      "82: <state>Assign(num3)\n",
      "83: <state>Assign(nums)\n",
      "84: <state>Assign(nums1)\n",
      "85: <state>Assign(nums1,nums2)\n",
      "86: <state>Assign(nums1,nums2,m,n)\n",
      "87: <state>Assign(nums1Left)\n",
      "88: <state>Assign(nums1Right)\n",
      "89: <state>Assign(nums1_left)\n",
      "90: <state>Assign(nums1_left_max)\n",
      "91: <state>Assign(nums1_right)\n",
      "92: <state>Assign(nums1_right_min)\n",
      "93: <state>Assign(nums2)\n",
      "94: <state>Assign(nums2Left)\n",
      "95: <state>Assign(nums2Right)\n",
      "96: <state>Assign(nums2_left)\n",
      "97: <state>Assign(nums2_left_max)\n",
      "98: <state>Assign(nums2_right)\n",
      "99: <state>Assign(nums2_right_min)\n",
      "100: <state>Assign(nums3)\n",
      "101: <state>Assign(odd_median)\n",
      "102: <state>Assign(p1,p2,np)\n",
      "103: <state>Assign(partition1)\n",
      "104: <state>Assign(partition2)\n",
      "105: <state>Assign(partitionX)\n",
      "106: <state>Assign(partitionY)\n",
      "107: <state>Assign(r)\n",
      "108: <state>Assign(r1)\n",
      "109: <state>Assign(r2)\n",
      "110: <state>Assign(res)\n",
      "111: <state>Assign(result)\n",
      "112: <state>Assign(right)\n",
      "113: <state>Assign(s)\n",
      "114: <state>Assign(s_pointer)\n",
      "115: <state>Assign(self)\n",
      "116: <state>Assign(size)\n",
      "117: <state>Assign(solution)\n",
      "118: <state>Assign(sort)\n",
      "119: <state>Assign(sorted_l)\n",
      "120: <state>Assign(sorted_merged_arr)\n",
      "121: <state>Assign(splt)\n",
      "122: <state>Assign(targetIndex)\n",
      "123: <state>Assign(total)\n",
      "124: <state>Assign(total_elements)\n",
      "125: <state>Assign(ttk)\n",
      "126: <state>Assign(upper_median)\n",
      "127: <state>Assign(val)\n",
      "128: <state>Assign(x)\n",
      "129: <state>Assign(x,y)\n",
      "130: <state>Attribute(append)\n",
      "131: <state>Attribute(ceil)\n",
      "132: <state>Attribute(extend)\n",
      "133: <state>Attribute(findMedianSortedArrays)\n",
      "134: <state>Attribute(floor)\n",
      "135: <state>Attribute(merge)\n",
      "136: <state>Attribute(sort)\n",
      "137: <state>ClassDef(Solution)\n",
      "138: <state>FunctionDef(findMedianSortedArrays)\n",
      "139: <state>FunctionDef(median)\n",
      "140: <state>FunctionDef(merge)\n",
      "141: <state>Import(math)\n",
      "142: <state>arg(a)\n",
      "143: <state>arg(array1)\n",
      "144: <state>arg(array2)\n",
      "145: <state>arg(b)\n",
      "146: <state>arg(nums1)\n",
      "147: <state>arg(nums2)\n",
      "148: <state>arg(p)\n",
      "149: <state>arg(q)\n",
      "150: <state>arg(self)\n",
      "151: Add\n",
      "152: And\n",
      "153: AugAssign\n",
      "154: BinOp\n",
      "155: BitAnd\n",
      "156: BoolOp\n",
      "157: Call(Solution)\n",
      "158: Call(ValueError)\n",
      "159: Call(append)\n",
      "160: Call(ceil)\n",
      "161: Call(extend)\n",
      "162: Call(findMedianSortedArrays)\n",
      "163: Call(float)\n",
      "164: Call(floor)\n",
      "165: Call(int)\n",
      "166: Call(len)\n",
      "167: Call(max)\n",
      "168: Call(median)\n",
      "169: Call(merge)\n",
      "170: Call(min)\n",
      "171: Call(print)\n",
      "172: Call(range)\n",
      "173: Call(round)\n",
      "174: Call(sort)\n",
      "175: Call(sorted)\n",
      "176: Compare\n",
      "177: Constant\n",
      "178: Div\n",
      "179: Eq\n",
      "180: Expr\n",
      "181: FloorDiv\n",
      "182: For\n",
      "183: Gt\n",
      "184: GtE\n",
      "185: If\n",
      "186: IfExp\n",
      "187: Is\n",
      "188: List\n",
      "189: Load\n",
      "190: Lt\n",
      "191: LtE\n",
      "192: Mod\n",
      "193: Module\n",
      "194: Mult\n",
      "195: Name(L)\n",
      "196: Name(M)\n",
      "197: Name(N)\n",
      "198: Name(Solution)\n",
      "199: Name(ValueError)\n",
      "200: Name(__name__)\n",
      "201: Name(a)\n",
      "202: Name(ans)\n",
      "203: Name(array)\n",
      "204: Name(array1)\n",
      "205: Name(array2)\n",
      "206: Name(b)\n",
      "207: Name(com_arr)\n",
      "208: Name(count)\n",
      "209: Name(f_pointer)\n",
      "210: Name(float)\n",
      "211: Name(half)\n",
      "212: Name(half_len)\n",
      "213: Name(high)\n",
      "214: Name(i)\n",
      "215: Name(imax)\n",
      "216: Name(imin)\n",
      "217: Name(ind)\n",
      "218: Name(index)\n",
      "219: Name(int)\n",
      "220: Name(int_upper_median)\n",
      "221: Name(j)\n",
      "222: Name(k)\n",
      "223: Name(l)\n",
      "224: Name(l1)\n",
      "225: Name(l2)\n",
      "226: Name(left)\n",
      "227: Name(len)\n",
      "228: Name(len1)\n",
      "229: Name(len2)\n",
      "230: Name(len_list)\n",
      "231: Name(length)\n",
      "232: Name(low)\n",
      "233: Name(ls)\n",
      "234: Name(m)\n",
      "235: Name(m1)\n",
      "236: Name(m2)\n",
      "237: Name(m_length)\n",
      "238: Name(math)\n",
      "239: Name(max)\n",
      "240: Name(maxLeftX)\n",
      "241: Name(maxLeftY)\n",
      "242: Name(max_left1)\n",
      "243: Name(max_left2)\n",
      "244: Name(max_of_left)\n",
      "245: Name(max_size)\n",
      "246: Name(med)\n",
      "247: Name(med1)\n",
      "248: Name(med2)\n",
      "249: Name(median)\n",
      "250: Name(merged)\n",
      "251: Name(merged_arr)\n",
      "252: Name(merged_array)\n",
      "253: Name(mid)\n",
      "254: Name(mid1)\n",
      "255: Name(mid2)\n",
      "256: Name(midIndex)\n",
      "257: Name(midNums1)\n",
      "258: Name(midNums2)\n",
      "259: Name(mid_val)\n",
      "260: Name(middle)\n",
      "261: Name(middle1)\n",
      "262: Name(middle2)\n",
      "263: Name(min)\n",
      "264: Name(minRightX)\n",
      "265: Name(minRightY)\n",
      "266: Name(min_of_right)\n",
      "267: Name(min_right1)\n",
      "268: Name(min_right2)\n",
      "269: Name(n)\n",
      "270: Name(n1)\n",
      "271: Name(n2)\n",
      "272: Name(new)\n",
      "273: Name(new2)\n",
      "274: Name(newArray)\n",
      "275: Name(new_array)\n",
      "276: Name(new_len)\n",
      "277: Name(new_list)\n",
      "278: Name(new_nums)\n",
      "279: Name(new_splt)\n",
      "280: Name(newlist)\n",
      "281: Name(np)\n",
      "282: Name(num3)\n",
      "283: Name(nums)\n",
      "284: Name(nums1)\n",
      "285: Name(nums1Left)\n",
      "286: Name(nums1Right)\n",
      "287: Name(nums1_left)\n",
      "288: Name(nums1_left_max)\n",
      "289: Name(nums1_right)\n",
      "290: Name(nums1_right_min)\n",
      "291: Name(nums2)\n",
      "292: Name(nums2Left)\n",
      "293: Name(nums2Right)\n",
      "294: Name(nums2_left)\n",
      "295: Name(nums2_left_max)\n",
      "296: Name(nums2_right)\n",
      "297: Name(nums2_right_min)\n",
      "298: Name(nums3)\n",
      "299: Name(object)\n",
      "300: Name(odd_median)\n",
      "301: Name(p)\n",
      "302: Name(p1)\n",
      "303: Name(p2)\n",
      "304: Name(partition1)\n",
      "305: Name(partition2)\n",
      "306: Name(partitionX)\n",
      "307: Name(partitionY)\n",
      "308: Name(print)\n",
      "309: Name(q)\n",
      "310: Name(r)\n",
      "311: Name(r1)\n",
      "312: Name(r2)\n",
      "313: Name(range)\n",
      "314: Name(res)\n",
      "315: Name(result)\n",
      "316: Name(right)\n",
      "317: Name(round)\n",
      "318: Name(s)\n",
      "319: Name(s_pointer)\n",
      "320: Name(self)\n",
      "321: Name(size)\n",
      "322: Name(solution)\n",
      "323: Name(sort)\n",
      "324: Name(sorted)\n",
      "325: Name(sorted_l)\n",
      "326: Name(sorted_merged_arr)\n",
      "327: Name(splt)\n",
      "328: Name(targetIndex)\n",
      "329: Name(total)\n",
      "330: Name(total_elements)\n",
      "331: Name(ttk)\n",
      "332: Name(upper_median)\n",
      "333: Name(val)\n",
      "334: Name(x)\n",
      "335: Name(y)\n",
      "336: Not\n",
      "337: NotEq\n",
      "338: Or\n",
      "339: Raise\n",
      "340: Return\n",
      "341: Slice\n",
      "342: Store\n",
      "343: Sub\n",
      "344: Subscript\n",
      "345: Tuple\n",
      "346: USub\n",
      "347: UnaryOp\n",
      "348: While\n",
      "349: alias\n",
      "350: arguments\n",
      "\n",
      "Line 1:\n",
      "State node representations: ['<state>ClassDef(Solution)']\n",
      "Action node representations: ['<state>ClassDef(Solution)', 'Name(object)']\n",
      "\n",
      "Line 2:\n",
      "State node representations: ['<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "\n",
      "Line 3:\n",
      "State node representations: ['<state>Assign(merged)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['<state>Assign(merged)', 'BinOp', 'Call(sorted)', 'Name(merged)', 'Name(nums1)', 'Name(nums2)', 'Name(sorted)']\n",
      "\n",
      "Line 4:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['<state>Assign(n)', 'Call(len)', 'Name(len)', 'Name(merged)', 'Name(n)']\n",
      "\n",
      "Line 5:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['BinOp', 'Compare', 'Constant', 'If', 'Name(n)']\n",
      "\n",
      "Line 6:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['BinOp', 'Constant', 'Name(merged)', 'Name(n)', 'Return', 'Subscript']\n",
      "\n",
      "Line 7:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: []\n",
      "\n",
      "Line 8:\n",
      "State node representations: ['<state>Assign(merged)', '<state>Assign(n)', '<state>ClassDef(Solution)', '<state>FunctionDef(findMedianSortedArrays)', '<state>arg(nums1)', '<state>arg(nums2)', '<state>arg(self)']\n",
      "Action node representations: ['BinOp', 'Constant', 'Name(merged)', 'Name(n)', 'Return', 'Subscript']\n"
     ]
    }
   ],
   "source": [
    "code_temp = new_data[\"code\"][0]\n",
    "\n",
    "print(code_temp)\n",
    "\n",
    "states, actions = code_to_states(code_temp, vocab)\n",
    "\n",
    "print(\"Vocabulary:\")\n",
    "for node_repr, idx in vocab.items():\n",
    "    print(f\"{idx}: {node_repr}\")\n",
    "\n",
    "for i, (state_indices, action_indices) in enumerate(zip(states, actions)):\n",
    "    state_nodes = [list(vocab.keys())[list(vocab.values()).index(idx)] for idx in state_indices]\n",
    "    action_nodes = [list(vocab.keys())[list(vocab.values()).index(idx)] for idx in action_indices]\n",
    "    print(f\"\\nLine {i+1}:\")\n",
    "    print(f\"State node representations: {state_nodes}\")\n",
    "    print(f\"Action node representations: {action_nodes}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "151\n",
      "[53, 137, 138, 146, 147, 150]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0.\n",
      " 0. 0. 1. 1. 0. 0. 1.]\n",
      "351\n",
      "[53, 154, 175, 250, 284, 291, 324]\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
      " 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "# one-hot check\n",
    "state_temp = new_data[\"states\"][0][2]\n",
    "action_temp = new_data[\"actions\"][0][2]\n",
    "one_hot_state_temp = new_data[\"one_hot_state\"][0][2]\n",
    "one_hot_action_temp = new_data[\"one_hot_action\"][0][2]\n",
    "\n",
    "print(len(one_hot_state_temp))\n",
    "print(state_temp)\n",
    "print(one_hot_state_temp)\n",
    "print(len(one_hot_action_temp))\n",
    "print(action_temp)\n",
    "print(one_hot_action_temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0.]\n",
      "[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "# unseen code encoding test\n",
    "unseen_code = \"\"\"class is_not_a_Solution(object):\n",
    "    def findMedianSortedArrays(self, nums1, nums2):\n",
    "        merged = sorted(nums1+nums2)\n",
    "        n = len(merged)\n",
    "        if n%2==1:\n",
    "            return merged[n//2]\n",
    "        else:\n",
    "            return (merged[n//2]+merged[n//2-1])/2.0\"\"\"\n",
    "\n",
    "unseen_states_one_hot, unseen_actions_one_hot = process_new_code(unseen_code, vocab, unk_token=\"<state><UNK>\")\n",
    "\n",
    "print(unseen_states_one_hot[0])\n",
    "print(unseen_actions_one_hot[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### state define test 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                code\n",
      "0                   def add(a, b):\\n    return a + b\n",
      "1  class Calculator:\\n    def multiply(self, x, y...\n",
      "2  import math\\nfrom os import path\\nx = math.sqr...\n",
      "Vocabulary:\n",
      "<state><UNK>: 0\n",
      "Name_U1: 1\n",
      "Name_U2: 2\n",
      "Name_U3: 3\n",
      "Name_U4: 4\n",
      "Name_U5: 5\n",
      "Name_U6: 6\n",
      "Name_U7: 7\n",
      "Name_U8: 8\n",
      "Name_U9: 9\n",
      "Name_U10: 10\n",
      "Name_<UNKNOWN>: 11\n",
      "arg_U1: 12\n",
      "arg_U2: 13\n",
      "arg_U3: 14\n",
      "arg_U4: 15\n",
      "arg_U5: 16\n",
      "arg_U6: 17\n",
      "arg_U7: 18\n",
      "arg_U8: 19\n",
      "arg_U9: 20\n",
      "arg_U10: 21\n",
      "arg_<UNKNOWN>: 22\n",
      "Assign_U1: 23\n",
      "Assign_U2: 24\n",
      "Assign_U3: 25\n",
      "Assign_U4: 26\n",
      "Assign_U5: 27\n",
      "Assign_U6: 28\n",
      "Assign_U7: 29\n",
      "Assign_U8: 30\n",
      "Assign_U9: 31\n",
      "Assign_U10: 32\n",
      "Assign_<UNKNOWN>: 33\n",
      "<state>Assign(Assign_U1): 34\n",
      "<state>Attribute(Attribute): 35\n",
      "<state>ClassDef(ClassDef): 36\n",
      "<state>FunctionDef(FunctionDef): 37\n",
      "<state>Import(math): 38\n",
      "<state>ImportFrom(os::path): 39\n",
      "<state>arg(arg_U1): 40\n",
      "<state>arg(arg_U2): 41\n",
      "<state>arg(arg_U3): 42\n",
      "Add: 43\n",
      "BinOp: 44\n",
      "Call(sqrt): 45\n",
      "Constant: 46\n",
      "Load: 47\n",
      "Module: 48\n",
      "Mult: 49\n",
      "Name(Name_U1): 50\n",
      "Name(Name_U2): 51\n",
      "Return: 52\n",
      "Store: 53\n",
      "alias: 54\n",
      "arguments: 55\n",
      "\n",
      "Processed DataFrame:\n",
      "                                              states  \\\n",
      "0                       [[37, 40, 41], [37, 40, 41]]   \n",
      "1  [[36], [36, 37, 40, 41, 42], [36, 37, 40, 41, ...   \n",
      "2                 [[38], [38, 39], [34, 35, 38, 39]]   \n",
      "\n",
      "                                      actions  \\\n",
      "0            [[37, 40, 41], [44, 50, 51, 52]]   \n",
      "1  [[36], [37, 40, 41, 42], [44, 50, 51, 52]]   \n",
      "2      [[38], [39], [34, 35, 45, 46, 50, 51]]   \n",
      "\n",
      "                                       one_hot_state  \\\n",
      "0  [[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0,...   \n",
      "1  [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,...   \n",
      "2  [[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,...   \n",
      "\n",
      "                                      one_hot_action  \n",
      "0  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...  \n",
      "1  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...  \n",
      "2  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...  \n",
      "\n",
      "New Code One-Hot States:\n",
      "[array([0., 0., 0., 0., 1., 0., 0., 1., 1., 0.], dtype=float32), array([0., 0., 0., 0., 1., 0., 0., 1., 1., 0.], dtype=float32)]\n",
      "\n",
      "New Code One-Hot Actions:\n",
      "[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0.], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1.,\n",
      "       1., 1., 0., 0., 0.], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "# Configuration: Number of unknown tokens per node type\n",
    "UNKNOWN_LIMIT = 10  # Example value; can be adjusted based on requirements\n",
    "\n",
    "def get_identity(node, unknown_counters, mappings):\n",
    "    identity_info = None\n",
    "    is_state_node = True\n",
    "\n",
    "    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.ClassDef):\n",
    "        identity_info = node.name\n",
    "    elif isinstance(node, ast.Name):\n",
    "        identity_info = node.id\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.arg):\n",
    "        identity_info = node.arg\n",
    "    elif isinstance(node, ast.Attribute):\n",
    "        identity_info = node.attr\n",
    "    elif isinstance(node, ast.Call):\n",
    "        if isinstance(node.func, ast.Name):\n",
    "            identity_info = node.func.id\n",
    "        elif isinstance(node.func, ast.Attribute):\n",
    "            identity_info = node.func.attr\n",
    "        is_state_node = False\n",
    "    elif isinstance(node, ast.Assign):\n",
    "        targets = []\n",
    "        for t in node.targets:\n",
    "            if isinstance(t, ast.Name):\n",
    "                targets.append(t.id)\n",
    "            elif isinstance(t, ast.Tuple):\n",
    "                targets.extend([elt.id for elt in t.elts if isinstance(elt, ast.Name)])\n",
    "        identity_info = ','.join(targets)\n",
    "    elif isinstance(node, ast.Import):\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = ','.join(names)\n",
    "    elif isinstance(node, ast.ImportFrom):\n",
    "        module = node.module or ''\n",
    "        names = [alias.name for alias in node.names]\n",
    "        identity_info = f\"{module}::{','.join(names)}\"\n",
    "    else:\n",
    "        identity_info = None\n",
    "        is_state_node = False\n",
    "\n",
    "    # Abstract naming nodes\n",
    "    naming_nodes = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef,\n",
    "                    ast.Name, ast.arg, ast.Attribute, ast.Assign)\n",
    "    \n",
    "    if isinstance(node, naming_nodes):\n",
    "        node_type = type(node).__name__\n",
    "        if node_type in ['Name', 'arg', 'Assign']:\n",
    "            # Handle limited unknown tokens\n",
    "            if isinstance(node, ast.Assign):\n",
    "                key = 'Assign'\n",
    "            elif isinstance(node, ast.arg):\n",
    "                key = 'arg'\n",
    "            elif isinstance(node, ast.Name):\n",
    "                key = 'Name'\n",
    "            else:\n",
    "                key = node_type\n",
    "\n",
    "            # Initialize counter and mapping if not present\n",
    "            if key not in unknown_counters:\n",
    "                unknown_counters[key] = 0\n",
    "            if key not in mappings:\n",
    "                mappings[key] = {}\n",
    "\n",
    "            # Assign to predefined unknown tokens\n",
    "            if identity_info not in mappings[key]:\n",
    "                if unknown_counters[key] < UNKNOWN_LIMIT:\n",
    "                    mappings[key][identity_info] = f\"{key}_U{unknown_counters[key]+1}\"\n",
    "                    unknown_counters[key] += 1\n",
    "                else:\n",
    "                    mappings[key][identity_info] = f\"{key}_<UNKNOWN>\"\n",
    "\n",
    "            # Use the abstracted token\n",
    "            identity_info = mappings[key][identity_info]\n",
    "        else:\n",
    "            # For other naming nodes, keep the abstraction\n",
    "            identity_info = type(node).__name__\n",
    "\n",
    "    return identity_info, is_state_node\n",
    "\n",
    "def get_node_reprs(node, node_reprs, unknown_counters, mappings):\n",
    "    node_type = type(node).__name__\n",
    "    identity, is_state_node = get_identity(node, unknown_counters, mappings)\n",
    "    if identity:\n",
    "        node_repr = f\"{node_type}({identity})\"\n",
    "    else:\n",
    "        node_repr = node_type\n",
    "    if is_state_node:\n",
    "        node_repr = \"<state>\" + node_repr\n",
    "    node_reprs.add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        get_node_reprs(child, node_reprs, unknown_counters, mappings)\n",
    "\n",
    "def build_vocab(code_list, unknown_limit=UNKNOWN_LIMIT):\n",
    "    node_reprs = set()\n",
    "    for code in code_list:\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "            # Initialize counters and mappings for each code snippet\n",
    "            unknown_counters = {}\n",
    "            mappings = {}\n",
    "            get_node_reprs(tree, node_reprs, unknown_counters, mappings)\n",
    "        except SyntaxError:\n",
    "            continue  # Skip code snippets with syntax errors\n",
    "    # Initialize vocabulary with <UNK> token\n",
    "    vocab = {\"<state><UNK>\": 0}\n",
    "    # Predefine unknown tokens for naming node types\n",
    "    naming_node_types = ['Name', 'arg', 'Assign']\n",
    "    for node_type in naming_node_types:\n",
    "        for i in range(1, unknown_limit + 1):\n",
    "            vocab[f\"{node_type}_U{i}\"] = len(vocab)\n",
    "        vocab[f\"{node_type}_<UNKNOWN>\"] = len(vocab)\n",
    "    # Start indexing other node representations from current vocab size\n",
    "    for idx, node_repr in enumerate(sorted(node_reprs), start=len(vocab)):\n",
    "        if node_repr not in vocab:\n",
    "            vocab[node_repr] = idx\n",
    "    return vocab\n",
    "\n",
    "def collect_nodes_per_line(node, line_nodes, unknown_counters, mappings):\n",
    "    lineno = getattr(node, 'lineno', None)\n",
    "    if lineno is not None:\n",
    "        lineno -= 1  # Adjust for zero-based index\n",
    "        if 0 <= lineno < len(line_nodes):\n",
    "            node_type = type(node).__name__\n",
    "            identity, is_state_node = get_identity(node, unknown_counters, mappings)\n",
    "            if identity:\n",
    "                node_repr = f\"{node_type}({identity})\"\n",
    "            else:\n",
    "                node_repr = node_type\n",
    "            if is_state_node:\n",
    "                node_repr = \"<state>\" + node_repr\n",
    "            line_nodes[lineno].add(node_repr)\n",
    "    for child in ast.iter_child_nodes(node):\n",
    "        collect_nodes_per_line(child, line_nodes, unknown_counters, mappings)\n",
    "\n",
    "def code_to_states(code, vocab, unk_token=\"<state><UNK>\"):\n",
    "    lines = code.split('\\n')\n",
    "    N = len(lines)\n",
    "    line_nodes = [set() for _ in range(N)]\n",
    "    try:\n",
    "        tree = ast.parse(code)\n",
    "        # Initialize counters and mappings for this code snippet\n",
    "        unknown_counters = {}\n",
    "        mappings = {}\n",
    "        collect_nodes_per_line(tree, line_nodes, unknown_counters, mappings)\n",
    "        cumulative_nodes = set()\n",
    "        states = []\n",
    "        actions = []\n",
    "        for i in range(N):\n",
    "            # State: cumulative node representations up to current line\n",
    "            cumulative_nodes.update(line_nodes[i])\n",
    "            # Handle state nodes, map to vocab or <UNK>\n",
    "            state_indices = [\n",
    "                vocab.get(node_repr, vocab[unk_token])\n",
    "                for node_repr in sorted(cumulative_nodes)\n",
    "                if \"<state>\" in node_repr\n",
    "            ]\n",
    "            states.append(state_indices)\n",
    "            # Action: node representations that start at current line\n",
    "            action_nodes = line_nodes[i]\n",
    "            action_indices = [\n",
    "                vocab.get(node_repr, vocab[unk_token])\n",
    "                for node_repr in sorted(action_nodes)\n",
    "            ]\n",
    "            actions.append(action_indices)\n",
    "        return states, actions\n",
    "    except SyntaxError:\n",
    "        return [], []\n",
    "\n",
    "def one_hot_from_indices(indices_list, vocab_size, state_vocab_indices=None, is_state=False):\n",
    "    encode_list = []\n",
    "    for indices in indices_list:\n",
    "        one_hot_seq = np.zeros(vocab_size, dtype=np.float32)\n",
    "        for index in indices:\n",
    "            one_hot_seq[index] = 1.0\n",
    "        if is_state and state_vocab_indices is not None:\n",
    "            one_hot_seq = one_hot_seq[state_vocab_indices]\n",
    "        encode_list.append(one_hot_seq)\n",
    "    return encode_list\n",
    "\n",
    "def process_dataframe(df, unknown_limit=UNKNOWN_LIMIT):\n",
    "    code_list = df['code'].tolist()\n",
    "    vocab = build_vocab(code_list, unknown_limit)\n",
    "    results = df['code'].apply(lambda code: code_to_states(code, vocab))\n",
    "    df['states'] = [res[0] for res in results]\n",
    "    df['actions'] = [res[1] for res in results]\n",
    "    \n",
    "    # Separate state and action vocab indices\n",
    "    state_vocab_indices = [v for k, v in vocab.items() if \"<state>\" in k]\n",
    "    action_vocab_indices = [v for k, v in vocab.items() if not any(kt in k for kt in [\"<state>\", \"Name_U\", \"arg_U\", \"Assign_U\"])]\n",
    "    \n",
    "    vocab_size = len(vocab)\n",
    "    state_vocab_size = len(state_vocab_indices)\n",
    "    \n",
    "    # Create mapping for states\n",
    "    state_map = {k: v for k, v in vocab.items() if \"<state>\" in k}\n",
    "    \n",
    "    # One-hot encode states and actions\n",
    "    df['one_hot_state'] = df['states'].apply(\n",
    "        lambda state: one_hot_from_indices(\n",
    "            state, vocab_size, state_vocab_indices, True\n",
    "        )\n",
    "    )\n",
    "    df['one_hot_action'] = df['actions'].apply(\n",
    "        lambda action: one_hot_from_indices(\n",
    "            action, vocab_size\n",
    "        )\n",
    "    )\n",
    "    return df, vocab\n",
    "\n",
    "def process_new_code(code, vocab, unknown_limit=UNKNOWN_LIMIT, unk_token=\"<state><UNK>\"):\n",
    "    states, actions = code_to_states(code, vocab, unk_token=unk_token)\n",
    "    vocab_size = len(vocab)\n",
    "    state_vocab_indices = [v for k, v in vocab.items() if \"<state>\" in k]\n",
    "    \n",
    "    states_one_hot = one_hot_from_indices(states, vocab_size, state_vocab_indices, True)\n",
    "    actions_one_hot = one_hot_from_indices(actions, vocab_size)\n",
    "    \n",
    "    return states_one_hot, actions_one_hot\n",
    "\n",
    "# Example Usage\n",
    "if __name__ == \"__main__\":\n",
    "    # Sample DataFrame\n",
    "    data = {\n",
    "        'code': [\n",
    "            \"def add(a, b):\\n    return a + b\",\n",
    "            \"class Calculator:\\n    def multiply(self, x, y):\\n        return x * y\",\n",
    "            \"import math\\nfrom os import path\\nx = math.sqrt(16)\"\n",
    "        ]\n",
    "    }\n",
    "    df = pd.DataFrame(data)\n",
    "    print(df.head())\n",
    "    # Process DataFrame\n",
    "    processed_df, vocabulary = process_dataframe(df)\n",
    "    \n",
    "    print(\"Vocabulary:\")\n",
    "    for k, v in vocabulary.items():\n",
    "        print(f\"{k}: {v}\")\n",
    "    \n",
    "    print(\"\\nProcessed DataFrame:\")\n",
    "    print(processed_df[['states', 'actions', 'one_hot_state', 'one_hot_action']])\n",
    "    \n",
    "    # Process new code snippet\n",
    "    new_code = \"def subtract(a, b):\\n    return a - b\"\n",
    "    states_one_hot, actions_one_hot = process_new_code(new_code, vocabulary)\n",
    "    print(\"\\nNew Code One-Hot States:\")\n",
    "    print(states_one_hot)\n",
    "    print(\"\\nNew Code One-Hot Actions:\")\n",
    "    print(actions_one_hot)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IRL test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Assume 'data' is the DataFrame provided\n",
    "# data = pd.read_csv('path_to_data.csv')\n",
    "\n",
    "# Hyperparameters\n",
    "STATE_SIZE = len(data['one_hot_state'][0][0])  # Size of the state space\n",
    "ACTION_SIZE = len(data['one_hot_action'][0][0])  # Size of the action space\n",
    "EMBEDDING_SIZE = 128\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 100\n",
    "LEARNING_RATE = 1e-3\n",
    "\n",
    "# Dataset Preparation\n",
    "class TrajectoryDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.state_action_pairs = []\n",
    "        for idx, row in data.iterrows():\n",
    "            states = torch.tensor(row['one_hot_state'], dtype=torch.float)\n",
    "            actions = torch.tensor(row['one_hot_action'], dtype=torch.float)\n",
    "            self.state_action_pairs.extend(list(zip(states, actions)))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state_action_pairs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.state_action_pairs[idx]\n",
    "\n",
    "dataset = TrajectoryDataset(data)\n",
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "# Reward Model Definition\n",
    "class RewardModel(nn.Module):\n",
    "    def __init__(self, state_size, action_size, embedding_size):\n",
    "        super(RewardModel, self).__init__()\n",
    "        self.state_embedding = nn.Linear(state_size, embedding_size)\n",
    "        self.action_embedding = nn.Linear(action_size, embedding_size)\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(embedding_size * 2, 1)\n",
    "        )\n",
    "\n",
    "    def forward(self, state, action):\n",
    "        state_emb = self.state_embedding(state)\n",
    "        action_emb = self.action_embedding(action)\n",
    "        x = torch.cat([state_emb, action_emb], dim=1)\n",
    "        reward = self.fc(x)\n",
    "        return reward\n",
    "\n",
    "reward_model = RewardModel(STATE_SIZE, ACTION_SIZE, EMBEDDING_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100, Loss: 2.970627899293776\n",
      "Epoch 2/100, Loss: 2.79595415003888\n",
      "Epoch 3/100, Loss: 2.7811747030778364\n",
      "Epoch 4/100, Loss: 2.7763735876454936\n",
      "Epoch 5/100, Loss: 2.773988454372852\n",
      "Epoch 6/100, Loss: 2.7726410797664096\n",
      "Epoch 7/100, Loss: 2.771777332603157\n",
      "Epoch 8/100, Loss: 2.7711508305041823\n",
      "Epoch 9/100, Loss: 2.770595528862693\n",
      "Epoch 10/100, Loss: 2.7702215925439613\n",
      "Epoch 11/100, Loss: 2.7698802297765557\n",
      "Epoch 12/100, Loss: 2.7697721301735223\n",
      "Epoch 13/100, Loss: 2.7695957964116875\n",
      "Epoch 14/100, Loss: 2.7695502151142466\n",
      "Epoch 15/100, Loss: 2.769415143248323\n",
      "Epoch 16/100, Loss: 2.7693519034943024\n",
      "Epoch 17/100, Loss: 2.769291592882825\n",
      "Epoch 18/100, Loss: 2.7692262723848415\n",
      "Epoch 19/100, Loss: 2.7691846915653775\n",
      "Epoch 20/100, Loss: 2.7691563971630937\n",
      "Epoch 21/100, Loss: 2.7691194732467848\n",
      "Epoch 22/100, Loss: 2.7690910766651102\n",
      "Epoch 23/100, Loss: 2.769077152400822\n",
      "Epoch 24/100, Loss: 2.7690625252661767\n",
      "Epoch 25/100, Loss: 2.76903383453171\n",
      "Epoch 26/100, Loss: 2.769016374241222\n",
      "Epoch 27/100, Loss: 2.7690175291779755\n",
      "Epoch 28/100, Loss: 2.768995377924535\n",
      "Epoch 29/100, Loss: 2.7689855284505076\n",
      "Epoch 30/100, Loss: 2.7690175477560466\n",
      "Epoch 31/100, Loss: 2.7689982544292104\n",
      "Epoch 32/100, Loss: 2.7690003909073866\n",
      "Epoch 33/100, Loss: 2.769007379358465\n",
      "Epoch 34/100, Loss: 2.7690296699474386\n",
      "Epoch 35/100, Loss: 2.7690619648276984\n",
      "Epoch 36/100, Loss: 2.769072235404671\n",
      "Epoch 37/100, Loss: 2.7690522639782396\n",
      "Epoch 38/100, Loss: 2.7690178945467068\n",
      "Epoch 39/100, Loss: 2.7689950001704227\n",
      "Epoch 40/100, Loss: 2.7690073019498356\n",
      "Epoch 41/100, Loss: 2.769006363757245\n",
      "Epoch 42/100, Loss: 2.7690002020303304\n",
      "Epoch 43/100, Loss: 2.7689776946971945\n",
      "Epoch 44/100, Loss: 2.7689865719188345\n",
      "Epoch 45/100, Loss: 2.768999266934085\n",
      "Epoch 46/100, Loss: 2.7689917335262546\n",
      "Epoch 47/100, Loss: 2.7690204892839705\n",
      "Epoch 48/100, Loss: 2.769006890135926\n",
      "Epoch 49/100, Loss: 2.7689865719188345\n",
      "Epoch 50/100, Loss: 2.768967696598598\n",
      "Epoch 51/100, Loss: 2.768963959309962\n",
      "Epoch 52/100, Loss: 2.7689670618478353\n",
      "Epoch 53/100, Loss: 2.7689634669910776\n",
      "Epoch 54/100, Loss: 2.76896054094488\n",
      "Epoch 55/100, Loss: 2.7689559119088307\n",
      "Epoch 56/100, Loss: 2.768955209038474\n",
      "Epoch 57/100, Loss: 2.7689437835247483\n",
      "Epoch 58/100, Loss: 2.7689667026717943\n",
      "Epoch 59/100, Loss: 2.768963200705392\n",
      "Epoch 60/100, Loss: 2.7689553266995914\n",
      "Epoch 61/100, Loss: 2.7689501372250644\n",
      "Epoch 62/100, Loss: 2.768958246553099\n",
      "Epoch 63/100, Loss: 2.768959432453304\n",
      "Epoch 64/100, Loss: 2.7689513602814118\n",
      "Epoch 65/100, Loss: 2.7689545804804023\n",
      "Epoch 66/100, Loss: 2.768939396003624\n",
      "Epoch 67/100, Loss: 2.7689366774125532\n",
      "Epoch 68/100, Loss: 2.7689338380640205\n",
      "Epoch 69/100, Loss: 2.7689303825427958\n",
      "Epoch 70/100, Loss: 2.768927255234161\n",
      "Epoch 71/100, Loss: 2.7689263944502\n",
      "Epoch 72/100, Loss: 2.7689108786644874\n",
      "Epoch 73/100, Loss: 2.7689006793034543\n",
      "Epoch 74/100, Loss: 2.7688949417758297\n",
      "Epoch 75/100, Loss: 2.768897784220708\n",
      "Epoch 76/100, Loss: 2.768912885096166\n",
      "Epoch 77/100, Loss: 2.7689083303724016\n",
      "Epoch 78/100, Loss: 2.7689124361261146\n",
      "Epoch 79/100, Loss: 2.768908085761132\n",
      "Epoch 80/100, Loss: 2.7689012954761454\n",
      "Epoch 81/100, Loss: 2.7689076987179844\n",
      "Epoch 82/100, Loss: 2.7689125197274342\n",
      "Epoch 83/100, Loss: 2.7689098847376834\n",
      "Epoch 84/100, Loss: 2.7689061567380833\n",
      "Epoch 85/100, Loss: 2.7688993323932993\n",
      "Epoch 86/100, Loss: 2.7688920126332865\n",
      "Epoch 87/100, Loss: 2.7688887459891185\n",
      "Epoch 88/100, Loss: 2.7688856434512448\n",
      "Epoch 89/100, Loss: 2.7688831075445397\n",
      "Epoch 90/100, Loss: 2.7688810763420997\n",
      "Epoch 91/100, Loss: 2.7688821043287004\n",
      "Epoch 92/100, Loss: 2.7688857920758134\n",
      "Epoch 93/100, Loss: 2.76888234893997\n",
      "Epoch 94/100, Loss: 2.7688819526077864\n",
      "Epoch 95/100, Loss: 2.7688860304943925\n",
      "Epoch 96/100, Loss: 2.7688796829867672\n",
      "Epoch 97/100, Loss: 2.7688744470670628\n",
      "Epoch 98/100, Loss: 2.768877735385647\n",
      "Epoch 99/100, Loss: 2.768887303092263\n",
      "Epoch 100/100, Loss: 2.7688918113708496\n"
     ]
    }
   ],
   "source": [
    "# Optimizer\n",
    "optimizer = optim.Adam(reward_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(EPOCHS):\n",
    "    total_loss = 0\n",
    "    for states, actions in dataloader:\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Compute rewards for the batch\n",
    "        rewards = reward_model(states, actions)\n",
    "\n",
    "        # Approximate the partition function (Z) using a batch of random state-action pairs\n",
    "        # For simplicity, we'll sample negative examples from the dataset\n",
    "        random_states = torch.randn_like(states)\n",
    "        random_actions = torch.randn_like(actions)\n",
    "        random_rewards = reward_model(random_states, random_actions)\n",
    "\n",
    "        # Compute the MaxEnt IRL loss\n",
    "        # Loss = - (Expert Reward - log_sum_exp(All Rewards))\n",
    "        expert_reward = rewards.mean()\n",
    "        all_rewards = torch.cat([rewards, random_rewards], dim=0)\n",
    "        log_sum_exp = torch.logsumexp(all_rewards, dim=0)\n",
    "        loss = - (expert_reward - log_sum_exp)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Expert Reward: 0.01671245314350734\n",
      "Average Random Reward: -13.22975471431347\n"
     ]
    }
   ],
   "source": [
    "# Testing Framework (Evaluation)\n",
    "def evaluate_model(model, dataset):\n",
    "    with torch.no_grad():\n",
    "        expert_rewards = []\n",
    "        random_rewards = []\n",
    "        for states, actions in DataLoader(dataset, batch_size=BATCH_SIZE):\n",
    "            rewards = model(states, actions)\n",
    "            expert_rewards.extend(rewards.squeeze().tolist())\n",
    "\n",
    "            # Generate random state-action pairs\n",
    "            random_states = torch.randn_like(states)\n",
    "            random_actions = torch.randn_like(actions)\n",
    "            rewards_random = model(random_states, random_actions)\n",
    "            random_rewards.extend(rewards_random.squeeze().tolist())\n",
    "\n",
    "        avg_expert_reward = np.mean(expert_rewards)\n",
    "        avg_random_reward = np.mean(random_rewards)\n",
    "        print(f\"Average Expert Reward: {avg_expert_reward}\")\n",
    "        print(f\"Average Random Reward: {avg_random_reward}\")\n",
    "\n",
    "evaluate_model(reward_model, dataset)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IRL test 02"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Hyperparameters\n",
    "STATE_SIZE = len(data['one_hot_state'][0][0])  # Size of the state space\n",
    "ACTION_SIZE = len(data['one_hot_action'][0][0])  # Size of the action space\n",
    "EMBEDDING_SIZE = 128\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 50\n",
    "LEARNING_RATE = 1e-3\n",
    "MARGIN = 1.0\n",
    "\n",
    "# Dataset Preparation\n",
    "class TrajectoryDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.state_action_pairs = []\n",
    "        self.action_space = self.get_action_space(data)\n",
    "        for idx, row in data.iterrows():\n",
    "            states = torch.tensor(row['one_hot_state'], dtype=torch.float)\n",
    "            actions = torch.tensor(row['one_hot_action'], dtype=torch.float)\n",
    "            self.state_action_pairs.extend(list(zip(states, actions)))\n",
    "\n",
    "    def get_action_space(self, data):\n",
    "        # Extract unique actions from the dataset\n",
    "        all_actions = []\n",
    "        for actions in data['one_hot_action']:\n",
    "            all_actions.extend(actions)\n",
    "        unique_actions = np.unique(np.array(all_actions), axis=0)\n",
    "        return torch.tensor(unique_actions, dtype=torch.float)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state_action_pairs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        state, action = self.state_action_pairs[idx]\n",
    "        # Generate a negative action\n",
    "        negative_action = self.sample_negative_action(action)\n",
    "        return state, action, negative_action\n",
    "\n",
    "    def sample_negative_action(self, action):\n",
    "        # Ensure the negative action is different from the positive action\n",
    "        negative_actions = self.action_space[~torch.all(self.action_space == action, dim=1)]\n",
    "        negative_action = negative_actions[torch.randint(0, len(negative_actions), (1,))]\n",
    "        return negative_action.squeeze(0)\n",
    "\n",
    "dataset = TrajectoryDataset(data)\n",
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "# Reward Model Definition\n",
    "class RewardModel(nn.Module):\n",
    "    def __init__(self, state_size, action_size, embedding_size):\n",
    "        super(RewardModel, self).__init__()\n",
    "        self.state_embedding = nn.Sequential(\n",
    "            nn.Linear(state_size, embedding_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(embedding_size, embedding_size),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.action_embedding = nn.Sequential(\n",
    "            nn.Linear(action_size, embedding_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(embedding_size, embedding_size),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(embedding_size * 2, embedding_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(embedding_size, 1)\n",
    "        )\n",
    "\n",
    "    def forward(self, state, action):\n",
    "        state_emb = self.state_embedding(state)\n",
    "        action_emb = self.action_embedding(action)\n",
    "        x = torch.cat([state_emb, action_emb], dim=1)\n",
    "        reward = self.fc(x)\n",
    "        return reward\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50, Loss: 0.6668734542735211\n",
      "Epoch 2/50, Loss: 0.511513689895729\n",
      "Epoch 3/50, Loss: 0.4932940335242779\n",
      "Epoch 4/50, Loss: 0.4389616307113078\n",
      "Epoch 5/50, Loss: 0.4007357716947407\n",
      "Epoch 6/50, Loss: 0.36478382194197023\n",
      "Epoch 7/50, Loss: 0.323225122961131\n",
      "Epoch 8/50, Loss: 0.31662735981600626\n",
      "Epoch 9/50, Loss: 0.2930532091043212\n",
      "Epoch 10/50, Loss: 0.2869051749056036\n",
      "Epoch 11/50, Loss: 0.27554292570460925\n",
      "Epoch 12/50, Loss: 0.24784999514942047\n",
      "Epoch 13/50, Loss: 0.26860197585124473\n",
      "Epoch 14/50, Loss: 0.18090507131698844\n",
      "Epoch 15/50, Loss: 0.230037510878854\n",
      "Epoch 16/50, Loss: 0.18841257135008838\n",
      "Epoch 17/50, Loss: 0.17792054531829699\n",
      "Epoch 18/50, Loss: 0.16796656189994377\n",
      "Epoch 19/50, Loss: 0.1836938149259462\n",
      "Epoch 20/50, Loss: 0.15830183421055993\n",
      "Epoch 21/50, Loss: 0.15930594445823074\n",
      "Epoch 22/50, Loss: 0.1406405048517438\n",
      "Epoch 23/50, Loss: 0.11097246906780578\n",
      "Epoch 24/50, Loss: 0.16018896762813842\n",
      "Epoch 25/50, Loss: 0.12607062641869893\n",
      "Epoch 26/50, Loss: 0.11118482856394409\n",
      "Epoch 27/50, Loss: 0.14637921194170977\n",
      "Epoch 28/50, Loss: 0.1378786750137806\n",
      "Epoch 29/50, Loss: 0.13143594098555578\n",
      "Epoch 30/50, Loss: 0.1191814470697533\n",
      "Epoch 31/50, Loss: 0.0928087882794343\n",
      "Epoch 32/50, Loss: 0.12332348631961006\n",
      "Epoch 33/50, Loss: 0.09226917818962753\n",
      "Epoch 34/50, Loss: 0.09455766814289154\n",
      "Epoch 35/50, Loss: 0.12754655458250796\n",
      "Epoch 36/50, Loss: 0.09923342329921661\n",
      "Epoch 37/50, Loss: 0.086344924453017\n",
      "Epoch 38/50, Loss: 0.12251845703690083\n",
      "Epoch 39/50, Loss: 0.10616549955947059\n",
      "Epoch 40/50, Loss: 0.1110751780790168\n",
      "Epoch 41/50, Loss: 0.10653766936489514\n",
      "Epoch 42/50, Loss: 0.08757669522197216\n",
      "Epoch 43/50, Loss: 0.10005465532084565\n",
      "Epoch 44/50, Loss: 0.09533136367023766\n",
      "Epoch 45/50, Loss: 0.09460665982264976\n",
      "Epoch 46/50, Loss: 0.09100262535276352\n",
      "Epoch 47/50, Loss: 0.0976732319535373\n",
      "Epoch 48/50, Loss: 0.10565839354674537\n",
      "Epoch 49/50, Loss: 0.12542481464031455\n",
      "Epoch 50/50, Loss: 0.09966351207974669\n"
     ]
    }
   ],
   "source": [
    "reward_model = RewardModel(STATE_SIZE, ACTION_SIZE, EMBEDDING_SIZE)\n",
    "\n",
    "# Loss Function\n",
    "criterion = nn.MarginRankingLoss(margin=MARGIN)\n",
    "\n",
    "# Optimizer\n",
    "optimizer = optim.Adam(reward_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(EPOCHS):\n",
    "    total_loss = 0\n",
    "    for states, actions, negative_actions in dataloader:\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Compute rewards for positive and negative examples\n",
    "        positive_rewards = reward_model(states, actions).squeeze()\n",
    "        negative_rewards = reward_model(states, negative_actions).squeeze()\n",
    "\n",
    "        # Labels for MarginRankingLoss: 1 indicates positive_rewards should be larger than negative_rewards\n",
    "        target = torch.ones(positive_rewards.size())\n",
    "\n",
    "        # Compute the margin ranking loss\n",
    "        loss = criterion(positive_rewards, negative_rewards, target)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Positive Reward: 11.776057419085008\n",
      "Average Negative Reward: 5.487925983953263\n"
     ]
    }
   ],
   "source": [
    "# Testing Framework (Evaluation)\n",
    "def evaluate_model(model, dataset):\n",
    "    with torch.no_grad():\n",
    "        positive_rewards = []\n",
    "        negative_rewards = []\n",
    "        for states, actions, negative_actions in DataLoader(dataset, batch_size=BATCH_SIZE):\n",
    "            pos_rewards = model(states, actions).squeeze()\n",
    "            neg_rewards = model(states, negative_actions).squeeze()\n",
    "            positive_rewards.extend(pos_rewards.tolist())\n",
    "            negative_rewards.extend(neg_rewards.tolist())\n",
    "\n",
    "        avg_positive_reward = np.mean(positive_rewards)\n",
    "        avg_negative_reward = np.mean(negative_rewards)\n",
    "        print(f\"Average Positive Reward: {avg_positive_reward}\")\n",
    "        print(f\"Average Negative Reward: {avg_negative_reward}\")\n",
    "\n",
    "evaluate_model(reward_model, dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IRL test 03 (init version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Hyperparameters\n",
    "STATE_SIZE = len(new_data['one_hot_state'][0][0])  # Size of the state space\n",
    "ACTION_SIZE = len(new_data['one_hot_action'][0][0])  # Size of the action space\n",
    "EMBEDDING_SIZE = 128\n",
    "HIDDEN_SIZE = 256\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 20\n",
    "LEARNING_RATE = 1e-3\n",
    "MARGIN = 1.0\n",
    "\n",
    "# Dataset Preparation\n",
    "class TrajectoryDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.state_sequences = []\n",
    "        self.action_sequences = []\n",
    "        self.action_space = self.get_action_space(data)\n",
    "        for idx, row in data.iterrows():\n",
    "            states = torch.tensor(row['one_hot_state'], dtype=torch.float)\n",
    "            actions = torch.tensor(row['one_hot_action'], dtype=torch.float)\n",
    "            self.state_sequences.append(states)\n",
    "            self.action_sequences.append(actions)\n",
    "\n",
    "    def get_action_space(self, data):\n",
    "        # Extract unique actions from the dataset\n",
    "        all_actions = []\n",
    "        for actions in data['one_hot_action']:\n",
    "            all_actions.extend(actions)\n",
    "        unique_actions = np.unique(np.array(all_actions), axis=0)\n",
    "        return torch.tensor(unique_actions, dtype=torch.float)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state_sequences)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        states = self.state_sequences[idx]\n",
    "        actions = self.action_sequences[idx]\n",
    "        negative_actions = self.sample_negative_actions(actions)\n",
    "        return states, actions, negative_actions\n",
    "\n",
    "    def sample_negative_actions(self, actions):\n",
    "        negative_actions = []\n",
    "        for action in actions:\n",
    "            neg_action = self.generate_similar_negative_action(action)\n",
    "            negative_actions.append(neg_action)\n",
    "        negative_actions = torch.stack(negative_actions)\n",
    "        return negative_actions\n",
    "\n",
    "    def generate_similar_negative_action(self, action):\n",
    "        # Get actions with similar number of ones\n",
    "        num_ones = action.sum().int().item()\n",
    "        candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() == num_ones) & (~torch.all(self.action_space == action, dim=1))]\n",
    "        if len(candidate_actions) == 0:\n",
    "            # If no candidate found, relax the condition slightly\n",
    "            candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() == num_ones + 5) | (self.action_space.sum(dim=1).int() == num_ones - 5)]\n",
    "            candidate_actions = candidate_actions[~torch.all(candidate_actions == action, dim=1)]\n",
    "        neg_action = candidate_actions[torch.randint(0, len(candidate_actions), (1,))]\n",
    "        return neg_action.squeeze(0)\n",
    "\n",
    "dataset = TrajectoryDataset(new_data)\n",
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20, Loss: 0.9183646932892178\n",
      "Epoch 2/20, Loss: 0.9402862411478291\n",
      "Epoch 3/20, Loss: 0.9998625283655913\n",
      "Epoch 4/20, Loss: 0.9956491537716078\n",
      "Epoch 5/20, Loss: 0.9845445285672727\n",
      "Epoch 6/20, Loss: 0.9825708606968755\n",
      "Epoch 7/20, Loss: 0.8583516841349395\n",
      "Epoch 8/20, Loss: 0.863449576108352\n",
      "Epoch 9/20, Loss: 0.9138337762459464\n",
      "Epoch 10/20, Loss: 0.8954090942507205\n",
      "Epoch 11/20, Loss: 0.8158405462036962\n",
      "Epoch 12/20, Loss: 0.37157660593157227\n",
      "Epoch 13/20, Loss: 0.15306211878424106\n",
      "Epoch 14/20, Loss: 0.19915164909932925\n",
      "Epoch 15/20, Loss: 0.07706141147924506\n",
      "Epoch 16/20, Loss: 0.1081201754834341\n",
      "Epoch 17/20, Loss: 0.12770453516555869\n",
      "Epoch 18/20, Loss: 0.03445055254775545\n",
      "Epoch 19/20, Loss: 0.019036808982491493\n",
      "Epoch 20/20, Loss: 0.009310365695020428\n"
     ]
    }
   ],
   "source": [
    "# Reward Model Definition with LSTM\n",
    "class RewardModel(nn.Module):\n",
    "    def __init__(self, state_size, action_size, embedding_size, hidden_size):\n",
    "        super(RewardModel, self).__init__()\n",
    "        self.state_embedding = nn.Linear(state_size, embedding_size)\n",
    "        self.action_embedding = nn.Linear(action_size, embedding_size)\n",
    "        self.lstm = nn.LSTM(input_size=embedding_size * 2, hidden_size=hidden_size, batch_first=True)\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(hidden_size, embedding_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(embedding_size, 1)\n",
    "        )\n",
    "\n",
    "    def forward(self, states, actions):\n",
    "        # states and actions are sequences\n",
    "        state_emb = self.state_embedding(states)  # (batch_size, seq_len, embedding_size)\n",
    "        action_emb = self.action_embedding(actions)  # (batch_size, seq_len, embedding_size)\n",
    "        x = torch.cat([state_emb, action_emb], dim=2)  # (batch_size, seq_len, embedding_size * 2)\n",
    "        lstm_out, _ = self.lstm(x)  # (batch_size, seq_len, hidden_size)\n",
    "        # Use the last output of LSTM for reward prediction\n",
    "        reward = self.fc(lstm_out[:, -1, :])  # (batch_size, 1)\n",
    "        return reward.squeeze()\n",
    "\n",
    "reward_model = RewardModel(STATE_SIZE, ACTION_SIZE, EMBEDDING_SIZE, HIDDEN_SIZE)\n",
    "\n",
    "# Loss Function\n",
    "criterion = nn.MarginRankingLoss(margin=MARGIN)\n",
    "\n",
    "# Optimizer\n",
    "optimizer = optim.Adam(reward_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(EPOCHS):\n",
    "    total_loss = 0\n",
    "    for batch in dataloader:\n",
    "        optimizer.zero_grad()\n",
    "        states_batch, actions_batch, negative_actions_batch = zip(*batch)\n",
    "        states_padded = nn.utils.rnn.pad_sequence(states_batch, batch_first=True)\n",
    "        actions_padded = nn.utils.rnn.pad_sequence(actions_batch, batch_first=True)\n",
    "        negative_actions_padded = nn.utils.rnn.pad_sequence(negative_actions_batch, batch_first=True)\n",
    "\n",
    "        # Compute rewards for positive and negative examples\n",
    "        positive_rewards = reward_model(states_padded, actions_padded)\n",
    "        negative_rewards = reward_model(states_padded, negative_actions_padded)\n",
    "\n",
    "        # Labels for MarginRankingLoss: 1 indicates positive_rewards should be larger than negative_rewards\n",
    "        target = torch.ones(positive_rewards.size())\n",
    "\n",
    "        # Compute the margin ranking loss\n",
    "        loss = criterion(positive_rewards, negative_rewards, target)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Positive Reward: 10.665490261843948\n",
      "Average Negative Reward: -8.322006594450748\n"
     ]
    }
   ],
   "source": [
    "# Testing Framework (Evaluation)\n",
    "def evaluate_model(model, dataset):\n",
    "    with torch.no_grad():\n",
    "        positive_rewards = []\n",
    "        negative_rewards = []\n",
    "        for batch in DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=lambda x: x):\n",
    "            states_batch, actions_batch, negative_actions_batch = zip(*batch)\n",
    "            # print(states_batch)\n",
    "            states_padded = nn.utils.rnn.pad_sequence(states_batch, batch_first=True)\n",
    "            actions_padded = nn.utils.rnn.pad_sequence(actions_batch, batch_first=True)\n",
    "            negative_actions_padded = nn.utils.rnn.pad_sequence(negative_actions_batch, batch_first=True)\n",
    "\n",
    "            pos_rewards = model(states_padded, actions_padded)\n",
    "            neg_rewards = model(states_padded, negative_actions_padded)\n",
    "            positive_rewards.extend(pos_rewards.tolist())\n",
    "            negative_rewards.extend(neg_rewards.tolist())\n",
    "\n",
    "        avg_positive_reward = np.mean(positive_rewards)\n",
    "        avg_negative_reward = np.mean(negative_rewards)\n",
    "        print(f\"Average Positive Reward: {avg_positive_reward}\")\n",
    "        print(f\"Average Negative Reward: {avg_negative_reward}\")\n",
    "\n",
    "evaluate_model(reward_model, dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "valid unseen code\n",
      "Input Code:\n",
      "class Solution:\n",
      "    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:\n",
      "      return self.Median(nums1,nums2)  \n",
      "    def Median(self,nums1,nums2):\n",
      "        i = 0\n",
      "        j = 0\n",
      "        count= 0\n",
      "        flag = 0\n",
      "        m1 = m2 = 0\n",
      "        size = len(nums1)+len(nums2)\n",
      "        while (count < size//2+1):\n",
      "            m2 = m1\n",
      "            if i<len(nums1) and j<len(nums2):\n",
      "                if nums1[i] <= nums2[j]:\n",
      "                    m1=nums1[i]\n",
      "                    i +=1\n",
      "                else :\n",
      "                    m1 = nums2[j]\n",
      "                    j +=1\n",
      "            elif i<len(nums1):\n",
      "                m1 = nums1[i]\n",
      "                i +=1\n",
      "            elif j<len(nums2):\n",
      "                m1 = nums2[j]\n",
      "                j +=1\n",
      "            count +=1\n",
      "        \n",
      "        if size %2==0:\n",
      "            return(m1+m2)/2\n",
      "        else:\n",
      "            return m1\n",
      "\n",
      "Reward: 14.972352027893066\n"
     ]
    }
   ],
   "source": [
    "def model_single_test(model, code, vocab):\n",
    "    states, actions = process_new_code(code, vocab, unk_token=\"<state><UNK>\")\n",
    "    states = torch.tensor([states], dtype=torch.float)\n",
    "    actions = torch.tensor([actions], dtype=torch.float)\n",
    "    with torch.no_grad():\n",
    "            rewards = model(states, actions)\n",
    "\n",
    "    print(f\"Input Code:\\n{code}\")\n",
    "    print(f\"Reward: {rewards}\")\n",
    "\n",
    "code_test = \"\"\"class Solution:\n",
    "    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:\n",
    "      return self.Median(nums1,nums2)\n",
    "\"\"\"\n",
    "\n",
    "if code_test in data[\"code\"].tolist():\n",
    "      print(\"code is already exist in dataset\")\n",
    "else:\n",
    "      print(\"valid unseen code\")\n",
    "\n",
    "model_single_test(reward_model, code_test, vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IRL test 04"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original State: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.\n",
      " 0. 0. 0. 0. 0. 0. 0.]\n",
      "Binary Encoding: tensor([1., 0., 0., 1., 0., 0., 0., 1.])\n"
     ]
    }
   ],
   "source": [
    "def binary_encode(array):\n",
    "    # Convert the one-hot array to indices\n",
    "    indices = np.argmax(array, axis=1)\n",
    "    # Calculate the number of bits needed\n",
    "    num_bits = int(np.ceil(np.log2(np.max(indices) + 1)))\n",
    "    # Convert indices to binary representation\n",
    "    binary_array = ((indices[:, None] & (1 << np.arange(num_bits))) > 0).astype(int)\n",
    "    return binary_array\n",
    "\n",
    "test_state = data['one_hot_state'][0]\n",
    "test_binary_state = torch.tensor(binary_encode(test_state), dtype=torch.float)\n",
    "print(f\"Original State: {test_state[0]}\")\n",
    "print(f\"Binary Encoding: {test_binary_state[0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Hyperparameters\n",
    "NUM_NEGATIVE_SAMPLES = 5  # Number of negative samples per positive action\n",
    "binary_state_test = torch.tensor(binary_encode(data['one_hot_state'][0]), dtype=torch.float)[0]\n",
    "binary_action_test = torch.tensor(binary_encode(data['one_hot_action'][0]), dtype=torch.float)[0]\n",
    "STATE_SIZE = len(NUM_NEGATIVE_SAMPLES)  # Size of the state space\n",
    "ACTION_SIZE = len(data['one_hot_action'][0][0])  # Size of the action space\n",
    "EMBEDDING_SIZE = 128\n",
    "HIDDEN_SIZE = 256\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 20\n",
    "LEARNING_RATE = 1e-3\n",
    "MARGIN = 1.0\n",
    "NHEAD = 4  # Number of attention heads\n",
    "NUM_LAYERS = 2  # Number of transformer layers\n",
    "LIMIT_OUTPUT_RANGE = True  # Whether to limit the output scores\n",
    "OUTPUT_RANGE = (0, 1)  # The range to limit the output scores\n",
    "\n",
    "# Function to reduce the dimensionality of one-hot encoded states/actions\n",
    "def binary_encode(array):\n",
    "    # Convert the one-hot array to indices\n",
    "    indices = np.argmax(array, axis=1)\n",
    "    # Calculate the number of bits needed\n",
    "    num_bits = int(np.ceil(np.log2(np.max(indices) + 1)))\n",
    "    # Convert indices to binary representation\n",
    "    binary_array = ((indices[:, None] & (1 << np.arange(num_bits))) > 0).astype(int)\n",
    "    return binary_array\n",
    "\n",
    "# Dataset Preparation\n",
    "class TrajectoryDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data, num_negative_samples=5):\n",
    "        self.state_sequences = []\n",
    "        self.action_sequences = []\n",
    "        self.num_negative_samples = num_negative_samples\n",
    "        self.action_space = self.get_action_space(data)\n",
    "        for idx, row in data.iterrows():\n",
    "            # Apply binary encoding to reduce dimensionality\n",
    "            states = torch.tensor(binary_encode(np.array(row['one_hot_state'])), dtype=torch.float)\n",
    "            actions = torch.tensor(binary_encode(np.array(row['one_hot_action'])), dtype=torch.float)\n",
    "            self.state_sequences.append(states)\n",
    "            self.action_sequences.append(actions)\n",
    "\n",
    "    def get_action_space(self, data):\n",
    "        # Extract unique actions from the dataset\n",
    "        all_actions = []\n",
    "        for actions in data['one_hot_action']:\n",
    "            all_actions.extend(actions)\n",
    "        unique_actions = np.unique(np.array(all_actions), axis=0)\n",
    "        # Apply binary encoding\n",
    "        unique_actions = binary_encode(unique_actions)\n",
    "        return torch.tensor(unique_actions, dtype=torch.float)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state_sequences)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        states = self.state_sequences[idx]\n",
    "        actions = self.action_sequences[idx]\n",
    "        negative_actions = self.sample_negative_actions(actions)\n",
    "        return states, actions, negative_actions\n",
    "\n",
    "    def sample_negative_actions(self, actions):\n",
    "        negative_actions = []\n",
    "        for action in actions:\n",
    "            neg_actions = self.generate_similar_negative_actions(action)\n",
    "            negative_actions.append(neg_actions)\n",
    "        negative_actions = torch.stack(negative_actions)  # Shape: (seq_len, num_negative_samples, action_size)\n",
    "        return negative_actions\n",
    "\n",
    "    def generate_similar_negative_actions(self, action):\n",
    "        num_ones = action.sum().int().item()\n",
    "        # Exclude the positive action\n",
    "        candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() == num_ones) & (~torch.all(self.action_space == action, dim=1))]\n",
    "        if len(candidate_actions) == 0:\n",
    "            # Relax the condition if no candidates are found\n",
    "            candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() >= num_ones - 5) & (self.action_space.sum(dim=1).int() <= num_ones + 5)]\n",
    "            candidate_actions = candidate_actions[~torch.all(candidate_actions == action, dim=1)]\n",
    "        if len(candidate_actions) == 0:\n",
    "            # Fallback to any action except the positive one\n",
    "            candidate_actions = self.action_space[~torch.all(self.action_space == action, dim=1)]\n",
    "        # Sample multiple negative actions\n",
    "        indices = torch.randint(0, len(candidate_actions), (self.num_negative_samples,))\n",
    "        neg_actions = candidate_actions[indices]\n",
    "        return neg_actions  # Shape: (num_negative_samples, action_size)\n",
    "\n",
    "dataset = TrajectoryDataset(data, num_negative_samples=NUM_NEGATIVE_SAMPLES)\n",
    "\n",
    "def collate_fn(batch):\n",
    "    states_batch, actions_batch, negative_actions_batch = zip(*batch)\n",
    "    # Pad sequences\n",
    "    states_padded = nn.utils.rnn.pad_sequence(states_batch, batch_first=True)\n",
    "    actions_padded = nn.utils.rnn.pad_sequence(actions_batch, batch_first=True)\n",
    "    # Pad negative actions\n",
    "    max_seq_len = max([na.shape[0] for na in negative_actions_batch])\n",
    "    batch_size = len(negative_actions_batch)\n",
    "    num_negative_samples = negative_actions_batch[0].shape[1]\n",
    "    action_size = negative_actions_batch[0].shape[2]\n",
    "    negative_actions_padded = torch.zeros((batch_size, max_seq_len, num_negative_samples, action_size))\n",
    "    for i in range(batch_size):\n",
    "        seq_len = negative_actions_batch[i].shape[0]\n",
    "        negative_actions_padded[i, :seq_len] = negative_actions_batch[i]\n",
    "    return states_padded, actions_padded, negative_actions_padded\n",
    "\n",
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\transformer.py:306: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
      "  warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (9) must match the size of tensor b (8) at non-singleton dimension 1",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[8], line 51\u001b[0m\n\u001b[0;32m     49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(EPOCHS):\n\u001b[0;32m     50\u001b[0m     total_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m---> 51\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m states_padded, actions_padded, negative_actions_padded \u001b[38;5;129;01min\u001b[39;00m dataloader:\n\u001b[0;32m     52\u001b[0m         optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m     53\u001b[0m         \u001b[38;5;66;03m# Compute rewards for positive examples\u001b[39;00m\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    629\u001b[0m     \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m    630\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset()  \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m    633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m    634\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m    635\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:675\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    674\u001b[0m     index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index()  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 675\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m    676\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[0;32m    677\u001b[0m         data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     49\u001b[0m         data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m     50\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m         data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m     52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m     53\u001b[0m     data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
      "Cell \u001b[1;32mIn[7], line 63\u001b[0m, in \u001b[0;36mTrajectoryDataset.__getitem__\u001b[1;34m(self, idx)\u001b[0m\n\u001b[0;32m     61\u001b[0m states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate_sequences[idx]\n\u001b[0;32m     62\u001b[0m actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maction_sequences[idx]\n\u001b[1;32m---> 63\u001b[0m negative_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample_negative_actions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     64\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m states, actions, negative_actions\n",
      "Cell \u001b[1;32mIn[7], line 69\u001b[0m, in \u001b[0;36mTrajectoryDataset.sample_negative_actions\u001b[1;34m(self, actions)\u001b[0m\n\u001b[0;32m     67\u001b[0m negative_actions \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m     68\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m action \u001b[38;5;129;01min\u001b[39;00m actions:\n\u001b[1;32m---> 69\u001b[0m     neg_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate_similar_negative_actions\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     70\u001b[0m     negative_actions\u001b[38;5;241m.\u001b[39mappend(neg_actions)\n\u001b[0;32m     71\u001b[0m negative_actions \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mstack(negative_actions)  \u001b[38;5;66;03m# Shape: (seq_len, num_negative_samples, action_size)\u001b[39;00m\n",
      "Cell \u001b[1;32mIn[7], line 77\u001b[0m, in \u001b[0;36mTrajectoryDataset.generate_similar_negative_actions\u001b[1;34m(self, action)\u001b[0m\n\u001b[0;32m     75\u001b[0m num_ones \u001b[38;5;241m=\u001b[39m action\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;241m.\u001b[39mint()\u001b[38;5;241m.\u001b[39mitem()\n\u001b[0;32m     76\u001b[0m \u001b[38;5;66;03m# Exclude the positive action\u001b[39;00m\n\u001b[1;32m---> 77\u001b[0m candidate_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maction_space[(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maction_space\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mint() \u001b[38;5;241m==\u001b[39m num_ones) \u001b[38;5;241m&\u001b[39m (\u001b[38;5;241m~\u001b[39mtorch\u001b[38;5;241m.\u001b[39mall(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maction_space\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m))]\n\u001b[0;32m     78\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(candidate_actions) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m     79\u001b[0m     \u001b[38;5;66;03m# Relax the condition if no candidates are found\u001b[39;00m\n\u001b[0;32m     80\u001b[0m     candidate_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maction_space[(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maction_space\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mint() \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m num_ones \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m5\u001b[39m) \u001b[38;5;241m&\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maction_space\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mint() \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m num_ones \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m5\u001b[39m)]\n",
      "\u001b[1;31mRuntimeError\u001b[0m: The size of tensor a (9) must match the size of tensor b (8) at non-singleton dimension 1"
     ]
    }
   ],
   "source": [
    "# Reward Model Definition with Transformer Encoder\n",
    "class RewardModel(nn.Module):\n",
    "    def __init__(self, state_size, action_size, embedding_size, hidden_size, nhead=4, num_layers=2, limit_output_range=False, output_range=(0, 1)):\n",
    "        super(RewardModel, self).__init__()\n",
    "        self.state_embedding = nn.Linear(state_size, embedding_size)\n",
    "        self.action_embedding = nn.Linear(action_size, embedding_size)\n",
    "        encoder_layers = nn.TransformerEncoderLayer(d_model=embedding_size * 2, nhead=nhead)\n",
    "        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(embedding_size * 2, hidden_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_size, 1)\n",
    "        )\n",
    "        self.limit_output_range = limit_output_range\n",
    "        self.output_range = output_range\n",
    "\n",
    "    def forward(self, states, actions, src_key_padding_mask=None):\n",
    "        # states and actions are sequences\n",
    "        state_emb = self.state_embedding(states) \n",
    "        action_emb = self.action_embedding(actions) \n",
    "        x = torch.cat([state_emb, action_emb], dim=2) \n",
    "        x = x.transpose(0, 1)\n",
    "        transformer_out = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)\n",
    "        output = transformer_out[-1, :, :]  # Use the last time step\n",
    "        reward = self.fc(output) \n",
    "        if self.limit_output_range:\n",
    "            min_val, max_val = self.output_range\n",
    "            reward = torch.clamp(reward, min=min_val, max=max_val)\n",
    "        return reward.squeeze()\n",
    "\n",
    "reward_model = RewardModel(\n",
    "    STATE_SIZE,\n",
    "    ACTION_SIZE,\n",
    "    EMBEDDING_SIZE,\n",
    "    HIDDEN_SIZE,\n",
    "    nhead=NHEAD,\n",
    "    num_layers=NUM_LAYERS,\n",
    "    limit_output_range=LIMIT_OUTPUT_RANGE,\n",
    "    output_range=OUTPUT_RANGE\n",
    ")\n",
    "\n",
    "# Loss Function\n",
    "criterion = nn.MarginRankingLoss(margin=MARGIN)\n",
    "\n",
    "# Optimizer\n",
    "optimizer = optim.Adam(reward_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(EPOCHS):\n",
    "    total_loss = 0\n",
    "    for states_padded, actions_padded, negative_actions_padded in dataloader:\n",
    "        optimizer.zero_grad()\n",
    "        # Compute rewards for positive examples\n",
    "        positive_rewards = reward_model(states_padded, actions_padded)\n",
    "        # Compute rewards for negative examples\n",
    "        batch_size, seq_len, num_negative_samples, action_size = negative_actions_padded.shape\n",
    "        # Expand states to match negative actions\n",
    "        states_expanded = states_padded.unsqueeze(2).expand(-1, -1, num_negative_samples, -1)\n",
    "        # Flatten the tensors to feed into the model\n",
    "        states_neg = states_expanded.reshape(batch_size, seq_len * num_negative_samples, -1)\n",
    "        actions_neg = negative_actions_padded.reshape(batch_size, seq_len * num_negative_samples, -1)\n",
    "        negative_rewards = reward_model(states_neg, actions_neg)\n",
    "        # Reshape negative_rewards back to (batch_size, seq_len, num_negative_samples)\n",
    "        negative_rewards = negative_rewards.reshape(batch_size, seq_len, num_negative_samples)\n",
    "        # Take the mean over negative samples\n",
    "        negative_rewards = negative_rewards.mean(dim=2)\n",
    "        # Compute the loss\n",
    "        target = torch.ones(positive_rewards.size())\n",
    "        loss = criterion(positive_rewards, negative_rewards.mean(dim=1), target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to evaluate the model on the test dataset\n",
    "def evaluate_model_on_test_data(model, test_data, vocab):\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for idx, row in test_data.iterrows():\n",
    "        code = row['code']\n",
    "        label = row['label']\n",
    "        # Process the new code (assume process_new_code is defined)\n",
    "        states, actions = process_new_code(code, vocab, unk_token=\"<state><UNK>\")\n",
    "        states = torch.tensor([states], dtype=torch.float)\n",
    "        actions = torch.tensor([actions], dtype=torch.float)\n",
    "        with torch.no_grad():\n",
    "            reward = model(states, actions)\n",
    "        prediction = 1 if reward.item() > 0.5 else 0  # Threshold can be adjusted\n",
    "        if prediction == label:\n",
    "            correct += 1\n",
    "        total += 1\n",
    "    accuracy = correct / total\n",
    "    print(f\"Model Accuracy on Test Data: {accuracy * 100:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming process_new_code and vocab are defined\n",
    "def model_single_test(model, code, vocab):\n",
    "    states, actions = process_new_code(code, vocab, unk_token=\"<state><UNK>\")\n",
    "    states = torch.tensor([states], dtype=torch.float)\n",
    "    actions = torch.tensor([actions], dtype=torch.float)\n",
    "    with torch.no_grad():\n",
    "        rewards = model(states, actions)\n",
    "    print(f\"Input Code:\\n{code}\")\n",
    "    print(f\"Reward: {rewards}\")\n",
    "\n",
    "code_test = \"\"\"class Solution:\n",
    "    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:\n",
    "        return self.Median(nums1, nums2)\n",
    "\"\"\"\n",
    "\n",
    "if code_test in data[\"code\"].tolist():\n",
    "    print(\"Code already exists in dataset\")\n",
    "else:\n",
    "    print(\"Valid unseen code\")\n",
    "\n",
    "model_single_test(reward_model, code_test, vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### IRL test 05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Zihang Zeng\\AppData\\Local\\Temp\\ipykernel_125220\\1298829264.py:32: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:277.)\n",
      "  states = torch.tensor(row['one_hot_state'], dtype=torch.float)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20, Loss: 0.9998919665813446\n",
      "Epoch 2/20, Loss: 0.9954293370246887\n",
      "Epoch 3/20, Loss: 0.9816782921552658\n",
      "Epoch 4/20, Loss: 0.8432859182357788\n",
      "Epoch 5/20, Loss: 0.6865130141377449\n",
      "Epoch 6/20, Loss: 0.5930077955126762\n",
      "Epoch 7/20, Loss: 0.9817382544279099\n",
      "Epoch 8/20, Loss: 0.9958653897047043\n",
      "Epoch 9/20, Loss: 0.9956686645746231\n",
      "Epoch 10/20, Loss: 0.9938212335109711\n",
      "Epoch 11/20, Loss: 0.9913119822740555\n",
      "Epoch 12/20, Loss: 0.9870873838663101\n",
      "Epoch 13/20, Loss: 0.9858999699354172\n",
      "Epoch 14/20, Loss: 0.9725503921508789\n",
      "Epoch 15/20, Loss: 0.9622585028409958\n",
      "Epoch 16/20, Loss: 0.9474791586399078\n",
      "Epoch 17/20, Loss: 0.9105489999055862\n",
      "Epoch 18/20, Loss: 0.7340191751718521\n",
      "Epoch 19/20, Loss: 0.3629789501428604\n",
      "Epoch 20/20, Loss: 0.11706003081053495\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Hyperparameters\n",
    "STATE_SIZE = len(new_data['one_hot_state'][0][0])  # Size of the state space\n",
    "ACTION_SIZE = len(new_data['one_hot_action'][0][0])  # Size of the action space\n",
    "EMBEDDING_SIZE = 128\n",
    "HIDDEN_SIZE = 256\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 20\n",
    "LEARNING_RATE = 1e-3\n",
    "MARGIN = 1.0\n",
    "NUM_NEGATIVE_SAMPLES = 10  # Number of negative samples per positive action\n",
    "OUTPUT_BOUND = None  # Set to (min_value, max_value) or None (-1.0, 1.0)\n",
    "REDUCED_DIM = None  # Dimension after reducing the one-hot vectors\n",
    "\n",
    "# Dataset Preparation\n",
    "class TrajectoryDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, data):\n",
    "        self.state_sequences = []\n",
    "        self.action_sequences = []\n",
    "        self.action_space = self.get_action_space(data)\n",
    "        self.action_space_size = len(data['one_hot_action'][0][0])\n",
    "        for idx, row in data.iterrows():\n",
    "            states = torch.tensor(row['one_hot_state'], dtype=torch.float)\n",
    "            actions = torch.tensor(row['one_hot_action'], dtype=torch.float)\n",
    "            self.state_sequences.append(states)\n",
    "            self.action_sequences.append(actions)\n",
    "        # Precompute the projection matrix for dimension reduction\n",
    "        if REDUCED_DIM == -1:\n",
    "            self.projection_matrix_state = None\n",
    "            self.projection_matrix_action = None\n",
    "            self.reduce_method = 'binary'\n",
    "            self.state_reduced_dimention = int(np.ceil(np.log2(STATE_SIZE)))\n",
    "            self.action_reduced_dimention = int(np.ceil(np.log2(ACTION_SIZE)))\n",
    "        elif REDUCED_DIM is None:\n",
    "            self.projection_matrix_state = None\n",
    "            self.projection_matrix_action = None\n",
    "            self.reduce_method = 'original'\n",
    "            self.state_reduced_dimention = STATE_SIZE\n",
    "            self.action_reduced_dimention = ACTION_SIZE\n",
    "        else:\n",
    "            self.projection_matrix_state = torch.randn(STATE_SIZE, REDUCED_DIM)\n",
    "            self.projection_matrix_action = torch.randn(ACTION_SIZE, REDUCED_DIM)\n",
    "            self.reduce_method = 'linear'\n",
    "            self.state_reduced_dimention = REDUCED_DIM\n",
    "            self.action_reduced_dimention = REDUCED_DIM\n",
    "\n",
    "    def get_action_space(self, data):\n",
    "        # Extract unique actions from the dataset\n",
    "        all_actions = []\n",
    "        for actions in data['one_hot_action']:\n",
    "            all_actions.extend(actions)\n",
    "        unique_actions = np.unique(np.array(all_actions), axis=0)\n",
    "        return torch.tensor(unique_actions, dtype=torch.float)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state_sequences)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        states = self.state_sequences[idx]\n",
    "        actions = self.action_sequences[idx]\n",
    "        negative_actions = self.sample_negative_actions(actions)\n",
    "        # print(negative_actions)\n",
    "        # raise Exception(\"Stop here\")\n",
    "        # Reduce dimension of states and actions\n",
    "        states = self.reduce_dimension(states, self.projection_matrix_state)\n",
    "        actions = self.reduce_dimension(actions, self.projection_matrix_action)\n",
    "        negative_actions = self.reduce_dimension(negative_actions.view(-1, ACTION_SIZE), self.projection_matrix_action)\n",
    "        negative_actions = negative_actions.view(-1, NUM_NEGATIVE_SAMPLES, self.action_reduced_dimention)\n",
    "        return states, actions, negative_actions\n",
    "\n",
    "    def sample_negative_actions(self, actions):\n",
    "        negative_actions_list = []\n",
    "        for action in actions:\n",
    "            neg_actions = self.generate_similar_negative_actions(action, NUM_NEGATIVE_SAMPLES)\n",
    "            # random_neg_actions = self.generate_negative_actions(action, NUM_NEGATIVE_SAMPLES)\n",
    "            negative_actions_list.append(neg_actions)\n",
    "            # negative_actions_list.append(random_neg_actions)\n",
    "        negative_actions = torch.stack(negative_actions_list)  # (seq_len, NUM_NEGATIVE_SAMPLES, action_size)\n",
    "        return negative_actions\n",
    "\n",
    "    def generate_similar_negative_actions(self, action, num_samples):\n",
    "        num_ones = action.sum().int().item()\n",
    "        candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() == num_ones) & (~torch.all(self.action_space == action, dim=1))]\n",
    "        if len(candidate_actions) == 0:\n",
    "            # If no candidate found, relax the condition slightly\n",
    "            candidate_actions = self.action_space[(self.action_space.sum(dim=1).int() >= num_ones - 5) & (self.action_space.sum(dim=1).int() <= num_ones + 5)]\n",
    "            candidate_actions = candidate_actions[~torch.all(candidate_actions == action, dim=1)]\n",
    "        if len(candidate_actions) == 0:\n",
    "            # If still no candidate, use all actions except the current one\n",
    "            candidate_actions = self.action_space[~torch.all(self.action_space == action, dim=1)]\n",
    "        # Sample num_samples negative actions\n",
    "        indices = torch.randint(0, len(candidate_actions), (num_samples,))\n",
    "        neg_actions = candidate_actions[indices]\n",
    "        # print(neg_actions)\n",
    "        # raise Exception(\"Stop here\")\n",
    "        return neg_actions  # (num_samples, action_size)\n",
    "    \n",
    "    def generate_negative_actions(self, action, negative_sample_size=1):\n",
    "        negative_actions = []\n",
    "        num_ones = int(action.sum().item())\n",
    "        action_size = self.action_space_size  # Original action size\n",
    "        for _ in range(negative_sample_size):\n",
    "            # Generate a random action with the same number of ones\n",
    "            neg_action = self.generate_random_action_with_num_ones(action_size, num_ones)\n",
    "            # Ensure it's not the same as the positive action\n",
    "            while torch.equal(neg_action, action):\n",
    "                neg_action = self.generate_random_action_with_num_ones(action_size, num_ones)\n",
    "            negative_actions.append(neg_action)\n",
    "        return torch.stack(negative_actions)\n",
    "    \n",
    "    def generate_random_action_with_num_ones(self, action_size, num_ones):\n",
    "        indices = torch.randperm(action_size)[:num_ones]\n",
    "        neg_action = torch.zeros(action_size)\n",
    "        neg_action[indices] = 1.0\n",
    "        return neg_action\n",
    "    \n",
    "    def reduce_dimension(self, one_hot_vectors, projection_matrix):\n",
    "        method = self.reduce_method\n",
    "        if method == 'binary':\n",
    "            # print(one_hot_vectors)\n",
    "            reduced_vectors = self.binary_encode(one_hot_vectors)\n",
    "        elif method == 'linear':\n",
    "            reduced_vectors = torch.matmul(one_hot_vectors, projection_matrix)\n",
    "        elif method == 'original':\n",
    "            reduced_vectors = one_hot_vectors\n",
    "        return reduced_vectors\n",
    "    \n",
    "    def binary_encode(self, array):\n",
    "        # Convert the one-hot array to indices\n",
    "        indices = np.argmax(array, axis=1)\n",
    "        # Calculate the number of bits needed\n",
    "        num_bits = int(np.ceil(np.log2(np.max(indices) + 1)))\n",
    "        # Convert indices to binary representation\n",
    "        binary_array = ((indices[:, None] & (1 << np.arange(num_bits))) > 0).astype(int)\n",
    "        binary_array = torch.tensor(binary_array, dtype=torch.float)\n",
    "        return binary_array\n",
    "\n",
    "dataset = TrajectoryDataset(new_data)\n",
    "reduced_dimention = (dataset.action_reduced_dimention, dataset.state_reduced_dimention)\n",
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: x)\n",
    "\n",
    "# Reward Model Definition with LSTM\n",
    "class RewardModel(nn.Module):\n",
    "    def __init__(self, reduced_dim, embedding_size, hidden_size, output_bound=None):\n",
    "        super(RewardModel, self).__init__()\n",
    "        self.state_embedding = nn.Linear(reduced_dim[1], embedding_size)\n",
    "        self.action_embedding = nn.Linear(reduced_dim[0], embedding_size)\n",
    "        self.lstm = nn.LSTM(input_size=embedding_size * 2, hidden_size=hidden_size, batch_first=True)\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(hidden_size, embedding_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(embedding_size, 1)\n",
    "        )\n",
    "        self.output_bound = output_bound\n",
    "\n",
    "    def forward(self, states, actions):\n",
    "        # states and actions are sequences\n",
    "        state_emb = self.state_embedding(states)  # (batch_size, seq_len, embedding_size)\n",
    "        action_emb = self.action_embedding(actions)  # (batch_size, seq_len, embedding_size)\n",
    "        x = torch.cat([state_emb, action_emb], dim=2)  # (batch_size, seq_len, embedding_size * 2)\n",
    "        lstm_out, _ = self.lstm(x)  # (batch_size, seq_len, hidden_size)\n",
    "        # Use the last output of LSTM for reward prediction\n",
    "        reward = self.fc(lstm_out[:, -1, :])  # (batch_size, 1)\n",
    "        if self.output_bound is not None:\n",
    "            min_val, max_val = self.output_bound\n",
    "            reward = torch.clamp(reward, min=min_val, max=max_val)\n",
    "        return reward.squeeze()\n",
    "\n",
    "reward_model = RewardModel(reduced_dimention, EMBEDDING_SIZE, HIDDEN_SIZE, output_bound=OUTPUT_BOUND).to(device)\n",
    "\n",
    "# Loss Function\n",
    "criterion = nn.MarginRankingLoss(margin=MARGIN)\n",
    "\n",
    "# Optimizer\n",
    "optimizer = optim.Adam(reward_model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(EPOCHS):\n",
    "    total_loss = 0\n",
    "    for batch in dataloader:\n",
    "        optimizer.zero_grad()\n",
    "        states_batch, actions_batch, negative_actions_batch = zip(*batch)\n",
    "        # states_batch: list of tensors, each of shape (seq_len, reduced_dim)\n",
    "        # actions_batch: list of tensors, each of shape (seq_len, reduced_dim)\n",
    "        # negative_actions_batch: list of tensors, each of shape (seq_len, NUM_NEGATIVE_SAMPLES, reduced_dim)\n",
    "\n",
    "        # Pad sequences\n",
    "        states_padded = nn.utils.rnn.pad_sequence(states_batch, batch_first=True).to(device)\n",
    "        actions_padded = nn.utils.rnn.pad_sequence(actions_batch, batch_first=True).to(device)\n",
    "\n",
    "        # For negative actions, we need to pad the sequences and handle the extra dimensions\n",
    "        max_seq_len = states_padded.size(1)\n",
    "        negative_actions_padded = []\n",
    "        for neg_actions in negative_actions_batch:\n",
    "            # neg_actions: (seq_len, NUM_NEGATIVE_SAMPLES, reduced_dim)\n",
    "            pad_size = max_seq_len - neg_actions.size(0)\n",
    "            if pad_size > 0:\n",
    "                pad_tensor = torch.zeros(pad_size, NUM_NEGATIVE_SAMPLES, reduced_dimention[0])\n",
    "                neg_actions_padded_seq = torch.cat([neg_actions, pad_tensor], dim=0)\n",
    "            else:\n",
    "                neg_actions_padded_seq = neg_actions\n",
    "            negative_actions_padded.append(neg_actions_padded_seq)\n",
    "        negative_actions_padded = torch.stack(negative_actions_padded).to(device)  # (batch_size, max_seq_len, NUM_NEGATIVE_SAMPLES, reduced_dim)\n",
    "\n",
    "        # Compute rewards for positive examples\n",
    "        positive_rewards = reward_model(states_padded, actions_padded)  # (batch_size,)\n",
    "\n",
    "        # Compute rewards for negative examples\n",
    "        # We need to compute the reward for each negative action\n",
    "        negative_rewards_list = []\n",
    "        for i in range(NUM_NEGATIVE_SAMPLES):\n",
    "            negative_actions_i = negative_actions_padded[:, :, i, :]  # (batch_size, max_seq_len, reduced_dim)\n",
    "            negative_reward = reward_model(states_padded, negative_actions_i)  # (batch_size,)\n",
    "            negative_rewards_list.append(negative_reward)\n",
    "        # Stack negative rewards: (NUM_NEGATIVE_SAMPLES, batch_size)\n",
    "        negative_rewards = torch.stack(negative_rewards_list)  # (NUM_NEGATIVE_SAMPLES, batch_size)\n",
    "\n",
    "        # Compute the loss for each negative sample\n",
    "        # Labels for MarginRankingLoss: 1 indicates positive_rewards should be larger than negative_rewards\n",
    "        target = torch.ones(positive_rewards.size()).unsqueeze(0).repeat(NUM_NEGATIVE_SAMPLES, 1).to(device)  # (NUM_NEGATIVE_SAMPLES, batch_size)\n",
    "        # Compute the margin ranking loss for each negative sample\n",
    "        loss = criterion(positive_rewards.unsqueeze(0).repeat(NUM_NEGATIVE_SAMPLES, 1), negative_rewards, target)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader)}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.6639, ROC AUC: 0.6251, Precision: 0.6173, Recall: 0.8475, F1-score: 0.7143\n"
     ]
    }
   ],
   "source": [
    "# Evaluation Function\n",
    "def evaluate_model(model, test_data, vocab, threshold=0, new_version=False):\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "    true_labels = []\n",
    "    predicted_rewards = []\n",
    "    for idx, row in test_data.iterrows():\n",
    "        code = row['code']\n",
    "        label = row['labels']\n",
    "        # Process the code to get states and actions\n",
    "        states, actions = process_new_code(code, vocab, unk_token=\"<state><UNK>\")\n",
    "        states = torch.tensor([states], dtype=torch.float)\n",
    "        actions = torch.tensor([actions], dtype=torch.float)\n",
    "        # Reduce dimension\n",
    "        if new_version:\n",
    "            try:\n",
    "                states = dataset.reduce_dimension(states, dataset.projection_matrix_state)\n",
    "            except Exception as e:\n",
    "                print(code)\n",
    "                print(label)\n",
    "                print(e)\n",
    "                raise Exception(\"Stop here\")\n",
    "            actions = dataset.reduce_dimension(actions, dataset.projection_matrix_action)\n",
    "        states = states.to(device)\n",
    "        actions = actions.to(device)\n",
    "        with torch.no_grad():\n",
    "            rewards = model(states, actions)\n",
    "        true_labels.append(label)\n",
    "        predicted_rewards.append(rewards.item())\n",
    "    # Compute evaluation metrics\n",
    "    from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_fscore_support\n",
    "    # Binarize predicted rewards based on threshold (e.g., 0)\n",
    "    predicted_labels = [1 if r >= threshold else 0 for r in predicted_rewards]\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    roc_auc = roc_auc_score(true_labels, predicted_rewards)\n",
    "    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='binary')\n",
    "    print(f\"Accuracy: {accuracy:.4f}, ROC AUC: {roc_auc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}\")\n",
    "    # model.train()\n",
    "\n",
    "evaluate_model(reward_model, test_data, vocab, threshold=0, new_version = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid unseen code\n",
      "Input Code:\n",
      "class Solution(object):\n",
      "    def longestPalindrome(self, s):\n",
      "        n=len(s)\n",
      "        dp=[[False]*n for i in range(n)]\n",
      "        start,max_length=0,1\n",
      "        for i in range(n):\n",
      "            dp[i][i]=True\n",
      "        for i in range(n-1):\n",
      "            if s[i]==s[i+1]:\n",
      "                dp[i][i+1]=True\n",
      "                start=i\n",
      "                max_length=2\n",
      "        for length in range(3,n+1):\n",
      "            for i in range(n-length+1):\n",
      "                j=i+length-1\n",
      "                if s[i]==s[j]:\n",
      "                    if dp[i+1][j-1]==True:\n",
      "                        dp[i][j]=True\n",
      "                        start=i\n",
      "                        max_length=length\n",
      "        return s[start:start+max_length]\n",
      "\n",
      "Reward: -18.1295223236084\n"
     ]
    }
   ],
   "source": [
    "# Single Code Test Function\n",
    "def model_single_test(model, code, vocab):\n",
    "    states, actions = process_new_code(code, vocab, unk_token=\"<state><UNK>\")\n",
    "    states = torch.tensor([states], dtype=torch.float)\n",
    "    actions = torch.tensor([actions], dtype=torch.float)\n",
    "    # Reduce dimension\n",
    "    states = dataset.reduce_dimension(states, dataset.projection_matrix_state)\n",
    "    actions = dataset.reduce_dimension(actions, dataset.projection_matrix_action)\n",
    "    states = states.to(device)\n",
    "    actions = actions.to(device)\n",
    "    with torch.no_grad():\n",
    "        rewards = model(states, actions)\n",
    "\n",
    "    print(f\"Input Code:\\n{code}\")\n",
    "    print(f\"Reward: {rewards}\")\n",
    "\n",
    "code_test = \"\"\"class Solution(object):\n",
    "    def longestPalindrome(self, s):\n",
    "        n=len(s)\n",
    "        dp=[[False]*n for i in range(n)]\n",
    "        start,max_length=0,1\n",
    "        for i in range(n):\n",
    "            dp[i][i]=True\n",
    "        for i in range(n-1):\n",
    "            if s[i]==s[i+1]:\n",
    "                dp[i][i+1]=True\n",
    "                start=i\n",
    "                max_length=2\n",
    "        for length in range(3,n+1):\n",
    "            for i in range(n-length+1):\n",
    "                j=i+length-1\n",
    "                if s[i]==s[j]:\n",
    "                    if dp[i+1][j-1]==True:\n",
    "                        dp[i][j]=True\n",
    "                        start=i\n",
    "                        max_length=length\n",
    "        return s[start:start+max_length]\n",
    "\"\"\"\n",
    "\n",
    "if code_test in data[\"code\"].tolist():\n",
    "    print(\"Code already exists in dataset\")\n",
    "else:\n",
    "    print(\"Valid unseen code\")\n",
    "\n",
    "model_single_test(reward_model, code_test, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-4.529577255249023, -4.626935958862305, -4.66011381149292, -5.679779529571533, -6.231942653656006, -5.299332618713379, -6.250425338745117, -2.4460322856903076, -4.8368048667907715, -7.424708843231201, -5.064548015594482, -4.4185099601745605, -7.158999919891357, -7.68357515335083, -5.44248628616333, -10.093977928161621, -1.060196042060852, -12.92013168334961, -8.132647514343262, -1.060196042060852, -6.740943431854248, -9.330506324768066, -9.895550727844238, -9.23755931854248, -8.560052871704102, -12.710282325744629, -10.154448509216309, -10.154448509216309, -10.154448509216309, -14.438100814819336, -9.820096969604492, -13.025313377380371, -9.110482215881348, -10.210829734802246, -7.969449520111084, -7.747878074645996, -7.1244707107543945, -7.0620551109313965, -8.439840316772461, -8.799816131591797, -11.188681602478027, -11.131129264831543, -12.523900985717773, -7.1244707107543945, -7.0620551109313965, -0.6957090497016907, -4.747443675994873, -3.341648817062378, -0.22742697596549988, 2.531212568283081, 3.382767677307129, 11.45654010772705, 1.355323076248169, -3.564143419265747, 1.765967607498169, 11.45654010772705, 2.6028378009796143, -1.3701289892196655, 1.493760347366333, 11.45654010772705, 11.36131763458252, 14.137557029724121, -0.5778067708015442, -2.6705737113952637, 2.4164934158325195, 15.170639991760254, 14.784523963928223, 9.2218017578125, 1.7529222965240479, 15.094086647033691, 14.021018981933594, -3.696516752243042, -4.517564296722412, 7.261398792266846, 14.924201011657715, 14.270874977111816, 14.383886337280273, 15.01157283782959, 14.971423149108887, 11.884063720703125, 12.111798286437988, 8.01781940460205, 11.253247261047363, 0.5303850769996643, 14.758339881896973, -5.449491024017334, 0.48396316170692444, 13.864641189575195, 9.805276870727539, 14.696816444396973, 3.786200761795044, 10.562129020690918, 3.292829990386963, 15.122862815856934, 14.850698471069336, 3.292829990386963, 13.787893295288086, -4.208994388580322, 12.818134307861328, 14.77609920501709, 14.935830116271973, 14.939279556274414, 14.403119087219238, 15.065840721130371, 4.789445877075195, 1.6841909885406494, -2.8948781490325928, 7.261398792266846, 11.354531288146973, 14.640252113342285, 10.492502212524414, 11.168323516845703, 13.3378267288208, 13.939017295837402, -3.596510887145996, 7.94765043258667, 3.786200761795044, -0.8940779566764832, 10.17502498626709]\n",
      "Accuracy: 0.8487, ROC AUC: 0.9322, Precision: 0.8475, Recall: 0.8475, F1-score: 0.8475\n",
      "Average Positive Score: 8.6423\n",
      "Average Negative Score: -5.3530\n",
      "Absolute Difference: 13.9953\n"
     ]
    }
   ],
   "source": [
    "# Updated Evaluation Function\n",
    "def evaluate_model_with_avg_scores(model, test_data, vocab, threshold=0, new_version=False):\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "    true_labels = []\n",
    "    predicted_rewards = []\n",
    "    positive_scores = []\n",
    "    negative_scores = []\n",
    "\n",
    "    for idx, row in test_data.iterrows():\n",
    "        code = row['code']\n",
    "        label = row['labels']\n",
    "        # Process the code to get states and actions\n",
    "        states, actions = process_new_code(code, vocab, unk_token=\"<state><UNK>\")\n",
    "        states = torch.tensor([states], dtype=torch.float)\n",
    "        actions = torch.tensor([actions], dtype=torch.float)\n",
    "        # Reduce dimension\n",
    "        if new_version:\n",
    "            try:\n",
    "                states = dataset.reduce_dimension(states, dataset.projection_matrix_state)\n",
    "            except Exception as e:\n",
    "                print(code)\n",
    "                print(label)\n",
    "                print(e)\n",
    "                raise Exception(\"Stop here\")\n",
    "            actions = dataset.reduce_dimension(actions, dataset.projection_matrix_action)\n",
    "        states = states.to(device)\n",
    "        actions = actions.to(device)\n",
    "        with torch.no_grad():\n",
    "            rewards = model(states, actions)\n",
    "        true_labels.append(label)\n",
    "        predicted_rewards.append(rewards.item())\n",
    "\n",
    "        # Collect scores for positive and negative samples\n",
    "        if label == 1:\n",
    "            positive_scores.append(rewards.item())\n",
    "        elif label == 0:\n",
    "            negative_scores.append(rewards.item())\n",
    "\n",
    "    # Compute evaluation metrics\n",
    "    from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_fscore_support\n",
    "    predicted_labels = [1 if r >= threshold else 0 for r in predicted_rewards]\n",
    "    print(predicted_rewards)\n",
    "    accuracy = accuracy_score(true_labels, predicted_labels)\n",
    "    roc_auc = roc_auc_score(true_labels, predicted_rewards)\n",
    "    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='binary')\n",
    "\n",
    "    # Calculate average scores\n",
    "    avg_positive_score = sum(positive_scores) / len(positive_scores) if positive_scores else 0\n",
    "    avg_negative_score = sum(negative_scores) / len(negative_scores) if negative_scores else 0\n",
    "    abs_difference = abs(avg_positive_score - avg_negative_score)\n",
    "\n",
    "    # Print results\n",
    "    print(f\"Accuracy: {accuracy:.4f}, ROC AUC: {roc_auc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}\")\n",
    "    print(f\"Average Positive Score: {avg_positive_score:.4f}\")\n",
    "    print(f\"Average Negative Score: {avg_negative_score:.4f}\")\n",
    "    print(f\"Absolute Difference: {abs_difference:.4f}\")\n",
    "\n",
    "evaluate_model_with_avg_scores(reward_model, test_data, vocab, threshold=0, new_version = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid unseen code\n",
      "Input Code:\n",
      "class Solution(object):\n",
      "    def longestPalindrome(self, s):\n",
      "        n=len(s)\n",
      "        dp=[[False]*n for i in range(n)]\n",
      "        start,max_length=0,1\n",
      "        for i in range(n):\n",
      "            dp[i][i]=True\n",
      "        for i in range(n-1):\n",
      "            if s[i]==s[i+1]:\n",
      "                dp[i][i+1]=True\n",
      "                start=i\n",
      "                max_length=2\n",
      "        for length in range(3,n+1):\n",
      "            for i in range(n-length+1):\n",
      "                j=i+length-1\n",
      "                if s[i]==s[j]:\n",
      "                    if dp[i+1][j-1]==True:\n",
      "                        dp[i][j]=True\n",
      "                        start=i\n",
      "                        max_length=length\n",
      "        return s[start:start+max_length]\n",
      "\n",
      "Reward: -10.325618743896484\n"
     ]
    }
   ],
   "source": [
    "# Single Code Test Function\n",
    "def model_single_test(model, code, vocab):\n",
    "    states, actions = process_new_code(code, vocab, unk_token=\"<state><UNK>\")\n",
    "    states = torch.tensor([states], dtype=torch.float)\n",
    "    actions = torch.tensor([actions], dtype=torch.float)\n",
    "    # Reduce dimension\n",
    "    states = dataset.reduce_dimension(states, dataset.projection_matrix_state)\n",
    "    actions = dataset.reduce_dimension(actions, dataset.projection_matrix_action)\n",
    "    states = states.to(device)\n",
    "    actions = actions.to(device)\n",
    "    with torch.no_grad():\n",
    "        rewards = model(states, actions)\n",
    "\n",
    "    print(f\"Input Code:\\n{code}\")\n",
    "    print(f\"Reward: {rewards}\")\n",
    "\n",
    "code_test = \"\"\"class Solution(object):\n",
    "    def longestPalindrome(self, s):\n",
    "        n=len(s)\n",
    "        dp=[[False]*n for i in range(n)]\n",
    "        start,max_length=0,1\n",
    "        for i in range(n):\n",
    "            dp[i][i]=True\n",
    "        for i in range(n-1):\n",
    "            if s[i]==s[i+1]:\n",
    "                dp[i][i+1]=True\n",
    "                start=i\n",
    "                max_length=2\n",
    "        for length in range(3,n+1):\n",
    "            for i in range(n-length+1):\n",
    "                j=i+length-1\n",
    "                if s[i]==s[j]:\n",
    "                    if dp[i+1][j-1]==True:\n",
    "                        dp[i][j]=True\n",
    "                        start=i\n",
    "                        max_length=length\n",
    "        return s[start:start+max_length]\n",
    "\"\"\"\n",
    "\n",
    "if code_test in data[\"code\"].tolist():\n",
    "    print(\"Code already exists in dataset\")\n",
    "else:\n",
    "    print(\"Valid unseen code\")\n",
    "\n",
    "model_single_test(reward_model, code_test, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
