{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Investigate required values\n",
    "\n",
    "Investigate which values are needed to predict in order to fully specify a protein correctly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "import functools\n",
    "import itertools\n",
    "import multiprocessing\n",
    "import warnings\n",
    "import importlib\n",
    "import tempfile\n",
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "from scipy import stats\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "import torch\n",
    "\n",
    "SRC_DIR = os.path.join(os.path.dirname(os.getcwd()), \"foldingdiff\")\n",
    "assert os.path.isdir(SRC_DIR)\n",
    "sys.path.append(SRC_DIR)\n",
    "import datasets\n",
    "import angles_and_coords as ac\n",
    "import nerf\n",
    "import tmalign  # So we can compare structural similarity\n",
    "import plotting\n",
    "\n",
    "datasets.LOCAL_DATA_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(datasets)\n",
    "\n",
    "train_dset = datasets.CathCanonicalAnglesOnlyDataset(\n",
    "    split='train',\n",
    "    zero_center=True,\n",
    "    pad=128,\n",
    "    trim_strategy='randomcrop',\n",
    ")\n",
    "len(train_dset.filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dset_noised = datasets.NoisedAnglesDataset(\n",
    "    train_dset,\n",
    "    timesteps=1000,\n",
    ")\n",
    "str(train_dset_noised)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(ac)\n",
    "importlib.reload(nerf)\n",
    "\n",
    "# https://arxiv.org/pdf/2205.04676.pdf\n",
    "# N:CA:C = tau\n",
    "\n",
    "# Full spec should be angles_to_use=[\"N:CA:C\", \"CA:C:1N\", \"C:1N:1CA\", \"phi\", \"psi\", \"omega\"], dists_to_use=[\"N:CA\", \"CA:C\", \"C:1N\"]\n",
    "\n",
    "def test_consistency(fname:str, angles_to_use=[\"phi\", \"psi\", \"omega\", \"tau\", \"CA:C:1N\", \"C:1N:1CA\"], dists_to_use=[\"0C:1N\", \"N:CA\", \"CA:C\"], visualize:bool=False):\n",
    "    \"\"\"Test the consistency of reconstructing a pdb file\"\"\"\n",
    "    # Create the internal coordinates\n",
    "    angles = ac.canonical_distances_and_dihedrals(fname, distances=dists_to_use, angles=angles_to_use)\n",
    "    if angles is None:\n",
    "        return np.nan, None\n",
    "    with tempfile.TemporaryDirectory() as dirname:\n",
    "        out_fname = os.path.join(dirname, \"rebuilt_\" + os.path.basename(fname))\n",
    "        # rebuilt = ac.create_new_chain(\n",
    "        #     out_fname, angles,\n",
    "        #     angles_to_set=angles_to_use, distances_to_set=dists_to_use\n",
    "        # )\n",
    "        out_fname_written = ac.create_new_chain_nerf(out_fname, angles, angles_to_set=angles_to_use, dists_to_set=dists_to_use)\n",
    "        if not out_fname_written:  # Failed on the way, should have returned empty string \"\"\n",
    "            return np.nan, None\n",
    "        score = tmalign.run_tmalign(fname, out_fname)\n",
    "        # angles_new = ac.canonical_distances_and_dihedrals(out_fname, distances=dists_to_use, angles=angles_to_use)\n",
    "        view = None\n",
    "        if visualize:\n",
    "            view = view_pdb(out_fname)\n",
    "    return score, view\n",
    "\n",
    "score, view = test_consistency(train_dset.filenames[40], visualize=False)\n",
    "print(score)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate a folding visual example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from curses import savetty\n",
    "\n",
    "\n",
    "importlib.reload(ac)\n",
    "\n",
    "def view_pdb(*args, **kwargs) -> None:\n",
    "    return None\n",
    "\n",
    "def visualize_training_example(i: int = 0, timestep: int = 0, struct_pdb: str = \"\"):\n",
    "    \"\"\"Visualize the training example\"\"\"\n",
    "    # Keys ['angles', 'attn_mask', 'position_ids', 'corrupted', 't', 'known_noise']\n",
    "    item = train_dset_noised.__getitem__(i, use_t_val=timestep, ignore_zero_center=True)\n",
    "    assert item['t'].item() == timestep\n",
    "    attn_idx = torch.where(item['attn_mask'])[0]\n",
    "\n",
    "    angles = item['corrupted'][attn_idx].cpu().numpy()\n",
    "    angles_df = pd.DataFrame(angles, columns=train_dset_noised.feature_names['angles'])\n",
    "    if not struct_pdb:\n",
    "        with tempfile.TemporaryDirectory() as tempdir:\n",
    "            fname = ac.create_new_chain_nerf(os.path.join(tempdir, \"temp.pdb\"), angles_df)\n",
    "            return view_pdb(fname), angles\n",
    "    else:\n",
    "        fname = ac.create_new_chain_nerf(struct_pdb, angles_df)\n",
    "        return view_pdb(fname), angles\n",
    "\n",
    "noised_view, noised_angles = visualize_training_example(\n",
    "    i=56, timestep=999,\n",
    "    struct_pdb=\"../plots/pdb_structures/noising_visualization/fully_noised.pdb\",\n",
    ")\n",
    "noised_view"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "noised_angles.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_matrix(mat, fname:str=\"\"):\n",
    "    \"\"\"Visualize the matrix\"\"\"\n",
    "    fig, ax = plt.subplots(dpi=300)\n",
    "    ax.imshow(mat.T, aspect=3.5)\n",
    "    ax.set(\n",
    "        xticklabels=[],\n",
    "        xticks=[],\n",
    "        yticklabels=[],\n",
    "        yticks=[],\n",
    "    )\n",
    "    if fname:\n",
    "        fig.savefig(fname, bbox_inches='tight')\n",
    "    return fig\n",
    "\n",
    "show_matrix(\n",
    "    noised_angles,\n",
    "    \"../plots/pdb_structures/noising_visualization/fully_noised_values.pdf\"\n",
    ").show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_view, clean_angles = visualize_training_example(\n",
    "    i=56, timestep=0,\n",
    "    struct_pdb=\"../plots/pdb_structures/noising_visualization/clean.pdb\",\n",
    ")\n",
    "clean_view"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_matrix(\n",
    "    clean_angles,\n",
    "    \"../plots/pdb_structures/noising_visualization/clean_values.pdf\"\n",
    ").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Look at reconstruction within training examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Look at the full defined set, or 4 angles we currently fit, or 3 dihedrals\n",
    "\n",
    "angle_sets_to_evaluate = [\n",
    "    [\"phi\", \"psi\", \"omega\", \"tau\", \"CA:C:1N\", \"C:1N:1CA\"],\n",
    "    [\"phi\", \"psi\", \"omega\", \"tau\"], \n",
    "    [\"phi\", \"psi\", \"omega\"],\n",
    "]\n",
    "\n",
    "# Look at either all the distances or none of them\n",
    "dist_sets_to_evaluate = [\n",
    "    [\"0C:1N\", \"N:CA\", \"CA:C\"],\n",
    "    []\n",
    "]\n",
    "\n",
    "# Combinatorially look at these angle sets\n",
    "def evaluate_angle_set_parallel(filenames, angles, dists):\n",
    "    warnings.filterwarnings('ignore', '.*elements were guessed from atom_name.*')\n",
    "    warnings.filterwarnings('ignore', '.*invalid value encountered in true_div.*')\n",
    "    pfunc = functools.partial(\n",
    "        test_consistency,\n",
    "        angles_to_use=angles,\n",
    "        dists_to_use=dists,\n",
    "    )\n",
    "    pool = multiprocessing.Pool(multiprocessing.cpu_count())\n",
    "    tm_scores = np.array([score for score, _view in pool.map(pfunc, filenames, chunksize=20)])\n",
    "    pool.close()\n",
    "    pool.join()\n",
    "    return tm_scores\n",
    "\n",
    "human_readable_angle_combos = [\n",
    "    r\"dihedrals, angles, and distances\",\n",
    "    r\"dihedrals and angles\",\n",
    "    r\"dihedrals, $\\theta_1$, and distances\",\n",
    "    r\"dihedrals, $\\theta_1$\",\n",
    "    r\"dihedrals, distances\",\n",
    "    r\"dihedrals only\"\n",
    "]\n",
    "per_angle_dist_results = {}\n",
    "for human_name, (a, d) in zip(human_readable_angle_combos, itertools.product(angle_sets_to_evaluate, dist_sets_to_evaluate)):\n",
    "    print(human_name, a, d)\n",
    "    per_angle_dist_results[human_name] = evaluate_angle_set_parallel(train_dset.filenames[:5000], angles=a, dists=d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot these results\n",
    "reconst_tmscore_dir = plotting.PLOT_DIR / \"reconstruction_angles_coords\"\n",
    "if not reconst_tmscore_dir.is_dir():\n",
    "    os.makedirs(reconst_tmscore_dir)\n",
    "\n",
    "fig, ax = plt.subplots(dpi=300)\n",
    "for i, (k, v) in enumerate(per_angle_dist_results.items()):\n",
    "    c = [\n",
    "        'tab:blue',\n",
    "        'tab:orange',\n",
    "        'tab:green',\n",
    "        'tab:red',\n",
    "        'tab:purple',\n",
    "        'tab:brown',\n",
    "    ][i]\n",
    "    mean = np.nanmean(v)\n",
    "    std = np.nanstd(v)\n",
    "    sns.histplot(v, bins=40, stat='proportion', ax=ax, label=f\"{k} - ${mean:.4f} \\pm {std:.4f}$\", alpha=0.5, color=c)\n",
    "ax.axvline(0.5, color='grey', alpha=0.3, linestyle='--')\n",
    "ax.legend(prop={'size': 6})\n",
    "ax.set(\n",
    "    xlabel=\"Reconstruction TM-score\",\n",
    ")\n",
    "ax.set_title(\"Reconstruction across different angle/distance sets, training set subset\", fontsize=10)\n",
    "fig.savefig(reconst_tmscore_dir / \"reconstruction_distributions.pdf\", bbox_inches='tight')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reconstruction_scores = pd.DataFrame(\n",
    "    per_angle_dist_results\n",
    ")\n",
    "reconstruction_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.nanmean(reconstruction_scores, axis=0), np.nanstd(reconstruction_scores, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(dpi=300)\n",
    "sns.boxplot(\n",
    "    data=reconstruction_scores,\n",
    "    whis=(5, 95),\n",
    "    color=sns.color_palette()[0],\n",
    "    fliersize=2,\n",
    "    flierprops={\"marker\": \"o\", \"alpha\": 0.3, \"color\": \"gray\"},\n",
    ")\n",
    "ax.set_xticks(list(range(6)))\n",
    "ax.set_xticklabels(reconstruction_scores.columns, rotation=45, ha='right', size=8)\n",
    "ax.axhline(0.5, linestyle='--', color='grey', alpha=0.5)\n",
    "ax.set(\n",
    "    ylabel=\"TM-score of reconstruction\",\n",
    "    title=\"Reconstruction accuracy, combinations of angles/distances\",\n",
    ")\n",
    "fig.savefig(reconst_tmscore_dir / \"reconstruction_barplot.pdf\", bbox_inches='tight')\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dset_seq_lens = np.array([train_dset[i]['lengths'].item() for i in range(5000)])\n",
    "train_dset_seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(dpi=300)\n",
    "sns.lineplot(\n",
    "    x=train_dset_seq_lens,\n",
    "    y=reconstruction_scores['dihedrals and angles'],\n",
    "    ax=ax,\n",
    ")\n",
    "ax.set(\n",
    "    xlabel=\"Structure length\",\n",
    "    ylabel=\"TM-score of reconstruction\",\n",
    "    title=\"Structure reconstruction with dihedrals and angles\",\n",
    ")\n",
    "fig.savefig(\n",
    "    reconst_tmscore_dir / \"reconstruction_by_len.pdf\", bbox_inches='tight'\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dset_seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nonnan_idx = np.where(~np.isnan(reconstruction_scores['dihedrals and angles'].values))\n",
    "stats.spearmanr(train_dset_seq_lens[nonnan_idx], reconstruction_scores['dihedrals and angles'].values[nonnan_idx])"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
