{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":125983,"status":"ok","timestamp":1674029909217,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"azin34LbflEi","outputId":"03e18f01-7383-43f5-e5b4-a298b9062787"},"outputs":[],"source":["import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","import time\n","import io\n","import copy\n","import os\n","\n","import pickle\n","from pylab import *\n","\n","from torchvision import transforms\n","\n","from scipy.stats import gaussian_kde\n","\n","from diffusers import DDPMScheduler\n","from statsmodels.tsa.stattools import acf\n","\n","import sys\n","\n","current_path = os.getcwd()\n","two_folders_up = os.path.join(current_path, '..', '..')\n","desired_folder = os.path.join(two_folders_up, 'SeqDiff')\n","absolute_desired_folder = os.path.abspath(desired_folder)\n","sys.path.insert(0, absolute_desired_folder)\n","\n","from dynamicsdiffusion import TemporalUnet, TemporalUnetEnergy, GaussianDiffusion, Trainer, count_parameters\n","\n","device = 'cuda'"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":432,"status":"ok","timestamp":1674030187420,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"HbfSo-pe1xj6"},"outputs":[],"source":["#@title Sampling and Logging\n","def batch_sampler(traj, batch_size, horizon, transition_dim):\n","    if isinstance(horizon, list):\n","        horizon = np.random.choice(horizon)\n","\n","    batch = torch.zeros(size = (batch_size, horizon, transition_dim))\n","    for i in range(batch_size):\n","        start = np.random.randint(0, traj.shape[1] - horizon)\n","        ensemble = np.random.randint(0, traj.shape[0], size = 1)\n","        batch[i] = torch.tensor(traj[ensemble, start:start+horizon])\n","    return batch\n","\n","def _displacement(xyz, pairs):\n","    \"Displacement vector between pairs of points in each frame\"\n","    value = np.diff(xyz[:, pairs], axis=2)[:, :, 0]\n","    assert value.shape == (xyz.shape[0], pairs.shape[0], 3), 'v.shape %s, xyz.shape %s, pairs.shape %s' % (str(value.shape), str(xyz.shape), str(pairs.shape))\n","    return value\n","\n","def compute_dihedral(traj, indices, periodic, out=None):\n","    \"\"\"Compute the dihedral angles of traj for the atom indices in indices.\n","    Parameters\n","    ----------\n","    xyz : np.ndarray, shape=(num_frames, num_atoms, 3), dtype=float\n","        The XYZ coordinates of a trajectory\n","    indices : np.ndarray, shape=(num_dihedrals, 4), dtype=int\n","        Atom indices to compute dihedrals.\n","    periodic : bool, default=True\n","        If `periodic` is True and the trajectory contains unitcell\n","        information, we will treat dihedrals that cross periodic images\n","        using the minimum image convention.\n","    Returns\n","    -------\n","    dih : np.ndarray, shape=(num_dihedrals), dtype=float\n","        dih[i,j] gives the dihedral angle at traj[i] correponding to indices[j].\n","    \"\"\"\n","    ix10 = indices[:, [0, 1]]\n","    ix21 = indices[:, [1, 2]]\n","    ix32 = indices[:, [2, 3]]\n","\n","    b1 = _displacement(traj, ix10)\n","    b2 = _displacement(traj, ix21)\n","    b3 = _displacement(traj, ix32)\n","\n","    c1 = np.cross(b2, b3)\n","    c2 = np.cross(b1, b2)\n","\n","    p1 = (b1 * c1).sum(-1)\n","    p1 *= (b2 * b2).sum(-1) ** 0.5\n","    p2 = (c1 * c2).sum(-1)\n","\n","    return np.arctan2(p1, p2, out)\n","\n","psi_indices_non_h, phi_indices_non_h = [3, 4, 6, 8], [1, 3, 4, 6]\n","\n","def plotting_fn(samples, data, step, training = True, title_addition = ''):\n","    #use the \n","    samples = samples.reshape((-1, 10, 3)).cpu().numpy()\n","    data = data.reshape((-1, 10, 3))\n","    samples = samples*(max_min[0] - max_min[1])/2 + (max_min[0] + max_min[1])/2\n","    data = data*(max_min[0] - max_min[1])/2 + (max_min[0] + max_min[1])/2\n","    \n","    angles_samples = compute_dihedral(samples, np.array([phi_indices_non_h, psi_indices_non_h]), True)\n","    angles_data = compute_dihedral(data, np.array([phi_indices_non_h, psi_indices_non_h]), True)\n","\n","    fig, ax = plt.subplots(1, 2, figsize = (14, 5))\n","\n","    density_samples = np.histogram2d(angles_samples[::, 0], angles_samples[::, 1], bins=100, range=[[-pi, pi], [-pi, pi]])[0]\n","    density_samples = density_samples/np.sum(density_samples)\n","    ax[0].imshow(-np.log(density_samples.T), origin='lower', extent=[-pi, pi, -pi, pi])\n","    ax[0].set_title('Free Energy of DDPM Samples')\n","    ax[0].set_xlabel(r'$\\Phi$ Angle [radians]')\n","    ax[0].set_ylabel(r'$\\Psi$ Angle [radians]')\n","    ax[0].set_xlim(-pi, pi)\n","    ax[0].set_ylim(-pi, pi)\n","\n","    cbar = fig.colorbar(ax[0].imshow(-np.log(density_samples.T), origin='lower', extent=[-pi, pi, -pi, pi]), ax=ax[0])\n","    cbar.ax.set_ylabel('Free Energy')\n","    \n","    density_data = np.histogram2d(angles_data[::10, 0], angles_data[::10, 1], bins=100, range=[[-pi, pi], [-pi, pi]])[0]\n","    density_data = density_data/np.sum(density_data)\n","    ax[1].imshow(-np.log(density_data.T), origin='lower', extent=[-pi, pi, -pi, pi])\n","    ax[1].set_title('Free Energy of Data')\n","    ax[1].set_xlabel(r'$\\Phi$ Angle [radians]')\n","    ax[1].set_ylabel(r'$\\Psi$ Angle [radians]')\n","    ax[1].set_xlim(-pi, pi)\n","    ax[1].set_ylim(-pi, pi)\n","\n","    cbar = fig.colorbar(ax[1].imshow(-np.log(density_data.T), origin='lower', extent=[-pi, pi, -pi, pi]), ax=ax[1])\n","    cbar.ax.set_ylabel('Free Energy')\n","\n","    #add a title to the figure\n","    if training:\n","        fig.suptitle('Dihedral Map: Alanine dipeptide DDPM Samples and Data at Trainingstep: ' + str(step), fontsize=16)\n","    else:\n","        fig.suptitle('Dihedral Map: Alanine dipeptide DDPM Samples and Data ' + title_addition, fontsize=16)\n","\n","    plt.show()"]},{"cell_type":"markdown","metadata":{"id":"yl50mOsMi3A0"},"source":["Instantiate Models"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["torch.cuda.empty_cache()"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":13718,"status":"ok","timestamp":1674029922926,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"z4Gp-UxCgvKY","outputId":"67a6a035-bb40-4ffd-b78f-144594afcf72"},"outputs":[],"source":["#load trajectories.npy and max_min.npy\n","traj = np.load('../Data Generation/Data/trajectories.npy')\n","max_min = np.load('../Data Generation/Data/max_min.npy')\n","\n","traj = traj.reshape((traj.shape[0], traj.shape[1], traj.shape[2]*traj.shape[3]))\n","\n","horizon = [64]\n","transition_dim = 30 \n","model = TemporalUnetEnergy(horizon = horizon[0], transition_dim = transition_dim, cond_dim = None, dim = 128, dim_mults = (1, 2, 4), kernel_size = 3, attention=True).to(device)\n","diffuser = GaussianDiffusion(model = model, horizon = horizon[0], action_dim=transition_dim, observation_dim=0).to(device)\n","batch_size = 256\n","count_parameters(diffuser)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer = Trainer(diffuser, training_data = traj, training_horizons = horizon, batch_sampler = batch_sampler, plotting_fn = plotting_fn, train_lr=1e-6, train_batch_size=batch_size, run_name=\"energy_\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.load(45000)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.train(15000)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Autocorrelation & Free Energy Plot"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def comp_corr(samples, data, lag, correlation_samples, horizon):\n","    #calculate the autocorrelation function for both the samples and the data for both dimensions\n","    samples = samples.reshape((-1, horizon[0], 10, 3)).cpu().numpy()\n","    data = data.reshape((-1, 10, 3))\n","    samples = samples*(max_min[0] - max_min[1])/2 + (max_min[0] + max_min[1])/2\n","    data = data*(max_min[0] - max_min[1])/2 + (max_min[0] + max_min[1])/2\n","    \n","    angles_data = compute_dihedral(data, np.array([phi_indices_non_h, psi_indices_non_h]), True)\n","\n","    angles_samples = np.zeros((samples.shape[0], samples.shape[1], 2))\n","    for i in range(samples.shape[0]):\n","        angles_samples[i, ::, ::] = compute_dihedral(samples[i, ::, ::], np.array([phi_indices_non_h, psi_indices_non_h]), True)\n","        \n","    print(angles_samples.shape, angles_data.shape)\n","    auto_corr = np.zeros((correlation_samples, lag+1, 2, 2))\n","    for i in range(correlation_samples):\n","        replica_idx_sample = np.random.randint(0, angles_samples.shape[0])\n","        rand_start_idx = np.random.randint(0, angles_data.shape[0])\n","\n","        auto_corr[i, ::, 0, 0] = acf(angles_samples[replica_idx_sample, ::, 0], nlags=lag)\n","        auto_corr[i, ::, 0, 1] = acf(angles_samples[replica_idx_sample, ::, 1], nlags=lag)\n","\n","        auto_corr[i, ::, 1, 0] = acf(angles_data[rand_start_idx:(rand_start_idx+horizon[0]), 0], nlags=lag)\n","        auto_corr[i, ::, 1, 1] = acf(angles_data[rand_start_idx:(rand_start_idx+horizon[0]), 1], nlags=lag)\n","\n","    return auto_corr"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["h = horizon[0]\n","sample_count = 500\n","iterations = 15\n","samples = torch.zeros(size = (iterations, sample_count, h, transition_dim))\n","for i in range(iterations):\n","    samples[i] = trainer.ema_model.p_sample_loop((sample_count, h, transition_dim), None).trajectories\n","    print(i)\n","samples = samples.reshape((iterations*sample_count, h, transition_dim))"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["auto_corr = comp_corr(samples, traj, 60, 10000, horizon)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plt.rcParams.update({'font.size': 15})\n","fig, ax = plt.subplots(2, 2, figsize = (13, 10))\n","lag = 60\n","bins = 100\n","scale = 1\n","\n","for i in range(2):\n","    ax[i, 0].plot(np.arange(lag+1), np.mean(auto_corr[::, ::, 0, i], axis=0), label='DDPM Samples')\n","    ax[i, 0].plot(np.arange(lag+1), np.mean(auto_corr[::, ::, 1, i], axis=0), label='Data') \n","\n","    ax[i, 0].fill_between(np.arange(lag+1), np.mean(auto_corr[::, ::, 0, i], axis=0) - np.std(auto_corr[::, ::, 0, i], axis=0), np.mean(auto_corr[::, ::, 0, i], axis=0) + np.std(auto_corr[::, ::, 0, i], axis=0), alpha=0.3)\n","    ax[i, 0].fill_between(np.arange(lag+1), np.mean(auto_corr[::, ::, 1, i], axis=0) - np.std(auto_corr[::, ::, 1, i], axis=0), np.mean(auto_corr[::, ::, 1, i], axis=0) + np.std(auto_corr[::, ::, 1, i], axis=0), alpha=0.3)\n","    ax[i, 0].legend()\n","\n","ax[0, 0].set_ylabel('Autocorrelation of \\u03C6')\n","ax[1, 0].set_ylabel('Autocorrelation of \\u03C8')\n","ax[1, 0].set_xlabel('Autocorrelation Lag [0.02 ps]')\n","ax[0, 0].set_xticks([])\n","\n","angles_samples = compute_dihedral(samples.reshape((-1, 10, 3)).cpu().numpy()*(max_min[0] - max_min[1])/2 + (max_min[0] + max_min[1])/2, np.array([phi_indices_non_h, psi_indices_non_h]), True)\n","angles_data = compute_dihedral(traj.reshape((-1, 10, 3))*(max_min[0] - max_min[1])/2 + (max_min[0] + max_min[1])/2, np.array([phi_indices_non_h, psi_indices_non_h]), True)\n","\n","density_samples = np.histogram2d(angles_samples[::, 0], angles_samples[::, 1], bins=bins, range=[[-pi, pi], [-pi, pi]])[0]\n","density_samples = density_samples/np.sum(density_samples)\n","\n","density_data = np.histogram2d(angles_data[::, 0], angles_data[::, 1], bins=bins, range=[[-pi, pi], [-pi, pi]])[0]\n","density_data = density_data/np.sum(density_data)\n","\n","#multiply the maks with the samples and the data\n","density_samples = density_samples\n","density_data = density_data\n","\n","log_density_samples = -np.log(density_samples.T)*scale\n","log_density_data = -np.log(density_data.T)*scale\n","\n","#subtract the minimum from the free energy\n","log_density_samples = log_density_samples - np.min(log_density_samples)\n","log_density_data = log_density_data - np.min(log_density_data)\n","\n","im1 = ax[0, 1].imshow(log_density_samples, origin='lower', extent=[-pi, pi, -pi, pi], cmap='cividis')\n","ax[0, 1].set_ylabel('\\u03C8 [rad]')\n","ax[0, 1].set_xticks([])\n","fig.colorbar(im1, ax=ax[0, 1], label='Free Energy of DDPM Samples', fraction=0.046, pad=0.01, cmap='cividis')\n","\n","im2 = ax[1, 1].imshow(log_density_data, origin='lower', extent=[-pi, pi, -pi, pi], cmap='cividis')\n","ax[1, 1].set_xlabel('\\u03C6 [rad]')\n","ax[1, 1].set_ylabel('\\u03C8 [rad]')\n","fig.colorbar(im1, ax=ax[1, 1], label='Free Energy of MD Data', fraction=0.046, pad=0.01, cmap='cividis')\n","\n","fig.tight_layout()\n","\n","plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyOnCsf26W2AwBpqJeHxSkuI","machine_shape":"hm","mount_file_id":"1SWRJU8pxRR0ppYpjID44px6_ZT9K7naD","provenance":[]},"gpuClass":"premium","kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.13"},"vscode":{"interpreter":{"hash":"639d8a7c3e620b1d142eea4deabde5aac9ed3b21a6e651e4622d69fbdac2ed0a"}}},"nbformat":4,"nbformat_minor":0}
