{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zb9d5N92iMz2"
   },
   "source": [
    "# Chroma Tutorial\n",
    "\n",
    "First, run the [setup cell](#setup) below. Then, run [this cell](#unconditional-chain) to get a Chroma sample. Further examples are below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KaT878cbeOQm"
   },
   "outputs": [],
   "source": [
    "# @title Setup\n",
    "\n",
    "# @markdown [Get your API key here](https://chroma-weights.generatebiomedicines.com) and enter it below before running.\n",
    "\n",
    "from google.colab import output\n",
    "\n",
    "output.enable_custom_widget_manager()\n",
    "\n",
    "import os\n",
    "\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "import contextlib\n",
    "\n",
    "api_key = \"\"  # @param {type:\"string\"}\n",
    "\n",
    "!pip install git+https://github.com/generatebio/chroma.git > /dev/null 2>&1\n",
    "\n",
    "import torch\n",
    "\n",
    "torch.use_deterministic_algorithms(True, warn_only=True)\n",
    "\n",
    "import warnings\n",
    "from tqdm import tqdm, TqdmExperimentalWarning\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", category=TqdmExperimentalWarning)\n",
    "from functools import partialmethod\n",
    "\n",
    "tqdm.__init__ = partialmethod(tqdm.__init__, leave=False)\n",
    "\n",
    "from google.colab import files\n",
    "import ipywidgets as widgets\n",
    "\n",
    "\n",
    "def create_button(filename, description=\"\"):\n",
    "    button = widgets.Button(description=description)\n",
    "    display(button)\n",
    "\n",
    "    def on_button_click(b):\n",
    "        files.download(filename)\n",
    "\n",
    "    button.on_click(on_button_click)\n",
    "\n",
    "\n",
    "def render(protein, trajectories=None, output=\"protein.cif\"):\n",
    "    display(protein)\n",
    "    print(protein)\n",
    "    protein.to_CIF(output)\n",
    "    create_button(output, description=\"Download sample\")\n",
    "    if trajectories is not None:\n",
    "        traj_output = output.replace(\".cif\", \"_trajectory.cif\")\n",
    "        trajectories[\"trajectory\"].to_CIF(traj_output)\n",
    "        create_button(traj_output, description=\"Download trajectory\")\n",
    "\n",
    "\n",
    "import locale\n",
    "\n",
    "locale.getpreferredencoding = lambda: \"UTF-8\"\n",
    "\n",
    "from chroma import Chroma, Protein, conditioners\n",
    "from chroma.models import graph_classifier, procap\n",
    "from chroma.utility.api import register_key\n",
    "from chroma.utility.chroma import letter_to_point_cloud, plane_split_protein\n",
    "\n",
    "register_key(api_key)\n",
    "\n",
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qYuYX-GVsKii"
   },
   "source": [
    "## Sampling basics\n",
    "\n",
    "Use `Chroma.sample` to get a protein from Chroma. By default, a backbone is generated through reverse diffusion from random noise, and then the sequence and associated side chain atoms are designed on this backbone."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oxe78ZXb4HkR"
   },
   "outputs": [],
   "source": [
    "chroma = Chroma()\n",
    "\n",
    "chain_lengths = [160]  # can have multiple chains in a single complex\n",
    "\n",
    "protein = chroma.sample(chain_lengths=chain_lengths, steps=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WOU-1ddU4pfH"
   },
   "source": [
    "Print the protein sequence or display the full structure. There's a `render` function in the setup cell that lets you do both and gives a download button, using `Protein.to_CIF`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "biCeh1pRtnuk"
   },
   "outputs": [],
   "source": [
    "print(protein)\n",
    "display(protein)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wt6lPj4GxW9x"
   },
   "outputs": [],
   "source": [
    "render(protein)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Sampling options\n",
    "\n",
    "There are several ways to control the backbone generation and sequence design processes. For instance, the `inverse_temperature` argument to `Chroma.sample` controls the temperature of the backbone sampling. Lower inverse temperature corresponds to more risky sampling."
   ],
   "metadata": {
    "id": "L9UpxbVESzOa"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "torch.manual_seed(42)\n",
    "hight = chroma.sample(chain_lengths=[100], steps=200, inverse_temperature=1)"
   ],
   "metadata": {
    "id": "naEUf-NqTWB6"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "torch.manual_seed(42)\n",
    "lowt = chroma.sample(chain_lengths=[100], steps=200, inverse_temperature=10)"
   ],
   "metadata": {
    "id": "FzQck6RCU9nB"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Scoring proteins\n",
    "\n",
    "We can score the generated proteins with `Chroma.score`. Generally, lower temperature sampling gives better quality at the expense of diversity."
   ],
   "metadata": {
    "id": "lgQJJ3X_VSyN"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "lowt_scores = chroma.score(lowt)\n",
    "hight_scores = chroma.score(hight)\n",
    "print(lowt_scores[\"elbo\"].score, hight_scores[\"elbo\"].score)"
   ],
   "metadata": {
    "id": "nBKRb4g9V4G0"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CkzL_dBPyP1C"
   },
   "source": [
    "## Getting diffusion trajectories\n",
    "\n",
    "Let's make a complex with two chains. This time, we'll set `full_output` to also get the diffusion trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FlLMjj7nygFB"
   },
   "outputs": [],
   "source": [
    "protein, trajectories = chroma.sample(\n",
    "    chain_lengths=[140, 140], steps=200, full_output=True\n",
    ")\n",
    "render(protein, trajectories)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "At each step in the reverse diffusion process, the model produces a best guess of what the sample should look like when fully denoised. These predictions are stored in the `Xhat_trajectory` key of the trajectory output. We can output these and see how the generated sample evolves towards the denoised prediction."
   ],
   "metadata": {
    "id": "tEdsbp9gPIv3"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-2ZiWLqrzlZ5"
   },
   "outputs": [],
   "source": [
    "print(list(trajectories.keys()))\n",
    "\n",
    "trajectories[\"Xhat_trajectory\"].to_CIF(\"xhat_trajectory.cif\")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Conditional generation\n",
    "\n",
    "Usually, we want to generate a protein that satisfies particular conditions. Chroma's conditioner framework enables this. Here, we show an example where we redesign the backbone of a protein with some residues fixed; the condition is that the coordinates of the fixed residues can't change through the diffusion process.\n",
    "\n",
    "We also show the `design_selection` option, which allows us to fix part of the sequence. There's even more you can do with sequence design, including specifying which residues are allowed by position."
   ],
   "metadata": {
    "id": "u7fSD4gaWomS"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "09F3VTAh0RqO"
   },
   "outputs": [],
   "source": [
    "protein = Protein(\"1XYZ\", device=\"cuda\")\n",
    "substructure_conditioner = conditioners.SubstructureConditioner(\n",
    "    protein=protein,\n",
    "    backbone_model=chroma.backbone_network,\n",
    "    selection=\"not (chain A and resid 30-60)\",\n",
    ").to(\"cuda\")\n",
    "\n",
    "new_protein = chroma.sample(\n",
    "    protein_init=protein,\n",
    "    conditioner=substructure_conditioner,\n",
    "    langevin_factor=4.0,\n",
    "    langevin_isothermal=True,\n",
    "    inverse_temperature=8.0,\n",
    "    sde_func=\"langevin\",\n",
    "    steps=500,\n",
    "    design_selection=\"chain B and resid 30-60\",\n",
    ")\n",
    "\n",
    "render(new_protein)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Note that sequence and sidechain design can also be done independently of backbone generation. Here's an example of redesigning the sequence of the same PDB structure."
   ],
   "metadata": {
    "id": "TKPkBZJV0UQA"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "designed_protein = chroma.design(protein)\n",
    "print(designed_protein)\n",
    "print(protein)"
   ],
   "metadata": {
    "id": "dqr3Kvgl0TTm"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Residue-level conditioning\n",
    "\n",
    "While the above example used a conditioner applied to the generated structure as a whole, Chroma can also condition on individual residues. Here's a conditioner where we can specify the secondary structure for each residue. You can specify a string where H = helix, E = strand, and T = turn."
   ],
   "metadata": {
    "id": "CbGQUxcdzuJd"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "A2012-PnoTHf"
   },
   "outputs": [],
   "source": [
    "SS = \"HHHHHHHTTTHHHHHHHTTTEEEEEETTTEEEEEEEETTTTHHHHHHHH\"\n",
    "\n",
    "proclass_model = graph_classifier.load_model(\"named:public\", device=device)\n",
    "ss_conditioner = conditioners.ProClassConditioner(\n",
    "    \"secondary_structure\", SS, max_norm=10.0, model=proclass_model\n",
    ")\n",
    "ss_conditioned_protein = chroma.sample(\n",
    "    conditioner=ss_conditioner, steps=500, chain_lengths=[len(SS)]\n",
    ")\n",
    "render(ss_conditioned_protein, output=\"ss_conditioned_protein.cif\")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Composing conditioners\n",
    "\n",
    "The `conditioners` module in Chroma allows for composition, via `composed_conditioner = conditioners.ComposedConditioner([conditioner1, conditioner2, ...])`. We can use the secondary structure conditioner from above along with a symmetry conditioner."
   ],
   "metadata": {
    "id": "dIkmR9rMJQX-"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "symm_conditioner = conditioners.SymmetryConditioner(G=\"C_3\", num_chain_neighbors=2)\n",
    "composed_cond = conditioners.ComposedConditioner([ss_conditioner, symm_conditioner])\n",
    "\n",
    "symm_ss_protein = chroma.sample(\n",
    "    chain_lengths=[len(SS)],\n",
    "    conditioner=composed_cond,\n",
    "    langevin_factor=8,\n",
    "    inverse_temperature=8,\n",
    "    sde_func=\"langevin\",\n",
    "    steps=500,\n",
    ")\n",
    "\n",
    "render(symm_ss_protein, output=\"symm_ss_protein.cif\")"
   ],
   "metadata": {
    "id": "740IURiRJhOD"
   },
   "execution_count": null,
   "outputs": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}