{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "arabic-relief",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"16\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "incident-january",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch.autograd import grad, Variable\n",
    "import autograd\n",
    "import autograd.numpy as np\n",
    "import copy\n",
    "import scipy as sp\n",
    "from scipy import stats\n",
    "from sklearn import metrics\n",
    "import sys\n",
    "import ot\n",
    "import gwot\n",
    "from gwot import models, sim, ts, util\n",
    "import gwot.bridgesampling as bs\n",
    "import dcor\n",
    "from tqdm import tqdm\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "import importlib\n",
    "import models\n",
    "import random\n",
    "import model_fig1 as model_sim\n",
    "import mmd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "authorized-coach",
   "metadata": {},
   "outputs": [],
   "source": [
    "PLT_CELL = 2.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "rapid-surveillance",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "\n",
    "fnames_all = glob.glob(\"out_N_*.npy\")\n",
    "fnames_all_gwot = glob.glob(\"out_gwot_N_*.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "found-works",
   "metadata": {},
   "outputs": [],
   "source": [
    "srand_all = np.array([int(f.split(\"_\")[4]) for f in fnames_all])\n",
    "srand_all_gwot = np.array([int(f.split(\"_\")[5]) for f in fnames_all_gwot])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "minute-affairs",
   "metadata": {},
   "outputs": [],
   "source": [
    "lamda_all = np.array([float(f.split(\"_\")[6].split(\".npy\")[0]) for f in fnames_all])\n",
    "lamda_all_gwot = np.array([float(f.split(\"_\")[7].split(\".npy\")[0]) for f in fnames_all_gwot])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "involved-chamber",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_all = np.array([int(f.split(\"_\")[2]) for f in fnames_all])\n",
    "N_all_gwot = np.array([int(f.split(\"_\")[3]) for f in fnames_all_gwot])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sought-database",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_all = [np.load(f, allow_pickle = True).item(0)[\"x\"] for f in fnames_all]\n",
    "x_gwot_all = [np.load(f, allow_pickle = True).item(0)[\"samples_gwot\"] for f in fnames_all_gwot]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "approximate-pattern",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup simulation object\n",
    "sim = gwot.sim.Simulation(V = model_sim.Psi, dV = model_sim.dPsi, birth_death = False, \n",
    "                          N = None,\n",
    "                          T = model_sim.T, \n",
    "                          d = model_sim.dim, \n",
    "                          D = model_sim.D, \n",
    "                          t_final = model_sim.t_final, \n",
    "                          ic_func = model_sim.ic_func, \n",
    "                          pool = None)\n",
    "\n",
    "sim_gt = copy.deepcopy(sim)\n",
    "sim_gt.N = np.array([1_000, ]*model_sim.T)\n",
    "sim_gt.sample(steps_scale = int(model_sim.sim_steps/sim.T));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "interpreted-bottle",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(np.linspace(0, model_sim.t_final, model_sim.T)[sim_gt.t_idx], sim_gt.x[:, 0], alpha = 0.01, color = \"blue\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "refined-syracuse",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_all[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "offshore-glenn",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    d_reconstruct = np.array([[dcor.energy_distance(sim_gt.x[sim_gt.t_idx == i, :], x_all[j][i, :]) for i in range(x_all[j].shape[0])] for j in tqdm(range(len(x_all)), position = 0, leave = True)])\n",
    "d_gwot = np.array([[dcor.energy_distance(sim_gt.x[sim_gt.t_idx == i, :], x_gwot_all[j][i, :]) for i in range(x_gwot_all[j].shape[0])] for j in tqdm(range(len(x_gwot_all)), position = 0, leave = True)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bridal-pierce",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(d_gwot[(N_all_gwot == 1) & (lamda_all_gwot == 0.005), :].mean(0), 'o-', label = \"gWOT\")\n",
    "plt.plot(d_reconstruct[(N_all == 1) & (lamda_all == 0.05), :].mean(0), 'o-', label = \"Langevin\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hawaiian-nature",
   "metadata": {},
   "outputs": [],
   "source": [
    "fnames_all[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "boring-champagne",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_vals, _ = np.unique(N_all, return_index = True)\n",
    "N_vals_gwot, _ = np.unique(N_all_gwot, return_index = True)\n",
    "lamda_vals, _ = np.unique(lamda_all, return_index = True)\n",
    "lamda_vals_gwot, _ = np.unique(lamda_all_gwot, return_index = True)\n",
    "srand_vals, _ = np.unique(srand_all, return_index = True)\n",
    "srand_vals_gwot, _ = np.unique(srand_all_gwot, return_index = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "subjective-reynolds",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_reconstruct_tensor = np.full((len(N_vals), len(lamda_vals), len(srand_vals), sim_gt.T), float(\"NaN\"))\n",
    "for (_N, _lamda, _srand) in zip(N_all, lamda_all, srand_all):\n",
    "    d_reconstruct_tensor[N_vals == _N, lamda_vals == _lamda, srand_vals == _srand, :] = d_reconstruct[(N_all == _N) & (lamda_all == _lamda) & (srand_all == _srand), :].flatten()\n",
    "\n",
    "d_gwot_tensor = np.full((len(N_vals_gwot), len(lamda_vals_gwot), len(srand_vals_gwot), sim_gt.T), float(\"NaN\"))\n",
    "for (_N, _lamda, _srand) in zip(N_all_gwot, lamda_all_gwot, srand_all_gwot):\n",
    "    d_gwot_tensor[N_vals_gwot == _N, lamda_vals_gwot == _lamda, srand_vals_gwot == _srand, :] = d_gwot[(N_all_gwot == _N) & (lamda_all_gwot == _lamda) & (srand_all_gwot == _srand), :].flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "latter-screening",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(lamda_vals, np.sqrt(d_reconstruct_tensor[0].mean(-1)).mean(-1), 'o-')\n",
    "plt.title(\"Langevin\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(lamda_vals_gwot, np.sqrt(d_gwot_tensor[0].mean(-1)).mean(-1), 'o-')\n",
    "plt.title(\"gWOT\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "generous-tyler",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3*PLT_CELL, 3/2*PLT_CELL))\n",
    "plt.subplot(1, 2, 1)\n",
    "# im = plt.imshow(np.nanmean(d_reconstruct_tensor, (2, 3)), origin = \"lower\")\n",
    "im = plt.imshow(np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1), origin = \"lower\")\n",
    "plt.xticks(range(len(lamda_vals)), lamda_vals, rotation = 30)\n",
    "plt.yticks(range(len(N_vals)), N_vals)\n",
    "plt.colorbar(im,fraction=0.038, pad=0.04)\n",
    "plt.xlabel(\"$\\lambda$\")\n",
    "plt.ylabel(\"N\")\n",
    "plt.title(\"Langevin\")\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "# im = plt.imshow(np.nanmean(d_gwot_tensor, (2, 3)), origin = \"lower\")\n",
    "im = plt.imshow(np.sqrt(d_gwot_tensor.mean(-1)).mean(-1), origin = \"lower\")\n",
    "plt.xticks(range(len(lamda_vals_gwot)), lamda_vals_gwot, rotation = 30)\n",
    "plt.yticks(range(len(N_vals_gwot)), N_vals_gwot)\n",
    "plt.colorbar(im,fraction=0.038, pad=0.04)\n",
    "plt.xlabel(\"$\\lambda$\")\n",
    "plt.ylabel(\"N\")\n",
    "plt.title(\"gWOT\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cardiovascular-interpretation",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "guided-emphasis",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(lamda_vals, np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1).T, 'o-');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "thousand-recovery",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(N_vals, np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1), 'o-', label = \"Langevin\");\n",
    "plt.ylim(0, 0.5)\n",
    "plt.legend()\n",
    "plt.xscale(\"log\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(N_vals, np.sqrt(d_gwot_tensor.mean(-1)).mean(-1), 'o-', label = \"gWOT\");\n",
    "plt.ylim(0, 0.5)\n",
    "plt.legend()\n",
    "plt.xscale(\"log\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "proof-singing",
   "metadata": {},
   "outputs": [],
   "source": [
    "means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "indirect-robinson",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (PLT_CELL, 1.75*PLT_CELL))\n",
    "\n",
    "sds = np.std(np.sqrt(d_reconstruct_tensor.mean(-1)), axis = 2)\n",
    "means = np.sqrt(d_reconstruct_tensor.mean(-1)).mean(-1)\n",
    "min_idx = np.nanargmin(means, axis = 1)\n",
    "sds_minmean = np.array([x[y] for (x, y) in zip(sds, min_idx)])\n",
    "plt.errorbar(N_vals, np.nanmin(means, 1), sds_minmean, label = \"MFL\", color = \"blue\", marker = \"o\")\n",
    "\n",
    "sds = np.std(np.sqrt(d_gwot_tensor.mean(-1)), axis = 2)\n",
    "means = np.sqrt(d_gwot_tensor.mean(-1)).mean(-1)\n",
    "min_idx = np.nanargmin(means, axis = 1)\n",
    "sds_minmean = np.array([x[y] for (x, y) in zip(sds, min_idx)])\n",
    "plt.errorbar(N_vals, np.nanmin(means, 1), sds_minmean, label = \"gWOT\", color = \"red\", marker = \"o\")\n",
    "\n",
    "plt.ylabel(\"RMS Energy Distance\")\n",
    "plt.xlabel(\"N\")\n",
    "\n",
    "plt.xscale(\"log\")\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../fig1_distances.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cooperative-commitment",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
