{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openmmtools import multistate\n",
    "from openmm import unit, app\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from tqdm import tqdm\n",
    "import netCDF4\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_coords(file_path, target_state=0):\n",
    "    \n",
    "    # 1. Open both files\n",
    "    with netCDF4.Dataset(file_path, 'r') as ds_main:\n",
    "        \n",
    "        states = ds_main.variables['states'][:] # (iterations, n_replicas)\n",
    "        positions = ds_main.variables['positions'][:] # (iterations, n_replicas, n_atoms, 3)\n",
    "            \n",
    "        print(f\"Loaded {positions.shape[0]} iterations and {positions.shape[1]} replicas.\")\n",
    "\n",
    "        # 4. Find the replica index for State 0 at every iteration\n",
    "        # This handles the 'swaps' that happened during Parallel Tempering\n",
    "        iterations, replica_indices = np.where(states == target_state)\n",
    "        \n",
    "        # 5. Extract only State 0 positions using advanced indexing\n",
    "        # This takes the iterations and the specific replica that was at State 0\n",
    "        coords = positions[iterations, replica_indices]\n",
    "        \n",
    "        # 6. Flatten (N, 22, 3) -> (N, 66)\n",
    "        return coords.reshape(coords.shape[0], -1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mdtraj as md\n",
    "import numpy as np\n",
    "from openmmtools.multistate import MultiStateReporter, MultiStateSamplerAnalyzer\n",
    "from pymbar import timeseries\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# --- CONFIGURATION ---\n",
    "TARGET_STATE_INDEX = 0  # Index 0 corresponds to T_MIN (300K)\n",
    "\n",
    "nc_files_list = [\"\", \"\"]  # Add your NetCDF file paths here\n",
    "\n",
    "uncorrelated_traj_list = []\n",
    "for nc_file in nc_files_list:\n",
    "\n",
    "    print(f\"Loading data from {nc_file}...\")\n",
    "    reporter = MultiStateReporter(nc_file, open_mode='r')\n",
    "    analyzer = MultiStateSamplerAnalyzer(reporter)\n",
    "\n",
    "    # 1. Extract the trajectory specifically for the 300K state\n",
    "    # This handles the \"demuxing\" of replica exchange data automatically\n",
    "    print(f\"Extracting trajectory for State {TARGET_STATE_INDEX} (300K)...\")\n",
    "    # sampled_positions shape: (n_iterations, n_atoms, 3)\n",
    "    # sampled_positions = analyzer.extract_positions_for_states([TARGET_STATE_INDEX])\n",
    "    sampled_positions = extract_coords(nc_file, target_state=TARGET_STATE_INDEX)\n",
    "    sampled_positions = sampled_positions.reshape(-1, 22, 3)\n",
    "\n",
    "    # Convert to MDTraj trajectory for easy analysis\n",
    "    # We need a topology. Since it's ALDP, we can grab it from openmmtools testsystems\n",
    "    from openmmtools.testsystems import AlanineDipeptideVacuum\n",
    "    topology = AlanineDipeptideVacuum().topology\n",
    "    traj = md.Trajectory(sampled_positions, md.Topology.from_openmm(topology))\n",
    "\n",
    "    print(f\"Total frames at 300K: {traj.n_frames}\")\n",
    "\n",
    "    # 2. Compute Observables (Phi and Psi angles)\n",
    "    # We use sin/cos of the angles to avoid periodicity issues (-180 to 180 jumps)\n",
    "    print(\"Computing Dihedral angles...\")\n",
    "    phi_indices, phi_angles = md.compute_phi(traj)\n",
    "    psi_indices, psi_angles = md.compute_psi(traj)\n",
    "\n",
    "    # Combine sin/cos of phi/psi into a single feature vector for correlation analysis\n",
    "    # This captures the structural relaxation better than energy\n",
    "    features = np.hstack([np.sin(phi_angles), np.cos(phi_angles), \n",
    "                        np.sin(psi_angles), np.cos(psi_angles)])\n",
    "\n",
    "    # 3. Calculate Statistical Inefficiency (g)\n",
    "    print(\"Analyzing correlation for each dihedral component...\")\n",
    "    labels = ['sin(phi)', 'cos(phi)', 'sin(psi)', 'cos(psi)']\n",
    "    t0_values = []\n",
    "    g_values = []\n",
    "\n",
    "    for i in range(4):\n",
    "        # detect_equilibration computes:\n",
    "        # t0: The starting frame (discarding burn-in)\n",
    "        # g:  The statistical inefficiency (correlation time)\n",
    "        # fast=True uses FFT, which is much faster for your 500k frames\n",
    "        t0, g, Neff = timeseries.detect_equilibration(features[:, i], fast=True, nskip=1000)\n",
    "        \n",
    "        t0_values.append(t0)\n",
    "        g_values.append(g)\n",
    "        print(f\"  {labels[i]}: Equilibration starts at frame {t0}, g = {g:.2f}\")\n",
    "\n",
    "    # 4. Select Conservative Parameters\n",
    "    # We start sampling only after ALL components have equilibrated (max t0)\n",
    "    # We skip frames based on the SLOWEST correlation (max g)\n",
    "    t0_final = int(max(t0_values))\n",
    "    g_final = max(g_values)\n",
    "\n",
    "    print(\"-\" * 30)\n",
    "    print(f\"Conservative Selection:\")\n",
    "    print(f\"  Start Frame (t0): {t0_final}\")\n",
    "    print(f\"  Stride (g):       {g_final:.2f} frames\")\n",
    "    print(f\"  Decorrelation time: {g_final * 2:.2f} ps\")\n",
    "    print(\"-\" * 30)\n",
    "\n",
    "    # 5. Extract Uncorrelated Indices\n",
    "    # We manually generate the indices to ensure we respect both t0 and g\n",
    "    indices = np.arange(t0_final, traj.n_frames, g_final).astype(int)\n",
    "\n",
    "    print(f\"Extracted {len(indices)} uncorrelated samples.\")\n",
    "\n",
    "    # 6. Save the uncorrelated trajectory\n",
    "    if len(indices) > 0:\n",
    "        uncorrelated_traj = traj[indices]\n",
    "        # uncorrelated_traj.save_pdb(OUTPUT_PDB)\n",
    "        # print(f\"Saved uncorrelated samples to {OUTPUT_PDB}\")\n",
    "    else:\n",
    "        print(\"No uncorrelated samples found! (Simulation might be too short)\")\n",
    "    uncorrelated_traj_list.append(uncorrelated_traj)\n",
    "\n",
    "# Combine all uncorrelated trajectories\n",
    "combined_traj = md.join(uncorrelated_traj_list)\n",
    "# transform uncorrelated_traj to torch tensor\n",
    "uncorrelated_samples = torch.from_numpy(combined_traj.xyz).float()\n",
    "\n",
    "print(f\"Total uncorrelated samples from all runs: {uncorrelated_samples.shape[0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sampling.utils import fix_chirality\n",
    "\n",
    "corrected_samples, n_flipped = fix_chirality(uncorrelated_samples, target_sign='positive')\n",
    "print(f\"Number of samples with flipped chirality corrected: {n_flipped}/{corrected_samples.shape[0]}\")        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE_PATH = \"\"\n",
    "corrected_samples_flat = corrected_samples.reshape(corrected_samples.shape[0], -1)\n",
    "torch.save(corrected_samples_flat, SAVE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openmmtools.testsystems import AlanineDipeptideVacuum\n",
    "import mdtraj\n",
    "# estimate histogram\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "\n",
    "# sample 10000\n",
    "idx = np.random.choice(corrected_samples.shape[0], size=10000, replace=False)\n",
    "vacuum_samples_np = corrected_samples.detach().cpu().numpy()[idx]\n",
    "vacuum_samples_np = vacuum_samples_np.reshape(-1, 22, 3)\n",
    "vacuum_samples_show = vacuum_samples_np[:, :, :]\n",
    "\n",
    "\n",
    "aldp = AlanineDipeptideVacuum(constraints=None)\n",
    "topology = mdtraj.Topology.from_openmm(aldp.topology)\n",
    "traj = mdtraj.Trajectory(vacuum_samples_show, topology)\n",
    "\n",
    "phi = mdtraj.compute_phi(traj)[1].reshape(-1)\n",
    "psi = mdtraj.compute_psi(traj)[1].reshape(-1)\n",
    "\n",
    "plt.figure(figsize=(5, 5))\n",
    "# cmap = plt.get_cmap(\"winter\")\n",
    "cmap = LinearSegmentedColormap.from_list(\"\", [\"navy\", \"aquamarine\"])\n",
    "plt.hist2d(phi, psi, bins=100, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True, cmap=cmap, norm=mpl.colors.LogNorm(vmin=0.001, vmax=1.0))\n",
    "plt.xlabel(\"Phi\")\n",
    "plt.ylabel(\"Psi\")\n",
    "# set ticks\n",
    "plt.xticks(ticks=[-np.pi, -np.pi/2, 0, np.pi/2, np.pi],\n",
    "           labels=[r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", \"0\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
    "plt.yticks(ticks=[-np.pi, -np.pi/2, 0, np.pi/2, np.pi],\n",
    "              labels=[r\"$-\\pi$\", r\"$-\\frac{\\pi}{2}$\", \"0\", r\"$\\frac{\\pi}{2}$\", r\"$\\pi$\"])\n",
    "plt.title(\"Ramachandran Plot from ALDP Samples\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sampling",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
