{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, Optional, Any\n",
    "import os\n",
    "import pickle\n",
    "import collections\n",
    "from pathlib import Path\n",
    "\n",
    "import logging\n",
    "logging.getLogger('fontTools').setLevel(logging.ERROR)  # Only show errors, not warnings or info\n",
    "\n",
    "import mdtraj as md\n",
    "import numpy as np\n",
    "import scipy.stats\n",
    "import pyemma\n",
    "import pandas as pd\n",
    "import lovelyplots\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import matplotlib.colors\n",
    "mpl.rcParams['axes.formatter.useoffset'] = False\n",
    "mpl.rcParams['axes.formatter.limits'] = (-10000, 10000)  # Controls range before scientific notation is used\n",
    "plt.style.use(\"ipynb\")\n",
    "\n",
    "# Comment if you want to see figures in notebook.\n",
    "# plt.use('agg')\n",
    "\n",
    "import pyemma_helper\n",
    "from jamun import utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Paths\n",
    "\n",
    "Load the results for the corresponding experiment, trajectories and reference trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = \"/path/to/the/jamun-ana/\"\n",
    "print(f\"Results directory: {results_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# experiment = \"Our_2AA\"\n",
    "# experiment = \"Our_5AA\"\n",
    "experiment = \"Timewarp_2AA\"\n",
    "# experiment = \"Timewarp_4AA\"\n",
    "# experiment = \"MDGen_4AA\"\n",
    "# experiment = \"MDGen_4AA_new\"\n",
    "\n",
    "runs_df = pd.read_csv(\"sample_runs.csv\")\n",
    "if experiment not in runs_df[\"experiment\"].values:\n",
    "    raise ValueError(f\"Experiment {experiment} not found in runs_df\")\n",
    "\n",
    "traj_name = runs_df.loc[runs_df[\"experiment\"] == experiment, \"trajectory\"].values[0]\n",
    "ref_traj_name = runs_df.loc[runs_df[\"experiment\"] == experiment, \"reference\"].values[0]\n",
    "\n",
    "print(f\"Experiment: {experiment}\")\n",
    "print(f\"Trajectory: {traj_name}\")\n",
    "print(f\"Reference: {ref_traj_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = os.path.join(\"/path/to/the/jamun-plots\", experiment, traj_name, f\"ref={ref_traj_name}\")\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "print(f\"Plots will be saved to {output_dir}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load All Trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_path = os.path.join(\n",
    "    results_dir, experiment, traj_name, f\"ref={ref_traj_name}\"\n",
    ")\n",
    "print(f\"Searching for results in {results_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_results_path(results_path: str) -> pd.DataFrame:\n",
    "    \"\"\"Loads the results as a pandas DataFrame.\"\"\"\n",
    "\n",
    "    # Split the path to get the trajectory and reference names\n",
    "    parts = Path(results_path).parts\n",
    "    traj_name = parts[-2]  # The second last part is the trajectory name\n",
    "    ref_traj_name = parts[-1]  # The last part is the reference trajectory name\n",
    "    if not ref_traj_name.startswith(\"ref=\"):\n",
    "        raise ValueError(f\"Expected reference trajectory name to start with 'ref=', got {ref_traj_name}\")\n",
    "    ref_traj_name = ref_traj_name[len(\"ref=\"):]\n",
    "\n",
    "    results = []\n",
    "    for results_file in sorted(os.listdir(results_path)):\n",
    "        peptide, ext = os.path.splitext(results_file)\n",
    "        if ext != \".pkl\":\n",
    "            continue\n",
    "\n",
    "        with open(os.path.join(results_path, results_file), \"rb\") as f:\n",
    "            all_results = pickle.load(f)\n",
    "\n",
    "        results.append({\n",
    "            \"traj\": traj_name,\n",
    "            \"ref_traj\": ref_traj_name,\n",
    "            \"peptide\": peptide,\n",
    "            \"results\": all_results[\"results\"],\n",
    "            \"args\": all_results[\"args\"],\n",
    "        })\n",
    "    return pd.DataFrame(results)\n",
    "\n",
    "\n",
    "results_df = load_results_path(results_path)\n",
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Also, load TBG results for the same experiment.\n",
    "def add_recursively(original_results, tbg_results, add_key):\n",
    "    if not isinstance(original_results, dict) or not isinstance(tbg_results, dict):\n",
    "        return\n",
    "\n",
    "    if add_key in original_results:\n",
    "        raise ValueError(f\"Key '{add_key}' already exists in original_results\")\n",
    "\n",
    "    if \"traj\" in original_results:\n",
    "        original_results[add_key] = tbg_results[\"traj\"]\n",
    "        return\n",
    "\n",
    "    for key in original_results:\n",
    "        add_recursively(original_results[key], tbg_results[key], add_key)\n",
    "\n",
    "\n",
    "        \n",
    "if experiment == \"Timewarp_2AA\":\n",
    "    tbg_results_path = os.path.join(\n",
    "        results_dir, \"Timewarp_2AA\", \"TBG\", f\"ref={ref_traj_name}\"\n",
    "    )\n",
    "\n",
    "    tbg_results_df = load_results_path(tbg_results_path)\n",
    "    tbg_results_df = tbg_results_df.reset_index(drop=True)\n",
    "\n",
    "    # Add tbg_results_df to the main results_df, by adding a key \"TBG\" to the result\n",
    "    for i, row in results_df.iterrows():\n",
    "        peptide = row[\"peptide\"]\n",
    "        tbg_row = tbg_results_df[tbg_results_df[\"peptide\"] == peptide].iloc[0]\n",
    "\n",
    "        original_results = row[\"results\"]\n",
    "        tbg_results = tbg_row[\"results\"]\n",
    "\n",
    "        add_recursively(original_results, tbg_results, add_key=\"TBG\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter based on peptide names.\n",
    "if \"5AA\" in experiment:\n",
    "    peptides = [\"KTYDI\", \"NRLCQ\", \"VWSPF\"]\n",
    "    peptides = [\"uncapped_\" + peptide for peptide in peptides]\n",
    "    sampled_results_df = results_df[results_df[\"peptide\"].isin(peptides)]\n",
    "\n",
    "else:\n",
    "    # Sample 4 random peptides\n",
    "    sampled_results_df = results_df.sample(n=min(len(results_df), 4), random_state=42)\n",
    "\n",
    "\n",
    "sampled_results_df = sampled_results_df.reset_index(drop=True)\n",
    "sampled_results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_traj_name(results_traj_name: str) -> str:\n",
    "    \"\"\"Format the trajectory name for plotting\"\"\"\n",
    "    return {\n",
    "        \"traj\": traj_name,\n",
    "        \"ref_traj\": \"Reference\",\n",
    "        \"ref_traj_10x\": \"Reference\\n(10x shorter)\",\n",
    "        \"ref_traj_100x\": \"Reference\\n(100x shorter)\",\n",
    "    }[results_traj_name]\n",
    "\n",
    "def format_quantity(quantity: str) -> str:\n",
    "    \"\"\"Format the quantity for plotting.\"\"\"\n",
    "    return {\n",
    "        \"JSD_backbone_torsions\": \"Backbone Torsions\",\n",
    "        \"JSD_sidechain_torsions\": \"Sidechain Torsions\",\n",
    "        \"JSD_all_torsions\": \"All Torsions\",\n",
    "        \"JSD_TICA-0\": \"TICA-0 Projections\",\n",
    "        \"JSD_TICA-0,1\": \"TICA-0,1 Projections\",\n",
    "        \"JSD_metastable_probs\": \"Metastable State Probabilities\",\n",
    "    }[quantity]\n",
    "\n",
    "def format_peptide_name(peptide: str) -> str:\n",
    "    \"\"\"Formats the peptide name for plotting.\"\"\"\n",
    "    if peptide.startswith(\"uncapped_\"):\n",
    "        peptide = peptide[len(\"uncapped_\"):]\n",
    "    if peptide.startswith(\"capped_\"):\n",
    "        peptide = peptide[len(\"capped_\"):]\n",
    "    if \"_\" in peptide:\n",
    "        return peptide.replace(\"_\", \"-\")\n",
    "    return utils.convert_to_one_letter_codes(peptide)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ramachandran Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ramachandran_contour(results: Dict[str, Any], dihedral_index: int, ax: Optional[plt.Axes] = None) -> plt.Axes:\n",
    "    \"\"\"Plots the Ramachandran contour plot of a trajectory.\"\"\"\n",
    "\n",
    "    if ax is None:\n",
    "        _, ax = plt.subplots(figsize=(10, 10))\n",
    "\n",
    "    pmf, xedges, yedges = results[\"pmf\"], results[\"xedges\"], results[\"yedges\"]\n",
    "    im = ax.contourf(xedges[:-1], yedges[:-1], pmf[dihedral_index], cmap=\"viridis\", levels=50)\n",
    "    contour = ax.contour(xedges[:-1], yedges[:-1], pmf[dihedral_index], colors=\"white\", linestyles=\"solid\", levels=10, linewidths=0.25)\n",
    "\n",
    "    ax.set_aspect(\"equal\", adjustable=\"box\")\n",
    "    ax.set_xlabel(\"$\\phi$\")\n",
    "    ax.set_ylabel(\"$\\psi$\")\n",
    "\n",
    "    tick_eps = 0.1\n",
    "    ticks = [-np.pi + tick_eps, -np.pi / 2, 0, np.pi / 2, np.pi - tick_eps]\n",
    "    tick_labels = [\"$-\\pi$\", \"$-\\pi/2$\", \"$0$\", \"$\\pi/2$\", \"$\\pi$\"]\n",
    "    ax.set_xticks(ticks, tick_labels)\n",
    "    ax.set_yticks(ticks, tick_labels)\n",
    "    return ax\n",
    "\n",
    "\n",
    "def get_num_dihedrals(experiment: str, pmf_type: str) -> int:\n",
    "    # \"internal\" for psi_2 - phi_2, psi_3 - phi_3, etc.\n",
    "    # \"all\" for psi_1 - phi_2, psi_2 - phi_3, etc.\n",
    "    if pmf_type not in [\"internal\", \"all\"]:\n",
    "        raise ValueError(f\"Invalid pmf_type: {pmf_type}\")\n",
    "\n",
    "    if experiment == \"Our_2AA\":\n",
    "        num_dihedrals = 1\n",
    "    elif \"2AA\" in experiment:\n",
    "        num_dihedrals = 0\n",
    "    elif \"4AA\" in experiment:\n",
    "        num_dihedrals = 2\n",
    "    elif \"5AA\" in experiment:\n",
    "        num_dihedrals = 3\n",
    "\n",
    "    if pmf_type == \"all\":\n",
    "        num_dihedrals += 1\n",
    "\n",
    "    return num_dihedrals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Ramachandran Plots against Reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pmf_type = \"all\"\n",
    "num_dihedrals = get_num_dihedrals(experiment, pmf_type)\n",
    "label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5\n",
    "\n",
    "\n",
    "ones = list(np.ones(num_dihedrals))\n",
    "fig, axs = plt.subplots(\n",
    "    len(sampled_results_df), 2 * num_dihedrals+1,\n",
    "    figsize=(8 * num_dihedrals, 4 * len(sampled_results_df)),\n",
    "    gridspec_kw={\n",
    "        'width_ratios': ones + [0.1] + ones,\n",
    "        'hspace': 0.1\n",
    "    }\n",
    ")\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "\n",
    "    for j in range(num_dihedrals):\n",
    "        plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"ref_traj\"][f\"pmf_{pmf_type}\"], j, axs[i, j])\n",
    "        plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"traj\"][f\"pmf_{pmf_type}\"], j, axs[i, j + num_dihedrals+1])\n",
    "\n",
    "    # Add labels.\n",
    "    ax_index = num_dihedrals // 2\n",
    "    axs[0, ax_index].text(\n",
    "        label_offset,\n",
    "        1.1,\n",
    "        format_traj_name(\"ref_traj\"),\n",
    "        horizontalalignment=\"center\",\n",
    "        verticalalignment=\"center\",\n",
    "        transform=axs[0, ax_index].transAxes,\n",
    "        fontsize=22,\n",
    "    )\n",
    "\n",
    "    ax_index = num_dihedrals // 2 + num_dihedrals + 1\n",
    "    axs[0, ax_index].text(\n",
    "        label_offset,\n",
    "        1.1,\n",
    "        format_traj_name(\"traj\"),\n",
    "        horizontalalignment=\"center\",\n",
    "        verticalalignment=\"center\",\n",
    "        transform=axs[0, ax_index].transAxes,\n",
    "        fontsize=22,\n",
    "    )\n",
    "\n",
    "    ax_index = -1\n",
    "    axs[i, ax_index].text(\n",
    "        1.1,\n",
    "        0.5,\n",
    "        format_peptide_name(peptide),\n",
    "        rotation=90,\n",
    "        verticalalignment=\"center\",\n",
    "        horizontalalignment=\"center\",\n",
    "        transform=axs[i, ax_index].transAxes,\n",
    "        fontsize=18,\n",
    "    )\n",
    "\n",
    "\n",
    "    axs[i, num_dihedrals].axis(\"off\")\n",
    "    \n",
    "    if i != len(axs) - 1:\n",
    "        for j in range(len(axs[i])):\n",
    "            axs[i, j].set_xticks([])\n",
    "            axs[i, j].set_xlabel(\"\")\n",
    "\n",
    "    for j in range(1,len(axs[i])):\n",
    "        axs[i, j].set_yticks([])\n",
    "        axs[i, j].set_ylabel(\"\")\n",
    "        \n",
    "\n",
    "plt.subplots_adjust(hspace=0.06, wspace=0.04)\n",
    "plt.savefig(os.path.join(output_dir, \"ramachandran_contours.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Ramachandran Plots against Reference (Shortened)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pmf_type = \"all\"\n",
    "num_dihedrals = get_num_dihedrals(experiment, pmf_type)\n",
    "label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5\n",
    "\n",
    "ones = list(np.ones(num_dihedrals))\n",
    "fig, axs = plt.subplots(len(sampled_results_df), 3 * num_dihedrals + 2, figsize=(12 * num_dihedrals, 4 * len(sampled_results_df)),gridspec_kw={'width_ratios': ones+[0.1]+ones+[0.1]+ones,'hspace':0.1})\n",
    "\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "\n",
    "    for j in range(num_dihedrals):\n",
    "        plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"ref_traj\"][f\"pmf_{pmf_type}\"], j, axs[i, j])\n",
    "        plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"traj\"][f\"pmf_{pmf_type}\"], j, axs[i, j + num_dihedrals + 1])\n",
    "        plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"ref_traj_100x\"][f\"pmf_{pmf_type}\"], j, axs[i, j + 2 * num_dihedrals + 2])\n",
    "\n",
    "    # Add labels.\n",
    "    ax_index = num_dihedrals // 2\n",
    "    axs[0, ax_index].text(\n",
    "        label_offset,\n",
    "        1.1,\n",
    "        format_traj_name(\"ref_traj\"),\n",
    "        horizontalalignment=\"center\",\n",
    "        verticalalignment=\"center\",\n",
    "        transform=axs[0, ax_index].transAxes,\n",
    "        fontsize=22,\n",
    "    )\n",
    "\n",
    "    ax_index = num_dihedrals // 2 + num_dihedrals + 1\n",
    "    axs[0, ax_index].text(\n",
    "        label_offset,\n",
    "        1.1,\n",
    "        format_traj_name(\"traj\"),\n",
    "        horizontalalignment=\"center\",\n",
    "        verticalalignment=\"center\",\n",
    "        transform=axs[0, ax_index].transAxes,\n",
    "        fontsize=22,\n",
    "    )\n",
    "    \n",
    "    ax_index = num_dihedrals // 2 + 2 * num_dihedrals + 2\n",
    "    axs[0, ax_index].text(\n",
    "        label_offset,\n",
    "        1.1,\n",
    "        format_traj_name(\"ref_traj_100x\"),\n",
    "        horizontalalignment=\"center\",\n",
    "        verticalalignment=\"center\",\n",
    "        transform=axs[0, ax_index].transAxes,\n",
    "        fontsize=22,\n",
    "    )\n",
    "\n",
    "    ax_index = -1\n",
    "    axs[i, ax_index].text(\n",
    "        1.1,\n",
    "        0.5,\n",
    "        format_peptide_name(peptide),\n",
    "        rotation=90,\n",
    "        verticalalignment=\"center\",\n",
    "        horizontalalignment=\"center\",\n",
    "        transform=axs[i, ax_index].transAxes,\n",
    "        fontsize=18,\n",
    "    )\n",
    "\n",
    "    axs[i, num_dihedrals].axis(\"off\")\n",
    "    axs[i, 2 * num_dihedrals + 1].axis(\"off\")\n",
    "    \n",
    "    if i != len(axs) - 1:\n",
    "        for j in range(len(axs[i])):\n",
    "            axs[i, j].set_xticks([])\n",
    "            axs[i, j].set_xlabel(\"\")\n",
    "\n",
    "    for j in range(1,len(axs[i])):\n",
    "        axs[i, j].set_yticks([])\n",
    "        axs[i, j].set_ylabel(\"\")\n",
    "        \n",
    "\n",
    "plt.subplots_adjust(hspace=0.06, wspace=0.04)\n",
    "plt.savefig(os.path.join(output_dir, \"ramachandran_contours_with_shortened_reference.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For experiment \"Timewarp_2AA\", plot the TBG results as well.\n",
    "if experiment == \"Timewarp_2AA\":\n",
    "    pmf_type = \"all\"\n",
    "    num_dihedrals = get_num_dihedrals(experiment, pmf_type)\n",
    "    label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5\n",
    "\n",
    "    ones = list(np.ones(num_dihedrals))\n",
    "    fig, axs = plt.subplots(len(sampled_results_df), 4 * num_dihedrals + 3, figsize=(12 * num_dihedrals, 4 * len(sampled_results_df)),gridspec_kw={'width_ratios': ones+[0.1]+ones+[0.1]+ones+[0.1]+ones,'hspace':0.1})\n",
    "\n",
    "    for i, row in sampled_results_df.iterrows():\n",
    "        peptide = row[\"peptide\"]\n",
    "\n",
    "        for j in range(num_dihedrals):\n",
    "            plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"ref_traj\"][f\"pmf_{pmf_type}\"], j, axs[i, j])\n",
    "            plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"traj\"][f\"pmf_{pmf_type}\"], j, axs[i, j + num_dihedrals + 1])\n",
    "            plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"TBG\"][f\"pmf_{pmf_type}\"], j, axs[i, j + 2 * num_dihedrals + 2])\n",
    "            plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"ref_traj_100x\"][f\"pmf_{pmf_type}\"], j, axs[i, j + 3 * num_dihedrals + 3])\n",
    "\n",
    "        # Add labels.\n",
    "        ax_index = num_dihedrals // 2\n",
    "        axs[0, ax_index].text(\n",
    "            label_offset,\n",
    "            1.2,\n",
    "            format_traj_name(\"ref_traj\"),\n",
    "            horizontalalignment=\"center\",\n",
    "            verticalalignment=\"center\",\n",
    "            transform=axs[0, ax_index].transAxes,\n",
    "            fontsize=22,\n",
    "        )\n",
    "\n",
    "        ax_index = num_dihedrals // 2 + num_dihedrals + 1\n",
    "        axs[0, ax_index].text(\n",
    "            label_offset,\n",
    "            1.2,\n",
    "            format_traj_name(\"traj\"),\n",
    "            horizontalalignment=\"center\",\n",
    "            verticalalignment=\"center\",\n",
    "            transform=axs[0, ax_index].transAxes,\n",
    "            fontsize=22,\n",
    "        )\n",
    "        \n",
    "        ax_index = num_dihedrals // 2 + 2 * num_dihedrals + 2\n",
    "        axs[0, ax_index].text(\n",
    "            label_offset,\n",
    "            1.2,\n",
    "            \"TBG\",\n",
    "            horizontalalignment=\"center\",\n",
    "            verticalalignment=\"center\",\n",
    "            transform=axs[0, ax_index].transAxes,\n",
    "            fontsize=22,\n",
    "        )\n",
    "\n",
    "        ax_index = num_dihedrals // 2 + 3 * num_dihedrals + 3\n",
    "        axs[0, ax_index].text(\n",
    "            label_offset,\n",
    "            1.2,\n",
    "            format_traj_name(\"ref_traj_100x\"),\n",
    "            horizontalalignment=\"center\",\n",
    "            verticalalignment=\"center\",\n",
    "            transform=axs[0, ax_index].transAxes,\n",
    "            fontsize=22,\n",
    "        )\n",
    "\n",
    "        ax_index = -1\n",
    "        axs[i, ax_index].text(\n",
    "            1.1,\n",
    "            0.5,\n",
    "            format_peptide_name(peptide),\n",
    "            rotation=90,\n",
    "            verticalalignment=\"center\",\n",
    "            horizontalalignment=\"center\",\n",
    "            transform=axs[i, ax_index].transAxes,\n",
    "            fontsize=18,\n",
    "        )\n",
    "\n",
    "        axs[i, num_dihedrals].axis(\"off\")\n",
    "        axs[i, 2 * num_dihedrals + 1].axis(\"off\")\n",
    "        axs[i, 3 * num_dihedrals + 2].axis(\"off\")\n",
    "\n",
    "        if i != len(axs) - 1:\n",
    "            for j in range(len(axs[i])):\n",
    "                axs[i, j].set_xticks([])\n",
    "                axs[i, j].set_xlabel(\"\")\n",
    "\n",
    "        for j in range(1,len(axs[i])):\n",
    "            axs[i, j].set_yticks([])\n",
    "            axs[i, j].set_ylabel(\"\")\n",
    "            \n",
    "\n",
    "    plt.subplots_adjust(hspace=0.06, wspace=0.04)\n",
    "    plt.savefig(os.path.join(output_dir, \"ramachandran_contours_with_TBG.pdf\"), dpi=300)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Ramachandran Plots for a Single Peptide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pmf_type = \"all\"\n",
    "num_dihedrals = get_num_dihedrals(experiment, pmf_type)\n",
    "label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(2, num_dihedrals, figsize=(4 * num_dihedrals, 8), squeeze=False)\n",
    "peptide = next(iter(sampled_results_df[\"peptide\"]))\n",
    "for j in range(num_dihedrals):\n",
    "    plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"ref_traj\"][f\"pmf_{pmf_type}\"], j, axs[0, j])\n",
    "    plot_ramachandran_contour(row[\"results\"][\"PMFs\"][\"traj\"][f\"pmf_{pmf_type}\"], j, axs[1, j])\n",
    "\n",
    "for i in range(2):\n",
    "    for j in range(1,len(axs[i])):\n",
    "        axs[i, j].set_yticks([])\n",
    "        axs[i, j].set_ylabel(\"\")\n",
    "\n",
    "for j in range(len(axs[0])):\n",
    "    axs[0, j].set_xticks([])\n",
    "    axs[0, j].set_xlabel(\"\")\n",
    "    \n",
    "# Add labels.\n",
    "axs[0, -1].text(\n",
    "    1.1,\n",
    "    0.5,\n",
    "    format_traj_name(\"ref_traj\"),\n",
    "    rotation=90,\n",
    "    verticalalignment=\"center\",\n",
    "    horizontalalignment=\"center\",\n",
    "    transform=axs[0, -1].transAxes,\n",
    ")\n",
    "axs[1, -1].text(\n",
    "    1.1,\n",
    "    0.5,\n",
    "    format_traj_name(\"traj\"),\n",
    "    rotation=90,\n",
    "    verticalalignment=\"center\",\n",
    "    horizontalalignment=\"center\",\n",
    "    transform=axs[1, -1].transAxes,\n",
    ")\n",
    "fig.suptitle(format_peptide_name(peptide))\n",
    "plt.subplots_adjust(hspace=0.06, wspace=0.04)\n",
    "plt.savefig(os.path.join(output_dir, f\"ramachandran_contours_{peptide}.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature Histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(14, 4 * len(sampled_results_df)), squeeze=False)\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "\n",
    "    feats = row[\"results\"][\"featurization\"]\n",
    "    histograms = row[\"results\"][\"feature_histograms\"]\n",
    "\n",
    "    pyemma_helper.plot_feature_histograms(\n",
    "        histograms[\"ref_traj\"][\"torsions\"][\"histograms\"],\n",
    "        histograms[\"ref_traj\"][\"torsions\"][\"edges\"],\n",
    "        feature_labels=feats[\"ref_traj\"][\"feats\"][\"torsions\"].describe(),\n",
    "        ax=axs[i, 0]\n",
    "    )\n",
    "\n",
    "    pyemma_helper.plot_feature_histograms(\n",
    "        histograms[\"traj\"][\"torsions\"][\"histograms\"],\n",
    "        histograms[\"traj\"][\"torsions\"][\"edges\"],    \n",
    "        feature_labels=feats[\"traj\"][\"feats\"][\"torsions\"].describe(),\n",
    "        ax=axs[i, 1]\n",
    "    )\n",
    "\n",
    "    axs[i, -1].text(\n",
    "        1.1,\n",
    "        0.5,\n",
    "        format_peptide_name(peptide),\n",
    "        rotation=90,\n",
    "        verticalalignment=\"center\",\n",
    "        horizontalalignment=\"center\",\n",
    "        transform=axs[i, -1].transAxes,\n",
    "    )\n",
    "\n",
    "axs[0, 0].set_title(format_traj_name(\"ref_traj\"))\n",
    "axs[0, 1].set_title(format_traj_name(\"traj\"))\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"feature_histograms.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(14, 4 * len(sampled_results_df)), squeeze=False)\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "\n",
    "    feats = row[\"results\"][\"featurization\"]\n",
    "    histograms = row[\"results\"][\"feature_histograms\"]\n",
    "\n",
    "    num_hists = len(histograms[\"ref_traj\"][\"distances\"][\"histograms\"])\n",
    "    indices = np.random.choice(num_hists, replace=False, size=min(num_hists, 10))\n",
    "\n",
    "    pyemma_helper.plot_feature_histograms(\n",
    "        histograms[\"ref_traj\"][\"distances\"][\"histograms\"][indices],\n",
    "        histograms[\"ref_traj\"][\"distances\"][\"edges\"][indices],\n",
    "        feature_labels=[feats[\"ref_traj\"][\"feats\"][\"distances\"].describe()[i] for i in indices],\n",
    "        ax=axs[i, 0]\n",
    "    )\n",
    "\n",
    "    pyemma_helper.plot_feature_histograms(\n",
    "        histograms[\"traj\"][\"distances\"][\"histograms\"][indices],\n",
    "        histograms[\"traj\"][\"distances\"][\"edges\"][indices],    \n",
    "        feature_labels=[feats[\"traj\"][\"feats\"][\"distances\"].describe()[i] for i in indices],\n",
    "        ax=axs[i, 1]\n",
    "    )\n",
    "\n",
    "    axs[i, 1].set_xlim(axs[i, 0].get_xlim())  # Ensure both axes have the same x-limits\n",
    "    axs[i, 1].set_ylim(axs[i, 0].get_ylim())  # Ensure both axes have the same y-limits\n",
    "\n",
    "    axs[i, -1].text(\n",
    "        1.1,\n",
    "        0.5,\n",
    "        format_peptide_name(peptide),\n",
    "        rotation=90,\n",
    "        verticalalignment=\"center\",\n",
    "        horizontalalignment=\"center\",\n",
    "        transform=axs[i, -1].transAxes,\n",
    "    )\n",
    "\n",
    "axs[0, 0].set_title(format_traj_name(\"ref_traj\"))\n",
    "axs[0, 1].set_title(format_traj_name(\"traj\"))\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"distance_histograms.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Torsion Angle Decorrelation Times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_ref_decorrelation_times = {\"backbone\": [], \"sidechain\": []}\n",
    "all_traj_decorrelation_times = {\"backbone\": [], \"sidechain\": []}\n",
    "total_count = {\"backbone\": 0, \"sidechain\": 0}\n",
    "\n",
    "for i, row in results_df.iterrows():    \n",
    "    results = row[\"results\"][\"torsion_decorrelations\"]\n",
    "\n",
    "    for feat in results:\n",
    "        ref_decorrelation_time = results[feat][\"ref_traj_decorrelation_time\"]\n",
    "        traj_decorrelation_time = results[feat][\"traj_decorrelation_time\"]\n",
    "\n",
    "        if 'PHI' in feat or 'PSI' in feat:\n",
    "            torsion_type = \"backbone\"\n",
    "        elif 'CHI' in feat:\n",
    "            torsion_type = \"sidechain\"\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown torsion type: {feat}\")\n",
    "\n",
    "        total_count[torsion_type] += 1\n",
    "        \n",
    "        if np.isnan(ref_decorrelation_time) or np.isnan(traj_decorrelation_time):\n",
    "            continue\n",
    "        \n",
    "        all_ref_decorrelation_times[torsion_type].append(ref_decorrelation_time)\n",
    "        all_traj_decorrelation_times[torsion_type].append(traj_decorrelation_time)\n",
    "\n",
    "\n",
    "for key in all_ref_decorrelation_times:\n",
    "    all_ref_decorrelation_times[key] = np.asarray(all_ref_decorrelation_times[key])\n",
    "    all_traj_decorrelation_times[key] = np.asarray(all_traj_decorrelation_times[key])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Backbone Torsion Angle Decorrelation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of backbone torsions with valid decorrelation times: {len(all_ref_decorrelation_times['backbone'])} out of {total_count['backbone']}\")\n",
    "\n",
    "# Scatter plot of probabilities.\n",
    "plt.scatter(all_ref_decorrelation_times[\"backbone\"], all_traj_decorrelation_times[\"backbone\"], alpha=0.3, edgecolors=\"none\", color='tab:blue')\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xlabel(format_traj_name(\"ref_traj\"))\n",
    "plt.ylabel(format_traj_name(\"traj\"))\n",
    "plt.title(\"Decorrelation Times of Backbone Torsions\")\n",
    "\n",
    "# Fit line.\n",
    "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n",
    "    np.log(all_ref_decorrelation_times[\"backbone\"]), np.log(all_traj_decorrelation_times[\"backbone\"])\n",
    ")\n",
    "\n",
    "# # Create x points for line.\n",
    "# x_line = np.array([np.percentile(all_ref_decorrelation_times[\"backbone\"], 5), np.percentile(all_ref_decorrelation_times[\"backbone\"], 95)])\n",
    "# log_x_line = np.log(x_line)\n",
    "# log_y_line = slope * log_x_line + intercept\n",
    "\n",
    "# # Transform back to original scale for plotting\n",
    "# y_line = np.exp(log_y_line)\n",
    "\n",
    "# # Plot the fitted line with dashed style.\n",
    "# plt.plot(x_line, y_line, color='tab:blue', linestyle='--')\n",
    "plt.text(0.65, 0.90, f'R² = {r_value**2:.3f}', transform=plt.gca().transAxes, color='tab:blue')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"backbone_torsion_decorrelation_times.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of backbone torsions with valid decorrelation times: {len(all_ref_decorrelation_times['backbone'])} out of {total_count['backbone']}\")\n",
    "\n",
    "backbone_torsion_speedups = all_ref_decorrelation_times[\"backbone\"] / all_traj_decorrelation_times[\"backbone\"]\n",
    "\n",
    "bins = np.logspace(np.log10(np.min(backbone_torsion_speedups)), np.log10(np.max(backbone_torsion_speedups)), 21)\n",
    "plt.hist(backbone_torsion_speedups, bins=bins)\n",
    "plt.xscale(\"log\")\n",
    "plt.xlabel(\"Speedup Factor\")\n",
    "plt.xticks([1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3])\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.suptitle(f\"Speedups of Backbone Torsion Decorrelation Times\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"backbone_torsion_speedups.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Sidechain Torsion Angle Decorrelation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of sidechain torsions with valid decorrelation times: {len(all_ref_decorrelation_times['sidechain'])} out of {total_count['sidechain']}\")\n",
    "\n",
    "# Scatter plot of probabilities.\n",
    "plt.scatter(all_ref_decorrelation_times[\"sidechain\"], all_traj_decorrelation_times[\"sidechain\"], alpha=0.3, edgecolors=\"none\", color='tab:orange')\n",
    "plt.xscale(\"log\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xlabel(format_traj_name(\"ref_traj\"))\n",
    "plt.ylabel(format_traj_name(\"traj\"))\n",
    "plt.title(\"Decorrelation Times of Sidechain Torsions\")\n",
    "\n",
    "# Fit line.\n",
    "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n",
    "    np.log(all_ref_decorrelation_times[\"sidechain\"]), np.log(all_traj_decorrelation_times[\"sidechain\"])\n",
    ")\n",
    "\n",
    "# # Create x points for line.\n",
    "# x_line = np.array([np.percentile(all_ref_decorrelation_times[\"sidechain\"], 5), np.percentile(all_ref_decorrelation_times[\"sidechain\"], 95)])\n",
    "# log_x_line = np.log(x_line)\n",
    "# log_y_line = slope * log_x_line + intercept\n",
    "\n",
    "# # Transform back to original scale for plotting\n",
    "# y_line = np.exp(log_y_line)\n",
    "\n",
    "# # Plot the fitted line with dashed style.\n",
    "# plt.plot(x_line, y_line, color='tab:orange', linestyle='--')\n",
    "plt.text(0.65, 0.90, f'R² = {r_value**2:.3f}', transform=plt.gca().transAxes, color='tab:orange')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"sidechain_torsion_decorrelation_times.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of sidechain torsions with valid decorrelation times: {len(all_ref_decorrelation_times['sidechain'])} out of {total_count['sidechain']}\")\n",
    "\n",
    "sidechain_torsion_speedups = all_ref_decorrelation_times[\"sidechain\"] / all_traj_decorrelation_times[\"sidechain\"]\n",
    "\n",
    "bins = np.logspace(np.log10(np.min(sidechain_torsion_speedups)),\n",
    "                   np.log10(np.max(sidechain_torsion_speedups)), 21)\n",
    "plt.hist(sidechain_torsion_speedups, bins=bins)\n",
    "plt.xscale(\"log\")\n",
    "plt.xlabel(\"Speedup Factor\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.suptitle(f\"Speedups of Sidechain Torsion Decorrelation Times\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"sidechain_torsion_speedups.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Jenson-Shannon Divergences (JSD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_JSD_results(quantity: str, name: str, key: str):\n",
    "    \"\"\"Helper to load final JSD results.\"\"\"\n",
    "    JSDs = []\n",
    "\n",
    "    for i, row in results_df.iterrows():\n",
    "        try:\n",
    "            JSD = row[\"results\"][key][name][quantity]\n",
    "        except KeyError:\n",
    "            continue\n",
    "        JSDs.append(JSD)\n",
    "\n",
    "    JSDs = np.asarray(JSDs)\n",
    "    return JSDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "JSD_final_results = {\n",
    "    \"JSD_backbone_torsions\": {},\n",
    "    \"JSD_sidechain_torsions\": {},\n",
    "    \"JSD_all_torsions\": {},\n",
    "    \"JSD_TICA-0\": {},\n",
    "    \"JSD_TICA-0,1\": {},\n",
    "    \"JSD_metastable_probs\": {},\n",
    "}\n",
    "traj_names = [\"traj\", \"ref_traj\", \"ref_traj_10x\", \"ref_traj_100x\"]\n",
    "if experiment == \"Timewarp_2AA\":\n",
    "    traj_names.append(\"TBG\")\n",
    "\n",
    "for quantity in [\"JSD_backbone_torsions\", \"JSD_sidechain_torsions\", \"JSD_all_torsions\"]:\n",
    "    for name in traj_names:\n",
    "        JSD_final_results[quantity][name] = get_JSD_results(\n",
    "            quantity, name, \"JSD_torsions\"\n",
    "        )\n",
    "\n",
    "for quantity in [\"JSD_TICA-0\", \"JSD_TICA-0,1\"]:\n",
    "    for name in traj_names:\n",
    "        JSD_final_results[quantity][name] = get_JSD_results(\n",
    "            quantity, name, \"JSD_TICA\"\n",
    "        )\n",
    "\n",
    "for quantity in [\"JSD_metastable_probs\"]:\n",
    "    for name in traj_names:\n",
    "        JSD_final_results[quantity][name] = get_JSD_results(\n",
    "            quantity, name, \"JSD_MSM\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "JSD_final_results_df = pd.DataFrame.from_dict(JSD_final_results)\n",
    "\n",
    "# Apply mean to each array in the DataFrame\n",
    "means_series = JSD_final_results_df.map(lambda x: np.mean(x) if isinstance(x, np.ndarray) else None)\n",
    "\n",
    "# Apply std to each array in the DataFrame\n",
    "stds_series = JSD_final_results_df.map(lambda x: np.std(x) if isinstance(x, np.ndarray) else None)\n",
    "\n",
    "means_series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "means_series.to_csv(os.path.join(output_dir, \"JSDs.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "JSD_MSM = JSD_final_results[\"JSD_metastable_probs\"][\"traj\"]\n",
    "\n",
    "plt.hist(JSD_MSM)\n",
    "plt.title(\"Jenson-Shannon Distances of Metastable State Probabilities\")\n",
    "plt.xlabel(\"JSD\")\n",
    "plt.xticks(np.arange(0.1, JSD_MSM.max() + 0.1, 0.1))\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.ticklabel_format(useOffset=False, style=\"plain\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"jsd_metastable_probs.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### JSD against Trajectory Progress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_JSD_results_against_time(quantity: str, name: str, key: str) -> np.ndarray:\n",
    "    \"\"\"Helper to load JSD vs time results.\"\"\"\n",
    "    JSD_vs_time = {\n",
    "        \"steps\": None,\n",
    "        \"JSDs\": []\n",
    "    }\n",
    "\n",
    "    for i, row in results_df.iterrows():  \n",
    "        try:  \n",
    "            results = row[\"results\"][key]\n",
    "        except KeyError:\n",
    "            continue\n",
    "\n",
    "        steps = np.asarray(list(results[name].keys()))\n",
    "        if JSD_vs_time[\"steps\"] is None:\n",
    "            JSD_vs_time[\"steps\"] = steps\n",
    "        \n",
    "        assert np.allclose(JSD_vs_time[\"steps\"], steps)\n",
    "\n",
    "        JSDs = np.asarray(list([v[quantity] for v in results[name].values()]))\n",
    "        JSD_vs_time[\"JSDs\"].append(JSDs)\n",
    "\n",
    "    JSD_vs_time[\"progress\"] = JSD_vs_time[\"steps\"] / JSD_vs_time[\"steps\"][-1]\n",
    "    JSD_vs_time[\"JSDs\"] = np.stack(JSD_vs_time[\"JSDs\"])\n",
    "    return JSD_vs_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "JSD_results = {\n",
    "    \"JSD_backbone_torsions\": {},\n",
    "    \"JSD_sidechain_torsions\": {},\n",
    "    \"JSD_all_torsions\": {},\n",
    "    \"JSD_TICA-0\": {},\n",
    "    \"JSD_TICA-0,1\": {},\n",
    "    \"JSD_metastable_probs\": {},\n",
    "}\n",
    "traj_names = [\"traj\", \"ref_traj\", \"ref_traj_10x\", \"ref_traj_100x\"]\n",
    "\n",
    "for quantity in [\"JSD_backbone_torsions\", \"JSD_sidechain_torsions\", \"JSD_all_torsions\"]:\n",
    "    for name in traj_names:\n",
    "        JSD_results[quantity][name] = get_JSD_results_against_time(\n",
    "            quantity, name, \"JSD_torsions_against_time\"\n",
    "        )\n",
    "\n",
    "# for quantity in [\"JSD_TICA-0\", \"JSD_TICA-0,1\"]:\n",
    "#     for name in traj_names:\n",
    "#         JSD_results[quantity][name] = get_JSD_results_against_time(\n",
    "#             quantity, name, \"JSD_TICA_against_time\"\n",
    "#         )\n",
    "\n",
    "# for quantity in [\"JSD_metastable_probs\"]:\n",
    "#     for name in traj_names:\n",
    "#         JSD_results[quantity][name] = get_JSD_results_against_time(\n",
    "#             quantity, name, \"JSD_MSM_against_time\"\n",
    "#         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for quantity in JSD_results:\n",
    "    for name in JSD_results[quantity]:\n",
    "        mean = np.mean(JSD_results[quantity][name][\"JSDs\"], axis=0)\n",
    "        std = np.std(JSD_results[quantity][name][\"JSDs\"], axis=0)\n",
    "        progress = JSD_results[quantity][name][\"progress\"]\n",
    "\n",
    "        # Plot mean line\n",
    "        if name == \"traj\":\n",
    "            color = \"tab:orange\"\n",
    "        else:\n",
    "            color = None\n",
    "\n",
    "        line, = plt.plot(progress, mean, label=format_traj_name(name), color=color)\n",
    "        color = line.get_color()\n",
    "        \n",
    "        # Add shaded region for standard deviation\n",
    "        plt.fill_between(progress, mean - std, mean + std,\n",
    "                         alpha=0.2, color=color)\n",
    "\n",
    "\n",
    "    # plt.yscale('function', functions=(np.sqrt, lambda x: x**2))\n",
    "    plt.ylim(0, 1) \n",
    "    plt.legend(bbox_to_anchor=(1.05, 0.5), loc='center left')\n",
    "    plt.title(f\"JSD vs Trajectory Progress\\n{format_quantity(quantity)}\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TICA Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(12, 3.5 * len(sampled_results_df)), squeeze=False)\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "    results = row[\"results\"][\"TICA_histograms\"]\n",
    "\n",
    "    # Plot free energy.\n",
    "    ref_traj_tica = results[\"ref_traj\"]\n",
    "    pyemma_helper.plot_free_energy(*ref_traj_tica, cmap=\"plasma\", ax=axs[i, 0])\n",
    "    \n",
    "    axs[i, 0].ticklabel_format(useOffset=False, style=\"plain\")\n",
    "\n",
    "    traj_tica = results[\"traj\"]\n",
    "    pyemma_helper.plot_free_energy(*traj_tica, cmap=\"plasma\", ax=axs[i, 1])\n",
    "    if i==0:\n",
    "        axs[i, 1].set_title(format_traj_name(\"traj\"))\n",
    "        axs[i, 0].set_title(format_traj_name(\"ref_traj\"))\n",
    "    axs[i, 1].ticklabel_format(useOffset=False, style=\"plain\")\n",
    "\n",
    "    # Set the same limits for both plots.\n",
    "    axs[i, 1].set_xlim(axs[i, 0].get_xlim())\n",
    "    axs[i, 1].set_ylim(axs[i, 0].get_ylim())\n",
    "    axs[i, -1].text(\n",
    "        1.4,\n",
    "        0.5,\n",
    "        format_peptide_name(peptide),\n",
    "        rotation=90,\n",
    "        verticalalignment=\"center\",\n",
    "        horizontalalignment=\"center\",\n",
    "        transform=axs[i, -1].transAxes,\n",
    "    )\n",
    "\n",
    "plt.suptitle(\"TICA-0,1 Projections\", fontsize=\"x-large\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"tica_projections.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)\n",
    "\n",
    "tica_0_speedups = []\n",
    "for i, row in results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "    results = row[\"results\"][\"TICA_decorrelations\"]\n",
    "    \n",
    "    speedup_factor = results['ref_traj_decorrelation_time'] / results['traj_decorrelation_time']\n",
    "    if np.isnan(speedup_factor):\n",
    "        continue\n",
    "\n",
    "    tica_0_speedups.append(speedup_factor)\n",
    "\n",
    "print(f\"Number of systems with valid decorrelations: {len(tica_0_speedups)} out of {len(results_df)}\")\n",
    "\n",
    "# Place legend outside plot.\n",
    "bins = np.logspace(np.log10(np.min(tica_0_speedups)),np.log10(np.max(tica_0_speedups)), 21)\n",
    "plt.hist(tica_0_speedups, bins=bins)\n",
    "plt.xscale(\"log\")\n",
    "plt.xlabel(\"Speedup Factor\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.suptitle(f\"Speedups of TICA-0 Decorrelation Times\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"tica_0_speedups.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MSM State Probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_ref_metastable_probs = []\n",
    "all_traj_metastable_probs = []\n",
    "for i, row in results_df.iterrows():\n",
    "    results = row[\"results\"][\"JSD_MSM\"][\"traj\"]\n",
    "    ref_metastable_probs = results[\"ref_metastable_probs\"]\n",
    "    traj_metastable_probs = results[\"traj_metastable_probs\"]\n",
    "    \n",
    "    all_ref_metastable_probs.append(ref_metastable_probs)\n",
    "    all_traj_metastable_probs.append(traj_metastable_probs)\n",
    "\n",
    "all_ref_metastable_probs = np.concatenate(all_ref_metastable_probs)\n",
    "all_traj_metastable_probs = np.concatenate(all_traj_metastable_probs)\n",
    "\n",
    "# Scatter plot of probabilities.\n",
    "plt.scatter(all_ref_metastable_probs, all_traj_metastable_probs, alpha=0.3, edgecolors=\"none\")\n",
    "\n",
    "# Fit line.\n",
    "slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(\n",
    "    all_ref_metastable_probs, all_traj_metastable_probs\n",
    ")\n",
    "\n",
    "# Create x points for line.\n",
    "x_line = np.array([-0.5, 1.5])\n",
    "y_line = slope * x_line + intercept\n",
    "\n",
    "# Plot the fitted line with dashed style.\n",
    "plt.plot(x_line, y_line, color='red', linestyle='--')\n",
    "plt.text(0.45, 0.90, f'R² = {r_value**2:.3f}', transform=plt.gca().transAxes, color='red')\n",
    "\n",
    "plt.title(\"Metastable State Probabilities\")\n",
    "plt.xlim((0, 1))\n",
    "plt.ylim((0, 1))\n",
    "plt.xlabel(format_traj_name(\"ref_traj\"))\n",
    "plt.ylabel(format_traj_name(\"traj\"))\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"metastable_probs.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transition and Flux Matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, len(sampled_results_df), figsize=(15, 5))\n",
    "\n",
    "mean_correlation = results_df[\"results\"].apply(lambda x: x[\"MSM_matrices\"][\"traj\"][\"transition_spearman_correlation\"]).mean()\n",
    "print(f\"Mean correlation for flux matrices: {mean_correlation:.2f}\")\n",
    "\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "    results = row[\"results\"][\"MSM_matrices\"][\"traj\"]\n",
    "    \n",
    "    msm_transition_matrix = results[\"msm_transition_matrix\"]\n",
    "    traj_transition_matrix = results[\"traj_transition_matrix\"]\n",
    "    correlation = results[\"transition_spearman_correlation\"]\n",
    "\n",
    "    im = axs[0][i].imshow(msm_transition_matrix, cmap='Blues', vmin=0, vmax=1)\n",
    "    axs[1][i].imshow(traj_transition_matrix, cmap='Blues', vmin=0, vmax=1)\n",
    "    axs[0][i].set_title(f\"{format_peptide_name(peptide)}\\nρ = {correlation:.2f}\")\n",
    "\n",
    "axs[0][0].text(\n",
    "    -0.4,\n",
    "    0.5,\n",
    "    format_traj_name(\"ref_traj\"),\n",
    "    horizontalalignment=\"right\",\n",
    "    verticalalignment=\"center\",\n",
    "    transform=axs[0, 0].transAxes\n",
    ")\n",
    "\n",
    "axs[1][0].text(\n",
    "    -0.4,\n",
    "    0.5,\n",
    "    format_traj_name(\"traj\"),\n",
    "    horizontalalignment=\"right\",\n",
    "    verticalalignment=\"center\",\n",
    "    transform=axs[1, 0].transAxes\n",
    ")\n",
    "\n",
    "fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.022)\n",
    "plt.savefig(os.path.join(output_dir, \"transition_matrices.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, len(sampled_results_df), figsize=(15, 5))\n",
    "\n",
    "vmin = np.inf\n",
    "vmax = -np.inf\n",
    "\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "    results = row[\"results\"][\"MSM_matrices\"][\"traj\"]\n",
    "    \n",
    "    msm_flux_matrix = results[\"msm_flux_matrix\"]\n",
    "    traj_flux_matrix = results[\"traj_flux_matrix\"]\n",
    "\n",
    "    vmin = min(vmin, np.min(msm_flux_matrix), np.min(traj_flux_matrix))\n",
    "    vmax = max(vmax, np.max(msm_flux_matrix), np.max(traj_flux_matrix))\n",
    "\n",
    "mean_correlation = results_df[\"results\"].apply(lambda x: x[\"MSM_matrices\"][\"traj\"][\"flux_spearman_correlation\"]).mean()\n",
    "print(f\"Mean correlation for flux matrices: {mean_correlation:.2f}\")\n",
    "\n",
    "for i, row in sampled_results_df.iterrows():\n",
    "    peptide = row[\"peptide\"]\n",
    "    results = row[\"results\"][\"MSM_matrices\"][\"traj\"]\n",
    "    \n",
    "    msm_flux_matrix = results[\"msm_flux_matrix\"]\n",
    "    traj_flux_matrix = results[\"traj_flux_matrix\"]\n",
    "    correlation = results[\"flux_spearman_correlation\"]\n",
    "\n",
    "    im = axs[0][i].imshow(msm_flux_matrix, cmap='Blues', norm=matplotlib.colors.PowerNorm(gamma=0.5, vmin=vmin, vmax=vmax))\n",
    "    axs[1][i].imshow(traj_flux_matrix, cmap='Blues', norm=matplotlib.colors.PowerNorm(gamma=0.5, vmin=vmin, vmax=vmax))\n",
    "    axs[0][i].set_title(f\"{format_peptide_name(peptide)}\\nρ = {correlation:.2f}\")\n",
    "\n",
    "\n",
    "axs[0][0].text(\n",
    "    -0.4,\n",
    "    0.5,\n",
    "    format_traj_name(\"ref_traj\"),\n",
    "    horizontalalignment=\"right\",\n",
    "    verticalalignment=\"center\",\n",
    "    transform=axs[0, 0].transAxes\n",
    ")\n",
    "\n",
    "axs[1][0].text(\n",
    "    -0.4,\n",
    "    0.5,\n",
    "    format_traj_name(\"traj\"),\n",
    "    horizontalalignment=\"right\",\n",
    "    verticalalignment=\"center\",\n",
    "    transform=axs[1, 0].transAxes\n",
    ")\n",
    "\n",
    "fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.022)\n",
    "plt.savefig(os.path.join(output_dir, \"flux_matrices.pdf\"), dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jamun",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
