{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AST with GNN for sft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch_geometric.data import Data, Dataset, DataLoader\n",
    "from torch_geometric.nn import GCNConv, global_mean_pool\n",
    "from torch.utils.data import Subset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3.1+cu118\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_valid_code(code):\n",
    "    try:\n",
    "        ast.parse(code)\n",
    "        return True\n",
    "    except SyntaxError:\n",
    "        return False\n",
    "\n",
    "def preprocess_data(data):\n",
    "    # Apply the is_valid_code function to each code entry\n",
    "    data['is_valid'] = data['code'].apply(is_valid_code)\n",
    "    \n",
    "    # Count and print the number of invalid code snippets\n",
    "    invalid_count = data['is_valid'].value_counts().get(False, 0)\n",
    "    print(f\"Number of invalid code snippets: {invalid_count}\")\n",
    "    \n",
    "    # Filter out invalid code entries and reset the index\n",
    "    data = data[data['is_valid']].drop(columns=['is_valid']).reset_index(drop=True)\n",
    "\n",
    "    data['code_list_embedding'] = data['code_list_embedding'].apply(ast.literal_eval)\n",
    "    data['previous_code_embedding'] = data['previous_code_embedding'].apply(ast.literal_eval)\n",
    "    data['code_list'] = data['code_list'].apply(ast.literal_eval)\n",
    "    data['previous_code_list'] = data['previous_code_list'].apply(ast.literal_eval)\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "187\n"
     ]
    }
   ],
   "source": [
    "data = pd.read_csv(r\"D:\\Python_project\\leetcode_data\\data\\leetcode_Median_of_Two_Sorted_Arrays_embedding.csv\")\n",
    "print(len(data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of invalid code snippets: 126\n"
     ]
    }
   ],
   "source": [
    "data = preprocess_data(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test 01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 61/61 [00:00<00:00, 1691.12it/s]\n",
      "d:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch_geometric\\deprecation.py:26: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
      "  warnings.warn(out)\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch_geometric.data import Data, DataLoader\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Assume 'data' is your DataFrame with 'code' and 'runtime' columns\n",
    "# data = pd.read_csv('your_data.csv')\n",
    "\n",
    "# Step 1: Normalize the runtime to be between 0 and 1\n",
    "scaler = MinMaxScaler()\n",
    "data['reward'] = scaler.fit_transform(data[['runtime']])\n",
    "\n",
    "# Step 2: Function to parse code into AST and create edge index and node features\n",
    "def code_to_graph(code_str):\n",
    "    tree = ast.parse(code_str)\n",
    "    nodes = []\n",
    "    edges = []\n",
    "\n",
    "    # Node index mapping\n",
    "    def index_nodes(node, idx=0):\n",
    "        node._idx = idx\n",
    "        nodes.append(node)\n",
    "        current_idx = idx\n",
    "        idx += 1\n",
    "        for child in ast.iter_child_nodes(node):\n",
    "            edges.append((current_idx, idx))\n",
    "            idx = index_nodes(child, idx)\n",
    "        return idx\n",
    "\n",
    "    index_nodes(tree)\n",
    "\n",
    "    # Create node features (e.g., node type)\n",
    "    node_features = []\n",
    "    for node in nodes:\n",
    "        node_type = type(node).__name__\n",
    "        node_type_idx = ast_types.index(node_type)  # ast_types is a list of unique node types\n",
    "        node_features.append([node_type_idx])\n",
    "\n",
    "    # Convert to tensors\n",
    "    x = torch.tensor(node_features, dtype=torch.long)\n",
    "    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
    "    return x, edge_index\n",
    "\n",
    "# Step 3: Prepare the dataset\n",
    "ast_types = sorted({type(ast.parse(code)).__name__ for code in data['code']})\n",
    "ast_types += [node_name for code in data['code'] for node_name in {type(node).__name__ for node in ast.walk(ast.parse(code))}]\n",
    "ast_types = list(set(ast_types))\n",
    "ast_type_to_idx = {t: i for i, t in enumerate(ast_types)}\n",
    "\n",
    "graph_data_list = []\n",
    "for idx, row in tqdm(data.iterrows(), total=len(data)):\n",
    "    code_str = row['code']\n",
    "    reward = row['reward']\n",
    "    try:\n",
    "        x, edge_index = code_to_graph(code_str)\n",
    "        # Convert node features to one-hot encoding\n",
    "        x = F.one_hot(x.view(-1), num_classes=len(ast_types)).to(torch.float)\n",
    "        y = torch.tensor([reward], dtype=torch.float)\n",
    "        graph_data = Data(x=x, edge_index=edge_index, y=y)\n",
    "        graph_data_list.append(graph_data)\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing code at index {idx}: {e}\")\n",
    "\n",
    "# Step 4: Split the dataset\n",
    "train_size = int(0.8 * len(graph_data_list))\n",
    "train_dataset = graph_data_list[:train_size]\n",
    "test_dataset = graph_data_list[train_size:]\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=32)\n",
    "\n",
    "# Step 5: Define the GNN Model\n",
    "class GNN(torch.nn.Module):\n",
    "    def __init__(self, num_node_features):\n",
    "        super(GNN, self).__init__()\n",
    "        self.conv1 = GCNConv(num_node_features, 64)\n",
    "        self.conv2 = GCNConv(64, 32)\n",
    "        self.lin = torch.nn.Linear(32, 1)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = torch.nn.functional.global_mean_pool(x, data.batch)\n",
    "        x = self.lin(x)\n",
    "        return x.squeeze()\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = GNN(num_node_features=len(ast_types)).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "criterion = torch.nn.MSELoss()\n",
    "\n",
    "# Step 6: Train the Model\n",
    "def train():\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        out = model(data)\n",
    "        loss = criterion(out, data.y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item() * data.num_graphs\n",
    "    return total_loss / len(train_dataset)\n",
    "\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for data in loader:\n",
    "            data = data.to(device)\n",
    "            out = model(data)\n",
    "            loss = criterion(out, data.y)\n",
    "            total_loss += loss.item() * data.num_graphs\n",
    "    return total_loss / len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'torch.nn.functional' has no attribute 'global_mean_pool'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[6], line 2\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m101\u001b[39m):\n\u001b[1;32m----> 2\u001b[0m     train_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      3\u001b[0m     test_loss \u001b[38;5;241m=\u001b[39m test(test_loader)\n\u001b[0;32m      4\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n",
      "Cell \u001b[1;32mIn[5], line 106\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m()\u001b[0m\n\u001b[0;32m    104\u001b[0m data \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m    105\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m--> 106\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    107\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(out, data\u001b[38;5;241m.\u001b[39my)\n\u001b[0;32m    108\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32md:\\anaconda3\\envs\\llm\\lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "Cell \u001b[1;32mIn[5], line 90\u001b[0m, in \u001b[0;36mGNN.forward\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m     88\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(x, edge_index)\n\u001b[0;32m     89\u001b[0m x \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mrelu(x)\n\u001b[1;32m---> 90\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mglobal_mean_pool\u001b[49m(x, data\u001b[38;5;241m.\u001b[39mbatch)\n\u001b[0;32m     91\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlin(x)\n\u001b[0;32m     92\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\u001b[38;5;241m.\u001b[39msqueeze()\n",
      "\u001b[1;31mAttributeError\u001b[0m: module 'torch.nn.functional' has no attribute 'global_mean_pool'"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, 101):\n",
    "    train_loss = train()\n",
    "    test_loss = test(test_loader)\n",
    "    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')\n",
    "\n",
    "# Step 7: Save the Model\n",
    "torch.save(model.state_dict(), 'gnn_reward_model.pth')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch_geometric.data import Data, Batch  # Import Batch\n",
    "from torch_geometric.nn import GCNConv, global_mean_pool\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from torch.utils.data import Dataset, Subset\n",
    "from torch_geometric.loader import DataLoader  # Corrected import\n",
    "\n",
    "# Step 1: Parse code into ASTs and convert to graph\n",
    "class CodeGraphDataset(Dataset):\n",
    "    def __init__(self, dataframe):\n",
    "        self.dataframe = dataframe.reset_index(drop=True)\n",
    "        self.scaler = MinMaxScaler()\n",
    "        # Fit scaler on runtime values using numpy arrays\n",
    "        self.scaler.fit(self.dataframe['runtime'].values.reshape(-1, 1))\n",
    "        # Build a vocabulary for AST node types\n",
    "        self.node_type_vocab = self.build_node_type_vocab()\n",
    "\n",
    "    def build_node_type_vocab(self):\n",
    "        node_types = set()\n",
    "        for idx, code in enumerate(self.dataframe['code']):\n",
    "            try:\n",
    "                tree = ast.parse(code)\n",
    "                for node in ast.walk(tree):\n",
    "                    node_types.add(type(node).__name__)\n",
    "            except Exception as e:\n",
    "                print(f\"Error parsing code at index {idx}: {e}\")\n",
    "        node_type_to_id = {nt: idx for idx, nt in enumerate(sorted(node_types))}\n",
    "        return node_type_to_id\n",
    "\n",
    "    def ast_to_graph(self, code):\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "        except Exception as e:\n",
    "            print(f\"Error parsing code: {e}\")\n",
    "            return None\n",
    "\n",
    "        nodes = []\n",
    "        edges = []\n",
    "        node_features = []\n",
    "        node_id = 0\n",
    "        node_id_map = {}\n",
    "\n",
    "        def traverse(node, parent_id=None):\n",
    "            nonlocal node_id\n",
    "            current_id = node_id\n",
    "            node_id_map[id(node)] = current_id\n",
    "            nodes.append(current_id)\n",
    "            # Encode node type as integer\n",
    "            node_type = type(node).__name__\n",
    "            node_type_id = self.node_type_vocab.get(node_type, len(self.node_type_vocab))  # Handle unknown types\n",
    "            node_features.append([node_type_id])\n",
    "            node_id += 1\n",
    "\n",
    "            if parent_id is not None:\n",
    "                edges.append((parent_id, current_id))\n",
    "\n",
    "            for child in ast.iter_child_nodes(node):\n",
    "                traverse(child, current_id)\n",
    "\n",
    "        traverse(tree)\n",
    "\n",
    "        if not nodes:\n",
    "            return None\n",
    "\n",
    "        # Convert edges to a tensor\n",
    "        if edges:\n",
    "            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
    "        else:\n",
    "            edge_index = torch.empty((2, 0), dtype=torch.long)\n",
    "\n",
    "        # Convert node features to a tensor\n",
    "        x = torch.tensor(node_features, dtype=torch.long)\n",
    "\n",
    "        # Create a Data object\n",
    "        data = Data(x=x, edge_index=edge_index)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataframe)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        row = self.dataframe.iloc[idx]\n",
    "        code = row['code']\n",
    "        runtime = row['runtime']\n",
    "\n",
    "        graph = self.ast_to_graph(code)\n",
    "        if graph is None:\n",
    "            # Return an empty graph with zero runtime if parsing failed\n",
    "            graph = Data(x=torch.zeros((1, 1), dtype=torch.long), edge_index=torch.empty((2, 0), dtype=torch.long))\n",
    "            runtime_normalized = 0.0\n",
    "        else:\n",
    "            # Normalize runtime using numpy arrays\n",
    "            runtime_normalized = self.scaler.transform([[runtime]]).flatten()[0]\n",
    "\n",
    "        graph.y = torch.tensor([runtime_normalized], dtype=torch.float)\n",
    "        return graph\n",
    "\n",
    "# Step 2: Instantiate the dataset and create train/test splits\n",
    "dataset = CodeGraphDataset(data)\n",
    "\n",
    "# Split indices\n",
    "train_indices, test_indices = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)\n",
    "\n",
    "# Create subsets using the corrected Subset import\n",
    "train_dataset = Subset(dataset, train_indices)\n",
    "test_dataset = Subset(dataset, test_indices)\n",
    "\n",
    "# Step 3: Create DataLoaders using torch_geometric.loader.DataLoader\n",
    "batch_size = 32\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "# Step 4: Define the GNN model\n",
    "class GNNModel(nn.Module):\n",
    "    def __init__(self, num_node_types, embed_dim=64, hidden_dim=128):\n",
    "        super(GNNModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(num_node_types + 1, embed_dim)  # +1 for unknown types\n",
    "        self.conv1 = GCNConv(embed_dim, hidden_dim)\n",
    "        self.conv2 = GCNConv(hidden_dim, hidden_dim)\n",
    "        self.fc1 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, 1)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index, batch = data.x, data.edge_index, data.batch\n",
    "\n",
    "        # Handle unknown node types\n",
    "        x = self.embedding(x.squeeze())  # Shape: [num_nodes, embed_dim]\n",
    "\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = self.relu(x)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        x = self.relu(x)\n",
    "\n",
    "        # Global pooling\n",
    "        x = global_mean_pool(x, batch)  # Shape: [batch_size, hidden_dim]\n",
    "\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = torch.sigmoid(x)  # Ensure output is between 0 and 1\n",
    "\n",
    "        return x.squeeze()\n",
    "\n",
    "# Step 5: Initialize the model, loss function, and optimizer\n",
    "num_node_types = len(dataset.node_type_vocab)\n",
    "model = GNNModel(num_node_types=num_node_types)\n",
    "model = model.to(device)  # Ensure model is on the correct device\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Step 6: Training Loop\n",
    "def train(model, loader, criterion, optimizer, device):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for batch in loader:\n",
    "        batch = batch.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        preds = model(batch)\n",
    "        loss = criterion(preds, batch.y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item() * batch.num_graphs\n",
    "    return total_loss / len(loader.dataset)\n",
    "\n",
    "# Step 7: Evaluation Loop\n",
    "def evaluate(model, loader, criterion, device):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for batch in loader:\n",
    "            batch = batch.to(device)\n",
    "            preds = model(batch)\n",
    "            loss = criterion(preds, batch.y)\n",
    "            total_loss += loss.item() * batch.num_graphs\n",
    "    return total_loss / len(loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: Train Loss = 0.0594, Test Loss = 0.0484\n",
      "Epoch 2: Train Loss = 0.0412, Test Loss = 0.0362\n",
      "Epoch 3: Train Loss = 0.0303, Test Loss = 0.0296\n",
      "Epoch 4: Train Loss = 0.0246, Test Loss = 0.0278\n",
      "Epoch 5: Train Loss = 0.0246, Test Loss = 0.0279\n",
      "Epoch 6: Train Loss = 0.0244, Test Loss = 0.0257\n",
      "Epoch 7: Train Loss = 0.0224, Test Loss = 0.0227\n",
      "Epoch 8: Train Loss = 0.0197, Test Loss = 0.0207\n",
      "Epoch 9: Train Loss = 0.0186, Test Loss = 0.0203\n",
      "Epoch 10: Train Loss = 0.0185, Test Loss = 0.0203\n",
      "Epoch 11: Train Loss = 0.0188, Test Loss = 0.0199\n",
      "Epoch 12: Train Loss = 0.0184, Test Loss = 0.0188\n",
      "Epoch 13: Train Loss = 0.0177, Test Loss = 0.0181\n",
      "Epoch 14: Train Loss = 0.0172, Test Loss = 0.0178\n",
      "Epoch 15: Train Loss = 0.0175, Test Loss = 0.0178\n",
      "Epoch 16: Train Loss = 0.0175, Test Loss = 0.0177\n",
      "Epoch 17: Train Loss = 0.0174, Test Loss = 0.0175\n",
      "Epoch 18: Train Loss = 0.0173, Test Loss = 0.0174\n",
      "Epoch 19: Train Loss = 0.0172, Test Loss = 0.0175\n",
      "Epoch 20: Train Loss = 0.0172, Test Loss = 0.0176\n",
      "Epoch 21: Train Loss = 0.0172, Test Loss = 0.0176\n",
      "Epoch 22: Train Loss = 0.0171, Test Loss = 0.0175\n",
      "Epoch 23: Train Loss = 0.0170, Test Loss = 0.0174\n",
      "Epoch 24: Train Loss = 0.0171, Test Loss = 0.0173\n",
      "Epoch 25: Train Loss = 0.0170, Test Loss = 0.0173\n",
      "Epoch 26: Train Loss = 0.0170, Test Loss = 0.0173\n",
      "Epoch 27: Train Loss = 0.0170, Test Loss = 0.0173\n",
      "Epoch 28: Train Loss = 0.0170, Test Loss = 0.0173\n",
      "Epoch 29: Train Loss = 0.0170, Test Loss = 0.0172\n",
      "Epoch 30: Train Loss = 0.0169, Test Loss = 0.0170\n",
      "Epoch 31: Train Loss = 0.0170, Test Loss = 0.0169\n",
      "Epoch 32: Train Loss = 0.0169, Test Loss = 0.0168\n",
      "Epoch 33: Train Loss = 0.0168, Test Loss = 0.0168\n",
      "Epoch 34: Train Loss = 0.0168, Test Loss = 0.0168\n",
      "Epoch 35: Train Loss = 0.0169, Test Loss = 0.0168\n",
      "Epoch 36: Train Loss = 0.0169, Test Loss = 0.0168\n",
      "Epoch 37: Train Loss = 0.0169, Test Loss = 0.0168\n",
      "Epoch 38: Train Loss = 0.0168, Test Loss = 0.0167\n",
      "Epoch 39: Train Loss = 0.0167, Test Loss = 0.0165\n",
      "Epoch 40: Train Loss = 0.0168, Test Loss = 0.0165\n",
      "Epoch 41: Train Loss = 0.0168, Test Loss = 0.0165\n",
      "Epoch 42: Train Loss = 0.0166, Test Loss = 0.0167\n",
      "Epoch 43: Train Loss = 0.0167, Test Loss = 0.0172\n",
      "Epoch 44: Train Loss = 0.0170, Test Loss = 0.0174\n",
      "Epoch 45: Train Loss = 0.0169, Test Loss = 0.0170\n",
      "Epoch 46: Train Loss = 0.0167, Test Loss = 0.0166\n",
      "Epoch 47: Train Loss = 0.0165, Test Loss = 0.0163\n",
      "Epoch 48: Train Loss = 0.0168, Test Loss = 0.0163\n",
      "Epoch 49: Train Loss = 0.0168, Test Loss = 0.0162\n",
      "Epoch 50: Train Loss = 0.0166, Test Loss = 0.0162\n"
     ]
    }
   ],
   "source": [
    "# Step 8: Training Process\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = model.to(device)\n",
    "\n",
    "epochs = 50\n",
    "best_test_loss = float('inf')\n",
    "\n",
    "for epoch in range(1, epochs + 1):\n",
    "    train_loss = train(model, train_loader, criterion, optimizer, device)\n",
    "    test_loss = evaluate(model, test_loader, criterion, device)\n",
    "    if test_loss < best_test_loss:\n",
    "        best_test_loss = test_loss\n",
    "        torch.save(model.state_dict(), 'best_gnn_model.pth')\n",
    "    print(f\"Epoch {epoch}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted Runtime: 58.38628762960434\n"
     ]
    }
   ],
   "source": [
    "# Step 9: Loading the Best Model and Making Predictions\n",
    "model.load_state_dict(torch.load('best_gnn_model.pth'))\n",
    "model.eval()\n",
    "\n",
    "def predict_runtime(code_snippet):\n",
    "    # Convert the code snippet to a graph\n",
    "    graph = dataset.ast_to_graph(code_snippet)\n",
    "    \n",
    "    if graph is None:\n",
    "        # Return 0.0 if parsing fails\n",
    "        return 0.0\n",
    "    \n",
    "    # Create a batch containing the single graph\n",
    "    batch = Batch.from_data_list([graph]).to(device)\n",
    "    \n",
    "    # Debugging: Print device information (optional)\n",
    "    # print(f\"Batch.x is on device: {batch.x.device}\")\n",
    "    # print(f\"Batch.edge_index is on device: {batch.edge_index.device}\")\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        # Get the model's prediction\n",
    "        prediction = model(batch)\n",
    "    \n",
    "    # Inverse transform to get actual runtime\n",
    "    runtime_pred = dataset.scaler.inverse_transform([[prediction.item()]]).flatten()[0]\n",
    "    \n",
    "    return runtime_pred\n",
    "\n",
    "# Example usage:\n",
    "new_code_temp = \"\"\"\n",
    "def compute_sum(n):\n",
    "    total = 0\n",
    "    return total\n",
    "\"\"\"\n",
    "\n",
    "predicted_runtime = predict_runtime(new_code_temp)\n",
    "print(f\"Predicted Runtime: {predicted_runtime}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test 02"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:Using device: cuda\n",
      "INFO:Filtered data to 61 valid samples out of 61.\n",
      "INFO:Runtime values scaled using MinMaxScaler.\n",
      "INFO:Built node type vocabulary with size: 48\n",
      "INFO:Training samples: 48, Testing samples: 13\n",
      "INFO:DataLoaders created with batch size 32.\n",
      "INFO:Model, loss function, and optimizer initialized.\n",
      "INFO:Epoch 01/50 | Train Loss: 0.0632 | Test Loss: 0.0547\n",
      "INFO:Best model saved with Test Loss: 0.0547\n",
      "INFO:Epoch 02/50 | Train Loss: 0.0478 | Test Loss: 0.0416\n",
      "INFO:Best model saved with Test Loss: 0.0416\n",
      "INFO:Epoch 03/50 | Train Loss: 0.0357 | Test Loss: 0.0317\n",
      "INFO:Best model saved with Test Loss: 0.0317\n",
      "INFO:Epoch 04/50 | Train Loss: 0.0277 | Test Loss: 0.0263\n",
      "INFO:Best model saved with Test Loss: 0.0263\n",
      "INFO:Epoch 05/50 | Train Loss: 0.0234 | Test Loss: 0.0241\n",
      "INFO:Best model saved with Test Loss: 0.0241\n",
      "INFO:Epoch 06/50 | Train Loss: 0.0213 | Test Loss: 0.0225\n",
      "INFO:Best model saved with Test Loss: 0.0225\n",
      "INFO:Epoch 07/50 | Train Loss: 0.0209 | Test Loss: 0.0204\n",
      "INFO:Best model saved with Test Loss: 0.0204\n",
      "INFO:Epoch 08/50 | Train Loss: 0.0190 | Test Loss: 0.0187\n",
      "INFO:Best model saved with Test Loss: 0.0187\n",
      "INFO:Epoch 09/50 | Train Loss: 0.0178 | Test Loss: 0.0183\n",
      "INFO:Best model saved with Test Loss: 0.0183\n",
      "INFO:Epoch 10/50 | Train Loss: 0.0181 | Test Loss: 0.0184\n",
      "INFO:Epoch 11/50 | Train Loss: 0.0179 | Test Loss: 0.0178\n",
      "INFO:Best model saved with Test Loss: 0.0178\n",
      "INFO:Epoch 12/50 | Train Loss: 0.0176 | Test Loss: 0.0172\n",
      "INFO:Best model saved with Test Loss: 0.0172\n",
      "INFO:Epoch 13/50 | Train Loss: 0.0175 | Test Loss: 0.0172\n",
      "INFO:Best model saved with Test Loss: 0.0172\n",
      "INFO:Epoch 14/50 | Train Loss: 0.0176 | Test Loss: 0.0172\n",
      "INFO:Epoch 15/50 | Train Loss: 0.0176 | Test Loss: 0.0173\n",
      "INFO:Epoch 16/50 | Train Loss: 0.0177 | Test Loss: 0.0172\n",
      "INFO:Epoch 17/50 | Train Loss: 0.0175 | Test Loss: 0.0172\n",
      "INFO:Best model saved with Test Loss: 0.0172\n",
      "INFO:Epoch 18/50 | Train Loss: 0.0173 | Test Loss: 0.0174\n",
      "INFO:Epoch 19/50 | Train Loss: 0.0172 | Test Loss: 0.0175\n",
      "INFO:Epoch 20/50 | Train Loss: 0.0172 | Test Loss: 0.0176\n",
      "INFO:Epoch 21/50 | Train Loss: 0.0172 | Test Loss: 0.0177\n",
      "INFO:Epoch 22/50 | Train Loss: 0.0172 | Test Loss: 0.0178\n",
      "INFO:Epoch 23/50 | Train Loss: 0.0173 | Test Loss: 0.0179\n",
      "INFO:Epoch 24/50 | Train Loss: 0.0172 | Test Loss: 0.0181\n",
      "INFO:Epoch 25/50 | Train Loss: 0.0176 | Test Loss: 0.0182\n",
      "INFO:Epoch 26/50 | Train Loss: 0.0174 | Test Loss: 0.0177\n",
      "INFO:Epoch 27/50 | Train Loss: 0.0172 | Test Loss: 0.0172\n",
      "INFO:Epoch 28/50 | Train Loss: 0.0174 | Test Loss: 0.0169\n",
      "INFO:Best model saved with Test Loss: 0.0169\n",
      "INFO:Epoch 29/50 | Train Loss: 0.0170 | Test Loss: 0.0169\n",
      "INFO:Best model saved with Test Loss: 0.0169\n",
      "INFO:Epoch 30/50 | Train Loss: 0.0170 | Test Loss: 0.0168\n",
      "INFO:Best model saved with Test Loss: 0.0168\n",
      "INFO:Epoch 31/50 | Train Loss: 0.0169 | Test Loss: 0.0168\n",
      "INFO:Epoch 32/50 | Train Loss: 0.0170 | Test Loss: 0.0168\n",
      "INFO:Best model saved with Test Loss: 0.0168\n",
      "INFO:Epoch 33/50 | Train Loss: 0.0169 | Test Loss: 0.0168\n",
      "INFO:Epoch 34/50 | Train Loss: 0.0169 | Test Loss: 0.0167\n",
      "INFO:Best model saved with Test Loss: 0.0167\n",
      "INFO:Epoch 35/50 | Train Loss: 0.0170 | Test Loss: 0.0166\n",
      "INFO:Best model saved with Test Loss: 0.0166\n",
      "INFO:Epoch 36/50 | Train Loss: 0.0168 | Test Loss: 0.0165\n",
      "INFO:Best model saved with Test Loss: 0.0165\n",
      "INFO:Epoch 37/50 | Train Loss: 0.0167 | Test Loss: 0.0163\n",
      "INFO:Best model saved with Test Loss: 0.0163\n",
      "INFO:Epoch 38/50 | Train Loss: 0.0167 | Test Loss: 0.0162\n",
      "INFO:Best model saved with Test Loss: 0.0162\n",
      "INFO:Epoch 39/50 | Train Loss: 0.0167 | Test Loss: 0.0161\n",
      "INFO:Best model saved with Test Loss: 0.0161\n",
      "INFO:Epoch 40/50 | Train Loss: 0.0168 | Test Loss: 0.0159\n",
      "INFO:Best model saved with Test Loss: 0.0159\n",
      "INFO:Epoch 41/50 | Train Loss: 0.0168 | Test Loss: 0.0158\n",
      "INFO:Best model saved with Test Loss: 0.0158\n",
      "INFO:Epoch 42/50 | Train Loss: 0.0167 | Test Loss: 0.0157\n",
      "INFO:Best model saved with Test Loss: 0.0157\n",
      "INFO:Epoch 43/50 | Train Loss: 0.0167 | Test Loss: 0.0157\n",
      "INFO:Epoch 44/50 | Train Loss: 0.0168 | Test Loss: 0.0157\n",
      "INFO:Best model saved with Test Loss: 0.0157\n",
      "INFO:Epoch 45/50 | Train Loss: 0.0169 | Test Loss: 0.0156\n",
      "INFO:Best model saved with Test Loss: 0.0156\n",
      "INFO:Epoch 46/50 | Train Loss: 0.0166 | Test Loss: 0.0157\n",
      "INFO:Epoch 47/50 | Train Loss: 0.0166 | Test Loss: 0.0160\n",
      "INFO:Epoch 48/50 | Train Loss: 0.0164 | Test Loss: 0.0158\n",
      "INFO:Epoch 49/50 | Train Loss: 0.0164 | Test Loss: 0.0157\n",
      "INFO:Epoch 50/50 | Train Loss: 0.0164 | Test Loss: 0.0158\n",
      "INFO:Training completed.\n"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch_geometric.data import Data, Batch\n",
    "from torch_geometric.nn import GCNConv, global_mean_pool\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from torch.utils.data import Dataset, Subset\n",
    "from torch_geometric.loader import DataLoader\n",
    "import logging\n",
    "import sys\n",
    "\n",
    "# Configure logging\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(levelname)s:%(message)s')\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "logging.info(f'Using device: {device}')\n",
    "\n",
    "class CodeGraphDataset(Dataset):\n",
    "    def __init__(self, dataframe):\n",
    "        self.dataframe = dataframe.reset_index(drop=True)\n",
    "        self.scaler = MinMaxScaler()\n",
    "        # Fit scaler on runtime values using numpy arrays\n",
    "        self.scaler.fit(self.dataframe['runtime'].values.reshape(-1, 1))\n",
    "        logging.info('Runtime values scaled using MinMaxScaler.')\n",
    "        # Build a vocabulary for AST node types\n",
    "        self.node_type_vocab = self.build_node_type_vocab()\n",
    "        logging.info(f'Built node type vocabulary with size: {len(self.node_type_vocab)}')\n",
    "\n",
    "    def build_node_type_vocab(self):\n",
    "        node_types = set()\n",
    "        for idx, code in enumerate(self.dataframe['code']):\n",
    "            try:\n",
    "                tree = ast.parse(code)\n",
    "                for node in ast.walk(tree):\n",
    "                    node_types.add(type(node).__name__)\n",
    "            except Exception as e:\n",
    "                logging.warning(f\"Error parsing code at index {idx}: {e}\")\n",
    "        node_type_to_id = {nt: idx for idx, nt in enumerate(sorted(node_types))}\n",
    "        return node_type_to_id\n",
    "\n",
    "    def ast_to_graph(self, code):\n",
    "        try:\n",
    "            tree = ast.parse(code)\n",
    "        except Exception as e:\n",
    "            logging.warning(f\"Error parsing code: {e}\")\n",
    "            return None\n",
    "\n",
    "        nodes = []\n",
    "        edges = []\n",
    "        node_features = []\n",
    "        node_id = 0\n",
    "        node_id_map = {}\n",
    "\n",
    "        def traverse(node, parent_id=None):\n",
    "            nonlocal node_id\n",
    "            current_id = node_id\n",
    "            node_id_map[id(node)] = current_id\n",
    "            nodes.append(current_id)\n",
    "            # Encode node type as integer\n",
    "            node_type = type(node).__name__\n",
    "            node_type_id = self.node_type_vocab.get(node_type, len(self.node_type_vocab))  # Handle unknown types\n",
    "            node_features.append([node_type_id])\n",
    "            node_id += 1\n",
    "\n",
    "            if parent_id is not None:\n",
    "                edges.append((parent_id, current_id))\n",
    "\n",
    "            for child in ast.iter_child_nodes(node):\n",
    "                traverse(child, current_id)\n",
    "\n",
    "        traverse(tree)\n",
    "\n",
    "        if not nodes:\n",
    "            return None\n",
    "\n",
    "        # Convert edges to a tensor\n",
    "        if edges:\n",
    "            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
    "        else:\n",
    "            edge_index = torch.empty((2, 0), dtype=torch.long)\n",
    "\n",
    "        # Convert node features to a tensor\n",
    "        x = torch.tensor(node_features, dtype=torch.long)\n",
    "\n",
    "        # Create a Data object\n",
    "        data = Data(x=x, edge_index=edge_index)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataframe)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        row = self.dataframe.iloc[idx]\n",
    "        code = row['code']\n",
    "        runtime = row['runtime']\n",
    "\n",
    "        graph = self.ast_to_graph(code)\n",
    "        if graph is None:\n",
    "            # Skip samples with parsing errors by raising an exception\n",
    "            # Alternatively, implement a different handling strategy\n",
    "            logging.debug(f\"Skipping index {idx} due to parsing error.\")\n",
    "            raise ValueError(f\"Parsing failed for code at index {idx}.\")\n",
    "\n",
    "        # Normalize runtime using the scaler\n",
    "        runtime_normalized = self.scaler.transform([[runtime]]).flatten()[0]\n",
    "\n",
    "        graph.y = torch.tensor([runtime_normalized], dtype=torch.float)\n",
    "        return graph\n",
    "\n",
    "# Filter out samples with parsing errors\n",
    "valid_indices = []\n",
    "for idx in range(len(data)):\n",
    "    try:\n",
    "        # Attempt to parse to check validity\n",
    "        ast.parse(data.loc[idx, 'code'])\n",
    "        valid_indices.append(idx)\n",
    "    except Exception as e:\n",
    "        logging.warning(f\"Excluding index {idx} due to parsing error: {e}\")\n",
    "\n",
    "filtered_data = data.iloc[valid_indices].reset_index(drop=True)\n",
    "logging.info(f'Filtered data to {len(filtered_data)} valid samples out of {len(data)}.')\n",
    "\n",
    "# Instantiate the dataset\n",
    "dataset = CodeGraphDataset(filtered_data)\n",
    "\n",
    "# Split indices\n",
    "train_indices, test_indices = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)\n",
    "logging.info(f'Training samples: {len(train_indices)}, Testing samples: {len(test_indices)}')\n",
    "\n",
    "# Create subsets using the corrected Subset import\n",
    "train_dataset = Subset(dataset, train_indices)\n",
    "test_dataset = Subset(dataset, test_indices)\n",
    "\n",
    "# Step 3: Create DataLoaders using torch_geometric.loader.DataLoader\n",
    "batch_size = 32\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "logging.info(f'DataLoaders created with batch size {batch_size}.')\n",
    "\n",
    "# Step 4: Define the GNN model\n",
    "class GNNModel(nn.Module):\n",
    "    def __init__(self, num_node_types, embed_dim=64, hidden_dim=128):\n",
    "        super(GNNModel, self).__init__()\n",
    "        self.embedding = nn.Embedding(num_node_types + 1, embed_dim)  # +1 for unknown types\n",
    "        self.conv1 = GCNConv(embed_dim, hidden_dim)\n",
    "        self.conv2 = GCNConv(hidden_dim, hidden_dim)\n",
    "        self.fc1 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, 1)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index, batch = data.x, data.edge_index, data.batch\n",
    "\n",
    "        # Handle unknown node types\n",
    "        x = self.embedding(x.squeeze())  # Shape: [num_nodes, embed_dim]\n",
    "\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = self.relu(x)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        x = self.relu(x)\n",
    "\n",
    "        # Global pooling\n",
    "        x = global_mean_pool(x, batch)  # Shape: [batch_size, hidden_dim]\n",
    "\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        # Remove sigmoid if runtime is not bounded between 0 and 1\n",
    "        x = torch.sigmoid(x)  # Ensure output is between 0 and 1\n",
    "\n",
    "        return x.squeeze()\n",
    "\n",
    "# Step 5: Initialize the model, loss function, and optimizer\n",
    "num_node_types = len(dataset.node_type_vocab)\n",
    "model = GNNModel(num_node_types=num_node_types)\n",
    "model = model.to(device)\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "logging.info('Model, loss function, and optimizer initialized.')\n",
    "\n",
    "# Step 6: Training Loop\n",
    "def train_epoch(model, loader, criterion, optimizer, device):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for batch in loader:\n",
    "        batch = batch.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        preds = model(batch)\n",
    "        loss = criterion(preds, batch.y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item() * batch.num_graphs\n",
    "    average_loss = total_loss / len(loader.dataset)\n",
    "    return average_loss\n",
    "\n",
    "# Step 7: Evaluation Loop\n",
    "def evaluate(model, loader, criterion, device):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for batch in loader:\n",
    "            batch = batch.to(device)\n",
    "            preds = model(batch)\n",
    "            loss = criterion(preds, batch.y)\n",
    "            total_loss += loss.item() * batch.num_graphs\n",
    "    average_loss = total_loss / len(loader.dataset)\n",
    "    return average_loss\n",
    "\n",
    "# Training configuration\n",
    "num_epochs = 50\n",
    "best_test_loss = float('inf')\n",
    "\n",
    "for epoch in range(1, num_epochs + 1):\n",
    "    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)\n",
    "    test_loss = evaluate(model, test_loader, criterion, device)\n",
    "    logging.info(f'Epoch {epoch:02d}/{num_epochs} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}')\n",
    "    \n",
    "    # Save the best model\n",
    "    if test_loss < best_test_loss:\n",
    "        best_test_loss = test_loss\n",
    "        torch.save(model.state_dict(), 'test_gnn_model.pth')\n",
    "        logging.info(f'Best model saved with Test Loss: {best_test_loss:.4f}')\n",
    "\n",
    "logging.info('Training completed.')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### sequential input test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HDHGN test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
