{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "54873307",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from TrajectoryNet.dataset import EBData\n",
    "\n",
    "from src.light_sb_ou import LightSB_OU\n",
    "from src.distributions import LoaderSampler, TensorSampler\n",
    "from tqdm import tqdm\n",
    "from sklearn.decomposition import PCA\n",
    "from TrajectoryNet.optimal_transport.emd import earth_mover_distance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "278e4eac",
   "metadata": {},
   "source": [
    "## Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fb2772a4",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "DIM = 5\n",
    "assert DIM > 1\n",
    "\n",
    "SEED = 42\n",
    "BATCH_SIZE = 128\n",
    "EPSILON = 0.1\n",
    "D_LR = 1e-2\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "N_POTENTIALS = 100\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "INIT_BY_SAMPLES = True\n",
    "IS_DIAGONAL = True\n",
    "T = 3\n",
    "DEVICE = \"cpu\"\n",
    "\n",
    "MAX_STEPS = 2000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "32fc49dd-257d-43fd-b797-06231c770672",
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_consistent_evaluation():\n",
    "    EVAL_SEED = 0xBADBEEF \n",
    "    torch.manual_seed(EVAL_SEED)\n",
    "    np.random.seed(EVAL_SEED)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(EVAL_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3c485eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPS = EPSILON\n",
    "EPSILON_END = EPSILON"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db8ba029",
   "metadata": {},
   "source": [
    "## Data loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5bc87ec0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Clipping dimensionality to 5\n"
     ]
    }
   ],
   "source": [
    "ds = EBData('pcs', max_dim=5)\n",
    "\n",
    "frame_0_start, frame_0_end = np.where(ds.labels == 0)[0][0], np.where(ds.labels == 0)[0][-1]\n",
    "frame_1_start, frame_1_end = np.where(ds.labels == 1)[0][0], np.where(ds.labels == 1)[0][-1]\n",
    "frame_2_start, frame_2_end = np.where(ds.labels == 2)[0][0], np.where(ds.labels == 2)[0][-1]\n",
    "frame_3_start, frame_3_end = np.where(ds.labels == 3)[0][0], np.where(ds.labels == 3)[0][-1]\n",
    "frame_4_start, frame_4_end = np.where(ds.labels == 4)[0][0], np.where(ds.labels == 4)[0][-1]\n",
    "\n",
    "X_mid_1 = ds.get_data()[frame_1_start:frame_1_end+1]\n",
    "X_mid_2 = ds.get_data()[frame_2_start:frame_2_end+1]\n",
    "X_mid_3 = ds.get_data()[frame_3_start:frame_3_end+1]\n",
    "\n",
    "if T == 1:\n",
    "    X_mid = X_mid_1\n",
    "    \n",
    "    X_0_f = ds.get_data()[frame_0_start:frame_0_end+1]\n",
    "    X_1_f = ds.get_data()[frame_2_start:frame_2_end+1]\n",
    "elif T == 2:\n",
    "    X_mid = X_mid_2\n",
    "    \n",
    "    X_0_f = ds.get_data()[frame_1_start:frame_1_end+1]\n",
    "    X_1_f = ds.get_data()[frame_3_start:frame_3_end+1] \n",
    "elif T == 3:\n",
    "    X_mid = X_mid_3\n",
    "    \n",
    "    X_0_f = ds.get_data()[frame_2_start:frame_2_end+1]\n",
    "    X_1_f = ds.get_data()[frame_4_start:frame_4_end+1]\n",
    "\n",
    "X_sampler = TensorSampler(torch.tensor(X_0_f).float(), device=\"cpu\")\n",
    "Y_sampler = TensorSampler(torch.tensor(X_1_f).float(), device=\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c7d1f42",
   "metadata": {},
   "source": [
    "## Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b3a27187-64e0-4956-8fdf-84d68c798261",
   "metadata": {},
   "outputs": [],
   "source": [
    "b_T = [-0.2, -0.2, 0.2]\n",
    "mu_T = [4, 4, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "827a4344-81c1-4924-8e49-43cb18ec7873",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T mean_emd  1 0.7866985570527272\n",
      "T mean_emd  2 0.8137121286435318\n",
      "T mean_emd  3 0.8379492544440297\n",
      "0.8127866467134295 0.020933241242513858\n"
     ]
    }
   ],
   "source": [
    "result = []\n",
    "setup_consistent_evaluation()\n",
    "\n",
    "for T in [1, 2, 3]:    \n",
    "    if T == 1:\n",
    "        X_mid = X_mid_1\n",
    "        X_0_f = ds.get_data()[frame_0_start:frame_0_end+1]\n",
    "        X_1_f = ds.get_data()[frame_2_start:frame_2_end+1]\n",
    "    elif T == 2:\n",
    "        X_mid = X_mid_2\n",
    "        X_0_f = ds.get_data()[frame_1_start:frame_1_end+1]\n",
    "        X_1_f = ds.get_data()[frame_3_start:frame_3_end+1] \n",
    "    elif T == 3:\n",
    "        X_mid = X_mid_3\n",
    "        X_0_f = ds.get_data()[frame_2_start:frame_2_end+1]\n",
    "        X_1_f = ds.get_data()[frame_4_start:frame_4_end+1]\n",
    "    \n",
    "    X_sampler = TensorSampler(torch.tensor(X_0_f).float(), device=\"cpu\")\n",
    "    Y_sampler = TensorSampler(torch.tensor(X_1_f).float(), device=\"cpu\")\n",
    "    \n",
    "    trial_results = []\n",
    "    for i in range(5):\n",
    "        D = LightSB_OU(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON, \n",
    "                        m=mu_T[T-1], b=b_T[T-1], sampling_batch_size=SAMPLING_BATCH_SIZE,\n",
    "                        is_diagonal=IS_DIAGONAL).cpu()\n",
    "                \n",
    "        if INIT_BY_SAMPLES:\n",
    "            D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS).to(DEVICE))\n",
    "                \n",
    "        D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)\n",
    "                \n",
    "        for step in range(CONTINUE + 1, MAX_STEPS):\n",
    "            D_opt.zero_grad()\n",
    "            X0, X1 = X_sampler.sample(BATCH_SIZE).to(DEVICE), Y_sampler.sample(BATCH_SIZE).to(DEVICE)\n",
    "                    \n",
    "            log_potential = D.get_log_potential(X1)\n",
    "            log_C = D.get_log_C(X0)\n",
    "                    \n",
    "            D_loss = (-log_potential + log_C).mean()\n",
    "            D_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)\n",
    "            D_opt.step()\n",
    "                \n",
    "        with torch.no_grad():\n",
    "            X = X_sampler.sample(X_0_f.shape[0]).to(DEVICE)\n",
    "            X_mid_pred = D.sample_at_time_moment(X, torch.ones(X.shape[0], 1)*0.5).detach().cpu().numpy()\n",
    "            EMD = earth_mover_distance(X_mid_pred, X_mid)\n",
    "            trial_results.append(EMD)\n",
    "            \n",
    "    mean_emd = np.mean(trial_results)\n",
    "    print(\"T mean_emd \", T, mean_emd)\n",
    "    result.append(mean_emd)\n",
    "\n",
    "print(np.mean(result), np.std(result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "673cab3c-5e27-43ff-a195-0f86cd7d6983",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
