{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ff657f38-0f3a-44d1-bd0b-28a6de394cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import torch\n",
    "from torch_geometric.datasets import QM9\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import radius_graph\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse, softmax\n",
    "from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder\n",
    "import numpy as np\n",
    "import itertools\n",
    "import random\n",
    "import math\n",
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f1ec323f-37c5-4d8c-bd83-b4f88c4c42f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6631d5b7-5917-4773-b38b-b6d4110ffb65",
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed(seed=42):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5ae5d7b7-75d1-43fa-9276-88c82b27bb33",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GenerateGraph(object):\n",
    "    def __call__(self, data):\n",
    "        device = data.pos.device  # Ensure device compatibility\n",
    "        \n",
    "        # --- Pair-wise interactions ---\n",
    "        num_nodes = data.num_nodes\n",
    "        row = torch.arange(num_nodes, dtype=torch.long, device=device)\n",
    "        col = torch.arange(num_nodes, dtype=torch.long, device=device)\n",
    "\n",
    "        row = row.view(-1, 1).repeat(1, num_nodes).view(-1)\n",
    "        col = col.repeat(num_nodes)\n",
    "        edge_index = torch.stack([row, col], dim=0)\n",
    "\n",
    "        edge_attr = None\n",
    "        if data.edge_attr is not None:\n",
    "            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]\n",
    "            size = list(data.edge_attr.size())\n",
    "            size[0] = data.num_nodes * data.num_nodes\n",
    "            edge_attr = data.edge_attr.new_zeros(size)\n",
    "            edge_attr[idx] = data.edge_attr\n",
    "\n",
    "        # Remove self-loops from pair-wise interactions\n",
    "        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)\n",
    "\n",
    "        data.edge_attr = edge_attr\n",
    "        data.edge_index = edge_index\n",
    "\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "78a114e0-7310-4740-9667-48b33b48cd11",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = T.Compose([GenerateGraph()])\n",
    "\n",
    "# Load the QM9 dataset with the transforms defined\n",
    "dataset = QM9(root='./data/att_qm9/QM9', pre_transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dcfadbe5-f2a4-432e-9bba-646b0b7432c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.data.y = dataset.data.y[:, 1]\n",
    "mean = dataset.data.y.mean(dim=0, keepdim=True)\n",
    "std = dataset.data.y.std(dim=0, keepdim=True)\n",
    "dataset.data.y = (dataset.data.y - mean) / std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "058e6a47-16d9-4405-8bb8-04254a5a547c",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_triplet_size = 0\n",
    "data_list = []\n",
    "\n",
    "for index, data in enumerate(dataset):\n",
    "    pos =  data.pos\n",
    "    pairwise_distances = torch.norm(pos[data.edge_index[0]] - pos[data.edge_index[1]], dim=-1, keepdim=True)\n",
    "    \n",
    "    # --- Triplet interactions (i,j,k) ---\n",
    "    triplet_indices = []\n",
    "    triplet_graph_index = []\n",
    "    for i in range(data.num_nodes):\n",
    "        neighbors = data.edge_index[1][data.edge_index[0] == i]  # Neighbors of atom i\n",
    "        if len(neighbors) < 2:\n",
    "            # Create a dummy triplet if not enough neighbors\n",
    "            if len(neighbors) == 1:\n",
    "                neighbor = neighbors[0]\n",
    "                triplet_indices.append((i, neighbor.item(), neighbor.item()))  # Triplet with one neighbor\n",
    "            else:\n",
    "                triplet_indices.append((i, i, i))  # Self-referencing triplet for no neighbors\n",
    "            continue\n",
    "        \n",
    "        # Generate triplets (i,j,k)\n",
    "        triplet_indices += [(i, j, k) for j, k in itertools.combinations(neighbors.tolist(), 2)]\n",
    "        triplet_graph_index+=[index for _, _ in itertools.combinations(neighbors.tolist(), 2)]\n",
    "        \n",
    "    if triplet_indices:\n",
    "        triplet_indices = torch.tensor(triplet_indices, dtype=torch.long).t()\n",
    "        triplet_graph_index = torch.tensor(triplet_graph_index, dtype=torch.long)\n",
    "\n",
    "        max_triplet_size = max(max_triplet_size, triplet_indices.size(1))  # Update max triplet size\n",
    "        \n",
    "        pos_i = pos[triplet_indices[0]]  # Central atom positions\n",
    "        pos_j = pos[triplet_indices[1]]  # Neighbor j positions\n",
    "        pos_k = pos[triplet_indices[2]]  # Neighbor k positions\n",
    "    \n",
    "        vec_ij = pos_j - pos_i\n",
    "        vec_ik = pos_k - pos_i\n",
    "        vec_jk = pos_j - pos_k\n",
    "    \n",
    "        norm_ij = torch.norm(vec_ij, dim=-1, keepdim=True)\n",
    "        norm_ik = torch.norm(vec_ik, dim=-1, keepdim=True)\n",
    "        norm_jk = torch.norm(vec_jk, dim=-1, keepdim=True)\n",
    "    \n",
    "        cos_theta_ijk = torch.sum(vec_ij * vec_ik, dim=-1, keepdim=True) / (norm_ij * norm_ik)\n",
    "    \n",
    "    data.pairwise_distances = pairwise_distances\n",
    "    data.triplet_indices = triplet_indices.T\n",
    "    data.triplet_distances = torch.cat((norm_ij, norm_ik, norm_jk), dim=1)\n",
    "    data.cos_theta_ijk = cos_theta_ijk\n",
    "    data.triplet_batch = triplet_graph_index\n",
    "    \n",
    "    del triplet_indices\n",
    "    del pairwise_distances\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    data_list.append(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "95382f5f-e43b-475e-8006-ffb7d1929a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_list = torch.load('data_list_qm9.pt')\n",
    "# torch.save(data_list, 'data_list_qm9.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a6f9b8b9-240d-41b8-91c3-b5cc6bd77061",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data split\n",
    "data_size = 100000\n",
    "train_index = int(data_size * 0.8)\n",
    "val_index = train_index + int(data_size * 0.1)\n",
    "test_index = val_index + int(data_size * 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4f3d3e91-533a-4997-bd80-dfb4a6b51896",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_list = []\n",
    "valid_data_list = []\n",
    "test_data_list = []\n",
    "\n",
    "for index in np.arange(train_index):\n",
    "    train_data_list.append(data_list[index])\n",
    "\n",
    "for index in np.arange(val_index):\n",
    "    valid_data_list.append(data_list[index])\n",
    "\n",
    "for index in np.arange(test_index):\n",
    "    test_data_list.append(data_list[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1d35dec1-e7cf-40c8-b4af-f3a04d947da4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(train_data_list, batch_size=16, shuffle=True)\n",
    "test_loader = DataLoader(test_data_list, batch_size=16, shuffle=False)\n",
    "val_loader = DataLoader(valid_data_list, batch_size=16, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f3626442-364d-49d6-9cde-615ebfe30b06",
   "metadata": {},
   "outputs": [],
   "source": [
    "del data_list\n",
    "del train_data_list\n",
    "del valid_data_list\n",
    "del test_data_list\n",
    "del dataset\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e293c6b5-bd4a-4d73-94dc-efd46c4a7a40",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TripletMessagePassing(MessagePassing):\n",
    "    def __init__(self, in_channels, in_q_channels, in_k_channels, out_channels):\n",
    "        super(TripletMessagePassing, self).__init__(aggr='mean') \n",
    "        self.linQ = torch.nn.Linear(in_q_channels, out_channels).to(device)\n",
    "        self.linK = torch.nn.Linear(in_k_channels, out_channels).to(device)\n",
    "        self.gamma = torch.nn.Parameter(torch.Tensor(1)).to(device)\n",
    "\n",
    "        # Gate mechanism for higher-order interaction\n",
    "        self.gru = torch.nn.GRU(input_size=3 * in_channels, hidden_size=1, batch_first=True).to(device)\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        stdv_gamma = 0.1 / math.sqrt(self.gamma.size(0))\n",
    "        torch.nn.init.constant_(self.gamma, stdv_gamma)\n",
    "\n",
    "    def forward(self, x, triplet_indices, triplet_attr):\n",
    "        edge_index = triplet_indices[:2]\n",
    "        \n",
    "        return self.propagate(edge_index, x=x, triplet_indices=triplet_indices, triplet_attr=triplet_attr)\n",
    "\n",
    "    def message(self, x_j, triplet_indices, triplet_attr):\n",
    "        source_nodes = triplet_indices[0]\n",
    "        neighbor1_indices = triplet_indices[1]\n",
    "        neighbor2_indices = triplet_indices[2]\n",
    "\n",
    "        q_i = self.linQ(x_j[source_nodes])\n",
    "        k_i = self.linK(triplet_attr.to(device))\n",
    "\n",
    "        with torch.cuda.amp.autocast():\n",
    "            alpha_ijk = (q_i * k_i)\n",
    "            alpha_ijk = alpha_ijk.sum(dim=-1)\n",
    "            alph_ijk = F.softmax(alpha_ijk).view(-1, 1)\n",
    "            \n",
    "        # Apply gate to modulate higher-order interactions\n",
    "        concatenated_inputs = torch.cat([x_j[source_nodes], x_j[neighbor1_indices], x_j[neighbor2_indices]], dim=-1).unsqueeze(1)  # Add a time dimension\n",
    "        gate, _ = self.gru(concatenated_inputs)\n",
    "        gate = torch.sigmoid(gate[:, -1, :])  # Get the last output and apply sigmoid\n",
    "\n",
    "        radius_ij = torch.exp(-1*(triplet_attr[:, 0]-0.01))**2\n",
    "        radius_ik = torch.exp(-1*(triplet_attr[:, 1]-0.01))**2\n",
    "        radius_jk = torch.exp(-1*(triplet_attr[:, 2]-0.01))**2\n",
    "        \n",
    "        f_ij = F.softplus(radius_ij.view(-1, 1) * (x_j[source_nodes] - x_j[neighbor1_indices]))\n",
    "        f_ik = F.softplus(radius_ik.view(-1, 1) * (x_j[source_nodes] - x_j[neighbor2_indices]))\n",
    "        f_jk = F.softplus(radius_jk.view(-1, 1) * (x_j[neighbor1_indices] - x_j[neighbor2_indices]))\n",
    "\n",
    "        neighbor_features = gate * alph_ijk * (1+0.5*triplet_attr[:, 3].view(-1, 1))**2 * (f_ij * f_ik * f_jk)\n",
    "\n",
    "        del alpha_ijk\n",
    "        torch.cuda.empty_cache()\n",
    "        \n",
    "        return neighbor_features\n",
    "\n",
    "    def update(self, aggr_out):\n",
    "        return aggr_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "0606e9e4-a6df-4a2b-9def-7963b2a25c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EquivariantAttention(MessagePassing):\n",
    "    def __init__(self, in_q_channels, in_k_channels, in_channels, bond_channels, pairwise_channels, \n",
    "                 triplet_channels, hidden_dim, heads=2, dropout=0.3): # 0.2\n",
    "        super(EquivariantAttention, self).__init__(aggr='add')\n",
    "        self.heads = heads\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.dropout = dropout\n",
    "\n",
    "        self.triplet_agg = TripletMessagePassing(in_channels, in_q_channels, in_k_channels, hidden_dim)\n",
    "        \n",
    "        # Define learnable kernel parameters\n",
    "        self.gamma = torch.nn.Parameter(torch.Tensor(1)).to(device)\n",
    "        self.eps = torch.nn.Parameter(torch.FloatTensor(1)).to(device)\n",
    "        \n",
    "        # Projections for MLP and attention heads\n",
    "        self.query_proj = torch.nn.Linear(in_channels, heads * hidden_dim).to(device)\n",
    "        self.key_proj = torch.nn.Linear(bond_channels+1, heads * hidden_dim).to(device)\n",
    "        self.value_proj = torch.nn.Linear(in_channels, heads * hidden_dim).to(device)\n",
    "        \n",
    "        # MLP layers for each attention head\n",
    "        self.mlp = torch.nn.Sequential(\n",
    "            torch.nn.Linear(in_channels, hidden_dim), torch.nn.Dropout(dropout), \n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(hidden_dim, hidden_dim), \n",
    "        ).to(device)\n",
    "        \n",
    "        # Gate mechanism for higher-order interaction\n",
    "        self.gru = torch.nn.GRU(input_size=2 * in_channels, hidden_size=1, batch_first=True).to(device)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        stdv_gamma = 0.1 / math.sqrt(self.gamma.size(0))\n",
    "        torch.nn.init.constant_(self.gamma, stdv_gamma)\n",
    "\n",
    "        stdv_eps = 0.1 / math.sqrt(self.eps.size(0))\n",
    "        nn.init.constant_(self.eps, stdv_eps)\n",
    "\n",
    "    def forward(self, x, pairwise_distances, triplet_distances, cos_theta_ijk, edge_index, triplet_indices, pos, edge_attr):\n",
    "        K = torch.cat((edge_attr, pairwise_distances), dim=1)\n",
    "        \n",
    "        # Project into multiple heads\n",
    "        query = self.query_proj(x[edge_index[0]]).view(-1, self.heads, self.hidden_dim)\n",
    "        key = self.key_proj(K).view(-1, self.heads, self.hidden_dim)\n",
    "\n",
    "        # Compute attention weights\n",
    "        alpha = ((query * key)).sum(dim=-1).mean(1).view(-1, 1)\n",
    "        alpha = softmax(alpha, edge_index[0])\n",
    "\n",
    "        # Apply gate to modulate higher-order interactions\n",
    "        concatenated_inputs = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1).unsqueeze(1)  # Add a time dimension\n",
    "        gate, _ = self.gru(concatenated_inputs)\n",
    "        gate = torch.sigmoid(gate[:, -1, :])  # Get the last output and apply sigmoid\n",
    "\n",
    "        triplet_attr = torch.cat((triplet_distances, cos_theta_ijk), dim=1)\n",
    "        output_triplets = self.mlp((1+self.eps) * x + self.triplet_agg(x, triplet_indices, triplet_attr))\n",
    "\n",
    "        f_ij = F.softplus(torch.exp(-1*(pairwise_distances-0.01))**2)\n",
    "        output_pairs = self.mlp((1+self.eps) * x + self.propagate(edge_index, x=x, gate=gate, alpha=alpha, f_ij=f_ij))\n",
    "        \n",
    "        return torch.cat((output_pairs, output_triplets), dim=1)\n",
    "\n",
    "    def message(self, x_j, alpha, gate, f_ij):\n",
    "        neighbor_features = gate * f_ij * alpha * x_j\n",
    "        \n",
    "        return neighbor_features\n",
    "    \n",
    "    def update(self, aggr_out):\n",
    "        # Apply the final MLP to update the node features\n",
    "        return aggr_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4c4388c3-c995-4bfc-a803-07cab19a098a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CombinedModel(torch.nn.Module):\n",
    "    def __init__(self, in_q_dim, in_k_dim, node_dim, bond_dim, pairwise_dim, triplet_dim, hidden_dim):\n",
    "        super(CombinedModel, self).__init__()\n",
    "        self.gnn = EquivariantAttention(in_q_dim, in_k_dim, node_dim, bond_dim, pairwise_dim, triplet_dim, hidden_dim)\n",
    "        self.lin = torch.nn.Linear(2*hidden_dim, 1).to(device)\n",
    "        self.pool = global_add_pool\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, pairwise_distances, triplet_distances, cos_theta_ijk, edge_index, triplet_indices, pos, edge_attr = data.x, data.pairwise_distances, data.triplet_distances, data.cos_theta_ijk, data.edge_index, data.triplet_indices, data.pos, data.edge_attr\n",
    "\n",
    "        output_size = data.batch.max().item() + 1\n",
    "        theta_ijk_tensor = torch.zeros((output_size, 1)).to(device)\n",
    "        _, inverse_indices = torch.unique(data.triplet_batch, return_inverse=True)\n",
    "        theta_ijk_tensor.index_add_(0, inverse_indices, data.cos_theta_ijk)\n",
    "        \n",
    "        unique_values = torch.unique(inverse_indices)\n",
    "        index_ranges = {}\n",
    "        for value in unique_values:\n",
    "            indices = (inverse_indices == value).nonzero(as_tuple=True)[0]  # Get indices for the current value\n",
    "            index_ranges[value.item()] = (indices[0].item(), indices[-1].item())  # Store the first and last index\n",
    "        \n",
    "        triplet_indices = data.triplet_indices.t()\n",
    "        \n",
    "        max_val = 0\n",
    "        for graph_id, (start, end) in index_ranges.items():\n",
    "            triplet_indices[:, start:end + 1] = triplet_indices[:, start:end + 1] + max_val\n",
    "            max_val = triplet_indices[:, start:end + 1].max().item() + 1\n",
    "\n",
    "        h_node = self.gnn(x, pairwise_distances, triplet_distances, cos_theta_ijk, edge_index, triplet_indices, pos, edge_attr)\n",
    "        h_graph = self.pool(h_node, data.batch)\n",
    "        h_graph = self.lin(h_graph.to(device))\n",
    "        \n",
    "        return h_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5c38a391-1875-4dff-bbb6-877774398515",
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(loader, model, loss, optimizer):\n",
    "    \"\"\"Training one epoch\n",
    "\n",
    "    Args:\n",
    "        loader (DataLoader): loader (DataLoader): training data divided into batches\n",
    "        model (nn.Module): GNN model to train on\n",
    "        loss (nn.functional): loss function to use during training\n",
    "        optimizer (torch.optim): optimizer during training\n",
    "\n",
    "    Returns:\n",
    "        float: training loss\n",
    "    \"\"\"\n",
    "    model.train()\n",
    "\n",
    "    current_loss = 0\n",
    "    for data in loader:\n",
    "        data = data.to(torch.device(device))\n",
    "        optimizer.zero_grad()\n",
    "        out = model(data)\n",
    "        l = loss(out, torch.reshape(data.y, (len(data.y), 1)))\n",
    "        current_loss += l / len(loader)\n",
    "        l.backward()\n",
    "        optimizer.step()\n",
    "    return current_loss, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6432d7ed-7a39-455c-841a-154ba8ccaf88",
   "metadata": {},
   "outputs": [],
   "source": [
    "def validation(loader, model, loss):\n",
    "    \"\"\"Validation\n",
    "\n",
    "    Args:\n",
    "        loader (DataLoader): validation set in batches\n",
    "        model (nn.Module): current trained model\n",
    "        loss (nn.functional): loss function\n",
    "\n",
    "    Returns:\n",
    "        float: validation loss\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    val_loss = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        out = model(data)\n",
    "        l = loss(out, torch.reshape(data.y, (len(data.y), 1)))\n",
    "        val_loss += l / len(loader)\n",
    "    return val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "3cbaac27-36e4-40e1-95f4-22b8b6c8bc5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def testing(loader, model):\n",
    "    \"\"\"Testing\n",
    "\n",
    "    Args:\n",
    "        loader (DataLoader): test dataset\n",
    "        model (nn.Module): trained model\n",
    "\n",
    "    Returns:\n",
    "        float: test loss\n",
    "    \"\"\"\n",
    "    loss = torch.nn.MSELoss()\n",
    "    test_loss = 0\n",
    "    test_target = np.empty((0))\n",
    "    test_y_target = np.empty((0))\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        out = model(data)\n",
    "        l = loss(out, torch.reshape(data.y, (len(data.y), 1)))\n",
    "        test_loss += l / len(loader)\n",
    "\n",
    "        # save prediction vs ground truth values for plotting\n",
    "        test_target = np.concatenate((test_target, out.cpu().detach().numpy()[:, 0]))\n",
    "        test_y_target = np.concatenate((test_y_target, data.cpu().y.detach().numpy()))\n",
    "\n",
    "    return test_loss, test_target, test_y_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4d198fb2-441e-4124-9a0c-5a7fbc249397",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epochs(epochs, model, train_loader, val_loader, path):\n",
    "    \"\"\"Training over all epochs\n",
    "\n",
    "    Args:\n",
    "        epochs (int): number of epochs to train for\n",
    "        model (nn.Module): the current model\n",
    "        train_loader (DataLoader): training data in batches\n",
    "        val_loader (DataLoader): validation data in batches\n",
    "        path (string): path to save the best model\n",
    "\n",
    "    Returns:\n",
    "        array: returning train and validation losses over all epochs, prediction and ground truth values for training data in the last epoch\n",
    "    \"\"\"\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n",
    "    loss = torch.nn.MSELoss()\n",
    "\n",
    "    train_target = np.empty((0))\n",
    "    train_y_target = np.empty((0))\n",
    "    train_loss = np.empty(epochs)\n",
    "    val_loss = np.empty(epochs)\n",
    "    best_loss = math.inf\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        epoch_loss, model = training(train_loader, model, loss, optimizer)\n",
    "        v_loss = validation(val_loader, model, loss)\n",
    "        if v_loss < best_loss:\n",
    "            torch.save(model.state_dict(), path)\n",
    "        for data in train_loader:\n",
    "            data = data.to(device)\n",
    "            out = model(data)\n",
    "            if epoch == epochs - 1:\n",
    "                # record truly vs predicted values for training data from last epoch\n",
    "                train_target = np.concatenate((train_target, out.cpu().detach().numpy()[:, 0]))\n",
    "                train_y_target = np.concatenate((train_y_target, data.cpu().y.detach().numpy()))\n",
    "\n",
    "        train_loss[epoch] = epoch_loss.cpu().detach().numpy()\n",
    "        val_loss[epoch] = v_loss.cpu().detach().numpy()\n",
    "\n",
    "        # print current train and val loss\n",
    "        if epoch % 2 == 0:\n",
    "            print(\n",
    "                \"Epoch: \"\n",
    "                + str(epoch)\n",
    "                + \", Train loss: \"\n",
    "                + str(epoch_loss.item())\n",
    "                + \", Val loss: \"\n",
    "                + str(v_loss.item())\n",
    "            )\n",
    "    return train_loss, val_loss, train_target, train_y_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "6bca8df6-7d83-485b-a21a-a5fb5d0271ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Train loss: 1.3860986232757568, Val loss: 1.0192714929580688\n",
      "Epoch: 2, Train loss: 1.0025746822357178, Val loss: 0.8096003532409668\n",
      "Epoch: 4, Train loss: 0.8174251914024353, Val loss: 0.6619618535041809\n",
      "Epoch: 6, Train loss: 0.7004879713058472, Val loss: 0.6868166923522949\n",
      "Epoch: 8, Train loss: 0.5771207213401794, Val loss: 0.525598406791687\n",
      "Epoch: 10, Train loss: 0.5138940811157227, Val loss: 0.44771909713745117\n",
      "Epoch: 12, Train loss: 0.4654841125011444, Val loss: 0.4005628824234009\n",
      "Epoch: 14, Train loss: 0.4346614480018616, Val loss: 0.4658851623535156\n",
      "Epoch: 16, Train loss: 0.41031354665756226, Val loss: 0.34570199251174927\n",
      "Epoch: 18, Train loss: 0.3736397325992584, Val loss: 0.3224579691886902\n",
      "Epoch: 20, Train loss: 0.3562317192554474, Val loss: 0.31678473949432373\n",
      "Epoch: 22, Train loss: 0.32839977741241455, Val loss: 0.2724252939224243\n",
      "Epoch: 24, Train loss: 0.3063674569129944, Val loss: 0.2710008919239044\n",
      "Epoch: 26, Train loss: 0.2876291573047638, Val loss: 0.2688060700893402\n",
      "Epoch: 28, Train loss: 0.27150776982307434, Val loss: 0.2380138635635376\n",
      "Epoch: 30, Train loss: 0.25853976607322693, Val loss: 0.2324010580778122\n",
      "Epoch: 32, Train loss: 0.2375846803188324, Val loss: 0.20090827345848083\n",
      "Epoch: 34, Train loss: 0.22739849984645844, Val loss: 0.1905764639377594\n",
      "Epoch: 36, Train loss: 0.2126311957836151, Val loss: 0.18397195637226105\n",
      "Epoch: 38, Train loss: 0.1970495879650116, Val loss: 0.17726007103919983\n",
      "Epoch: 40, Train loss: 0.18710505962371826, Val loss: 0.15622740983963013\n",
      "Epoch: 42, Train loss: 0.17905277013778687, Val loss: 0.1566985547542572\n",
      "Epoch: 44, Train loss: 0.16618144512176514, Val loss: 0.14655499160289764\n",
      "Epoch: 46, Train loss: 0.16042186319828033, Val loss: 0.15121804177761078\n",
      "Epoch: 48, Train loss: 0.15651895105838776, Val loss: 0.1301877647638321\n",
      "Epoch: 50, Train loss: 0.1419859677553177, Val loss: 0.12462920695543289\n",
      "Epoch: 52, Train loss: 0.13670381903648376, Val loss: 0.12116637080907822\n",
      "Epoch: 54, Train loss: 0.13185828924179077, Val loss: 0.10908201336860657\n",
      "Epoch: 56, Train loss: 0.12587347626686096, Val loss: 0.1058192327618599\n",
      "Epoch: 58, Train loss: 0.12061098217964172, Val loss: 0.09645699709653854\n",
      "Epoch: 60, Train loss: 0.11668434739112854, Val loss: 0.09321969747543335\n",
      "Epoch: 62, Train loss: 0.113013856112957, Val loss: 0.08870154619216919\n",
      "Epoch: 64, Train loss: 0.1084231287240982, Val loss: 0.08353864401578903\n",
      "Epoch: 66, Train loss: 0.10297001153230667, Val loss: 0.09059692919254303\n",
      "Epoch: 68, Train loss: 0.10047104209661484, Val loss: 0.0789792463183403\n",
      "Epoch: 70, Train loss: 0.09968126565217972, Val loss: 0.07807767391204834\n",
      "Epoch: 72, Train loss: 0.0960245355963707, Val loss: 0.07466839998960495\n",
      "Epoch: 74, Train loss: 0.09290683269500732, Val loss: 0.06822036951780319\n",
      "Epoch: 76, Train loss: 0.0881880596280098, Val loss: 0.06898587942123413\n",
      "Epoch: 78, Train loss: 0.08733697980642319, Val loss: 0.07805602252483368\n",
      "Epoch: 80, Train loss: 0.08309873938560486, Val loss: 0.0739884227514267\n",
      "Epoch: 82, Train loss: 0.08134208619594574, Val loss: 0.06362833827733994\n",
      "Epoch: 84, Train loss: 0.07890268415212631, Val loss: 0.05468781664967537\n",
      "Epoch: 86, Train loss: 0.07857520878314972, Val loss: 0.06552409380674362\n",
      "Epoch: 88, Train loss: 0.07934611290693283, Val loss: 0.05537771061062813\n",
      "Epoch: 90, Train loss: 0.07477989047765732, Val loss: 0.05344398319721222\n",
      "Epoch: 92, Train loss: 0.0748068168759346, Val loss: 0.05169660598039627\n",
      "Epoch: 94, Train loss: 0.07087986916303635, Val loss: 0.04973829537630081\n",
      "Epoch: 96, Train loss: 0.07054053246974945, Val loss: 0.050019025802612305\n",
      "Epoch: 98, Train loss: 0.07023000717163086, Val loss: 0.05871185287833214\n"
     ]
    }
   ],
   "source": [
    "epochs = 100\n",
    "model = CombinedModel(in_q_dim=11, in_k_dim=4, node_dim=11, bond_dim=4, pairwise_dim=1, triplet_dim=4, hidden_dim=64)\n",
    "train_loss, val_loss, train_target, train_y_target = train_epochs(epochs, model, train_loader, val_loader, \"QM9_GNN_model.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "02d76139-c52b-4cdc-98df-d63d0e8b19b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss for GNN: 0.045853301882743835\n"
     ]
    }
   ],
   "source": [
    "# calculate test loss\n",
    "test_loss, test_target, test_y = testing(test_loader, model)\n",
    "print(\"Test Loss for GNN: \" + str(test_loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d34779-f1be-43f6-9254-bd192fd81fad",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
