{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "abec0ba4",
   "metadata": {},
   "source": [
    "## 1. Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad75d113",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import wandb\n",
    "import random\n",
    "from math import pi\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.datasets import make_swiss_roll, make_moons\n",
    "from sklearn import mixture\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import clear_output\n",
    "import ot\n",
    "from matplotlib import collections  as mc\n",
    "\n",
    "from src.ulight_ot import ULightOT\n",
    "from src.plotters import plot_2D, plot_2D_mapping, plot_2D_trajectory\n",
    "from src.distributions import StandardNormalSampler, SwissRollSampler\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81bc58d8",
   "metadata": {},
   "source": [
    "## 2. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70930174",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIM = 2\n",
    "assert DIM > 1\n",
    "\n",
    "OUTPUT_SEED = 42\n",
    "\n",
    "N_POTENTIALS = 10\n",
    "INIT_BY_SAMPLES = True\n",
    "IS_DIAGONAL = True\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "SAMPLING_BATCH_SIZE = 128\n",
    "\n",
    "EPSILON = 0.05\n",
    "\n",
    "LR = 1e-3\n",
    "D_GRADIENT_MAX_NORM = float(\"inf\")\n",
    "\n",
    "PLOT_EVERY = 2000\n",
    "MAX_STEPS = 20000\n",
    "CONTINUE = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70cdedac",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(OUTPUT_SEED); np.random.seed(OUTPUT_SEED)\n",
    "\n",
    "EPS = EPSILON"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c647e597",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = f'ULightOT_Gauss_{EPSILON}'\n",
    "OUTPUT_PATH = '../checkpoints/{}'.format(EXP_NAME)\n",
    "\n",
    "config = dict(\n",
    "    DIM=DIM,\n",
    "    LR=LR,\n",
    "    BATCH_SIZE=BATCH_SIZE,\n",
    "    EPSILON=EPSILON,\n",
    "    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,\n",
    "    N_POTENTIALS=N_POTENTIALS,\n",
    "    INIT_BY_SAMPLES=INIT_BY_SAMPLES,\n",
    "    IS_DIAGONAL=IS_DIAGONAL,\n",
    ")\n",
    "\n",
    "if not os.path.exists(OUTPUT_PATH):\n",
    "    os.makedirs(OUTPUT_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2dea248",
   "metadata": {},
   "source": [
    "## 3. Create samplers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf675b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_normal(size=64, loc=(0, 0), scale=(0.2, 0.2), device='cpu'):\n",
    "    return np.random.normal(size=(size, 2), loc=loc, scale=scale).astype(np.float32)\n",
    "\n",
    "def sample_mixture(size=128, loc0=(0, -1), loc1=(0, 1), \n",
    "                   scale=(0.2, 0.2), weights=(0.25, 0.75), device='cpu'):\n",
    "    locs = np.array([loc0, loc1])\n",
    "    indices = random.choices(range(len(locs)), k=size, weights=weights)\n",
    "    balls = locs[indices] + np.random.normal(size=(size, 2), loc=(0, 0), scale=scale).astype(np.float32)\n",
    "    return balls\n",
    "\n",
    "class Sampler:\n",
    "    def __init__(\n",
    "        self, device='cuda',\n",
    "    ):\n",
    "        self.device = device\n",
    "    \n",
    "    def sample(self, size=5):\n",
    "        pass\n",
    "    \n",
    "class MixtureNormalSampler(Sampler):\n",
    "    def __init__(self, dim=2, loc0=(0, -1), loc1=(0, 1), \n",
    "                   scale=(0.2, 0.2), weights=(0.25, 0.75),device='cpu'):\n",
    "        super(MixtureNormalSampler, self).__init__(device=device)\n",
    "        self.dim = dim\n",
    "        self.loc0, self.loc1 = loc0, loc1\n",
    "        self.scale = scale\n",
    "        self.weights = weights\n",
    "        device = self.device\n",
    "        \n",
    "    def sample(self, batch_size=10):\n",
    "        batch = sample_mixture(size=batch_size, loc0=self.loc0, loc1=self.loc1, \n",
    "                   scale=self.scale, weights=self.weights, device=self.device)\n",
    "        return torch.tensor(batch, device=self.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de48a2da",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sampler = MixtureNormalSampler(weights=(0.25, 0.75), loc0=(-3, 3), loc1=(1, 3))\n",
    "Y_sampler = MixtureNormalSampler(weights=(0.75, 0.25), loc0=(-3, 0), loc1=(1, 0))\n",
    "x_p = X_sampler.sample(250)\n",
    "y_q = Y_sampler.sample(250)\n",
    "\n",
    "fig,axes = plt.subplots(1, 1, figsize=(5,5),squeeze=True,sharex=True,sharey=True)\n",
    "axes.scatter(x_p[:, 0], x_p[:, 1],  s=60,c='white', edgecolors='black', zorder=3, label = r'$x\\sim p$', alpha=1)\n",
    "axes.scatter(y_q[:, 0], y_q[:, 1], s=60,c='grey', edgecolors='black', zorder=3, label = r'$y\\sim q$', alpha=1)\n",
    "\n",
    "lims=((-4, 2), (-1, 4))\n",
    "\n",
    "axes.set_xlim(*lims[0])\n",
    "axes.set_ylim(*lims[1])\n",
    "axes.set_yticklabels([])\n",
    "axes.set_xticklabels([])\n",
    "axes.grid(True)\n",
    "fig.tight_layout(pad=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5422bfb0",
   "metadata": {},
   "source": [
    "## 3. Plotter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74edb57c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_2D_lines_p0(step, model, z, x_1, lims=((-4, 2), (-1, 4)), x_scatter_alpha=1):\n",
    "    fig,axes = plt.subplots(1, 1, figsize=(5,5),squeeze=True,sharex=True,sharey=True)\n",
    "\n",
    "    x_0 = D.sample_marginal(250)\n",
    "    axes.scatter(z[:, 0], z[:, 1],  s=60,c='white', edgecolors='black', zorder=3, label = r'$x\\sim p$', alpha=1)\n",
    "    axes.scatter(x_1[:, 0], x_1[:, 1], s=60,c='grey', edgecolors='black', zorder=3, label = r'$y\\sim q$', alpha=1)\n",
    "    axes.scatter(x_0[:, 0], x_0[:, 1], s=60, c='#ae00deff', edgecolors='black', zorder=3, label = r'$\\hat{x}\\sim p_{\\omega}(x)$')\n",
    "    y_pred = model(x_0)\n",
    "    axes.scatter(y_pred[:, 0], y_pred[:, 1], s=60,c=\"#ff866dff\", edgecolors=\"black\", label = r'$\\hat{y}\\sim \\gamma_{\\theta,\\omega}(\\cdot|x)$', zorder=3)\n",
    "    \n",
    "    lines =  list(zip(x_0[:100].detach().cpu().numpy().astype('float'), \n",
    "                      y_pred[:100].detach().cpu().numpy().astype('float')))\n",
    "    lc = mc.LineCollection(lines, linewidths=0.5, color='black')\n",
    "    axes.add_collection(lc)\n",
    "\n",
    "    axes.set_xlim(*lims[0])\n",
    "    axes.set_ylim(*lims[1])\n",
    "    axes.set_yticklabels([])\n",
    "    axes.set_xticklabels([])\n",
    "    axes.grid(True)\n",
    "    fig.tight_layout(pad=0.5)\n",
    "    \n",
    "    return fig"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d936cbc5",
   "metadata": {},
   "source": [
    "## 4. Model initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2e3e62c",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIVERGENCE = \"UKL\" # 'UKL' \"Xi2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8549e051",
   "metadata": {},
   "outputs": [],
   "source": [
    "# wandb.init(name=EXP_NAME, config=config)\n",
    "\n",
    "D = ULightOT(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,\n",
    "            sampling_batch_size=SAMPLING_BATCH_SIZE, is_diagonal=IS_DIAGONAL)\n",
    "\n",
    "log_m = torch.zeros(1, requires_grad=True)\n",
    "\n",
    "if INIT_BY_SAMPLES:\n",
    "    D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS))\n",
    "\n",
    "D_opt = torch.optim.Adam(D.parameters(), lr=LR)\n",
    "m_opt = torch.optim.Adam([log_m], lr=LR)\n",
    "\n",
    "if CONTINUE > -1:\n",
    "    D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{CONTINUE}.pt')))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "814712e6",
   "metadata": {},
   "source": [
    "## 5. Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fe77ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "Taus = [5] # Set unbalancedness parameter\n",
    "\n",
    "for Tau in Taus:\n",
    "    arr = []\n",
    "    D = ULightOT(dim=DIM, n_potentials=N_POTENTIALS, epsilon=EPSILON,\n",
    "            sampling_batch_size=SAMPLING_BATCH_SIZE, is_diagonal=IS_DIAGONAL)\n",
    "\n",
    "    log_m = torch.zeros(1, requires_grad=True)\n",
    "\n",
    "    if INIT_BY_SAMPLES:\n",
    "        D.init_r_by_samples(Y_sampler.sample(N_POTENTIALS))\n",
    "\n",
    "    D_opt = torch.optim.Adam(D.parameters(), lr=LR)\n",
    "    m_opt = torch.optim.Adam([log_m], lr=LR)\n",
    "    \n",
    "    for step in tqdm(range(0, 20000)):\n",
    "        D_opt.zero_grad(); m_opt.zero_grad();\n",
    "        if step < 2000:\n",
    "            tau = 1000\n",
    "        else:\n",
    "            tau = Tau\n",
    "    \n",
    "        X, Y = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)\n",
    "\n",
    "        log_V = D.get_potential(Y)\n",
    "        psi = EPSILON * log_V + torch.norm(Y, p=2, dim=-1)**2/2\n",
    "        if DIVERGENCE == 'UKL':\n",
    "            f_psi = tau * (torch.exp(-psi/tau) - 1)\n",
    "        elif DIVERGENCE == 'Xi2':\n",
    "            psi = -(F.relu(-psi + 2*tau) - (1+(-psi>-2*tau))*tau)\n",
    "            f_psi = 0.25 * psi**2/tau - psi\n",
    "        \n",
    "\n",
    "        log_C = D.get_C(X)\n",
    "        log_U = D.get_marginal(X)\n",
    "        phi = EPSILON * (log_U + log_m - log_C) + torch.norm(X, p=2, dim=-1)**2/2\n",
    "        \n",
    "        if DIVERGENCE == 'UKL':\n",
    "            f_phi = tau * (torch.exp(-phi/tau) - 1)\n",
    "        elif DIVERGENCE == 'Xi2':\n",
    "            phi = -(F.relu(-phi + 2*tau) - (1+(-phi>-2*tau))*tau)\n",
    "            f_phi = 0.25 * phi**2/tau - phi\n",
    "        \n",
    "        D_loss = EPSILON * torch.exp(log_m) + f_phi.mean() + f_psi.mean()\n",
    "        arr.append(D_loss.item())\n",
    "        D_loss.backward()\n",
    "        D_opt.step(); m_opt.step();\n",
    "        \n",
    "        if step % PLOT_EVERY == 0:\n",
    "            clear_output(wait=True)\n",
    "            fig = plot_2D_lines_p0(step, D, x_p, y_q)\n",
    "            plt.show()\n",
    "\n",
    "\n",
    "    print('Results for Tau: ', tau)\n",
    "    fig = plot_2D_lines_p0(step, D, x_p, y_q)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50c88634",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
