{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-05T16:56:29.096300Z",
     "start_time": "2024-05-05T16:56:23.876613Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "from functools import partial\n",
    "sys.path.append(\"../../\")\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "from ml_collections import ConfigDict\n",
    "import orbax.checkpoint\n",
    "\n",
    "from sdebridge.sdes import fourier_gaussian_kernel_sde\n",
    "from sdebridge import diffusion_bridge \n",
    "from sdebridge import plotting\n",
    "from sdebridge.networks.score_unet import ScoreUNet\n",
    "from sdebridge.utils import create_train_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-05T16:56:33.693719Z",
     "start_time": "2024-05-05T16:56:32.883920Z"
    }
   },
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '../data/papilonidae_mean_pts.npy'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mFileNotFoundError\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[2], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m papilonidae_mean \u001B[38;5;241m=\u001B[39m \u001B[43mjnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43m../data/papilonidae_mean_pts.npy\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m      2\u001B[0m papilonidaes \u001B[38;5;241m=\u001B[39m jnp\u001B[38;5;241m.\u001B[39mload(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m../data/papilonidae_pts.npy\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m      3\u001B[0m ambrax \u001B[38;5;241m=\u001B[39m papilonidaes[\u001B[38;5;241m2\u001B[39m]\n",
      "File \u001B[0;32m~/Documents/Python/sdebridge-project/env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:288\u001B[0m, in \u001B[0;36mload\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m    284\u001B[0m \u001B[38;5;129m@util\u001B[39m\u001B[38;5;241m.\u001B[39m_wraps(np\u001B[38;5;241m.\u001B[39mload, update_doc\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[1;32m    285\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mload\u001B[39m(\u001B[38;5;241m*\u001B[39margs: Any, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Array:\n\u001B[1;32m    286\u001B[0m   \u001B[38;5;66;03m# The main purpose of this wrapper is to recover bfloat16 data types.\u001B[39;00m\n\u001B[1;32m    287\u001B[0m   \u001B[38;5;66;03m# Note: this will only work for files created via np.save(), not np.savez().\u001B[39;00m\n\u001B[0;32m--> 288\u001B[0m   out \u001B[38;5;241m=\u001B[39m \u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    289\u001B[0m   \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(out, np\u001B[38;5;241m.\u001B[39mndarray):\n\u001B[1;32m    290\u001B[0m     \u001B[38;5;66;03m# numpy does not recognize bfloat16, so arrays are serialized as void16\u001B[39;00m\n\u001B[1;32m    291\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m out\u001B[38;5;241m.\u001B[39mdtype \u001B[38;5;241m==\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mV2\u001B[39m\u001B[38;5;124m'\u001B[39m:\n",
      "File \u001B[0;32m~/Documents/Python/sdebridge-project/env/lib/python3.9/site-packages/numpy/lib/npyio.py:427\u001B[0m, in \u001B[0;36mload\u001B[0;34m(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)\u001B[0m\n\u001B[1;32m    425\u001B[0m     own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[1;32m    426\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m--> 427\u001B[0m     fid \u001B[38;5;241m=\u001B[39m stack\u001B[38;5;241m.\u001B[39menter_context(\u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mos_fspath\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mrb\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m)\n\u001B[1;32m    428\u001B[0m     own_fid \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[1;32m    430\u001B[0m \u001B[38;5;66;03m# Code to distinguish from NumPy binary files and pickles.\u001B[39;00m\n",
      "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: '../data/papilonidae_mean_pts.npy'"
     ]
    }
   ],
   "source": [
    "papilonidae_mean = jnp.load('../data/papilonidae_mean_pts.npy')\n",
    "papilonidaes = jnp.load('../data/papilonidae_pts.npy')\n",
    "ambrax = papilonidaes[2]\n",
    "deiphobus = papilonidaes[5]\n",
    "protenor = papilonidaes[6]\n",
    "phestus = papilonidaes[10]\n",
    "polytes = papilonidaes[23]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-05T16:56:39.738480Z",
     "start_time": "2024-05-05T16:56:39.720012Z"
    }
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'papilonidae_mean' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mNameError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[3], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m plt\u001B[38;5;241m.\u001B[39mscatter(\u001B[43mpapilonidae_mean\u001B[49m[:, \u001B[38;5;241m0\u001B[39m], papilonidae_mean[:, \u001B[38;5;241m1\u001B[39m], s\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m5\u001B[39m, label\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mmean\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m      2\u001B[0m plt\u001B[38;5;241m.\u001B[39mscatter(ambrax[:, \u001B[38;5;241m0\u001B[39m], ambrax[:, \u001B[38;5;241m1\u001B[39m], s\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m5\u001B[39m, label\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mambrax\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m      3\u001B[0m plt\u001B[38;5;241m.\u001B[39mscatter(deiphobus[:, \u001B[38;5;241m0\u001B[39m], deiphobus[:, \u001B[38;5;241m1\u001B[39m], s\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m5\u001B[39m, label\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mdeiphobus\u001B[39m\u001B[38;5;124m'\u001B[39m)\n",
      "\u001B[0;31mNameError\u001B[0m: name 'papilonidae_mean' is not defined"
     ]
    }
   ],
   "source": [
    "plt.scatter(papilonidae_mean[:, 0], papilonidae_mean[:, 1], s=5, label='mean')\n",
    "plt.scatter(ambrax[:, 0], ambrax[:, 1], s=5, label='ambrax')\n",
    "plt.scatter(deiphobus[:, 0], deiphobus[:, 1], s=5, label='deiphobus')\n",
    "plt.scatter(protenor[:, 0], protenor[:, 1], s=5, label='protenor')\n",
    "plt.scatter(phestus[:, 0], phestus[:, 1], s=5, label='phestus')\n",
    "plt.scatter(polytes[:, 0], polytes[:, 1], s=5, label='polytes')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_bases = 16\n",
    "\n",
    "sde_config_template = ConfigDict(\n",
    "    {\n",
    "        'init_S': None,\n",
    "        'n_bases': n_bases,\n",
    "        'n_grid': 64,\n",
    "        'grid_range': [-1.0, 1.0],\n",
    "        'alpha': 0.15,\n",
    "        'sigma': 0.1,\n",
    "        'T': 1.0,\n",
    "        'N': 50,\n",
    "        'dim': 2\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_X_coeffs(initial_shape, target_shape=papilonidae_mean, n_bases=n_bases):\n",
    "    n_pts = initial_shape.shape[0]\n",
    "    X0_flatten = jnp.zeros((2*n_bases, ), dtype=jnp.complex64)\n",
    "    diff = target_shape - initial_shape\n",
    "    XT = jnp.fft.fftshift(jnp.fft.fft(diff, axis=0), axes=0)\n",
    "    XT = XT[(n_pts-n_bases)//2:(n_pts+n_bases)//2, :]\n",
    "    XT_flatten = jnp.concatenate([XT[:, 0], XT[:, 1]], axis=0)\n",
    "    return X0_flatten, XT_flatten\n",
    "\n",
    "def reconstruct_traj(Xs_flatten, initial_shape):\n",
    "    n_pts = initial_shape.shape[0]\n",
    "    Xs = jnp.stack(jnp.split(Xs_flatten, 2, axis=-1), axis=-1)\n",
    "    n_bases = Xs.shape[-2]\n",
    "    n_padding = (n_pts - n_bases) // 2\n",
    "    Xs = jnp.pad(Xs, ((0, 0), (0, 0), (n_padding, n_padding), (0, 0)))\n",
    "    Xs = jnp.fft.ifftshift(Xs, axes=-2)\n",
    "    traj = jnp.fft.ifft(Xs, n=n_pts, axis=-2, norm='backward').real\n",
    "    traj = traj + initial_shape[None, None, :, :]\n",
    "    return traj\n",
    "\n",
    "def load_ckpts(ckpt_path):\n",
    "    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n",
    "    ckpt_dict = orbax_checkpointer.restore(ckpt_path)\n",
    "    network_config = ckpt_dict[\"training_config\"][\"network\"]\n",
    "    new_state = create_train_state(\n",
    "        model=ScoreUNet(**network_config),\n",
    "        rng_key=jax.random.PRNGKey(0),\n",
    "        input_shapes=((1, 2*2*n_bases), (1, 1)),\n",
    "        learning_rate=0.01,\n",
    "        warmup_steps=1,\n",
    "        decay_steps=10\n",
    "    )\n",
    "    traget = {\n",
    "        \"state\": new_state,\n",
    "        \"training_config\": {},\n",
    "    }\n",
    "    ckpt = orbax_checkpointer.restore(ckpt_path, item=traget)\n",
    "    return ckpt[\"state\"]\n",
    "\n",
    "def simulate_bridge(initial_shape, traget_shape=papilonidae_mean, n_bases=n_bases, score_p=None):\n",
    "    sde_config = sde_config_template\n",
    "    sde_config.init_S = initial_shape\n",
    "    sde = fourier_gaussian_kernel_sde(**sde_config)\n",
    "    bridge = DiffusionBridge(sde)\n",
    "    X0_flatten, XT_flatten = compute_X_coeffs(initial_shape, traget_shape, n_bases)\n",
    "    Xs_flatten = bridge.simulate_backward_bridge(X0_flatten, XT_flatten, score_p=score_p, num_batches=4)[\"trajectories\"]\n",
    "    bwd_traj = reconstruct_traj(Xs_flatten, initial_shape)\n",
    "    return bwd_traj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_p_state = load_ckpts('./ckpts/score_p_16_bases_mean')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "@partial(jax.jit)\n",
    "def score_p(val, time):\n",
    "    score_output = score_p_state.apply_fn(\n",
    "        {\"params\": score_p_state.params, \"batch_stats\": score_p_state.batch_stats},\n",
    "        x_complex=val,\n",
    "        t=time,\n",
    "        train=False\n",
    "    )\n",
    "    score_real, score_imag = jnp.split(score_output, 2, axis=-1)\n",
    "    score = score_real + 1j * score_imag\n",
    "    return score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "bwd_traj_ambrax = simulate_bridge(ambrax, score_p=score_p)\n",
    "bwd_traj_deiphobus = simulate_bridge(deiphobus, score_p=score_p)\n",
    "bwd_traj_protenor = simulate_bridge(protenor, score_p=score_p)\n",
    "bwd_traj_phestus = simulate_bridge(phestus, score_p=score_p)\n",
    "bwd_traj_polytes = simulate_bridge(polytes, score_p=score_p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 50, 100, 2)\n",
      "(4, 50, 100, 2)\n",
      "(4, 50, 100, 2)\n",
      "(4, 50, 100, 2)\n",
      "(4, 50, 100, 2)\n"
     ]
    }
   ],
   "source": [
    "print(bwd_traj_ambrax.shape)\n",
    "print(bwd_traj_deiphobus.shape)\n",
    "print(bwd_traj_protenor.shape)\n",
    "print(bwd_traj_phestus.shape)\n",
    "print(bwd_traj_polytes.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "jnp.save('../data/papilonidae_traj/bwd_traj_ambrax.npy', bwd_traj_ambrax)\n",
    "jnp.save('../data/papilonidae_traj/bwd_traj_deiphobus.npy', bwd_traj_deiphobus)\n",
    "jnp.save('../data/papilonidae_traj/bwd_traj_protenor.npy', bwd_traj_protenor)\n",
    "jnp.save('../data/papilonidae_traj/bwd_traj_phestus.npy', bwd_traj_phestus)\n",
    "jnp.save('../data/papilonidae_traj/bwd_traj_polytes.npy', bwd_traj_polytes)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sdebridge",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
