{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "502e06d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from src.light_sb import LightSB\n",
    "from src.distributions import LoaderSampler, TensorSampler\n",
    "from tqdm import tqdm\n",
    "from sklearn.decomposition import PCA\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import wandb\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics.pairwise import pairwise_distances"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ae13841",
   "metadata": {},
   "source": [
    "## Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "396f1eb4",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "DIM = 100\n",
    "assert DIM > 1\n",
    "\n",
    "SEED = 42\n",
    "BATCH_SIZE = 128\n",
    "EPSILON = 0.1\n",
    "D_LR = 1e-2\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "N_POTENTIALS = 10\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "INIT_BY_SAMPLES = True  \n",
    "IS_DIAGONAL = True\n",
    "DAY_START = 3\n",
    "DAY_END = 7\n",
    "DAY_EVAL = 4\n",
    "DEVICE = \"cpu\"\n",
    "EVAL_EVERY = 10000\n",
    "SERIES_ID = 1\n",
    "\n",
    "MAX_STEPS = 10000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a44ae96e",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(SEED); np.random.seed(SEED)\n",
    "EPS = EPSILON\n",
    "EPSILON_END = EPSILON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f58dca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EXP_NAME = f'Gaussians_Mixture_test_EPSILON_{EPSILON}_STEPS_{N_STEPS}_DIM_{DIM}'\n",
    "EXP_NAME = f'LightSB_Single_Cell_full_CITE_cell_DIM_{DIM}_DAY_EVAL_{DAY_EVAL}_EPSILON_{EPSILON}_SEED_{SEED}'\n",
    "OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    SERIES_ID=SERIES_ID,\n",
    "    DAY_START=DAY_START,\n",
    "    DAY_END=DAY_END,\n",
    "    DAY_EVAL=DAY_EVAL,\n",
    "    DIM=DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    INIT_BY_SAMPLES=INIT_BY_SAMPLES,\n",
    "    IS_DIAGONAL=IS_DIAGONAL,\n",
    "    SEED=SEED,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f487a958",
   "metadata": {},
   "source": [
    "## Data loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc014fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {}\n",
    "for day in [2, 3, 4, 7]:\n",
    "    data[day] = np.load(f\"../data/full_cite_pcas_{DIM}_day_{day}.npy\")\n",
    "    \n",
    "eval_data = data[DAY_EVAL]\n",
    "start_data = data[DAY_START]\n",
    "end_data = data[DAY_END]\n",
    "\n",
    "constant_scale = np.concatenate([start_data, end_data, eval_data]).std(axis=0).mean()\n",
    "\n",
    "eval_data_scaled = eval_data/constant_scale\n",
    "start_data_scaled = start_data/constant_scale\n",
    "end_data_scaled = end_data/constant_scale\n",
    "\n",
    "eval_data = torch.tensor(eval_data).float()\n",
    "start_data = torch.tensor(start_data_scaled).float()\n",
    "end_data = torch.tensor(end_data_scaled).float()\n",
    "\n",
    "X_sampler = TensorSampler(torch.tensor(start_data).float(), device=\"cpu\")\n",
    "Y_sampler = TensorSampler(torch.tensor(end_data).float(), device=\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3aafef3",
   "metadata": {},
   "source": [
    "## Model initialisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "949c8324",
   "metadata": {},
   "outputs": [],
   "source": [
    "D = LightSB(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,\n",
    "            sampling_batch_size=SAMPLING_BATCH_SIZE,\n",
    "            is_diagonal=IS_DIAGONAL).cpu()\n",
    "\n",
    "if INIT_BY_SAMPLES:\n",
    "    D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS).to(DEVICE))\n",
    "    \n",
    "D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)\n",
    "\n",
    "if CONTINUE > -1:\n",
    "    D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{CONTINUE}.pt')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07e6188b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mmd(x, y):\n",
    "    Kxx = pairwise_distances(x, x)\n",
    "    Kyy = pairwise_distances(y, y)\n",
    "    Kxy = pairwise_distances(x, y)\n",
    "\n",
    "    m = x.shape[0]\n",
    "    n = y.shape[0]\n",
    "    \n",
    "    c1 = 1 / ( m * (m - 1))\n",
    "    A = np.sum(Kxx - np.diag(np.diagonal(Kxx)))\n",
    "\n",
    "    # Term II\n",
    "    c2 = 1 / (n * (n - 1))\n",
    "    B = np.sum(Kyy - np.diag(np.diagonal(Kyy)))\n",
    "\n",
    "    # Term III\n",
    "    c3 = 1 / (m * n)\n",
    "    C = np.sum(Kxy)\n",
    "\n",
    "    # estimate MMD\n",
    "    mmd_est = -0.5*c1*A - 0.5*c2*B + c3*C\n",
    "    \n",
    "    return mmd_est"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "301a060e",
   "metadata": {},
   "source": [
    "## Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d7866a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):\n",
    "    # training cycle\n",
    "    D_opt.zero_grad()\n",
    "    \n",
    "    X0, X1 = X_sampler.sample(BATCH_SIZE).to(DEVICE), Y_sampler.sample(BATCH_SIZE).to(DEVICE)\n",
    "    \n",
    "    log_potential = D.get_log_potential(X1)\n",
    "    log_C = D.get_log_C(X0)\n",
    "    \n",
    "    D_loss = (-log_potential + log_C).mean()\n",
    "    D_loss.backward()\n",
    "    D_gradient_norm = torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)\n",
    "    D_opt.step()\n",
    "    \n",
    "    wandb.log({f'D gradient norm' : D_gradient_norm.item()}, step=step)\n",
    "    wandb.log({f'D_loss' : D_loss.item()}, step=step)\n",
    "    # eval and plots\n",
    "    if (step + 1) % EVAL_EVERY == 0:\n",
    "        with torch.no_grad():\n",
    "            plt.plot(np.stack(alphas, axis=0))\n",
    "            plt.show()\n",
    "            \n",
    "            plt.bar(np.arange(1, D.sum_component_probs.shape[0]+1), D.sum_component_probs/D.n_samples)\n",
    "            plt.grid()\n",
    "            plt.show()\n",
    "            D.reset_stats()\n",
    "            \n",
    "            X = X_sampler.sample(start_data.shape[0]).to(DEVICE)\n",
    "            Y = Y_sampler.sample(end_data.shape[0]).to(DEVICE)\n",
    "\n",
    "            XN = D(X)\n",
    "\n",
    "            X_mid_pred = D.sample_at_time_moment(X, torch.ones(X.shape[0], 1)*(DAY_EVAL - DAY_START)/(DAY_END-DAY_START)).detach().cpu().numpy()\n",
    "            \n",
    "            X_mid_pred = X_mid_pred*constant_scale\n",
    "        \n",
    "            \n",
    "            MMD = mmd(X_mid_pred, eval_data)\n",
    "            wandb.log({f'MMD' : MMD}, step=step)\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0345ff7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
