{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":124939,"status":"ok","timestamp":1673856780101,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"azin34LbflEi","outputId":"8c88e0bf-67cd-4a23-c02d-4c4387ced40c"},"outputs":[],"source":["import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","import time\n","import io\n","import copy\n","import os\n","import math\n","from scipy import integrate\n","from functools import partial\n","\n","import pickle\n","from pylab import *\n","\n","from torchvision import transforms\n","\n","from diffusers import DDPMScheduler, EulerDiscreteScheduler\n","from statsmodels.tsa.stattools import acf\n","\n","import sys\n","\n","current_path = os.getcwd()\n","two_folders_up = os.path.join(current_path, '..', '..')\n","desired_folder = os.path.join(two_folders_up, 'SeqDiff')\n","absolute_desired_folder = os.path.abspath(desired_folder)\n","sys.path.insert(0, absolute_desired_folder)\n","\n","from dynamicsdiffusion import TemporalUnet, GaussianDiffusion, Trainer, count_parameters, TemporalUnetEnergy\n","\n","device = 'cuda'"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":218,"status":"ok","timestamp":1673866836981,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"HbfSo-pe1xj6"},"outputs":[],"source":["#@title Sampling and Logging\n","def batch_sampler(traj, batch_size, horizon, transition_dim):\n","    #hflipper = transforms.RandomHorizontalFlip(p=0.5)\n","    if isinstance(horizon, list):\n","        horizon = np.random.choice(horizon)\n","    \n","    batch = torch.zeros(size = (batch_size, horizon, transition_dim))\n","\n","    for i in range(batch_size):\n","        start_replica = np.random.randint(0, traj.shape[1])\n","\n","        if traj.shape[0] <= horizon:\n","            #batch[i] = hflipper(torch.tensor(traj[:, start_replica, :]))\n","            batch[i] = torch.tensor(traj[:, start_replica, :])\n","        else:\n","            start_step = np.random.randint(0, traj.shape[0] - horizon)\n","            #batch[i] = hflipper(torch.tensor(traj[start_step:start_step + horizon, start_replica, :]))\n","            batch[i] = torch.tensor(traj[start_step:start_step + horizon, start_replica, :])\n","    return batch\n","\n","def potential(shape):\n","    #caclulate the potential at each point\n","    x = np.linspace(-2, 2, shape[0])\n","    y = np.linspace(-2, 2, shape[1])\n","    X, Y = np.meshgrid(x, y)\n","    return X, Y, -9*X**2 + X**3 + 4.5*x**4 + 3*Y**2\n","\n","def plotting_fn(samples, data, step, dim, training = True, title_addition = ''):\n","    #make a subplot of the data and the samples side by side\n","    samples = samples.cpu().detach().numpy()\n","    samples = samples.reshape(-1, dim)\n","\n","    data = data[::10].reshape(-1, dim)\n","\n","    samples = (samples + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","    data = (data + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","\n","\n","    fig, ax = plt.subplots(1, 2, figsize = (14, 5))\n","\n","    density_samples = np.histogram2d(samples[::, 0], samples[::, 1], bins=100, range=[[-2, 2], [-2, 2]])[0]\n","    density_samples = density_samples/np.sum(density_samples)\n","    ax[0].imshow(-np.log(density_samples.T), origin='lower', extent=[-2, 2, -2, 2])\n","    ax[0].set_title('Free Energy of DDPM Samples')\n","    ax[0].set_xlabel('x')\n","    ax[0].set_ylabel('y')\n","\n","    X, Y, Z = potential((100, 100))\n","    min_level = np.min(Z)\n","    Z = Z - min_level\n","    levels = np.logspace(0, 1, 10)\n","    levels = levels + min_level\n","    Z = Z + min_level\n","\n","    ax[0].contour(X, Y, Z, levels=levels, alpha=1)\n","    cbar = fig.colorbar(ax[0].contour(X, Y, Z, levels=levels, alpha=1), ax=ax[0])\n","    cbar.ax.set_ylabel('Potential Energy')\n","    \n","    density_data = np.histogram2d(data[::, 0], data[::, 1], bins=100, range=[[-2, 2], [-2, 2]])[0]\n","    density_data = density_data/np.sum(density_data)\n","    ax[1].imshow(-np.log(density_data.T), origin='lower', extent=[-2, 2, -2, 2])\n","    ax[1].set_title('Free Energy of Data')\n","    ax[1].set_xlabel('x')\n","    ax[1].set_ylabel('y')\n","\n","    cbar = fig.colorbar(ax[1].imshow(-np.log(density_data.T), origin='lower', extent=[-2, 2, -2, 2]), ax=ax[1])\n","    cbar.ax.set_ylabel('Free Energy')\n","\n","    #add a title to the figure\n","    if training:\n","        fig.suptitle('DDPM Samples and Data at Trainingstep: ' + str(step), fontsize=16)\n","    else:\n","        fig.suptitle('DDPM Samples and Data at ' + title_addition + str(step), fontsize=16)\n","\n","    plt.show()\n","\n","def plotting_fn_cross(samples, data, step):\n","    plt.hist(samples[:, :, 0].cpu().numpy().flatten(), bins=100, density=True)\n","    plt.hist(data[::10, :, 0].flatten(), bins=100, density=True, alpha=0.3)\n","    plt.title('Histogram of the samples')\n","    plt.xlabel('x')\n","    plt.ylabel('Density')\n","    plt.show()"]},{"cell_type":"markdown","metadata":{"id":"yl50mOsMi3A0"},"source":["Instantiate Models"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2787,"status":"ok","timestamp":1673866843043,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"z4Gp-UxCgvKY","outputId":"37cd9f80-1f1f-4510-cef1-3607a150cfc7"},"outputs":[],"source":["data_dir = \"../Data Generation/Data/traj_dw.pkl\"\n","file = open(data_dir,'rb')\n","data = pickle.load(file)[::, ::, 0:2]\n","\n","scaler_dir = \"../Data Generation/Data/dw_scaler.pkl\"\n","file = open(scaler_dir,'rb')\n","max_min_array = pickle.load(file)[::, 0:2]\n","\n","transition_dim = 2\n","    \n","horizon = [256]\n","plotting_fn = partial(plotting_fn, dim = transition_dim)\n","\n","model = TemporalUnetEnergy(horizon = horizon[0], transition_dim = transition_dim, cond_dim = None, dim=64, dim_mults=(1, 2), attention=False).to(device)\n","diffuser = GaussianDiffusion(model = model, horizon = horizon[0], action_dim=transition_dim, observation_dim=0, n_timesteps = 1000).to(device)\n","batch_size = 512\n","count_parameters(diffuser)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer = Trainer(diffuser, training_data = data, training_horizons = horizon, batch_sampler = batch_sampler, plotting_fn = plotting_fn_cross, train_lr=4e-6, train_batch_size=batch_size, run_name = \"best_no_att_\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.load(5000)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.train(15001)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["samples = trainer.ema_model.p_sample_loop((200, horizon[0], transition_dim), None).trajectories\n","\n","plotting_fn_cross(samples, data, 5000)"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Quality Check"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["h = horizon[0]\n","sample_count = 2000\n","iterations = 5\n","samples = torch.zeros(size = (iterations, sample_count, h, transition_dim))\n","for i in range(iterations):\n","    samples[i] = trainer.ema_model.p_sample_loop((sample_count, h, transition_dim), None).trajectories\n","    print(i)\n","\n","samples = samples.reshape(-1, h, transition_dim)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def plotting_fn_comp(samples, data, step, bins = 100, scale = 1):\n","    #a 2 by 2 plot of the samples and the data the first row is the same as in the plotting_fn function and the second row plots the difference between the samples and the data\n","    samples = samples.cpu().detach().numpy()\n","    samples = samples.reshape(-1, 4)\n","\n","    data = data[::, ::10].reshape(-1, 4)\n","\n","    samples = (samples + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","    data = (data + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","    \n","\n","    fig, ax = plt.subplots(2, 2, figsize = (14, 12))\n","\n","    density_samples = np.histogram2d(samples[::, 0], samples[::, 1], bins=bins, range=[[-2, 2], [-2, 2]])[0]\n","    density_samples = density_samples/np.sum(density_samples)\n","\n","    density_data = np.histogram2d(data[::, 0], data[::, 1], bins=bins, range=[[-2, 2], [-2, 2]])[0]\n","    density_data = density_data/np.sum(density_data)\n","\n","    #make a mask where the density of the data and the samples is not zero, nan or inf\n","    mask_zero = np.logical_and(density_samples != 0, density_data != 0)\n","    mask_nan = np.logical_and(np.isnan(density_samples) == False, np.isnan(density_data) == False)\n","    mask_inf = np.logical_and(np.isinf(density_samples) == False, np.isinf(density_data) == False)\n","    mask = np.logical_and(mask_zero, mask_nan)\n","    mask = np.logical_and(mask, mask_inf)\n","\n","    #multiply the maks with the samples and the data\n","    density_samples = density_samples*mask\n","    density_data = density_data*mask\n","\n","    log_density_samples = -np.log(density_samples.T)*scale\n","    log_density_data = -np.log(density_data.T)*scale\n","\n","    ax[0, 0].imshow(log_density_samples, origin='lower', extent=[-2, 2, -2, 2])\n","    ax[0, 0].set_title('Free Energy of DDPM Samples')\n","    ax[0, 0].set_xlabel('x')\n","    ax[0, 0].set_ylabel('y')\n","\n","    X, Y, Z = potential((100, 100))\n","    min_level = np.min(Z)\n","    Z = Z - min_level\n","    levels = np.logspace(0, 1, 10)\n","    levels = (levels + min_level)*scale\n","    Z = (Z + min_level)*scale\n","\n","    ax[0, 0].contour(X, Y, Z, levels=levels, alpha=1)\n","    cbar = fig.colorbar(ax[0, 0].contour(X, Y, Z, levels=levels, alpha=1), ax=ax[0, 0])\n","    cbar.ax.set_ylabel('Potential Energy')\n","\n","    ax[0, 1].imshow(log_density_data, origin='lower', extent=[-2, 2, -2, 2])\n","    ax[0, 1].set_title('Free Energy of Data')\n","    ax[0, 1].set_xlabel('x')\n","    ax[0, 1].set_ylabel('y')\n","\n","    cbar = fig.colorbar(ax[0, 1].imshow(log_density_data, origin='lower', extent=[-2, 2, -2, 2]), ax=ax[0, 1])\n","    cbar.ax.set_ylabel('Free Energy')\n","\n","    #add a title to the figure\n","    fig.suptitle('DDPM Samples and Data Comparison', fontsize=16)\n","    \n","    diff_free = log_density_samples - log_density_data\n","    diff_free_mean_nonnan_noninf = np.nanmean(diff_free[np.isfinite(diff_free)])\n","    diff_free_median_nonnan_noninf = np.nanmedian(diff_free[np.isfinite(diff_free)])\n","    diff_free_std_nonnan_noninf = np.nanstd(diff_free[np.isfinite(diff_free)])\n","    ax[1, 0].imshow(diff_free, origin='lower', extent=[-2, 2, -2, 2])\n","    ax[1, 0].set_title('Difference between DDPM Samples and Data \\n Mean: ' + str(np.round(diff_free_mean_nonnan_noninf, 2)) + ' Median: ' + str(np.round(diff_free_median_nonnan_noninf, 2)) + ' \\n Standard Deviation: ' + str(np.round(diff_free_std_nonnan_noninf, 2)))\n","    ax[1, 0].set_xlabel('x')\n","    ax[1, 0].set_ylabel('y')\n","\n","    cbar = fig.colorbar(ax[1, 0].imshow(diff_free, origin='lower', extent=[-2, 2, -2, 2]), ax=ax[1, 0])\n","    cbar.ax.set_ylabel('Difference in Free Energy')\n","\n","    #for the last plot show the difference between the samples and the data as a percentage\n","    diff_per = np.abs((log_density_samples - log_density_data)/log_density_data)*100\n","    diff_per_mean_nonnan_noninf = np.nanmean(diff_per[np.isfinite(diff_per)])\n","    diff_per_median_nonnan_noninf = np.nanmedian(diff_per[np.isfinite(diff_per)])\n","    diff_per_std_nonnan_noninf = np.nanstd(diff_per[np.isfinite(diff_per)])\n","\n","    ax[1, 1].imshow(diff_per, origin='lower', extent=[-2, 2, -2, 2])\n","    ax[1, 1].set_title('Difference in Percentage \\n Mean: ' + str(np.round(diff_per_mean_nonnan_noninf, 2)) + '  Median: ' + str(np.round(diff_free_median_nonnan_noninf, 2)) + ' \\n Standard Deviation: ' + str(np.round(diff_per_std_nonnan_noninf, 2)))\n","    ax[1, 1].set_xlabel('x')\n","    ax[1, 1].set_ylabel('y')\n","\n","    cbar = fig.colorbar(ax[1, 1].imshow(diff_per, origin='lower', extent=[-2, 2, -2, 2]), ax=ax[1, 1])\n","    cbar.ax.set_ylabel('Difference in Free Energy as a Percentage')\n","    #make the layout tight\n","    fig.tight_layout()\n","\n","    plt.show()\n","\n","def plotting_fn_comp_1d(samples, data, step, bins = 100, scale = 1):\n","    #now isntead of plotting the 2d free energy plot, plot the 1d free energy plot along the x axis for both the samples and the data\n","\n","    samples = samples.cpu().detach().numpy()\n","    samples = samples.reshape(-1, 4)\n","\n","    data = data[::, ::10].reshape(-1, 4)\n","\n","    samples = (samples + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","    data = (data + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","\n","    #calculate the free energy of the samples and the data\n","    samples_x = samples[::, 0]\n","    data_x = data[::, 0]\n","\n","    #calculate the free energy of the samples and the data\n","    log_density_samples, bin_edges = np.histogram(samples_x, bins=bins, density=True, range=(-1.6, 1.5))\n","    log_density_data, bin_edges = np.histogram(data_x, bins=bins, density=True, range=(-1.6, 1.5))\n","    #take the negative log of the free energy\n","    log_density_samples = -np.log(log_density_samples)\n","    log_density_data = -np.log(log_density_data)\n","\n","    #calculate the difference between the samples and the data\n","    diff_free = log_density_samples - log_density_data\n","    diff_free_mean_nonnan_noninf = np.nanmean(diff_free[np.isfinite(diff_free)])\n","    diff_free_median_nonnan_noninf = np.nanmedian(diff_free[np.isfinite(diff_free)])\n","    diff_free_std_nonnan_noninf = np.nanstd(diff_free[np.isfinite(diff_free)])\n","\n","    #calculate the difference between the samples and the data as a percentage\n","    diff_per = np.abs((log_density_samples - log_density_data)/log_density_data)*100\n","    diff_per_mean_nonnan_noninf = np.nanmean(diff_per[np.isfinite(diff_per)])\n","    diff_per_median_nonnan_noninf = np.nanmedian(diff_per[np.isfinite(diff_per)])\n","    diff_per_std_nonnan_noninf = np.nanstd(diff_per[np.isfinite(diff_per)])\n","\n","    #plot the free energy of the samples and the data\n","    fig, ax = plt.subplots(1, 2, figsize=(14, 10))\n","\n","    ax[0].plot(bin_edges[:-1], log_density_samples[::1], label='Samples')\n","    ax[0].plot(bin_edges[:-1], log_density_data[::1], label='Data')\n","    ax[0].set_title('Free Energy of Samples and Data')\n","    ax[0].set_xlabel('x')\n","    ax[0].set_ylabel('Free Energy')\n","    ax[0].legend()\n","\n","    #for the second plot show the difference between the samples and the data\n","    ax[1].plot(bin_edges[:-1], diff_free[::1])\n","    ax[1].set_title('Difference in Free Energy \\n Mean: ' + str(np.round(diff_free_mean_nonnan_noninf, 2)) + '  Median: ' + str(np.round(diff_free_median_nonnan_noninf, 2)) + ' \\n Standard Deviation: ' + str(np.round(diff_free_std_nonnan_noninf, 2)))\n","    ax[1].set_xlabel('x')\n","    ax[1].set_ylabel('Difference in Free Energy')\n","\n","    fig.tight_layout()\n","\n","    plt.show()\n","\n","\n","def plotting_fn_corr(samples, data, lag, correlation_samples, horizon):\n","    #calculate the autocorrelation function for both the samples and the data for both dimensions\n","    auto_corr = np.zeros((correlation_samples, lag+1, 2, transition_dim))\n","    for i in range(correlation_samples):\n","        replica_idx_sample = np.random.randint(0, samples.shape[0])\n","        replica_idx_data = np.random.randint(0, data.shape[0])\n","\n","        auto_corr[i, ::, 0, 0] = acf(samples[replica_idx_sample, ::, 0], nlags=lag)\n","        auto_corr[i, ::, 0, 1] = acf(samples[replica_idx_sample, ::, 1], nlags=lag)\n","\n","        auto_corr[i, ::, 1, 0] = acf(data[:horizon[0], replica_idx_sample, 0], nlags=lag)\n","        auto_corr[i, ::, 1, 1] = acf(data[:horizon[0], replica_idx_sample, 1], nlags=lag)\n","\n","    fig, ax = plt.subplots(transition_dim, 1, figsize = (14, 10))\n","\n","    for i in range(transition_dim):\n","        ax[i].plot(np.arange(lag+1), np.mean(auto_corr[::, ::, 0, i], axis=0), label='DDPM Samples')\n","        ax[i].plot(np.arange(lag+1), np.mean(auto_corr[::, ::, 1, i], axis=0), label='Data')\n","\n","        ax[i].fill_between(np.arange(lag+1), np.mean(auto_corr[::, ::, 0, i], axis=0) - np.std(auto_corr[::, ::, 0, i], axis=0), np.mean(auto_corr[::, ::, 0, i], axis=0) + np.std(auto_corr[::, ::, 0, i], axis=0), alpha=0.3)\n","        ax[i].fill_between(np.arange(lag+1), np.mean(auto_corr[::, ::, 1, i], axis=0) - np.std(auto_corr[::, ::, 1, i], axis=0), np.mean(auto_corr[::, ::, 1, i], axis=0) + np.std(auto_corr[::, ::, 1, i], axis=0), alpha=0.3)\n","        ax[i].set_title('Autocorrelation Function of DDPM Samples and Data for ' + str(i) + 'th Dimension')\n","        ax[i].set_xlabel('Lag')\n","        ax[i].set_ylabel('Autocorrelation')\n","        ax[i].legend()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plotting_fn_comp(samples, data, 8000, bins = 150)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plotting_fn_comp_1d(samples, data, 8000, bins = 25)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["plotting_fn_corr(samples, data, 40, 10000, horizon)"]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyMjwkK+LQFuanFegidX+2Vc","machine_shape":"hm","mount_file_id":"1T4CnIKZ8_7dslj0HpkrHuqBJXlIo1a4C","name":"","version":""},"gpuClass":"premium","kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.13"},"vscode":{"interpreter":{"hash":"639d8a7c3e620b1d142eea4deabde5aac9ed3b21a6e651e4622d69fbdac2ed0a"}}},"nbformat":4,"nbformat_minor":0}
