{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cd972eac",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6261775b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch\n",
    "import wandb\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.datasets import make_swiss_roll, make_moons\n",
    "\n",
    "from src.light_sb import LightSB\n",
    "from src.plotters import plot_2D, plot_2D_mapping, plot_2D_trajectory\n",
    "from src.distributions import StandardNormalSampler, SwissRollSampler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f036daa0",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8508840",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIM = 2\n",
    "assert DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 42\n",
    "\n",
    "N_POTENTIALS = 500\n",
    "INIT_BY_SAMPLES = True\n",
    "IS_DIAGONAL = True\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "EPSILON = 0.002\n",
    "\n",
    "D_LR = 3e-4 # 1e-3 for eps 0.1, 0.01 and 3e-4 for eps 0.002\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "PLOT_EVERY = 500\n",
    "MAX_STEPS = 20000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "339530f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(OUTPUT_SEED); np.random.seed(OUTPUT_SEED)\n",
    "\n",
    "EPS = EPSILON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caf71c45",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = f'LightSB_Swiss_Roll_EPSILON_{EPSILON}'\n",
    "OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    DIM=DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    INIT_BY_SAMPLES=INIT_BY_SAMPLES,\n",
    "    IS_DIAGONAL=IS_DIAGONAL,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c47f945",
   "metadata": {},
   "source": [
    "## 3. Create samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5973a562",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sampler = StandardNormalSampler(dim=2, device=\"cpu\")\n",
    "Y_sampler = SwissRollSampler(dim=2, device=\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfd8d737",
   "metadata": {},
   "source": [
    "## 4. Model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9172cfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "D = LightSB(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,\n",
    "            sampling_batch_size=SAMPLING_BATCH_SIZE, is_diagonal=IS_DIAGONAL)\n",
    "\n",
    "if INIT_BY_SAMPLES:\n",
    "    D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS))\n",
    "\n",
    "D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)\n",
    "\n",
    "if CONTINUE > -1:\n",
    "    D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{CONTINUE}.pt')))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3dbdab0",
   "metadata": {},
   "source": [
    "## 5. Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "410ec61f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):\n",
    "    # training cycle \n",
    "    D_opt.zero_grad()\n",
    "    \n",
    "    X0, X1 = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)\n",
    "    \n",
    "    log_potential = D.get_log_potential(X1)\n",
    "    log_C = D.get_log_C(X0)\n",
    "    \n",
    "    D_loss = (-log_potential + log_C).mean()\n",
    "    D_loss.backward()\n",
    "    D_gradient_norm = torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)\n",
    "    D_opt.step()\n",
    "    \n",
    "    wandb.log({f'D gradient norm' : D_gradient_norm.item()}, step=step)\n",
    "    wandb.log({f'D_loss' : D_loss.item()}, step=step)\n",
    "\n",
    "torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f'D.pt'))\n",
    "torch.save(D_opt.state_dict(), os.path.join(OUTPUT_PATH, f'D_opt.pt'))\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a98523d4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
