{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f29cefe0",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0656536",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import wandb\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from tqdm import tqdm\n",
    "\n",
    "from src.light_sb import LightSB\n",
    "\n",
    "from eot_benchmark.gaussian_mixture_benchmark import (\n",
    "    get_guassian_mixture_benchmark_sampler,\n",
    "    get_guassian_mixture_benchmark_ground_truth_sampler, \n",
    "    get_test_input_samples\n",
    ")\n",
    "\n",
    "from eot_benchmark.metrics import (\n",
    "    compute_BW_UVP_by_gt_samples, compute_BW_by_gt_samples, calculate_cond_bw\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a796493c",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efb9980b",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "DIM = 2\n",
    "assert DIM > 1\n",
    "\n",
    "SEED = 42\n",
    "BATCH_SIZE = 128\n",
    "EPSILON = 10\n",
    "D_LR = 6e-2 # 1e-3 for eps 0.1, 3e-3 for eps 1, 6e-2 for eps 10\n",
    "# ALPHA_LR = 1e-2\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "N_POTENTIALS = 50\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "INIT_BY_SAMPLES = True\n",
    "IS_DIAGONAL = True\n",
    "\n",
    "MAX_STEPS = 10000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c52d73ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(SEED); np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe552235",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = f'EOT_benchmark_EPSILON_{EPSILON}_DIM_{DIM}_SEED_{SEED}'\n",
    "OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    DIM=DIM,\n",
    "    D_LR=D_LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    INIT_BY_SAMPLES=INIT_BY_SAMPLES,\n",
    "    SEED=SEED,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "856c6034",
   "metadata": {},
   "source": [
    "## 3. Initialize Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa2fb89d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "X_sampler = get_guassian_mixture_benchmark_sampler(input_or_target=\"input\", dim=DIM, eps=EPSILON,\n",
    "                                           batch_size=BATCH_SIZE, device=f\"cpu\", download=True)\n",
    "Y_sampler = get_guassian_mixture_benchmark_sampler(input_or_target=\"target\", dim=DIM, eps=EPSILON,\n",
    "                                          batch_size=BATCH_SIZE, device=f\"cpu\", download=False)\n",
    "\n",
    "ground_truth_plan_sampler = get_guassian_mixture_benchmark_ground_truth_sampler(dim=DIM, eps=EPSILON,\n",
    "                                                                                batch_size=BATCH_SIZE, \n",
    "                                                                                device=f\"cpu\",\n",
    "                                                                                download=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b69102f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "pca = PCA(n_components=2)\n",
    "\n",
    "samples = X_sampler.sample(10000)\n",
    "samples = samples.cpu()\n",
    "\n",
    "target_samples = Y_sampler.sample(10000)\n",
    "target_samples = target_samples.cpu()\n",
    "\n",
    "pca.fit(target_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3d0f716",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_pca = pca.transform(samples)\n",
    "\n",
    "plt.scatter(samples_pca[:, 0], samples_pca[:, 1])\n",
    "plt.xlim(-10, 10)\n",
    "plt.ylim(-10, 10)\n",
    "plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76102008",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_samples_pca = pca.transform(target_samples)\n",
    "plt.scatter(target_samples_pca[:, 0], target_samples_pca[:, 1])\n",
    "plt.xlim(-10, 10)\n",
    "plt.ylim(-10, 10)\n",
    "plt.grid()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7609176d",
   "metadata": {},
   "source": [
    "## 4. Model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84491f14",
   "metadata": {},
   "outputs": [],
   "source": [
    "D = LightSB(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,\n",
    "            sampling_batch_size=SAMPLING_BATCH_SIZE, is_diagonal=IS_DIAGONAL, S_diagonal_init=1.0)\n",
    "\n",
    "if INIT_BY_SAMPLES:\n",
    "    D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS))\n",
    "\n",
    "D_opt = torch.optim.Adam(D.parameters(), lr=D_LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a5a164e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pca_plot(x_0_gt, x_1_gt, x_1_pred, n_plot, step):\n",
    "    fig,axes = plt.subplots(1, 3,figsize=(12,4),squeeze=True,sharex=True,sharey=True)\n",
    "    pca = PCA(n_components=2).fit(x_1_gt)\n",
    "    \n",
    "    x_0_gt_pca = pca.transform(x_0_gt[:n_plot])\n",
    "    x_1_gt_pca = pca.transform(x_1_gt[:n_plot])\n",
    "    x_1_pred_pca = pca.transform(x_1_pred[:n_plot])\n",
    "    \n",
    "    axes[0].scatter(x_0_gt_pca[:,0], x_0_gt_pca[:,1], c=\"g\", edgecolor = 'black',\n",
    "                    label = r'$x\\sim P_0(x)$', s =30)\n",
    "    axes[1].scatter(x_1_gt_pca[:,0], x_1_gt_pca[:,1], c=\"orange\", edgecolor = 'black',\n",
    "                    label = r'$x\\sim P_1(x)$', s =30)\n",
    "    axes[2].scatter(x_1_pred_pca[:,0], x_1_pred_pca[:,1], c=\"yellow\", edgecolor = 'black',\n",
    "                    label = r'$x\\sim T(x)$', s =30)\n",
    "    \n",
    "    for i in range(3):\n",
    "        axes[i].grid()\n",
    "        axes[i].set_xlim([-10, 10])\n",
    "        axes[i].set_ylim([-10, 10])\n",
    "        axes[i].legend()\n",
    "    \n",
    "    fig.tight_layout(pad=0.5)\n",
    "\n",
    "def plot_mapping(independent_mapping, true_mapping, predicted_mapping, target_data, n_plot, step):\n",
    "    s=30\n",
    "    linewidth=0.2\n",
    "    map_alpha=1\n",
    "    data_alpha=1\n",
    "    figsize=(5, 5)\n",
    "    dpi=None\n",
    "    data_color='red'\n",
    "    mapped_data_color='blue'\n",
    "    map_color='green'\n",
    "    map_label=None\n",
    "    data_label=None\n",
    "    mapped_data_label=None\n",
    "    \n",
    "    dim = target_data.shape[-1]\n",
    "    pca = PCA(n_components=2).fit(target_data)\n",
    "    \n",
    "    independent_mapping_pca = np.concatenate((        \n",
    "        pca.transform(independent_mapping[:n_plot, :dim]),\n",
    "        pca.transform(independent_mapping[:n_plot, dim:]),\n",
    "        ), axis=-1)\n",
    " \n",
    "    true_mapping_pca = np.concatenate((\n",
    "        pca.transform(true_mapping[:n_plot, :dim]),\n",
    "        pca.transform(true_mapping[:n_plot, dim:]),\n",
    "    ), axis=-1)\n",
    "    \n",
    "    predicted_mapping_pca = np.concatenate((\n",
    "        pca.transform(predicted_mapping[:n_plot, :dim]),\n",
    "        pca.transform(predicted_mapping[:n_plot, dim:]),\n",
    "    ), axis=-1)\n",
    "    \n",
    "    target_data_pca = pca.transform(target_data)\n",
    "    \n",
    "    fig, axes = plt.subplots(1, 3, figsize=(12,4),squeeze=True,sharex=True,sharey=True)\n",
    "    titles = [\"independent\", \"true\", \"predicted\"]\n",
    "    for i, mapping in enumerate([independent_mapping_pca, true_mapping_pca, predicted_mapping_pca]):\n",
    "        inp = mapping[:, :2]\n",
    "        out = mapping[:, 2:]\n",
    "\n",
    "        lines = np.concatenate([inp, out], axis=-1).reshape(-1, 2, 2)\n",
    "        lc = matplotlib.collections.LineCollection(\n",
    "            lines, color=map_color, linewidths=linewidth, alpha=map_alpha, label=map_label)\n",
    "        axes[i].add_collection(lc)\n",
    "\n",
    "        axes[i].scatter(\n",
    "            inp[:, 0], inp[:, 1], s=s, label=data_label,\n",
    "            alpha=data_alpha, zorder=2, color=data_color)\n",
    "        axes[i].scatter(\n",
    "            out[:, 0], out[:, 1], s=s, label=mapped_data_label,\n",
    "            alpha=data_alpha, zorder=2, color=mapped_data_color)\n",
    "\n",
    "        axes[i].scatter(target_data_pca[:1000,0], target_data_pca[:1000,1], c=\"orange\", edgecolor = 'black',\n",
    "                    label = r'$x\\sim P_1(x)$', s =10)\n",
    "        axes[i].grid()\n",
    "        axes[i].set_title(titles[i])\n",
    "    \n",
    "\n",
    "    \n",
    "    \n",
    "def pca_plot_plan(independent_plan, true_plan, predicted_plan, n_plot, step):\n",
    "    fig,axes = plt.subplots(1, 3,figsize=(12,4),squeeze=True,sharex=True,sharey=True)\n",
    "    pca = PCA(n_components=2).fit(true_plan)\n",
    "    \n",
    "    predicted_plan_pca = pca.transform(predicted_plan[:n_plot])\n",
    "    true_plan_pca = pca.transform(true_plan[:n_plot])\n",
    "    independent_plan_pca = pca.transform(independent_plan[:n_plot])\n",
    "    \n",
    "    axes[0].scatter(independent_plan_pca[:,0], independent_plan_pca[:,1], c=\"g\", edgecolor = 'black',\n",
    "                    label = r'Independent plan', s =30)\n",
    "    axes[1].scatter(true_plan_pca[:,0], true_plan_pca[:,1], c=\"orange\", edgecolor = 'black',\n",
    "                    label = r'True plan', s =30)\n",
    "    axes[2].scatter(predicted_plan_pca[:,0], predicted_plan_pca[:,1], c=\"yellow\", edgecolor = 'black',\n",
    "                    label = r'Predicted plan', s =30)\n",
    "    \n",
    "    for i in range(3):\n",
    "        axes[i].grid()\n",
    "        axes[i].set_xlim([-10, 10])\n",
    "        axes[i].set_ylim([-10, 10])\n",
    "        axes[i].legend()\n",
    "    \n",
    "    fig.tight_layout(pad=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20836460",
   "metadata": {},
   "source": [
    "## 5. Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f11cc63",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):\n",
    "    # training cycle\n",
    "    D_opt.zero_grad()\n",
    "    \n",
    "    X0, X1 = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)\n",
    "    \n",
    "    log_potential = D.get_log_potential(X1)\n",
    "    log_C = D.get_log_C(X0)\n",
    "    \n",
    "    D_loss = (-log_potential + log_C).mean()\n",
    "    D_loss.backward()\n",
    "    D_gradient_norm = torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=D_GRADIENT_MAX_NORM)\n",
    "    D_opt.step()\n",
    "    \n",
    "    wandb.log({f'D gradient norm' : D_gradient_norm.item()}, step=step)\n",
    "    wandb.log({f'D_loss' : D_loss.item()}, step=step)\n",
    "    \n",
    "# plotting and evaluating\n",
    "with torch.no_grad():\n",
    "# bw-uvp\n",
    "    X = X_sampler.sample(100000).cpu()\n",
    "    Y = Y_sampler.sample(100000).cpu()\n",
    "\n",
    "    XN = D(X)\n",
    "\n",
    "    X0_gt, XN_gt = ground_truth_plan_sampler.sample(20)\n",
    "\n",
    "    bw_uvp_target = compute_BW_UVP_by_gt_samples(XN.cpu().numpy(), Y.cpu().numpy())\n",
    "\n",
    "    pca_plot(X.detach().cpu().numpy(), Y.detach().cpu().numpy(),\n",
    "         XN.detach().cpu().numpy(), n_plot=500, step=step)\n",
    "\n",
    "    # calculate cond_bw\n",
    "    test_samples = get_test_input_samples(dim=DIM, device=\"cpu\")\n",
    "\n",
    "    model_input = test_samples.reshape(1000, 1, -1).repeat(1, 1000, 1)\n",
    "    predictions = []\n",
    "    for inp in tqdm(model_input):\n",
    "        predictions.append(D(inp))\n",
    "\n",
    "    predictions = torch.stack(predictions, dim=0)\n",
    "\n",
    "    # calculate cond_bw new\n",
    "    new_cond_bw = calculate_cond_bw(test_samples, predictions, eps=EPSILON, dim=DIM)\n",
    "\n",
    "    X_repeated = X0_gt[:5].repeat(20, 1)\n",
    "    Y_true_mapped = ground_truth_plan_sampler.conditional_plan.sample(X_repeated)\n",
    "\n",
    "    X_repeated = X_repeated.cpu()\n",
    "    Y_true_mapped = Y_true_mapped.cpu()\n",
    "\n",
    "    true_plan = torch.cat((X_repeated, Y_true_mapped), dim=-1)\n",
    "    independent_plan = torch.cat((X_repeated, Y[:100]), dim=-1)\n",
    "    predicted_plan = torch.cat((X_repeated, D(X_repeated)), dim=-1)\n",
    "\n",
    "    plot_mapping(independent_plan.detach().cpu().numpy(), true_plan.detach().cpu().numpy(),\n",
    "             predicted_plan.detach().cpu().numpy(), Y.detach().cpu().numpy(),\n",
    "                 n_plot=20, step=step)\n",
    "\n",
    "    wandb.log({f'BW-UVP_target' : bw_uvp_target}, step=step)\n",
    "    wandb.log({f'new_cond_bw' : new_cond_bw}, step=step)\n",
    "            \n",
    "torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f'D.pt'))\n",
    "torch.save(D_opt.state_dict(), os.path.join(OUTPUT_PATH, f'D_opt.pt'))\n",
    "\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd8258cf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
