{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":124939,"status":"ok","timestamp":1673856780101,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"azin34LbflEi","outputId":"8c88e0bf-67cd-4a23-c02d-4c4387ced40c"},"outputs":[],"source":["import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","\n","import time\n","import io\n","import copy\n","import os\n","import math\n","from scipy import integrate\n","from functools import partial\n","\n","import pickle\n","from pylab import *\n","\n","from torchvision import transforms\n","\n","from diffusers import DDPMScheduler, EulerDiscreteScheduler\n","from statsmodels.tsa.stattools import acf\n","\n","import einops\n","from einops.layers.torch import Rearrange\n","from functorch import jacrev, jacfwd, vmap\n","\n","import sys\n","\n","current_path = os.getcwd()\n","two_folders_up = os.path.join(current_path, '..', '..')\n","desired_folder = os.path.join(two_folders_up, 'SeqDiff')\n","absolute_desired_folder = os.path.abspath(desired_folder)\n","sys.path.insert(0, absolute_desired_folder)\n","\n","#from dynamicsdiffusion import count_parameters, cosine_beta_schedule, Losses, EMA, extract, Silent, make_timesteps, sort_by_values, Sample\n","from dynamicsdiffusion import count_parameters\n","from conditionalgeneration import TemporalUnetConditional, GaussianDiffusion, Trainer, default_sample_fn\n","\n","plt.rcParams.update({'font.size': 15})\n","device = 'cuda'"]},{"cell_type":"code","execution_count":null,"metadata":{"executionInfo":{"elapsed":218,"status":"ok","timestamp":1673866836981,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"HbfSo-pe1xj6"},"outputs":[],"source":["#@title Sampling and Logging\n","#sampling functions depend on both the data shape and if we have a conditional model or not and hence are defined in each notebook\n","def batch_sampler(traj, batch_size, horizon, transition_dim):\n","    \"\"\"Function to sample a batch of trajectories from the data\"\"\"\n","    if isinstance(horizon, list):\n","        horizon = np.random.choice(horizon)\n","    \n","    batch = torch.zeros(size = (batch_size, horizon, transition_dim))\n","    cond = torch.zeros(size = (batch_size, 1))\n","\n","    for i in range(batch_size):\n","        start_cond = np.random.randint(0, traj.shape[1])\n","        start_replica = np.random.randint(0, traj.shape[2])\n","\n","        if traj.shape[0] <= horizon:\n","            batch[i] = torch.tensor(traj[:, start_cond, start_replica, :])\n","        else:\n","            start_step = np.random.randint(0, traj.shape[0] - horizon)\n","            batch[i] = torch.tensor(traj[start_step:start_step + horizon, start_cond, start_replica, :])\n","\n","        cond[i] = (start_cond/traj.shape[1] - 0.5)*2\n","    return batch, cond\n","\n","def potential(shape):\n","    \"\"\"Potential function for the double well\"\"\"\n","    x = np.linspace(-2, 2, shape[0])\n","    y = np.linspace(-2, 2, shape[1])\n","    X, Y = np.meshgrid(x, y)\n","    return X, Y, -9*X**2 + X**3 + 4.5*x**4 + 3*Y**2\n","\n","def plotting_fn(samples, data, step, dim, training = True, title_addition = ''):\n","    \"\"\"Plotting function for the samples and the data to monitor training progress\"\"\"\n","    samples = samples.cpu().detach().numpy()\n","    samples = samples.reshape(-1, dim)\n","\n","    data = data[::10].reshape(-1, dim)\n","\n","    samples = (samples + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","    data = (data + 1)/2*(max_min_array[0] - max_min_array[1]) + max_min_array[1]\n","\n","\n","    fig, ax = plt.subplots(1, 2, figsize = (14, 5))\n","\n","    density_samples = np.histogram2d(samples[::, 0], samples[::, 1], bins=100, range=[[-2, 2], [-2, 2]])[0]\n","    density_samples = density_samples/np.sum(density_samples)\n","    ax[0].imshow(-np.log(density_samples.T), origin='lower', extent=[-2, 2, -2, 2])\n","    ax[0].set_title('Free Energy of DDPM Samples')\n","    ax[0].set_xlabel('x')\n","    ax[0].set_ylabel('y')\n","\n","    X, Y, Z = potential((100, 100))\n","    min_level = np.min(Z)\n","    Z = Z - min_level\n","    levels = np.logspace(0, 1, 10)\n","    levels = levels + min_level\n","    Z = Z + min_level\n","\n","    ax[0].contour(X, Y, Z, levels=levels, alpha=1)\n","    cbar = fig.colorbar(ax[0].contour(X, Y, Z, levels=levels, alpha=1), ax=ax[0])\n","    cbar.ax.set_ylabel('Potential Energy')\n","    \n","    density_data = np.histogram2d(data[::, 0], data[::, 1], bins=100, range=[[-2, 2], [-2, 2]])[0]\n","    density_data = density_data/np.sum(density_data)\n","    ax[1].imshow(-np.log(density_data.T), origin='lower', extent=[-2, 2, -2, 2])\n","    ax[1].set_title('Free Energy of Data')\n","    ax[1].set_xlabel('x')\n","    ax[1].set_ylabel('y')\n","\n","    cbar = fig.colorbar(ax[1].imshow(-np.log(density_data.T), origin='lower', extent=[-2, 2, -2, 2]), ax=ax[1])\n","    cbar.ax.set_ylabel('Free Energy')\n","\n","    #add a title to the figure\n","    if training:\n","        fig.suptitle('DDPM Samples and Data at Trainingstep: ' + str(step), fontsize=16)\n","    else:\n","        fig.suptitle('DDPM Samples and Data at ' + title_addition + str(step), fontsize=16)\n","\n","    plt.show()\n","\n","def plotting_fn_cross(samples, data, step, idx = None):\n","    \"\"\"Plotting function for the cross section of the data and the samples\"\"\"\n","    if idx is None:\n","        data_plot = data[:, :, :, 0]\n","    else:\n","        data_plot = data[:, idx, :, 0]\n","\n","    plt.hist(samples[:, :, 0].cpu().numpy().flatten(), bins=100, density=True)\n","    plt.hist(data_plot.flatten(), bins=100, density=True, alpha=0.3)\n","    if idx is None:\n","        plt.title('Histogram of the samples')\n","    else:\n","        kT = np.linspace(0.5, 1.5, 200)\n","        plt.title('Histogram of the samples at kT = ' + str(kT[idx]))\n","    plt.xlabel('x')\n","    plt.ylabel('Density')\n","    plt.show()"]},{"attachments":{},"cell_type":"markdown","metadata":{"id":"yl50mOsMi3A0"},"source":["## Instantiate Models"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2787,"status":"ok","timestamp":1673866843043,"user":{"displayName":"M Petersen","userId":"13637530209006951077"},"user_tz":-60},"id":"z4Gp-UxCgvKY","outputId":"37cd9f80-1f1f-4510-cef1-3607a150cfc7"},"outputs":[],"source":["import copy\n","\n","data_dir = \"../Data Generation/Data/traj_dw.pkl\"\n","file = open(data_dir,'rb')\n","data = pickle.load(file)[::, ::, ::, 0:2]\n","\n","scaler_dir = \"../Data Generation/Data/dw_scaler.pkl\"\n","file = open(scaler_dir,'rb')\n","max_min_array = pickle.load(file)[::, 0:2]\n","\n","transition_dim = 2\n","    \n","horizon = [256]\n","plotting_fn = partial(plotting_fn, dim = transition_dim)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model = TemporalUnetConditional(horizon = horizon[0], transition_dim = transition_dim, cond_dim = None, dim=64, dim_mults=(1, 2, 4), attention = True).to(device)\n","diffuser = GaussianDiffusion(model = model, horizon = horizon[0], action_dim=transition_dim, observation_dim=0, n_timesteps = 1000).to(device)\n","batch_size = 256\n","count_parameters(diffuser)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer = Trainer(diffuser, training_data = data, training_horizons = horizon, batch_sampler = batch_sampler, plotting_fn = None, train_lr=4e-4, train_batch_size=batch_size, run_name = \"demo_\") #, train_lr=4e-4, train_batch_size=batch_size, run_name = \"demo_\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.load(5000) #load the model at step 5000"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.train(5001) #train the model for 5001 steps"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer = Trainer(diffuser, training_data = data, training_horizons = horizon, batch_sampler = batch_sampler, plotting_fn = None, train_lr=4e-6, train_batch_size=batch_size, run_name = \"run_\") #create a new trainer with a lower learning rate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.load(5000) #load the model at step 5000 to continue training"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["trainer.train(2001) #train the model for 2001 steps"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Capability Demo"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["def plot_trajectories(data, *trajectories, labels=None, title='Trajectory Plot', point_sizes=None, colors=None):\n","    \"\"\"Plot trajectories with optional labels and point sizes and the data free energy as a reference\"\"\"\n","    pi = np.pi\n","    data_plot = data[::100].reshape(-1, 2)\n","    fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n","\n","    density_data = np.histogram2d(data_plot[::, 0], data_plot[::, 1], bins=200, range=[[-pi, pi], [-pi, pi]])[0]\n","    density_data = density_data / np.sum(density_data)\n","    ax.imshow(-np.log(density_data.T), origin='lower', extent=[-pi, pi, -pi, pi], alpha=0.5, label='Data Free Energy')\n","\n","    if labels is None:\n","        labels = [f'Trajectory {i+1}' for i in range(len(trajectories))]\n","\n","    if point_sizes is None:\n","        point_sizes = [0.8 for _ in range(len(trajectories))]\n","\n","    if colors is None:\n","        colors = [None for _ in range(len(trajectories))]\n","\n","    for i, trajectory in enumerate(trajectories):\n","        traj_plot = trajectory.cpu().detach().numpy().reshape(-1, 2)\n","        ax.scatter(traj_plot[::, 0], traj_plot[::, 1], s=point_sizes[i], alpha=1, label=labels[i], color=colors[i])\n","\n","    ax.set_title(title)\n","    ax.set_xlabel('x')\n","    ax.set_ylabel('y')\n","    ax.legend(loc='lower right')\n","    ax.set_xlim(-1, 1)\n","    ax.set_ylim(-1, 1)\n","    plt.show()"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Inpainting Demo"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["traj_count = 2 #number of trajectories to sample\t\n","length = 320 #length of each trajectory\n","repeats_x = 1 #number of times to repeat the start point\n","repeats_y = 50 #number of times to repeat the end point\n","\n","start_point = [-0.4, 0.0] #start point of the trajectory\n","end_point = [0.4, 0.0] #end point of the trajectory\n","\n","cond = 0*torch.ones((traj_count, 1)).to(device).float() #temperature of the trajectory\n","\n","sample_inpainting = trainer.ema_model.p_sample_loop_inpainting((traj_count, length, transition_dim), cond, fixed_frames = [torch.tensor(start_point).repeat(1, repeats_x, 1), torch.tensor(end_point).repeat(1, repeats_y, 1)], start_frames= [0, length-1-repeats_y]).trajectories\n","\n","#remove the fixed frames from the samples at the beginning and end and save them in a seperate array\n","sample_inpainting_fixed_start = sample_inpainting[:, 0:1, :]\n","sample_inpainting_fixed_end = sample_inpainting[:, -2:-1, :]\n","sample_inpainting_fixed = torch.cat((sample_inpainting_fixed_start, sample_inpainting_fixed_end), dim = 1)\n","\n","sample_inpainting = sample_inpainting[:, repeats_x:length-repeats_y, :]"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Call the function with sample trajectories, labels, title, point_sizes, and colors\n","plot_trajectories(data, sample_inpainting, sample_inpainting_fixed, labels=['Inpainting Trajectory', 'Inpainting Condition'], title='Inpainting Trajectories', point_sizes=[0.8, 15], colors=['blue', 'red'])"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Variation Demo"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#search the training data for a part of the trajectory that goes from state 0 to state 1 within a certain time horizon \n","state_a_bound_x = [-2, -0.6]\n","state_b_bound_x = [0.6, 2]\n","time_horizon = 256\n","#data.shape = (number of time steps, number of trajectories, number of dimensions)\n","def find_trajectory(data, state_a_bound_x, state_a_bound_y, time_horizon):\n","    found = False\n","    for i in range(data.shape[1]):\n","        for j in range(data.shape[0]-time_horizon):\n","            if data[j, i, 0] > state_a_bound_x[0] and data[j, i, 0] < state_a_bound_x[1] and data[j+time_horizon, i, 0] > state_b_bound_x[0]:\n","                found = True\n","                break\n","        if found:\n","            break\n","\n","    if found:\n","        return data[j:j+time_horizon,i,:]\n","    else:\n","        return None\n","#find a trajectory that goes from state 0 to state 1 within a certain time horizon\n","varation_sample = find_trajectory(data[:, 100], state_a_bound_x, state_b_bound_x, time_horizon)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["noise_level = 0.3\n","count = 1\n","cond = 0*torch.ones((count, 1)).to(device).float()\n","sample_variation = trainer.ema_model.p_sample_loop_variations(torch.tensor(varation_sample).unsqueeze(0).to(device), count, noise_level, cond, return_chain=True, sample_fn=default_sample_fn).trajectories\n","\n","# Call the function with sample trajectories, labels, title, point_sizes, and colors\n","plot_trajectories(data, sample_variation, torch.tensor(varation_sample).unsqueeze(0), labels=['Variation Trajectory', 'Variation Condition'], title='Variation Trajectories', point_sizes=[0.8, 0.8], colors=['blue', 'red'])"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Conditional Generation Demo"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["traj_count = 20 #number of trajectories to sample\t\n","\n","cond = 0*torch.ones((traj_count, 1)).to(device).float() #set the temperature of the trajectory to 1 KbT\n","sample_cond1 = trainer.ema_model.p_sample_loop((traj_count, horizon[0], transition_dim), cond).trajectories\n","\n","cond = 1*torch.ones((traj_count, 1)).to(device).float() #set the temperature of the trajectory to 1.5 KbT\n","sample_cond15 = trainer.ema_model.p_sample_loop((traj_count, horizon[0], transition_dim), cond).trajectories\n","\n","#plot the trajectories\n","plot_trajectories(data, sample_cond1, sample_cond15, labels=['Trajectory 1KbT', 'Trajectory 1.5KbT'], title='Trajectories at Different Temperatures', point_sizes=[0.8, 0.8], colors=['blue', 'red'])"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyMjwkK+LQFuanFegidX+2Vc","machine_shape":"hm","mount_file_id":"1T4CnIKZ8_7dslj0HpkrHuqBJXlIo1a4C","name":"","version":""},"gpuClass":"premium","kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.13"},"vscode":{"interpreter":{"hash":"639d8a7c3e620b1d142eea4deabde5aac9ed3b21a6e651e4622d69fbdac2ed0a"}}},"nbformat":4,"nbformat_minor":0}
