{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# ProVis: Attention Visualizer for Proteins\n",
    "\n",
    "### Notes\n",
    "* The tool visualizes attention between different amino acids in a sequence.\n",
    "* It does not currently visualize attention from an amino acid to itself.\n",
    "* Attention is visualized as unidirectional using the maximum of the two directed attention weights. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "import io\n",
    "import urllib\n",
    "\n",
    "import torch\n",
    "from Bio.Data import SCOPData\n",
    "from Bio.PDB import PDBParser, PPBuilder\n",
    "from tape import TAPETokenizer, ProteinBertModel\n",
    "import nglview\n",
    "\n",
    "attn_color = [0.937, .522, 0.212]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def get_structure(pdb_id):\n",
    "    resource = urllib.request.urlopen(f'https://files.rcsb.org/download/{pdb_id}.pdb')\n",
    "    content = resource.read().decode('utf8')\n",
    "    handle = io.StringIO(content)\n",
    "    parser = PDBParser(QUIET=True)\n",
    "    return parser.get_structure(pdb_id, handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def get_attn_data(chain, layer, head, min_attn, start_index=0, end_index=None, max_seq_len=1024):\n",
    "\n",
    "    tokens = []\n",
    "    coords = []\n",
    "    for res in chain:\n",
    "        t = SCOPData.protein_letters_3to1.get(res.get_resname(), \"X\")\n",
    "        tokens += t\n",
    "        if t == 'X':\n",
    "            coord = None\n",
    "        else:\n",
    "            coord = res['CA'].coord.tolist()\n",
    "        coords.append(coord)      \n",
    "    last_non_x = None\n",
    "    for i in reversed(range(len(tokens))):\n",
    "        if tokens[i] != 'X':\n",
    "            last_non_x = i\n",
    "            break\n",
    "    assert last_non_x is not None\n",
    "    tokens = tokens[:last_non_x + 1]\n",
    "    coords = coords[:last_non_x + 1]    \n",
    "    \n",
    "    tokenizer = TAPETokenizer()\n",
    "    model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)\n",
    "\n",
    "    if max_seq_len:\n",
    "        tokens = tokens[:max_seq_len - 2]  # Account for SEP, CLS tokens (added in next step)\n",
    "    token_idxs = tokenizer.encode(tokens).tolist()\n",
    "    if max_seq_len:\n",
    "        assert len(token_idxs) == min(len(tokens) + 2, max_seq_len)\n",
    "    else:\n",
    "        assert len(token_idxs) == len(tokens) + 2\n",
    "\n",
    "    inputs = torch.tensor(token_idxs).unsqueeze(0)\n",
    "    with torch.no_grad():\n",
    "        attns = model(inputs)[-1]\n",
    "        # Remove attention from <CLS> (first) and <SEP> (last) token\n",
    "    attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]\n",
    "    attns = torch.stack([attn.squeeze(0) for attn in attns])\n",
    "    attn = attns[layer, head]\n",
    "    if end_index is None:\n",
    "        end_index = len(tokens)\n",
    "    attn_data = []\n",
    "    for i in range(start_index, end_index):\n",
    "        for j in range(i, end_index):\n",
    "            # Currently non-directional: shows max of two attns\n",
    "            a = max(attn[i, j].item(), attn[j, i].item())\n",
    "            if a is not None and a >= min_attn:\n",
    "                attn_data.append((a, coords[i], coords[j]))\n",
    "    return attn_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize head 7-1 (targets binding sites)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Example for head 7-1 (targets binding sites)\n",
    "pdb_id = '7HVP'\n",
    "chain_ids = None # All chains\n",
    "layer = 7\n",
    "head = 1\n",
    "min_attn = 0.1\n",
    "attn_scale = .9\n",
    "\n",
    "layer_zero_indexed = layer - 1\n",
    "head_zero_indexed = head - 1\n",
    "\n",
    "structure = get_structure(pdb_id)\n",
    "view = nglview.show_biopython(structure)\n",
    "view.stage.set_parameters(**{\n",
    "    \"backgroundColor\": \"black\",\n",
    "    \"fogNear\": 50, \"fogFar\": 100,\n",
    "})\n",
    "\n",
    "models = list(structure.get_models())\n",
    "if len(models) > 1:\n",
    "    print('Warning:', len(models), 'models. Using first one')\n",
    "prot_model = models[0]\n",
    "\n",
    "if chain_ids is None:\n",
    "    chain_ids = [chain.id for chain in prot_model]\n",
    "for chain_id in chain_ids: \n",
    "    print('Loading chain', chain_id)\n",
    "    chain = prot_model[chain_id]    \n",
    "    attn_data = get_attn_data(chain, layer_zero_indexed, head_zero_indexed, min_attn)\n",
    "    for att, coords_from, coords_to in attn_data:\n",
    "        view.shape.add_cylinder(coords_from, coords_to, attn_color, att * attn_scale) \n",
    "        \n",
    "view"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize head 12-4 (targets contact maps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example for head 12-4 (targets contact maps)\n",
    "pdb_id = '2KC7'\n",
    "chain_ids = None # All chains\n",
    "layer = 12\n",
    "head = 4\n",
    "min_attn = 0.2\n",
    "attn_scale = .5\n",
    "\n",
    "layer_zero_indexed = layer - 1\n",
    "head_zero_indexed = head - 1\n",
    "\n",
    "structure = get_structure(pdb_id)\n",
    "view2 = nglview.show_biopython(structure)\n",
    "view2.stage.set_parameters(**{\n",
    "    \"backgroundColor\": \"black\",\n",
    "    \"fogNear\": 50, \"fogFar\": 100,\n",
    "})\n",
    "\n",
    "models = list(structure.get_models())\n",
    "if len(models) > 1:\n",
    "    print('Warning:', len(models), 'models. Using first one')\n",
    "prot_model = models[0]\n",
    "\n",
    "if chain_ids is None:\n",
    "    chain_ids = [chain.id for chain in prot_model]\n",
    "for chain_id in chain_ids: \n",
    "    print('Loading chain', chain_id)\n",
    "    chain = prot_model[chain_id]    \n",
    "    attn_data = get_attn_data(chain, layer_zero_indexed, head_zero_indexed, min_attn)\n",
    "    for att, coords_from, coords_to in attn_data:\n",
    "        view2.shape.add_cylinder(coords_from, coords_to, attn_color, att * attn_scale) \n",
    "        \n",
    "view2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}