{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c9cf30b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch\n",
    "import torch\n",
    "import torch.distributions as TD\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "# from torch.autograd import grad\n",
    "from functorch import grad, vmap\n",
    "# import geotorch\n",
    "\n",
    "# base\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.legend_handler import HandlerTuple\n",
    "from cycler import cycler\n",
    "\n",
    "from tqdm import tqdm\n",
    "import logging\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import itertools\n",
    "import functools\n",
    "import operator as ops\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "437a1b03",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset utils\n",
    "sys.path.append(\"../..\")\n",
    "from src.utils import Distrib2Sampler\n",
    "from src.utils import Config\n",
    "\n",
    " \n",
    "\n",
    "# langevin sampling\n",
    "from src.eot import sample_langevin_batch\n",
    "from src.eot_utils import computePotGrad, evaluating\n",
    "# from src.plotters import plot_training_phase\n",
    "from src.tools import *\n",
    "from src import distributions\n",
    "from src import benchmarks\n",
    "\n",
    "# training utils\n",
    "from src.dgm_utils.statsmanager import StatsManager, StatsManagerDrawScheduler\n",
    "\n",
    "# typing\n",
    "from typing import Callable, Tuple, Union\n",
    "import ot\n",
    "\n",
    "# models\n",
    "from src.models2D import FullyConnectedMLP\n",
    "DEVICE = 'cuda:0'\n",
    "# DEVICE = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c135fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = Config()\n",
    "\n",
    "CONFIG.CLIP_GRADS_NORM = False\n",
    "CONFIG.HREG = 0.01\n",
    "CONFIG.USE_L2 = False\n",
    "CONFIG.DISCRETE_OT_FOR_GT = CONFIG.USE_L2\n",
    "\n",
    "CONFIG.LANGEVIN_THRESH = None\n",
    "CONFIG.LANGEVIN_SAMPLING_NOISE = 0.03\n",
    "CONFIG.ENERGY_SAMPLING_ITERATIONS = 300\n",
    "CONFIG.LANGEVIN_DECAY = 1.0\n",
    "CONFIG.LANGEVIN_SCORE_COEFFICIENT = 1.0\n",
    "CONFIG.LANGEVIN_COST_COEFFICIENT = 1.0\n",
    "\n",
    "# learning parameters\n",
    "CONFIG.MAX_STEPS = 200\n",
    "CONFIG.BATCH_SIZE = 256\n",
    "CONFIG.BASIC_NOISE_VAR = 1.0\n",
    "\n",
    "CONFIG.DIM = 2\n",
    "CONFIG.NUM = 3\n",
    "CONFIG.ALPHAS = 1 / np.ones(CONFIG.NUM)\n",
    "# CONFIG.ALPHAS = np.array([0.3, 0.7])\n",
    "\n",
    "CONFIG.OUTPUT_SEED = 0xAB0BA\n",
    "\n",
    "assert CONFIG.NUM == len(CONFIG.ALPHAS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "340c09f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformedL2Generic:\n",
    "\n",
    "    def h(self, X):\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def h_inv(self, Z):\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def dist_squared(self, X, Y):\n",
    "        return self.dist(X, Y, squared=True)\n",
    "\n",
    "    def dist(self, X, Y, squared=False):\n",
    "        z_X = self.h(X)\n",
    "        z_Y = self.h(Y)\n",
    "        dist_squared = torch.sum((z_X - z_Y).pow(2), dim=-1)\n",
    "        if squared:\n",
    "            return dist_squared\n",
    "        return torch.sqrt(dist_squared)\n",
    "\n",
    "    def cdist(self, X, Y, squared=False):\n",
    "        '''\n",
    "        X: (bs, n, D)\n",
    "        Y: (bs, m, D)\n",
    "        '''\n",
    "        z_X = self.h(X.flatten(start_dim=0, end_dim=1)).view(X.shape)\n",
    "        z_Y = self.h(Y.flatten(start_dim=0, end_dim=1)).view(Y.shape)\n",
    "        dists = torch.cdist(z_X, z_Y)\n",
    "        if squared:\n",
    "            return dists.pow(2)\n",
    "        return dists\n",
    "\n",
    "    def bary(self, Xs, alps):\n",
    "        assert isinstance(Xs, list)\n",
    "        assert len(Xs) == len(alps)\n",
    "        alps = np.asarray(alps)\n",
    "        alps /= np.sum(alps)\n",
    "        baryZ = 0.\n",
    "        for i in range(len(alps)):\n",
    "            baryZ += self.h(Xs[i]) * alps[i]\n",
    "        baryX = self.h_inv(baryZ)\n",
    "        return baryX\n",
    "\n",
    "class TransformedL2TwoMaps(TransformedL2Generic):\n",
    "\n",
    "    def __init__(self, h, h_inv):\n",
    "        '''\n",
    "        h: X -> Z\n",
    "        h_inv: Z -> X\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.h = h\n",
    "        self.h_inv = h_inv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b98bb49d",
   "metadata": {},
   "outputs": [],
   "source": [
    "H_SLOPE = 0.4\n",
    "\n",
    "def norm2theta(norms):\n",
    "    return H_SLOPE * norms\n",
    "\n",
    "def rotate_batch(Rs, Xs):\n",
    "    if len(Xs.shape) == 1:\n",
    "        Xs = Xs[None]\n",
    "    assert len(Xs.shape) == 2\n",
    "    assert Xs.size(1) == 2\n",
    "    assert Xs.size(0) == Rs.size(0)\n",
    "    assert len(Rs.shape) == 3\n",
    "    assert Rs.size(1) == Rs.size(2) == 2\n",
    "    return torch.matmul(\n",
    "        Xs.unsqueeze(1), \n",
    "        Rs.transpose(1, 2)).squeeze(1)\n",
    "\n",
    "def cossin2R(cos, sin):\n",
    "    assert cos.shape == sin.shape\n",
    "    assert len(cos.shape) == 1\n",
    "    return torch.stack([cos, -sin, sin, cos]).T.view(-1, 2, 2)\n",
    "    \n",
    "def lin_space_rotator(Xs, pos=True):\n",
    "    if len(Xs.shape) == 1:\n",
    "        Xs = Xs[None]\n",
    "    assert len(Xs.shape) == 2\n",
    "    assert Xs.size(1) == 2\n",
    "    X_norms = torch.norm(Xs, dim=-1)\n",
    "    thetas = norm2theta(X_norms)\n",
    "    if not pos:\n",
    "        thetas = - thetas\n",
    "    cos = torch.cos(thetas)\n",
    "    sin = torch.sin(thetas)\n",
    "    Rs = cossin2R(cos, sin)\n",
    "    return rotate_batch(Rs, Xs)\n",
    "\n",
    "def h(Xs):\n",
    "    return lin_space_rotator(Xs, pos=True)\n",
    "\n",
    "def h_inv(Zs):\n",
    "    return lin_space_rotator(Zs, pos=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cf13595",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf = TransformedL2TwoMaps(h, h_inv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62600c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.cuda.is_available()\n",
    "\n",
    "if DEVICE != \"cpu\":\n",
    "    torch.cuda.set_device(DEVICE)    \n",
    "\n",
    "def seed(seed):\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.cuda.deterministic = True\n",
    "    \n",
    "seed(CONFIG.OUTPUT_SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e909c959",
   "metadata": {},
   "source": [
    "## Initializing distributionsCONFIG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42f3d800",
   "metadata": {},
   "outputs": [],
   "source": [
    "Z1distrib = TD.Normal(\n",
    "        torch.tensor([0., 4]).to(DEVICE),\n",
    "        torch.tensor([1., 1.]).to(DEVICE))\n",
    "\n",
    "Z2distrib = TD.Normal(\n",
    "        torch.tensor([3.46, -2.]).to(DEVICE),\n",
    "        torch.tensor([1., 1.]).to(DEVICE))\n",
    "\n",
    "Z3distrib = TD.Normal(\n",
    "        torch.tensor([-3.46, -2]).to(DEVICE),\n",
    "        torch.tensor([1., 1.]).to(DEVICE))\n",
    "\n",
    "Zgtdistrib = TD.Normal(\n",
    "        torch.tensor([0.0, 0.0]).to(DEVICE),\n",
    "        torch.tensor([1., 1.]).to(DEVICE))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a64e736d",
   "metadata": {},
   "source": [
    "# Nets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90d0f925",
   "metadata": {},
   "outputs": [],
   "source": [
    "nets = [FullyConnectedMLP(CONFIG.DIM, [32, 32], 1).to(DEVICE) for _ in range(CONFIG.NUM)]\n",
    "param_gens = [net.parameters() for net in nets]\n",
    "opt = torch.optim.Adam(\n",
    "    itertools.chain(*param_gens),\n",
    "    lr=1e-2,\n",
    ")\n",
    "\n",
    "def make_f_pot(idx):\n",
    "    def f_pot(x):\n",
    "        res = 0.0\n",
    "        for i, (net, alpha) in enumerate(zip(nets, CONFIG.ALPHAS)):\n",
    "            if i == idx:\n",
    "                res += net(x)\n",
    "            else:\n",
    "                res -= alpha * net(x) / (CONFIG.NUM - 1) / CONFIG.ALPHAS[idx]\n",
    "        return res\n",
    "    return f_pot\n",
    "\n",
    "f_pots = [make_f_pot(i) for i in range(CONFIG.NUM)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef43153d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cost_grad_y(y: torch.Tensor, x: torch.Tensor):\n",
    "    def f(y_in, x_in):\n",
    "#         return 0.5 * ((x_in - y_in) ** 2).sum(-1)\n",
    "        return 0.5 * tf.dist_squared(x_in, y_in).squeeze()\n",
    "    g = vmap(grad(f))\n",
    "    return g(y, x)\n",
    "\n",
    "def l2_grad_y(y, x):\n",
    "    '''\n",
    "    returns \\nabla_y c(x, y)\n",
    "    '''\n",
    "    return y - x\n",
    "\n",
    "grad_fn = l2_grad_y if CONFIG.USE_L2 else cost_grad_y\n",
    "\n",
    "def cond_score(\n",
    "        f : Callable[[torch.Tensor], torch.Tensor], \n",
    "        cost_grad_y_fn : Callable[[torch.Tensor, torch.Tensor], torch.Tensor], \n",
    "        y : torch.Tensor,\n",
    "        x : torch.Tensor,\n",
    "        config: Config,\n",
    "        ret_stats=False\n",
    "    ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:\n",
    "    with torch.enable_grad():\n",
    "        y.requires_grad_(True)\n",
    "        proto_s = f(y)\n",
    "        s = computePotGrad(y, proto_s)\n",
    "        assert s.shape == y.shape\n",
    "    cost_coeff = config.LANGEVIN_COST_COEFFICIENT * (config.LANGEVIN_SAMPLING_NOISE ** 2 / config.HREG)\n",
    "    cost_part = cost_grad_y_fn(y, x) * cost_coeff\n",
    "    score_part = s * config.LANGEVIN_SCORE_COEFFICIENT\n",
    "    if not ret_stats:\n",
    "        return score_part - cost_part\n",
    "    return score_part - cost_part, cost_part, score_part\n",
    "\n",
    "def sample_langevin_mu_f(\n",
    "        f: Callable[[torch.Tensor], torch.Tensor], \n",
    "        x: torch.Tensor, \n",
    "        y_init: torch.Tensor, \n",
    "        config: Config\n",
    "    ) -> torch.Tensor:\n",
    "    \n",
    "    def score(y, ret_stats=False):\n",
    "        return cond_score(f, grad_fn, y, x, config, ret_stats=ret_stats)\n",
    "    \n",
    "    y, r_t, cost_r_t, score_r_t, noise_norm = sample_langevin_batch(\n",
    "        score, \n",
    "        y_init,\n",
    "        n_steps=config.ENERGY_SAMPLING_ITERATIONS, \n",
    "        decay=config.LANGEVIN_DECAY, \n",
    "        thresh=config.LANGEVIN_THRESH, \n",
    "        noise=config.LANGEVIN_SAMPLING_NOISE, \n",
    "        data_projector=lambda x: x, \n",
    "        compute_stats=True)\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34bcf75f",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_noise_sampler = Distrib2Sampler(TD.Normal(\n",
    "    torch.zeros(CONFIG.DIM).to(DEVICE), \n",
    "    torch.ones(CONFIG.DIM).to(DEVICE) * CONFIG.BASIC_NOISE_VAR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7a040b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "samplers = [Z1distrib, Z2distrib, Z3distrib]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18f7b81a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def alpha_color(color_rgb, alpha=0.5):\n",
    "    color_rgb = np.asanyarray(color_rgb)\n",
    "    alpha_color_rgb = 1. - (1. - color_rgb) * alpha\n",
    "    return alpha_color_rgb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f4cc424",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_bary_i(f_pot, sampler, ax, i, n_samples=512, n_maps=0, n_arrows_per_map=1):\n",
    "    n_arrows = n_maps * n_arrows_per_map\n",
    "    X = sampler.sample((n_samples,)).to(DEVICE)\n",
    "    if n_maps > 0:\n",
    "        Xm = sampler.sample((n_maps,)).to(DEVICE)\n",
    "        Xm = torch.tile(Xm, (n_arrows_per_map, 1))\n",
    "        X = torch.concatenate((X, Xm), dim=0)\n",
    "        \n",
    "    X = tf.h_inv(X)\n",
    "    Y_init = init_noise_sampler.sample(n_samples + n_arrows).to(DEVICE)\n",
    "    Y = sample_langevin_mu_f(f_pot, X, Y_init, CONFIG).to(DEVICE)\n",
    "    X_np = X.detach().cpu().numpy()\n",
    "    Y_np = Y.detach().cpu().numpy()\n",
    "    \n",
    "    def darker(c): return tuple(x * 0.85 for x in c)\n",
    "    \n",
    "    cols = mpl.colormaps[\"Dark2\"].colors\n",
    "    col_bary = mpl.colormaps[\"tab10\"].colors[CONFIG.NUM]\n",
    "    p4 = ax.scatter(\n",
    "        X_np[:n_samples, 0], X_np[:n_samples, 1],\n",
    "        edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), zorder=0, linewidth=.5,\n",
    "    )\n",
    "    p1 = ax.scatter(\n",
    "        Y_np[:n_samples, 0], Y_np[:n_samples, 1],\n",
    "        edgecolors=(0, 0, 0), color=col_bary, zorder=0, linewidth=.5,\n",
    "    )\n",
    "\n",
    "    if n_arrows > 0:\n",
    "        p3 = ax.scatter(\n",
    "            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],\n",
    "            linewidth=.5, edgecolors='black', color=cols[i], zorder=2,\n",
    "        )\n",
    "        p2 = ax.scatter(\n",
    "            Y_np[-n_arrows:, 0], Y_np[-n_arrows:, 1],\n",
    "            linewidth=.5, edgecolors='black', color=cols[CONFIG.NUM + 2], zorder=2,\n",
    "        )\n",
    "        ax.quiver(\n",
    "            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],\n",
    "            Y_np[-n_arrows:, 0] - X_np[-n_arrows:, 0], Y_np[-n_arrows:, 1] - X_np[-n_arrows:, 1],\n",
    "            angles='xy', scale_units='xy', scale=0.95, width=.005, zorder=1, headwidth=0.0, headlength=0.0,\n",
    "        )\n",
    "        \n",
    "    ax.legend(\n",
    "        [\n",
    "            (p1, p2),\n",
    "            (p3, p4),\n",
    "        ], [\n",
    "            f\"$y \\\\sim \\\\pi_{i + 1}^{{f_{{\\\\theta^*,{i + 1}}}}}(\\\\cdot \\\\mid x_{i + 1})$\",\n",
    "            f\"$x_{i + 1} \\\\sim \\mathbb{{P}}_{i + 1}$\",\n",
    "        ],\n",
    "        handler_map={tuple: HandlerTuple(ndivide=None)},\n",
    "        loc=\"upper left\",\n",
    "        prop={\"size\": 13.5},\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ef0031",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_distributions(ax, samplers):\n",
    "    cols = mpl.colormaps[\"Dark2\"].colors\n",
    "    Xs = []\n",
    "    for i, distr in enumerate(samplers):\n",
    "        X = tf.h_inv(distr.sample((512,))).detach().cpu().numpy()\n",
    "        Xs.append(X)\n",
    "        ax.scatter(\n",
    "            X[:, 0], X[:, 1],\n",
    "            label=f\"$x_{{{i + 1}}} \\\\sim \\\\mathbb{{P}}_{{{i + 1}}}$\", \n",
    "            edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), linewidth=.5,\n",
    "        )\n",
    "    \n",
    "    if CONFIG.DISCRETE_OT_FOR_GT:\n",
    "        mw = [ot.unif(x.shape[0]) for x in Xs]\n",
    "        Y_init = init_noise_sampler.sample((Xs[0].shape[0],)).detach().cpu().numpy()\n",
    "        Xgt = ot.lp.free_support_barycenter(Xs, mw, Y_init)\n",
    "    else:\n",
    "        Xgt = tf.h_inv(Zgtdistrib.sample((512,))).detach().cpu().numpy()\n",
    "    ax.scatter(\n",
    "        Xgt[:, 0], Xgt[:, 1],\n",
    "        label=r\"$y \\sim \\mathbb{Q}_*$\", \n",
    "        edgecolors='black', color=cols[CONFIG.NUM + 1], linewidth=.5,\n",
    "    )\n",
    "    ax.legend(ncol=2, loc=\"upper center\", prop={\"size\": 13.5})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "195fb873",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_bary(potential_fns, samplers, arrows=True):\n",
    "    N_SAMPLES = 512\n",
    "    N_MAPS = 5\n",
    "    N_ARROWS_PER_MAP = 3\n",
    "    \n",
    "    n_maps = N_MAPS if arrows else 0\n",
    "    \n",
    "    fig, axs = plt.subplots(\n",
    "        ncols=CONFIG.NUM + 1,\n",
    "        figsize=(15, 3.75),\n",
    "        sharex=True, sharey=True,\n",
    "        dpi=200,\n",
    "    )\n",
    "        \n",
    "    plot_distributions(axs[0], samplers)\n",
    "    axs[0].set_xlim(-7, 7)\n",
    "    axs[0].set_ylim(-7, 7)\n",
    "    \n",
    "    for i, (f_pot, sampler, ax) in enumerate(zip(potential_fns, samplers, axs[1:])):\n",
    "        plot_bary_i(f_pot, sampler, ax, i, N_SAMPLES, n_maps, N_ARROWS_PER_MAP)\n",
    "    \n",
    "    fig.tight_layout()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7abd1990",
   "metadata": {},
   "outputs": [],
   "source": [
    "SMDS = StatsManagerDrawScheduler(StatsManager('loss'), 1, 1, (5, 4), epoch_freq=10)\n",
    "last_plot_it = -1\n",
    "last_score_it = -1\n",
    "\n",
    "for it in tqdm(range(CONFIG.MAX_STEPS)):\n",
    "    Xs = [tf.h_inv(s.sample((CONFIG.BATCH_SIZE,))).to(DEVICE) for s in samplers]\n",
    "    Ys_init = [init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(DEVICE) for _ in range(CONFIG.NUM)]\n",
    "\n",
    "    for net in nets: net.eval()\n",
    "    with torch.no_grad():\n",
    "        Ys = [sample_langevin_mu_f(f, X.to(DEVICE), Y_init, CONFIG) for f, X, Y_init in zip(f_pots, Xs, Ys_init)]\n",
    "\n",
    "    for net in nets: net.train()\n",
    "    loss = sum(alpha * f(Y).mean() for alpha, f, Y in zip(CONFIG.ALPHAS, f_pots, Ys))\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    SMDS.SM.upd('loss', loss.item())\n",
    "    SMDS.epoch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e48d59c",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed(4)\n",
    "plot_bary(f_pots, samplers)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b4ac980",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
