{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```markdown\n",
    "# Notes\n",
    "\n",
    "1. This notebook is a simplified implementation of **\"Wef-GNN: A Generalizable Graph Neural Network for Crystalline Material Property Prediction\"**.\n",
    "2. The data can be downloaded from Materials Project using the token provided (please limit usage to a few calls, only one is required to query the whole dataset).\n",
    "3. All the required modules are available here to run a full version of Wef-GNN if the users have enough compute power.\n",
    "4. A sample trained model (trained with one epoch) is provided along with this notebook.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0924 14:17:20.471667869 3247520 ev_epoll1_linux.cc:121]     grpc epoll fd: 75\n",
      "D0924 14:17:20.471702749 3247520 ev_posix.cc:141]            Using polling engine: epoll1\n",
      "D0924 14:17:20.471731779 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"grpclb\"\n",
      "D0924 14:17:20.471738379 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"rls_experimental\"\n",
      "D0924 14:17:20.471744329 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"priority_experimental\"\n",
      "D0924 14:17:20.471747309 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"weighted_target_experimental\"\n",
      "D0924 14:17:20.471750049 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"pick_first\"\n",
      "D0924 14:17:20.471752549 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"round_robin\"\n",
      "D0924 14:17:20.471757969 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"ring_hash_experimental\"\n",
      "D0924 14:17:20.471764989 3247520 dns_resolver_ares.cc:831]   Using ares dns resolver\n",
      "D0924 14:17:20.471780459 3247520 certificate_provider_registry.cc:39] registering certificate provider factory for \"file_watcher\"\n",
      "D0924 14:17:20.471786379 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"cds_experimental\"\n",
      "D0924 14:17:20.471789069 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"xds_cluster_impl_experimental\"\n",
      "D0924 14:17:20.471794119 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"xds_cluster_resolver_experimental\"\n",
      "D0924 14:17:20.471796999 3247520 lb_policy_registry.cc:43]   registering LB policy factory for \"xds_cluster_manager_experimental\"\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "logging.getLogger('tensorflow').setLevel(logging.ERROR)  # suppress warnings\n",
    "\n",
    "assert tf.distribute.get_replica_context() is not None\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "# connect to TPU if using a TPU VM\n",
    "resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')\n",
    "tf.config.experimental_connect_to_cluster(resolver)\n",
    "tf.tpu.experimental.initialize_tpu_system(resolver)\n",
    "strategy = tf.distribute.TPUStrategy(resolver)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gideon/.local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/gideon/.local/lib/python3.8/site-packages/mp_api/client/mprester.py:193: UserWarning: mpcontribs-client not installed. Install the package to query MPContribs data, or construct pourbaix diagrams: 'pip install mpcontribs-client'\n",
      "  warnings.warn(\n",
      "Retrieving SummaryDoc documents: 100%|██████████| 154879/154879 [05:30<00:00, 467.99it/s] \n"
     ]
    }
   ],
   "source": [
    "# download data from MP\n",
    "from mp_api.client import MPRester\n",
    "with MPRester('SLUYHMFKj6HXW0rYmm9SDSihv02TkWOa') as mpr:\n",
    "    docs = mpr.materials.summary.search(formation_energy=(-1000, 1000), # -1000, 1000 to isolate materials and not repeats\n",
    "                                        fields=['material_id', 'structure', 'formation_energy_per_atom', 'material_id', \n",
    "                                                'energy_above_hull', 'band_gap', 'is_magnetic', 'is_metal', 'is_gap_direct', 'density'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def structure_to_nodes(structure):\n",
    "    \n",
    "    atomic_num = []\n",
    "\n",
    "    for site in structure:\n",
    "        elem = site.specie\n",
    "        atomic_num.append(elem.Z)\n",
    "\n",
    "    return atomic_num\n",
    "\n",
    "data = {}\n",
    "for mat in docs:\n",
    "    pbe_a_i = structure_to_nodes(mat.structure)\n",
    "    data[mat.material_id.string] = {'atoms': pbe_a_i, \n",
    "                                    'fc':[i._frac_coords for i in mat.structure.sites],\n",
    "                                    'cc':[i.coords for i in mat.structure.sites],\n",
    "                                    'matrix': mat.structure.lattice._matrix,\n",
    "                                    'lattice': [mat.structure.lattice.a, mat.structure.lattice.b, mat.structure.lattice.c, mat.structure.lattice.alpha, mat.structure.lattice.beta, mat.structure.lattice.gamma],\n",
    "                                    'fe': mat.formation_energy_per_atom,\n",
    "                                    'eah': mat.energy_above_hull,\n",
    "                                    'band_gap': mat.band_gap,\n",
    "                                    'den': mat.density,\n",
    "                                    'dir': mat.is_gap_direct,\n",
    "                                    'met': mat.is_metal,\n",
    "                                    'mag': mat.is_magnetic}\n",
    "\n",
    "pbe_a, pbe_fc, pbe_cc, pbe_mat, pbe_lat, pbe_y = [], [], [], [], [], []\n",
    "\n",
    "for i in data:\n",
    "\n",
    "    pbe_a.append(data[i]['atoms'])\n",
    "    pbe_fc.append(data[i]['fc'])\n",
    "    pbe_cc.append(data[i]['cc'])\n",
    "    pbe_lat.append(data[i]['lattice'])\n",
    "    pbe_mat.append(data[i]['matrix'])\n",
    "    pbe_y.append([data[i]['fe'], data[i]['eah'], data[i]['band_gap']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "444"
      ]
     },
     "execution_count": 328,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MAX_LEN = 128\n",
    "max(len(i) for i in pbe_a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 500,
   "metadata": {},
   "outputs": [],
   "source": [
    "a, fc, cc, lat, y = [], [], [], [], []\n",
    "for i_a, i_fc, i_cc, i_lat, i_y in zip(pbe_a, pbe_fc, pbe_cc, pbe_lat, pbe_y):\n",
    "\n",
    "    if len(i_a) < MAX_LEN:\n",
    "        a.append(np.concatenate([i_a, np.zeros((MAX_LEN - len(i_a),), dtype=np.int32)], dtype=np.int32))\n",
    "        fc.append(np.concatenate([i_fc, np.zeros((MAX_LEN - len(i_fc), 3))]))\n",
    "        cc.append(np.concatenate([i_cc, np.zeros((MAX_LEN - len(i_cc), 3))]))\n",
    "        lat.append(np.array(i_lat))\n",
    "        y.append(np.array(i_y))\n",
    "\n",
    "a = np.array(a)\n",
    "fc = np.array(fc)\n",
    "cc = np.array(cc)\n",
    "lat = np.array(lat)\n",
    "y = np.array(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# WARNING: only run this cell if you want to switch to conventional cell representation\n",
    "# if using conventional\n",
    "from pymatgen.core import Structure\n",
    "from pymatgen.symmetry.analyzer import SpacegroupAnalyzer\n",
    "\n",
    "conventional = []\n",
    "\n",
    "for i in docs:\n",
    "    \n",
    "    if len(i.structure.labels) < MAX_LEN:\n",
    "        conventional.append(SpacegroupAnalyzer(i.structure).get_conventional_standard_structure())\n",
    "\n",
    "lat = [np.concatenate((i._lattice.abc, i._lattice.angles), axis=-1) for i in conventional]\n",
    "\n",
    "fc = [[ei._frac_coords for ei in i.sites] for i in conventional]\n",
    "MAX_LEN = max([len(i) for i in fc])\n",
    "fc = [np.concatenate([i, np.zeros((MAX_LEN-len(i), 3))], axis=0) for i in fc]\n",
    "\n",
    "a = [[list(ei._species._data.keys())[0].Z for ei in i.sites] for i in conventional]\n",
    "a = [np.concatenate([i, np.zeros((MAX_LEN-len(i)))], axis=0) for i in a]\n",
    "\n",
    "lat = np.array(lat)\n",
    "fc = np.array(fc)\n",
    "a = np.array(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lattice_matrix(a, b, c, alpha, beta, gamma):\n",
    "\n",
    "    alpha_r = alpha * 3.141592653589793238462643383279502884197 / 180.0\n",
    "    beta_r  = beta * 3.141592653589793238462643383279502884197 / 180.0\n",
    "    gamma_r = gamma * 3.141592653589793238462643383279502884197 / 180.0\n",
    "    \n",
    "    a1 = tf.cast([a, 0.0, 0.0], tf.float32)\n",
    "\n",
    "    a2 = tf.cast([b * tf.math.cos(gamma_r),\n",
    "                   b * tf.math.sin(gamma_r),\n",
    "                   0.0], tf.float32)\n",
    "\n",
    "    c_x = c * tf.math.cos(beta_r)\n",
    "    c_y = c * (tf.math.cos(alpha_r) - tf.math.cos(gamma_r)*tf.math.cos(beta_r)) / tf.math.sin(gamma_r)\n",
    "    c_z = tf.math.sqrt(c**2 - c_x**2 - c_y**2)\n",
    "\n",
    "    a3 = tf.cast([c_x, c_y, c_z], tf.float32)\n",
    "\n",
    "    L = tf.cast([a1, a2, a3], tf.float32)\n",
    "    return L\n",
    "\n",
    "def get_all_distances_general(lattice, fcoords1, fcoords2):\n",
    "    \n",
    "    lattice = tf.cast(lattice, tf.float32)\n",
    "    f1 = tf.cast(fcoords1, tf.float32)\n",
    "    f2 = tf.cast(fcoords2, tf.float32)\n",
    "    \n",
    "    frac_diff = f1[:, None, :] - f2[None, :, :]\n",
    "\n",
    "    shifts = tf.cast([[i, j, k] \n",
    "                       for i in [-1, 0, 1] \n",
    "                       for j in [-1, 0, 1] \n",
    "                       for k in [-1, 0, 1]], tf.float32)\n",
    "    all_frac_diffs = frac_diff[:, :, None, :] + shifts[None, None, :, :]\n",
    "    all_cart_diffs = tf.einsum('...j, jk -> ...k', all_frac_diffs, lattice)\n",
    "    all_distances = tf.linalg.norm(all_cart_diffs, axis=-1)\n",
    "\n",
    "    dist_matrix = tf.reduce_min(all_distances, axis=2)\n",
    "    return dist_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# detemine the presence of an edge (using the repeated images as discussed in the paper)\n",
    "\n",
    "EDGES = 1064\n",
    "tensor = (EDGES - 1) * np.ones([a.shape[0], EDGES, 2], dtype=np.int32)\n",
    "dist_tensor = np.zeros([a.shape[0], EDGES, 1], dtype=np.float32)\n",
    "\n",
    "mct = 0\n",
    "for a_i, abc_i, lat_i in zip(a, fc, lat):\n",
    "    \n",
    "    lattice = lattice_matrix(*lat_i)\n",
    "    mask = tf.cast(tf.not_equal(a_i, 0), tf.int32)\n",
    "    num_atoms = tf.reduce_sum(mask, axis=-1)\n",
    "\n",
    "    ct = 0\n",
    "    for io in range(num_atoms):\n",
    "        \n",
    "        dist_tracker = []\n",
    "        for jo in range(io + 1, num_atoms):\n",
    "            dist = get_all_distances_general(lattice, abc_i[io][None, ...], abc_i[jo][None, ...])\n",
    "            \n",
    "            if dist <= 4.:\n",
    "                indices = [[mct, ct, 0], [mct, ct, 1]]\n",
    "                updates = [io, jo]\n",
    "                tensor = tf.tensor_scatter_nd_update(tensor, indices, updates)\n",
    "\n",
    "                indices = [[mct, ct, 0]]\n",
    "                updates = dist[0]\n",
    "                dist_tensor = tf.tensor_scatter_nd_update(dist_tensor, indices, updates)\n",
    "\n",
    "                ct += 1\n",
    "    mct += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128\n",
    "D_MODEL = 128\n",
    "DROPOUT_RATE = 0.1\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensor_slices(((a, fc, dist_tensor, tensor), y))\n",
    "dataset = dataset.shuffle(dataset.cardinality())\n",
    "\n",
    "training_set = dataset.take(dataset.cardinality()*8//10)\n",
    "ds = dataset.skip(dataset.cardinality()*8//10)\n",
    "valid_set = ds.take(ds.cardinality()*95//100)\n",
    "test_set = ds.skip(ds.cardinality()*95//100)\n",
    "\n",
    "training_set = training_set.batch(BATCH_SIZE, drop_remainder=True).cache().prefetch(tf.data.experimental.AUTOTUNE)\n",
    "valid_set = valid_set.batch(BATCH_SIZE, drop_remainder=True).cache().prefetch(tf.data.experimental.AUTOTUNE)\n",
    "test_set = test_set.batch(1, drop_remainder=True).cache().prefetch(tf.data.experimental.AUTOTUNE)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(tf.keras.layers.Layer):\n",
    "    def __init__(self, dropout_rate=0.1, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.dropout_rate = dropout_rate\n",
    "        \n",
    "    def build(self, input_shape):\n",
    "\n",
    "        self.d_model = input_shape[0][-1]\n",
    "        self.steps = input_shape[0][1]\n",
    "\n",
    "        self.f_kernel = self.add_weight(\n",
    "            shape=(self.d_model, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"f_kernel\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.i_kernel = self.add_weight(\n",
    "            shape=(self.d_model, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"i_kernel\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.wo_kernel = self.add_weight(\n",
    "            shape=(self.d_model, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"wo_kernel\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.wc_kernel = self.add_weight(\n",
    "            shape=(self.d_model, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"wc_kernel\",\n",
    "            trainable=True\n",
    "        )\n",
    "  \n",
    "        self.f_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"f_bias\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.i_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"i_bias\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.wo_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"wo_bias\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.wc_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"wc_bias\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "    def call(self, states):\n",
    "        h, c = states\n",
    "        \n",
    "        for _ in range(self.steps):\n",
    "            f = tf.matmul(h, self.f_kernel) + self.f_bias\n",
    "\n",
    "            f = tf.nn.sigmoid(f)\n",
    "            \n",
    "\n",
    "            i = tf.matmul(h, self.i_kernel) + self.i_bias\n",
    "            i = tf.nn.sigmoid(i)\n",
    "\n",
    "            ct = tf.nn.tanh(tf.matmul(h, self.wc_kernel) + self.wc_bias)\n",
    "            ot = tf.sigmoid(tf.matmul(h, self.wo_kernel) + self.wo_bias)\n",
    "\n",
    "            c = f * c + i * ct\n",
    "            h = ot * tf.tanh(c)\n",
    "\n",
    "        return [h, c]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def edge_attention(bonds_in, pairs_in, s, e, s_index, e_index, attention_kernel_1, attention_kernel_2, s_kernel, e_kernel, bonds_kernel):\n",
    "\n",
    "    mask = tf.cast(tf.math.equal(tf.reduce_sum(bonds_in, axis=-1), 0), tf.float32)[..., None]\n",
    "\n",
    "    s = s_kernel(s)\n",
    "    e = e_kernel(e)\n",
    "    bonds_in = bonds_kernel(bonds_in)\n",
    "\n",
    "    eij_se = tf.concat([s, e, bonds_in], axis=-1) # batch, num_atoms, 3 * d_model\n",
    "    eij_se = attention_kernel_1(eij_se) # batch, num_atoms, d_model\n",
    "\n",
    "    eij_se = tf.nn.leaky_relu(eij_se)\n",
    "    eij_se = tf.exp(tf.clip_by_value(eij_se, -2, 2) + mask * -1e9)\n",
    "\n",
    "    eij_es = tf.concat([e, s, bonds_in], axis=-1)\n",
    "    eij_es = attention_kernel_2(eij_es)\n",
    "    eij_es = tf.nn.leaky_relu(eij_es)\n",
    "    eij_es = tf.exp(tf.clip_by_value(eij_es, -2, 2) + mask * -1e9)\n",
    "\n",
    "    placeholder_tensor = tf.zeros_like(eij_se)\n",
    "    eij_all = tf.tensor_scatter_nd_add(placeholder_tensor, s_index, eij_se)\n",
    "    eij_all = tf.tensor_scatter_nd_add(eij_all, e_index, eij_es)\n",
    "\n",
    "    # repeats values to ensure each eij is divided by the proper sum   \n",
    "    eij_all_se = tf.gather(eij_all, pairs_in[:, :, 0], batch_dims=1)\n",
    "\n",
    "    # add 1s to the division of padding values to avoid division by 0\n",
    "    eij_all_se += mask\n",
    "\n",
    "    aij_se = eij_se / eij_all_se\n",
    "\n",
    "    eij_all_es = tf.gather(eij_all, pairs_in[:, :, 1], batch_dims=1)\n",
    "    eij_all_es += mask\n",
    "    aij_es = eij_es / eij_all_es\n",
    "\n",
    "    return aij_se, aij_es\n",
    "\n",
    "\n",
    "class GNNAttention(tf.keras.layers.Layer):\n",
    "    \n",
    "    def __init__(self, d_model=32, attention=True, dropout_rate=0.1, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "        self.d_model = d_model\n",
    "        self.dropout_rate = dropout_rate\n",
    "        self.attention = attention\n",
    "        self.num_proj = 4\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.atom_dim = input_shape[0][-1]\n",
    "        self.bond_dim = input_shape[1][-1]\n",
    "        self.num_atoms = input_shape[0][-2]\n",
    "\n",
    "        self.att_1 = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "        self.att_2 = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "        self.att_3 = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "        self.att_4 = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "        self.att = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "    \n",
    "        self.input_atom_kernel = self.add_weight(\n",
    "            shape=(self.atom_dim, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_atom_kernel\", \n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.input_bond_kernel = self.add_weight(\n",
    "            shape=(2 * self.d_model, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_bond_kernel\",\n",
    "            trainable=True,\n",
    "        )\n",
    "\n",
    "        self.trans_bond_kernel = self.add_weight(\n",
    "            shape=(self.d_model, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"trans_bond_kernel\",\n",
    "            trainable=True,\n",
    "        )\n",
    "\n",
    "        self.output_atom_kernel = self.add_weight(\n",
    "            shape=(self.d_model, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"output_atom_kernel\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.input_atom_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_atom_bias\", \n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.input_bond_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_bond_bias\",\n",
    "            trainable=True,\n",
    "        )\n",
    "\n",
    "        self.trans_bond_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"trans_bond_bias\",\n",
    "            trainable=True,\n",
    "        )\n",
    "\n",
    "        self.output_atom_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"output_atom_bias\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.attention_bias = self.add_weight(\n",
    "            shape=(1, self.d_model), \n",
    "            initializer=\"glorot_uniform\", \n",
    "            name=\"attention_bias\", \n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.dropout_1 = tf.keras.layers.Dropout(self.dropout_rate)\n",
    "\n",
    "    def get_mask(self, input_var):\n",
    "        mask = tf.reduce_sum(input_var, axis=-1)\n",
    "        if len(input_var.shape) == 4:\n",
    "            mask = tf.reduce_sum(mask, axis=-1)[..., None]\n",
    "        mask = tf.cast(tf.not_equal(mask, 0), input_var.dtype)[..., None]\n",
    "        return mask\n",
    "\n",
    "    def call(self, inputs, training=True, self_att=False):\n",
    "        atom_features, bond_features, pair_indices = inputs\n",
    "\n",
    "        amask = self.get_mask(atom_features)\n",
    "\n",
    "        # indices for aggregation\n",
    "        added = tf.cumsum(tf.ones_like(pair_indices[..., 0])) - 1\n",
    "        added = added[..., None]\n",
    "        s_atom_indices = tf.concat([added, pair_indices[..., 0, None]], axis=-1)\n",
    "        e_atom_indices = tf.concat([added, pair_indices[..., 1, None]], axis=-1)\n",
    "\n",
    "\n",
    "        gathered = tf.gather(atom_features, pair_indices, batch_dims=1)  # batch, num_atoms, 2, d_model\n",
    "        s = gathered[:, :, 0, :]\n",
    "        e = gathered[:, :, 1, :]\n",
    "\n",
    "        # attention weights\n",
    "        aij_se, aij_es = edge_attention(bond_features, pair_indices, s, e, s_atom_indices, e_atom_indices, \n",
    "                                        self.att, self.att_1, self.att_2, self.att_3, self.att_4)\n",
    "\n",
    "        neighbors = tf.zeros_like(atom_features)\n",
    "        neighbors = tf.tensor_scatter_nd_add(neighbors, s_atom_indices, aij_se * e)\n",
    "        neighbors = tf.tensor_scatter_nd_add(neighbors, e_atom_indices, aij_es * s)\n",
    "\n",
    "        neighbors = tf.matmul(neighbors, self.output_atom_kernel) + self.output_atom_bias * amask\n",
    "        \n",
    "        # weigh atoms and bonds\n",
    "        atom_features = tf.matmul(atom_features, self.input_atom_kernel) + self.input_atom_bias * amask\n",
    "        atom_features = tf.matmul(tf.concat([atom_features, neighbors], axis=-1), self.input_bond_kernel) + self.input_bond_bias * amask\n",
    "        \n",
    "        return atom_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MHA(tf.keras.layers.Layer):\n",
    "    def __init__(self, d_model=8, num_heads=4, concat=True):\n",
    "        super().__init__()\n",
    "        self.num_heads = num_heads\n",
    "        self.concat = concat\n",
    "        self.d_model = d_model\n",
    "        \n",
    "        assert d_model % num_heads == 0\n",
    "        self.depth = d_model // num_heads\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.num_atoms = input_shape[-2]\n",
    "        \n",
    "        self.wq = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "        self.wk = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "        self.wv = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "\n",
    "        self.output_layer = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "\n",
    "    def scaled_dot_product_attention(self, q, k, v, mask):\n",
    "\n",
    "        product = tf.matmul(q, k, transpose_b=True)\n",
    "        keys_dim = tf.cast(tf.shape(k)[-1], tf.float32)\n",
    "\n",
    "        eij = product / tf.math.sqrt(keys_dim)\n",
    "\n",
    "        if mask is not None:\n",
    "            eij += (mask * -1e9)\n",
    "\n",
    "        aij = tf.nn.softmax(eij, axis=-1)\n",
    "        z = tf.matmul(aij, v)\n",
    "\n",
    "        return z\n",
    "    \n",
    "    def get_mask(self, input_var):\n",
    "        mask = tf.reduce_sum(input_var, axis=-1)\n",
    "        if len(input_var.shape) == 4:\n",
    "            mask = tf.reduce_sum(mask, axis=-1)[..., None]\n",
    "        mask = tf.cast(tf.not_equal(mask, 0), input_var.dtype)[..., None]\n",
    "        return mask\n",
    "\n",
    "    def split_heads(self, x):\n",
    "        x = tf.reshape(x, shape=[-1, self.num_atoms, self.num_heads, self.depth])\n",
    "        x = tf.transpose(x, perm=[0, 2, 1, 3])\n",
    "        return x\n",
    "\n",
    "    def call(self, q, k, v, training=True):\n",
    "\n",
    "        mask_1 = self.get_mask(q)\n",
    "        mask_2 = tf.cast(tf.equal(mask_1, 0.), mask_1.dtype)[:, None, None, :, 0]\n",
    "\n",
    "        q = self.wq(q)\n",
    "        k = self.wq(k)\n",
    "        v = self.wq(v)\n",
    "\n",
    "        q = self.split_heads(q)\n",
    "        k = self.split_heads(k)\n",
    "        v = self.split_heads(v)\n",
    "\n",
    "        scaled_attention = self.scaled_dot_product_attention(q, k, v, mask_2)\n",
    "        scaled_attention = tf.einsum('bhad->bahd', scaled_attention)\n",
    "\n",
    "        concat_attention = tf.reshape(scaled_attention, shape=[-1, self.num_atoms, self.d_model])\n",
    "        concat_attention = self.output_layer(concat_attention)\n",
    "        \n",
    "        return concat_attention\n",
    "\n",
    "\n",
    "class EdgeNetwork(tf.keras.layers.Layer):\n",
    "    def __init__(self, steps=2, d_model=4, num_heads=2, num_proj=2, concat=True, attention=True, dropout_rate=0.1):\n",
    "        super(EdgeNetwork, self).__init__()\n",
    "\n",
    "        assert d_model % num_heads == 0\n",
    "        self.d_model = d_model\n",
    "        self.num_heads = num_heads\n",
    "        self.concat = concat\n",
    "        self.attention = attention\n",
    "        self.dropout_rate = dropout_rate\n",
    "        self.steps = steps\n",
    "        self.num_proj = num_proj\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        self.atom_dim = input_shape[0][-1]\n",
    "        self.bond_dim = input_shape[1][-1]\n",
    "        self.num_atoms = input_shape[0][-2]\n",
    "\n",
    "        self.MHA = [MHA(d_model=self.d_model, num_heads=self.num_heads, concat=self.concat) \n",
    "                    for _ in range(self.steps)]\n",
    "        self.MHA_b = [MHA(d_model=self.d_model, num_heads=self.num_heads, concat=self.concat) \n",
    "                    for _ in range(self.steps)]\n",
    "        \n",
    "        self.norm_1a = tf.keras.layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.norm_1b = tf.keras.layers.LayerNormalization(epsilon=1e-6)\n",
    "\n",
    "        self.layers = [[GNNAttention(d_model=self.d_model, attention=self.attention, dropout_rate=self.dropout_rate) \n",
    "                        for _ in range(self.num_proj)] for _ in range(self.steps)]\n",
    "        \n",
    "        self.dense_1a = [tf.keras.layers.Dense(units=self.d_model, use_bias=False) for _ in range(self.steps)]\n",
    "        self.dense_11a = [tf.keras.layers.Dense(units=self.d_model, use_bias=False) for _ in range(self.steps)]\n",
    "        self.dropout_1 = [tf.keras.layers.Dropout(self.dropout_rate) for _ in range(self.steps)]\n",
    "\n",
    "        self.prelu = [tf.keras.layers.PReLU() for _ in range(self.steps)]\n",
    "        self.dense_bonds = [tf.keras.layers.Dense(units=self.d_model, use_bias=False) for _ in range(self.steps)]\n",
    "\n",
    "        self.input_atom_kernel = self.add_weight(\n",
    "            shape=(self.atom_dim, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_atom_kernel\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.input_bond_kernel = self.add_weight(\n",
    "            shape=(self.bond_dim, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_bond_kernel\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.input_atom_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_atom_bias\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.input_bond_bias = self.add_weight(\n",
    "            shape=(1, self.d_model),\n",
    "            initializer=\"glorot_uniform\",\n",
    "            name=\"input_bond_bias\",\n",
    "            trainable=True\n",
    "        )\n",
    "\n",
    "        self.built = True\n",
    "\n",
    "    def get_mask(self, input_var):\n",
    "        mask = tf.reduce_sum(input_var, axis=-1)\n",
    "        if len(input_var.shape) == 4:\n",
    "            mask = tf.reduce_sum(mask, axis=-1)[..., None]\n",
    "        mask = tf.cast(tf.not_equal(mask, 0), input_var.dtype)[..., None]\n",
    "        return mask\n",
    "\n",
    "    def scaled_dot_product_attention(self, q, k, v, mask):\n",
    "\n",
    "        product = tf.matmul(q, k, transpose_b=True)\n",
    "        keys_dim = tf.cast(tf.shape(k)[-1], tf.float32)\n",
    "\n",
    "        eij = product / tf.math.sqrt(keys_dim)\n",
    "\n",
    "        if mask is not None:\n",
    "            eij += (mask * -1e9)\n",
    "\n",
    "        aij = tf.nn.softmax(eij, axis=-1)\n",
    "        z = tf.matmul(aij, v)\n",
    "\n",
    "        return z\n",
    "    \n",
    "    def masked_normalization(self, x):\n",
    "\n",
    "        # calculate euclidean distance using 1s for padding values\n",
    "        l2_norm = tf.reduce_sum(x**2, axis=-1)\n",
    "        l2_norm = tf.where(tf.equal(l2_norm, tf.cast(0, l2_norm.dtype)), tf.cast(1, l2_norm.dtype), l2_norm)\n",
    "        l2_norm = tf.math.sqrt(l2_norm)\n",
    "        l2_norm = l2_norm[..., None]\n",
    "\n",
    "        # normalize\n",
    "        x /= l2_norm\n",
    "\n",
    "        return x\n",
    "\n",
    "    def call(self, inputs, training=True):\n",
    "\n",
    "        atom_features, bond_features, pair_indices = inputs\n",
    "        amask = self.get_mask(atom_features)\n",
    "        bmask = self.get_mask(bond_features)\n",
    "\n",
    "        # weigh input nodes and bonds\n",
    "        atom_features = tf.matmul(atom_features, self.input_atom_kernel)# + self.input_atom_bias * amask\n",
    "        bond_features = tf.matmul(bond_features, self.input_bond_kernel) #+ self.input_bond_bias * bmask\n",
    "\n",
    "        atom_features = self.norm_1a(atom_features)\n",
    "        bond_features = self.norm_1b(bond_features)\n",
    "\n",
    "        for step in range(self.steps):\n",
    "\n",
    "            # message passing\n",
    "            atom_features_new = [self.layers[0][proj]([atom_features, bond_features, pair_indices], training=training) \n",
    "                                 for proj in range(self.num_proj)]\n",
    "            \n",
    "            atom_features_new = tf.concat(atom_features_new, axis=-1)\n",
    "            atom_features_new = self.dense_11a[0](atom_features_new)\n",
    "\n",
    "            atom_features_new = self.dropout_1[0](atom_features_new, training=training)\n",
    "            atom_features_new = self.prelu[0](atom_features_new)\n",
    "            \n",
    "            # update\n",
    "            atom_features = self.MHA[0](atom_features_new, atom_features, atom_features) * amask\n",
    "\n",
    "            # skip connections and FFNs\n",
    "            atom_features += atom_features_new\n",
    "            atom_features = self.dense_1a[0](atom_features)\n",
    "\n",
    "            gathered = tf.gather(atom_features, pair_indices, batch_dims=1)\n",
    "            gathered = tf.reduce_sum(gathered, axis=-1)\n",
    "            bond_features_new = tf.concat([gathered, bond_features], axis=-1)\n",
    "            bond_features_new = self.dense_bonds[0](bond_features_new)\n",
    "            bond_features = self.MHA_b[0](bond_features_new, bond_features, bond_features) * bmask\n",
    "\n",
    "\n",
    "        return atom_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Readout(tf.keras.layers.Layer):\n",
    "    def __init__(self, steps=16, dropout_rate=0.1, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.dropout_rate = dropout_rate\n",
    "        self.steps = steps\n",
    "    \n",
    "    def build(self, input_shape):\n",
    "        self.hidden_dim = input_shape[-1]\n",
    "        self.num_atoms_dim = input_shape[-2]\n",
    "        self.batch_dim = input_shape[0]\n",
    "\n",
    "\n",
    "        self.lstm = LSTM()\n",
    "\n",
    "    def call(self, m):\n",
    "        rt = tf.cast(tf.zeros_like(m)[:, 0, :], tf.float32)\n",
    "        q_star = [rt, rt]  # batch, d_model\n",
    "\n",
    "        for _ in range(self.steps):\n",
    "            qt = self.lstm(q_star)[0]  # (batch, d_model)\n",
    "            eit = tf.reduce_sum(m * qt[:, None, :], axis=-1)  # batch, num_atoms\n",
    "\n",
    "            denom = tf.reduce_sum(tf.exp(eit), axis=-1)[..., None]  # batch, 1\n",
    "            ait = tf.exp(eit) / denom  # batch, num_atoms\n",
    "\n",
    "            rt = ait[..., None] * m  # batch, num_atoms, d_model\n",
    "            rt = tf.reduce_sum(rt, axis=-2)  # batch, d_model\n",
    "\n",
    "            q_star = [qt, rt]\n",
    "\n",
    "        return tf.concat(q_star, axis=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Catformer(tf.keras.Model):\n",
    "    def __init__(self, d_model, dropout_rate=0.1):\n",
    "        super(Catformer, self).__init__()\n",
    "\n",
    "        self.dropout_rate = dropout_rate\n",
    "        self.d_model = d_model\n",
    "        \n",
    "        self.loss_tracker = tf.keras.metrics.Mean(name=\"loss\")\n",
    "        self.auc_tracker = tf.keras.metrics.AUC(name=\"auc\")\n",
    "\n",
    "    def build(self, input_shape):\n",
    "        \n",
    "        self.mp = EdgeNetwork(steps=2, num_heads=1, num_proj=1, d_model=self.d_model, attention=False, dropout_rate=self.dropout_rate)\n",
    "\n",
    "        self.mid1 = tf.keras.layers.Dense(units=self.d_model, activation='relu', name=\"mid1\")\n",
    "        self.dropout1 = tf.keras.layers.Dropout(self.dropout_rate)\n",
    "\n",
    "        self.mid2 = tf.keras.layers.Dense(units=self.d_model, activation='relu', name=\"mid2\")\n",
    "        self.dropout2 = tf.keras.layers.Dropout(self.dropout_rate)\n",
    "\n",
    "        self.embedding = tf.keras.layers.Embedding(95, 16)\n",
    "\n",
    "        self.a2 = tf.keras.layers.Dense(units=7, name=\"a_2\")\n",
    "        self.b_w = tf.keras.layers.Dense(units=self.d_model, use_bias=False)\n",
    "        self.mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)\n",
    "\n",
    "        # self.readout = Readout(steps=2)\n",
    "\n",
    "    def _distance_gaussian_expansion(self, dist, dmin=0.0, dmax=4.0, steps=64):\n",
    "\n",
    "        centers = tf.linspace(dmin, dmax, steps)\n",
    "        width = (centers[1] - centers[0]) * 1.0 \n",
    "        gauss = tf.exp(-((dist - centers) ** 2) / (2 * width**2))\n",
    "        return gauss\n",
    "\n",
    "    def call(self, inputs, training=True):\n",
    "        \n",
    "        a, abc, tensor, dist_tensor = inputs\n",
    "        \n",
    "        amask = tf.cast(tf.not_equal(a, 0), tf.float32)[..., None]\n",
    "        bmask = tf.cast(tf.not_equal(dist_tensor, 0), tf.float32)\n",
    "\n",
    "        a = self.embedding(a)\n",
    "        a = tf.concat([a, abc], axis=-1)\n",
    "        a = self.b_w(a) * amask\n",
    "\n",
    "        b = self._distance_gaussian_expansion(dist_tensor, steps=512) # batch, num_atoms, expansion\n",
    "        b *= bmask\n",
    "        \n",
    "        x = self.mp([a, b, tensor], training=training)  # batch, num_atoms, d_model\n",
    "        x = tf.nn.relu(x)\n",
    "       \n",
    "        # x = self.readout(x)  # batch, d_model * 2\n",
    "        x = tf.reduce_mean(x, axis=1)  # batch, d_model\n",
    "\n",
    "        x = self.mid1(x) \n",
    "        x = self.dropout1(x, training=training)\n",
    "        \n",
    "        x = self.mid2(x)        \n",
    "        x = self.dropout2(x, training=training)\n",
    "        x = self.a2(x)\n",
    "\n",
    "        return x\n",
    "    \n",
    "    def test_step(self, inputs_testing):\n",
    "        a, abc, tensor, dist_tensor = inputs_testing[0]\n",
    "        \n",
    "        real = inputs_testing[1]\n",
    "        pred = self([a, abc, tensor, dist_tensor], training=False)\n",
    "       \n",
    "        val_loss = self._compute_loss(real, pred)\n",
    "        self.loss_tracker.update_state(val_loss)\n",
    "\n",
    "        return {\"loss\": self.loss_tracker.result()}\n",
    "\n",
    "    def train_step(self, inputs_t):\n",
    "        a, abc, tensor, dist_tensor = inputs_t[0]\n",
    "        real = inputs_t[1]\n",
    "\n",
    "        with tf.GradientTape() as tape:\n",
    "            pred = self([a, abc, tensor, dist_tensor], training=True)\n",
    "            loss = self._compute_loss(real, pred)\n",
    "\n",
    "        grads = tape.gradient(loss, self.trainable_weights)\n",
    "        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
    "\n",
    "        self.loss_tracker.update_state(loss)\n",
    "        return {\"loss\": self.loss_tracker.result()}\n",
    "\n",
    "    def _compute_loss(self, real, pred):\n",
    "        \n",
    "        return self.mae(real, pred) \n",
    "\n",
    "with strategy.scope():\n",
    "    pred_hist = tf.keras.metrics.Mean(name='pred_hist')\n",
    "\n",
    "    learning_rate = 1e-3\n",
    "    \n",
    "    optimizer = tf.keras.optimizers.Adam(learning_rate)\n",
    "    model = Catformer(D_MODEL) \n",
    "\n",
    "    model.compile(optimizer=optimizer)\n",
    "\n",
    "history = model.fit(training_set, validation_data=valid_set, epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"catformer_2\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " edge_network (EdgeNetwork)  multiple                  386432    \n",
      "                                                                 \n",
      " mid1 (Dense)                multiple                  16512     \n",
      "                                                                 \n",
      " dropout (Dropout)           multiple                  0         \n",
      "                                                                 \n",
      " mid2 (Dense)                multiple                  16512     \n",
      "                                                                 \n",
      " dropout_1 (Dropout)         multiple                  0         \n",
      "                                                                 \n",
      " embedding (Embedding)       multiple                  1520      \n",
      "                                                                 \n",
      " a_2 (Dense)                 multiple                  387       \n",
      "                                                                 \n",
      " dense (Dense)               multiple                  2048      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 424,213\n",
      "Trainable params: 423,411\n",
      "Non-trainable params: 802\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atst = []\n",
    "abctst = []\n",
    "abgtst = []\n",
    "\n",
    "ytst = []\n",
    "for i in valid_set:\n",
    "\n",
    "    atst.append(i[0][0])\n",
    "    abctst.append(i[0][1])\n",
    "    abgtst.append(i[0][2])\n",
    "    ytst.append(i[1])\n",
    "\n",
    "atst = tf.concat(atst, axis=0)\n",
    "abctst = tf.concat(abctst, axis=0)\n",
    "abgtst = tf.concat(abgtst, axis=0)\n",
    "ytst = tf.concat(ytst, axis=0)\n",
    "\n",
    "pred = model.predict([atst, abctst, abgtst])\n",
    "loss_object = tf.keras.losses.MeanAbsoluteError()\n",
    "tf.reduce_mean(abs(pred[..., 2] - tf.cast(ytst[..., 2], tf.float32)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(history.history['loss'])\n",
    "plt.plot(history.history['val_loss'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
