{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-21T16:52:14.980679Z",
     "start_time": "2024-04-21T16:52:10.341345Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from scripts import load_model, MambaLMHeadModelwithPosids, AA_TO_ID\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from config import DATA_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57f471c4cdc0103f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-21T16:52:51.503287Z",
     "start_time": "2024-04-21T16:52:51.112400Z"
    }
   },
   "outputs": [],
   "source": [
    "result_file = \"ProteinGym_reference_file_substitutions.a3m\"\n",
    "labels = []\n",
    "with open(result_file, \"r\") as f:\n",
    "    for line in f:\n",
    "        if line.startswith(\">\"):\n",
    "            labels.append(line[1:].strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e4f784b3a38fcb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66334d2bee797e99",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-21T16:53:24.228659Z",
     "start_time": "2024-04-21T16:53:24.222781Z"
    }
   },
   "outputs": [],
   "source": [
    "len(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9952b0e97f335f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-21T16:53:17.107339Z",
     "start_time": "2024-04-21T16:53:17.089772Z"
    }
   },
   "outputs": [],
   "source": [
    "labels.index(\"158\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20846bd519fe4f67",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T13:10:05.917910Z",
     "start_time": "2024-04-19T13:10:05.901418Z"
    }
   },
   "outputs": [],
   "source": [
    "# for all a3m file in this folder, rename with the name of the first sequence in the a3m file\n",
    "import os\n",
    "\n",
    "files = os.listdir(\"/data2/malbrank/protein_gym/mmseqs_colabfold_protocol\")\n",
    "for file in files:\n",
    "    if not file.endswith(\".a3m\"):\n",
    "        continue\n",
    "    with open(f\"/data2/malbrank/protein_gym/mmseqs_colabfold_protocol/{file}\", \"r\") as f:\n",
    "        first_line = f.readline()\n",
    "        prot_name = first_line.split()[0][1:]\n",
    "    os.rename(f\"/data2/malbrank/protein_gym/mmseqs_colabfold_protocol/{file}\", f\"/data2/malbrank/protein_gym/mmseqs_colabfold_protocol/{prot_name}.a3m\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "246cd33cbc10e8ea",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:22:25.359168Z",
     "start_time": "2024-04-19T09:22:25.350957Z"
    }
   },
   "outputs": [],
   "source": [
    "def build_landscape_from_csv(csv_file, msa_start=1):\n",
    "    df = pd.read_csv(csv_file)\n",
    "    msa_len = len(df[\"mutated_sequence\"].loc[0])\n",
    "    gt_mut_landscape = torch.ones((msa_len, 20)) * np.inf\n",
    "    for i, row in df.iterrows():\n",
    "        mut_pos = int(row[\"mutant\"][1:-1]) - msa_start\n",
    "        mut_aa = row[\"mutant\"][-1]\n",
    "        eff = float(row[\"DMS_score\"])\n",
    "        mut_aa_id = AA_TO_ID[mut_aa] - 4\n",
    "        gt_mut_landscape[mut_pos, mut_aa_id] = eff\n",
    "    keep_idx = torch.where(gt_mut_landscape != np.inf)\n",
    "    return gt_mut_landscape, keep_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b18507e11447c1fb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:08:50.673263Z",
     "start_time": "2024-04-19T09:08:50.652940Z"
    }
   },
   "outputs": [],
   "source": [
    "n_tokens = [8000, 16000, 32000, 64000, ]\n",
    "csv_folder = f\"{DATA_DIR}/protein_gym/substitutions/DMS_ProteinGym_substitutions/\"\n",
    "out_folder = f\"{DATA_DIR}/protein_gym/mut_effects/\"\n",
    "database_df = pd.read_csv(f\"{DATA_DIR}/protein_gym/substitutions/DMS_substitutions.csv\")\n",
    "database_df = database_df[database_df[\"DMS_number_multiple_mutants\"] == 0]\n",
    "results_df = pd.read_csv(f\"{DATA_DIR}/protein_gym/substitutions/DMS_substitutions_Spearman.csv\", index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6b76cb7ed5d8656",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:25:08.300654Z",
     "start_time": "2024-04-19T09:25:08.292077Z"
    }
   },
   "outputs": [],
   "source": [
    "results_df[\"TranceptEVE L\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aef9f4e96d7fb20",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:32:23.709211Z",
     "start_time": "2024-04-19T09:32:23.704633Z"
    }
   },
   "outputs": [],
   "source": [
    "select_models = [\"TranceptEVE L\", \"GEMME\", \"ESM-IF1\", \"MSA Transformer (single)\", \"ESM2 (650M)\", \"EVmutation\", \"Site-Independent\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32a1a3895887f18",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-17T15:45:46.039405Z",
     "start_time": "2024-04-17T15:45:43.970206Z"
    }
   },
   "outputs": [],
   "source": [
    "model = load_model(\"/nvme1/common/mamba_100M_FIM_finetuned_32k_checkpoint-16500\",\n",
    "                   model_class=MambaLMHeadModelwithPosids,\n",
    "                   device=\"cuda\",\n",
    "                   dtype=torch.bfloat16).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17aa40c22ba9fd8b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:18:15.541016Z",
     "start_time": "2024-04-19T09:18:14.693135Z"
    }
   },
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr\n",
    "from matplotlib import pyplot as plt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57001d0da8028461",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:20:24.628506Z",
     "start_time": "2024-04-19T09:20:24.607494Z"
    }
   },
   "outputs": [],
   "source": [
    "database_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83fc94b00c1da3e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from tqdm import tqdm_notebook\n",
    "\n",
    "tests = [\"mamba_8000\", \"mamba_16000\", \"mamba_32000\", \"mamba_ft_8000\", \"mamba_ft_16000\", \"mamba_ft_32000\"]\n",
    "colors = [\"red\", \"blue\", \"green\", \"red\", \"blue\", \"green\"]\n",
    "markers = [\"o\", \"o\", \"o\", \"x\", \"x\", \"x\"]\n",
    "spearmanrs_tests = {}\n",
    "\n",
    "for j, test in enumerate(tests):\n",
    "    spearmanrs_tests[test] = []\n",
    "    for i, row in tqdm_notebook(enumerate(database_df.iterrows())):\n",
    "        csv_file = row[1][\"DMS_filename\"]\n",
    "        msa_start = int(row[1][\"MSA_start\"])\n",
    "        gt_mut_landscape, keep_idx = build_landscape_from_csv(csv_folder+csv_file, msa_start)\n",
    "        dms_id = row[1][\"DMS_id\"]\n",
    "        pred_mut_landscape = torch.load(f\"{out_folder}/{test}/{dms_id}_landscape.pt\")\n",
    "        spearman = spearmanr(gt_mut_landscape[keep_idx], pred_mut_landscape[keep_idx])[0]\n",
    "        spearmanrs_tests[test].append(spearman)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dea8fb92cb62fbfa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:32:26.099912Z",
     "start_time": "2024-04-19T09:32:26.038336Z"
    }
   },
   "outputs": [],
   "source": [
    "spearmanrs_baselines = {}\n",
    "for model in select_models:\n",
    "    for i, row in enumerate(database_df.iterrows()):\n",
    "        dms_id = row[1][\"DMS_id\"]\n",
    "        spearman = results_df[model].loc[dms_id]\n",
    "        if model not in spearmanrs_baselines:\n",
    "            spearmanrs_baselines[model] = []\n",
    "        spearmanrs_baselines[model].append(spearman)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f1d4d243db64e77",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:34:14.074519Z",
     "start_time": "2024-04-19T09:34:14.066209Z"
    }
   },
   "outputs": [],
   "source": [
    "# get the mean of the spearman correlations\n",
    "for test in tests:\n",
    "    print(test, np.mean(spearmanrs_tests[test]))\n",
    "for model in select_models:\n",
    "    print(model, np.mean(spearmanrs_baselines[model]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d3ce5726c780490",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T13:24:30.491313Z",
     "start_time": "2024-04-19T13:24:30.486405Z"
    }
   },
   "outputs": [],
   "source": [
    "for k, v in zip(database_df[\"DMS_id\"], spearmanrs_tests[\"mamba_ft_8000\"]):\n",
    "    print(k, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f02826ca99691d96",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T09:33:11.962581Z",
     "start_time": "2024-04-19T09:33:10.036971Z"
    }
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 35))\n",
    "\n",
    "for i, test in enumerate(tests):\n",
    "    plt.scatter(spearmanrs_tests[test], range(len(database_df)), color=colors[i], marker=markers[i], label=test)\n",
    "for i, model in enumerate(select_models):\n",
    "    plt.scatter(spearmanrs_baselines[model], range(len(database_df)), color=\"black\", marker=\"x\", label=model)\n",
    "plt.legend()\n",
    "# add grid\n",
    "plt.grid(axis='x')\n",
    "plt.yticks(range(len(database_df)), database_df[\"DMS_id\"])\n",
    "# add ylabels\n",
    "plt.ylabel(\"DMS_id\")\n",
    "plt.yticks(range(len(database_df)), database_df[\"DMS_id\"])\n",
    "plt.xlabel(\"Spearman correlation\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ebda60afb4d2deb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T08:47:38.271125Z",
     "start_time": "2024-04-19T08:47:38.248713Z"
    }
   },
   "outputs": [],
   "source": [
    "database_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5303528ca861c4cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_performances_spearmanr = pd.read_csv(\"/data2/malbrank/protein_gym/ProteinGym/Detailed_performance_files/Substitutions/Spearman/all_models_substitutions_Spearman_DMS_level.csv\", index_col = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3ed4fa356647331",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-10T20:22:36.797255Z",
     "start_time": "2024-04-10T20:22:36.788209Z"
    }
   },
   "outputs": [],
   "source": [
    "# list all files in /data2/malbrank/protein_gym/DMS_msa_files by order of size\n",
    "import os\n",
    "msa_files = os.listdir(\"/data2/malbrank/protein_gym/DMS_msa_files\")\n",
    "msa_files = sorted(msa_files, key=lambda x: os.path.getsize(f\"/data2/malbrank/protein_gym/DMS_msa_files/{x}\"))\n",
    "\n",
    "prot_to_msa = {}\n",
    "for filename in msa_files:\n",
    "    prot = \"_\".join(filename.split(\"_\")[:2])\n",
    "    prot_to_msa[prot] = filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7faeee013b1656f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "row"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
