{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from ProtMamba_ssm.core import *\n",
    "from ProtMamba_ssm.dataloaders import *\n",
    "from ProtMamba_ssm.utils import *\n",
    "from ProtMamba_ssm.modules import *\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ProtMamba\n",
    "\n",
    "> A Homology Aware but Alignment Free Protein State Space Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**A Homology-Aware but Alignment-Free Protein State Space Model**: [ProtMamba](https://www.biorxiv.org/content/early/2024/05/25/2024.05.24.595730) is a novel protein language model designed to facilitate protein design. Unlike traditional models that rely on multiple sequence alignments (MSAs) to capture evolutionary information, ProtMamba can use unaligned homologous sequences, avoiding the imperfections associated with MSAs.\n",
    "\n",
    "<img src=\"data/logo.jpeg\" width=\"512\"/>\n",
    "\n",
    "\n",
    "\n",
    "### Features\n",
    "ProtMamba is based on the Mamba architecture, a state space model that efficiently handles very long sequences. The model uses a fill-in-the-middle (FIM) training objective, combining autoregressive modeling and masked language modeling to predict amino acids conditioned on the given context sequences. This makes ProtMamba particularly well-suited for generating novel protein sequences, filling in specific regions of sequences, and predicting the fitness of protein variants.\n",
    "\n",
    "- **Homology-Aware but Alignment-Free**: Captures evolutionary information without relying on MSAs and use it to condition the generation process.\n",
    "- **Efficient Long-Context handling**: Uses Mamba blocks to handle long sequences with linear memory scaling.\n",
    "- **Different training objective**: Combines autoregressive and masked language modeling through a fill-in-the-middle objective.\n",
    "- **Sequence-level positional embeddings**: Enhances the model's ability to reason about in-sequence dependencies and allow for precise inpainting.\n",
    "\n",
    "### Applications\n",
    "\n",
    "- **Sequence Generation**: Generate novel protein sequences from scratch conditioned on specific homologs.\n",
    "- **Sequence Inpainting**: Fill in specific masked regions within a sequence for targeted protein design.\n",
    "- **Fitness Prediction**: Predict the probability distribution of mutations to assess the functional impact of variants.\n",
    "\n",
    "### Repository Structure\n",
    "`configs/`: Configuration files for model training and evaluation.\n",
    "\n",
    "`data/`: Example dataset.\n",
    "\n",
    "`nbs/`: Implementation of the ProtMamba model architecture in jupyter notebooks.\n",
    "\n",
    "`ProtMamba_ssm/`: Implementation of the ProtMamba model architecture.\n",
    "\n",
    "`tests/`: Scripts to sample from ProtMamba and evaluate the model's performance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Repository"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```sh\n",
    "pip install -e .\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to tokenize all your MSAs to make a training dataset\n",
    "\n",
    "IMPORTANT: the sequences should be in `a3m` files but they do not need to be aligned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "msa_paths = {\"name-of-msa\" : \"../data/example_msa.a3m\"}\n",
    "# path saving directory\n",
    "filepath = input(\"What is the path to the folder where you want to save the dataset?\")   # example: \"../data/\"\n",
    "dataset_name = input(\"How do you want to name the dataset file?\") # example: \"encoded_MSAs_train.pkl\"\n",
    "\n",
    "dataset_dictionary = {}\n",
    "for msa_name, msa_path in msa_paths.items():\n",
    "    # Load an a3m file with all the context sequences\n",
    "    msa = load_from_file(msa_path)\n",
    "    # Tokenize the sequences and concatenate them into a single array\n",
    "    tokens = tokenizer(msa, concatenate=True)\n",
    "    tokens = tokens.numpy()[0]\n",
    "    dataset_dictionary[msa_name] = tokens\n",
    "    \n",
    "with open(filepath+dataset_name, \"wb\") as f:\n",
    "    pickle.dump(dataset_dictionary, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "if int(use_one_gpu) >= 0:\n",
    "    print(f\"Using gpu {use_one_gpu}\")\n",
    "print(\"Number of gpus used: \", torch.cuda.device_count())\n",
    "\n",
    "# Load the default config file (change it and add the path to the training dataset)\n",
    "with open(\"../configs/default_config.yaml\", \"r\") as file:\n",
    "    defaultconfig = yaml.safe_load(file) \n",
    "\n",
    "namedir = input(\"Enter name of directory to save results: \")\n",
    "finetune_path = input(\"If you want to finetune a model, enter the relative path to the model's checkpoint, otherwise press enter:\")\n",
    "finetune_path = finetune_path if finetune_path else None\n",
    "\n",
    "# Run the trainer with the selected training configuration\n",
    "trainer = run(defaultconfig, namedir, finetune_model_path=finetune_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to sample from a pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_custom = input(\"Do you want to use a custom MSA? (y/n): \")\n",
    "if use_custom == \"y\":\n",
    "    # Load an a3m file with all the context sequences\n",
    "    msa = load_from_file(\"../data/example_msa.a3m\")\n",
    "    target_sequence = msa[:1]\n",
    "    context_msa = msa[1:]\n",
    "    # Tokenize the sequences and concatenate them into a single array\n",
    "    target = tokenizer(target_sequence, concatenate=True)\n",
    "    tokens = tokenizer(context_msa, concatenate=True)\n",
    "    fim_gen = input(\"Do you want to generate using FIM? (y/n): \")\n",
    "    if fim_gen==\"n\":\n",
    "        # AUTOREGRESSIVE, no-FIM generation\n",
    "        # generate the full sequence autoregressively starting from residue 10 in the sequence `target`\n",
    "        gen_dictionary = {\"<cls>\": 10}\n",
    "        input_seq, targ_pos = prepare_target(target, use_fim={\"<cls>\": 10})\n",
    "    if fim_gen==\"y\":\n",
    "        # FIM generation\n",
    "        # mask_dictionary is a dictionary of the positions in the sequence that you want to mask, in this example there will be\n",
    "        # - a mask that covers the residues 4,5,6 and the model will fill it by sampling 10 residues\n",
    "        # - a mask that covers the residues 30,31,32,33,34 and the model will fill it by sampling 3 residues\n",
    "        mask_dictionary = {\"<mask-1>\": ((4,7),10),\"<mask-2>\": ((30,35),3)}\n",
    "        input_seq, targ_pos, is_fim_dict = prepare_target(target, use_fim=mask_dictionary)\n",
    "    context_tokens, context_pos_ids = prepare_tokens(tokens,\n",
    "                                    target_tokens=input_seq,\n",
    "                                    target_pos_ids=targ_pos,\n",
    "                                    DatasetClass=Uniclust30_Dataset,\n",
    "                                    num_sequences=50,\n",
    "                                    fim_strategy=\"multiple_span\",\n",
    "                                    mask_fraction=0.2,\n",
    "                                    max_patches=5,\n",
    "                                    add_position_ids=\"1d\")\n",
    "if use_custom == \"n\":\n",
    "    is_fim = input(\"Do you want to use FIM? (y/n): \")\n",
    "    filepath = input(\"What is the path to the folder with the dataset?\")    \n",
    "    is_fim = True if is_fim == \"y\" else False\n",
    "    # Load the dataset used for training\n",
    "    dataset_name = input(\"What is the name of the dataset file?\") # example: \"encoded_MSAs_subset-100.pkl\", \"encoded_MSAs_train.pkl\"\n",
    "    fim_strategy = \"multiple_span\" if is_fim else \"no-scramble\"\n",
    "    dataset = Uniclust30_Dataset(filename=dataset_name,\n",
    "                                 filepath=filepath,\n",
    "                                 sample=False,\n",
    "                                 mask_fraction=0.2,\n",
    "                                 fim_strategy=fim_strategy,\n",
    "                                 max_position_embeddings=2048,\n",
    "                                 add_position_ids=\"1d\")\n",
    "    # Select a sample of the dataset to be the input\n",
    "    data = dataset[1]\n",
    "    tokens = data[\"input_ids\"][None,:].to(\"cuda\")\n",
    "    pos_ids = data[\"position_ids\"][None,:].to(\"cuda\")\n",
    "\n",
    "model_name = input(\"What is the path to the folder with the checkpoint of the model?\") # example: \"results/train_100M_FIM_restart-spikes_merged/checkpoint_131k-750\"\n",
    "# Load pretrained model\n",
    "model = load_model(model_name,\n",
    "                   model_class=MambaLMHeadModelwithPosids,\n",
    "                   device=\"cuda\",\n",
    "                   dtype=torch.bfloat16,\n",
    "                   checkpoint_mixer=False # Must be False when using model for Inference\n",
    "                   )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate new sequences starting from a custom MSA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Number of tokens in the MSA (target, context): \", len(input_seq[0]), \",\", len(tokens[0]))\n",
    "print(\"Target:\")\n",
    "print(\"sequence:\", decode_sequence(input_seq[0].numpy()))\n",
    "print(\"original sequence:\", decode_sequence(target[0].numpy()))\n",
    "print(\"pos ids:\", list(targ_pos[0].numpy()))\n",
    "print(\"Context:\")\n",
    "print(\"sequence:\", decode_sequence(context_tokens[0].numpy()))\n",
    "print(\"pos ids:\", list(context_pos_ids[0].numpy()))\n",
    "print(\"Mask positions:\", is_fim_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate the new sequence\n",
    "output = generate_sequence(model,\n",
    "                           context_tokens,\n",
    "                           position_ids=context_pos_ids,\n",
    "                           is_fim=is_fim_dict,\n",
    "                           max_length=20000,\n",
    "                           temperature=1.,\n",
    "                           top_k=3,\n",
    "                           top_p=0.0,\n",
    "                           return_dict_in_generate=True,\n",
    "                           output_scores=True,\n",
    "                           eos_token_id=AA_TO_ID[\"<cls>\"],\n",
    "                           device=\"cuda\")\n",
    "\n",
    "input_seq, output_seq = output[\"input\"], output[\"generated\"]\n",
    "logits = output[\"scores\"]\n",
    "print(f\"All input context (len = {len(input_seq[0])}):\\n\", input_seq[0])\n",
    "print(\"Last sequence where the masked parts should be predicted:\\n\", input_seq[0].split(\"<cls>\")[-1])\n",
    "print(f\"Generated (len = {len(output_seq[0])}):\\n\", output_seq[0])\n",
    "input_continuation = decode_sequence(target[0].numpy())+\"<cls>\"\n",
    "print(f\"Continuation of input:\\n\", input_continuation)\n",
    "print(f\"\\nLogits (shape = {logits.shape})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate new sequences using the Dataset and by filling in the masked parts (FIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "context_tokens, context_pos_ids, tokens_fim, pos_ids_fim, is_fim_dict = prepare_dataset_for_fim_generation(tokens, pos_ids)\n",
    "print(\"Length of context information\", context_tokens.shape[1])\n",
    "print(\"Masked part of the sequence to predict and its position indices:\\n\", tokens_fim, \"\\n\", pos_ids_fim)\n",
    "print(\"Masked tokens and their positions in the input sequence:\", is_fim_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate the new sequence\n",
    "output = generate_sequence(model,\n",
    "                           context_tokens,\n",
    "                           position_ids=context_pos_ids,\n",
    "                           is_fim=is_fim_dict,\n",
    "                           max_length=1570,\n",
    "                           temperature=1.,\n",
    "                           top_k=3,\n",
    "                           top_p=0.0,\n",
    "                           return_dict_in_generate=True,\n",
    "                           output_scores=True,\n",
    "                           eos_token_id=AA_TO_ID[\"<cls>\"],\n",
    "                           device=\"cuda\")\n",
    "\n",
    "input_seq, output_seq = output[\"input\"], output[\"generated\"]\n",
    "logits = output[\"scores\"]\n",
    "print(f\"All input context (len = {len(input_seq[0])}):\\n\", input_seq[0])\n",
    "print(\"Last sequence where the masked parts should be predicted:\\n\", input_seq[0].split(\"<cls>\")[-1])\n",
    "print(f\"Generated (len = {len(output_seq[0])}):\\n\", output_seq[0])\n",
    "input_continuation = decode_sequence(tokens_fim[0].cpu().numpy())+\"<cls>\"\n",
    "print(f\"Continuation of input:\\n\", input_continuation)\n",
    "print(f\"\\nLogits (shape = {logits.shape})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total = len(tokens_fim[0])+1\n",
    "if total>4:\n",
    "        fig, axs = plt.subplots(total//4,4, figsize=(20,5*total//4))\n",
    "else:\n",
    "        fig, axs = plt.subplots(1,total, figsize=(20,total,5))\n",
    "        axs = axs[None,:]\n",
    "for el in range(total):\n",
    "        ax = axs[el//4,el%4]\n",
    "        ax.bar(np.arange(logits.shape[-1]),\n",
    "                torch.softmax(torch.from_numpy(logits[0,el,:]), dim=0))\n",
    "        ax.axvline(output[\"generated_tokens\"][0][el], color=\"red\", label=\"Prediction: \"+ID_TO_AA[output[\"generated_tokens\"][0][el]] + f\" ({output['generated_tokens'][0][el]})\", linewidth=0.5)\n",
    "        # ax.axvline(tokens_fim[0,el].cpu().numpy(), color=\"k\",label=\"Original: \"+ID_TO_AA[tokens_fim[0,el].cpu().numpy()] +f\" ({tokens_fim[0,el].cpu().numpy()})\", linewidth=0.5)\n",
    "        ax.legend()\n",
    "fig.suptitle(f\"Real sequence: {input_continuation}\\nPred sequence: {output_seq[0]}\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate new sequences using the Dataset and by sampling amino acids autoregressively from `<cls>`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate the new sequence\n",
    "L = 650#628#\n",
    "output = generate_sequence(model,\n",
    "                           tokens[:,:L],\n",
    "                           position_ids=pos_ids[:,:L],\n",
    "                           is_fim=False,\n",
    "                           max_length=1570,\n",
    "                           temperature=1.,\n",
    "                           top_k=10,\n",
    "                           top_p=0.0,\n",
    "                           return_dict_in_generate=True,\n",
    "                           output_scores=True,\n",
    "                           eos_token_id=torch.tensor([AA_TO_ID[\"<cls>\"],AA_TO_ID[\"<mask-1>\"], AA_TO_ID[\"<mask-2>\"], AA_TO_ID[\"<mask-3>\"],\n",
    "                                                      AA_TO_ID[\"<mask-4>\"], AA_TO_ID[\"<mask-5>\"]]).to(\"cuda\"),\n",
    "                           device=\"cuda\")\n",
    "\n",
    "input_seq, output_seq = output[\"input\"], output[\"generated\"]\n",
    "logits = output[\"scores\"]\n",
    "\n",
    "print(f\"All input context (len = {len(input_seq[0])}):\\n\", input_seq[0])\n",
    "print(\"Last sequence where the masked parts should be predicted:\\n\", input_seq[0].split(\"<cls>\")[-1])\n",
    "print(f\"Generated (len = {len(output_seq[0])}):\\n\", output_seq[0])\n",
    "input_continuation = decode_sequence(tokens[0,L:].cpu().numpy()).split(\"<cls>\")[0]+\"<cls>\"\n",
    "print(f\"Continuation of input:\\n\", input_continuation)\n",
    "print(f\"\\nLogits (shape = {logits.shape})\")\n",
    "print(\"\\nStops if the model predicts a mask token\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(4,4, figsize=(20,10))\n",
    "for el in range(16):\n",
    "        ax = axs[el//4,el%4]\n",
    "        ax.bar(np.arange(logits.shape[-1]),\n",
    "                torch.softmax(torch.from_numpy(logits[0,el,:]), dim=0))\n",
    "        ax.axvline(output[\"generated_tokens\"][0][el], color=\"red\", label=\"Prediction: \"+output_seq[0][el] + f\" ({AA_TO_ID[output_seq[0][el]]})\", linewidth=0.5)\n",
    "        ax.axvline(tokens[0,L+el].cpu().numpy(), color=\"k\",label=\"Original: \"+input_continuation[el] +f\" ({AA_TO_ID[input_continuation[el]]})\", linewidth=0.5)\n",
    "        ax.legend()\n",
    "fig.suptitle(f\"Real sequence: {input_continuation[:16]}\\nPred sequence: {output_seq[0][:16]}\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get the hidden states at each layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "which_layers = list(range(1,model.config.n_layer+1))\n",
    "\n",
    "def get_hidden_states(model, tokens, which_layers, position_ids=None, seq_position_ids=None):\n",
    "    hidden_states = model(input_ids=tokens,\n",
    "                          save_layer=which_layers,\n",
    "                          position_ids=position_ids,\n",
    "                          seq_position_ids=seq_position_ids)\n",
    "    return hidden_states\n",
    "\n",
    "hidden_states = get_hidden_states(model, tokens[:,:10], which_layers, pos_ids[:,:10])\n",
    "print(f\"Saved hidden states for layers: {hidden_states.keys()}\")\n",
    "print(f\"Shape of hidden state of layer 1: {hidden_states[1].shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## nbdev\n",
    "\n",
    "Project developed using [nbdev](https://nbdev.fast.ai/)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
