{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "bures_weighted_flows_comparison.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r1M9Oi7LEHUo"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/anon-author-dev/gsw/blob/master/code_flow/bures_weighted_flows_comparison.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3w6ye6Yn4vdw"
      },
      "source": [
        "! pip install POT\n",
        "! pip install celluloid"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3cO3I6E3D6mD"
      },
      "source": [
        "!git clone https://github.com/anon-author-dev/gsw\n",
        "%cd gsw/code_flow/gsw"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fMzrUrF15qiW"
      },
      "source": [
        "import numpy as np\n",
        "from bures_weighted import WGSB\n",
        "from weighted_utils import w2_weighted\n",
        "from gsw_utils import w2,load_data\n",
        "\n",
        "import torch\n",
        "from torch import optim\n",
        "\n",
        "from celluloid import Camera\n",
        "from tqdm import tqdm\n",
        "from IPython import display\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3Ft7rgUQ7he8"
      },
      "source": [
        "# in  ['swiss_roll','circle','8gaussians','25gaussians']:\n",
        "\n",
        "dataset_name = 'circle' \n",
        "np.random.seed(10)\n",
        "N = 1000  # Number of samples from p_X\n",
        "X = load_data(name=dataset_name, n_samples=N)\n",
        "X -= X.mean(dim=0)[np.newaxis,:]  # Normalization\n",
        "meanX = 0\n",
        "# Show the dataset\n",
        "_, d = X.shape\n",
        "fig = plt.figure(figsize=(5,5))\n",
        "plt.scatter(X[:,0], X[:,1])\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0MJ3LPr07wKO"
      },
      "source": [
        "# Use GPU if available, CPU otherwise\n",
        "if torch.cuda.is_available():\n",
        "    device = torch.device('cuda')\n",
        "else:\n",
        "    device = torch.device('cpu')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Isa8P9lZ7xxY"
      },
      "source": [
        "# Number of iterations for the optimization process\n",
        "nofiterations = 250"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1GOlvBwR79jR"
      },
      "source": [
        "# # Define the different defining functions\n",
        "titles = ['Max Sliced W2: linear', 'RFB-2000-0.2', 'RFB-2000-0.1', 'Max Sliced Bures: linear', 'RFB-2000-0.2', 'RFB-2000-0.1']"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JV2hVSfb9oXo"
      },
      "source": [
        "# Define the initial distribution\n",
        "#Y = torch.from_numpy(np.random.normal(loc=meanX, scale=0.5, size=(N,d))).float()\n",
        "Y = torch.from_numpy((np.random.rand(N,d)-0.5)*4 ).float()\n",
        "temp = np.ones((N,))/N\n",
        "# Define the optimizers\n",
        "beta = list()\n",
        "optimizer = list()\n",
        "wmsd = list()\n",
        "\n",
        "new_lr = 1e-2\n",
        "nb = 2000 # number of bases in the random Fourier bases\n",
        "sigma0 = 0.2\n",
        "sigma1 = 0.1\n",
        "\n",
        "# WARNING the order of the optimizers is HARDCODED below ...\n",
        "\n",
        "beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))\n",
        "optimizer.append(optim.Adam([beta[-1]], lr = new_lr))\n",
        "wmsd.append([WGSB(ftype='linear')])\n",
        "\n",
        "beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))\n",
        "optimizer.append(optim.Adam([beta[-1]], lr = new_lr))\n",
        "wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma0)])\n",
        "\n",
        "beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))\n",
        "optimizer.append(optim.Adam([beta[-1]], lr = new_lr))\n",
        "wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma1)])\n",
        "\n",
        "beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))\n",
        "optimizer.append(optim.Adam([beta[-1]], lr = new_lr))\n",
        "wmsd.append([WGSB(ftype='linear')])\n",
        "\n",
        "beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))\n",
        "optimizer.append(optim.Adam([beta[-1]], lr = new_lr))\n",
        "wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma0)])\n",
        "\n",
        "beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))\n",
        "optimizer.append(optim.Adam([beta[-1]], lr = new_lr))\n",
        "wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma1)])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v0K1qfIhcKl0"
      },
      "source": [
        "fig, f_axs = plt.subplots(ncols=3, nrows=3, figsize=(15, 15));\n",
        "camera = Camera(fig)\n",
        "gs = f_axs[2, 0].get_gridspec()\n",
        "# remove the underlying axes\n",
        "for ax in f_axs[-1, :]:\n",
        "    ax.remove()\n",
        "axbig = fig.add_subplot(gs[-1, :])\n",
        "axbig.set_title('Wasserstein-2 Distance', fontsize=22)\n",
        "axbig.set_ylabel(r'$Log_{10}(W_2)$', fontsize=22)\n",
        "colors = ['#1f77b4',\n",
        "          '#ff7f0e',\n",
        "          '#2ca02c',\n",
        "          '#d62728',\n",
        "          '#9467bd',\n",
        "          '#8c564b']\n",
        "# Define the variables to store the loss (2-Wasserstein distance) for each defining function and each problem \n",
        "w2_dist = np.nan * np.zeros((nofiterations, 6))\n",
        "fig.suptitle('FLOWS COMPARISON', fontsize=44)\n",
        "for i in range(nofiterations):            \n",
        "    loss = list()\n",
        "    # We loop over the different defining functions for the max- GSW problem\n",
        "    for k in range(6):\n",
        "        # Loss computation (here, max-GSW and max-sliced Bures)\n",
        "        if k is 0:\n",
        "            loss_ = wmsd[k][0].max_gsw_weighted(X.to(device),Y.to(device),beta[k].to(device))        \n",
        "        elif k is 1 or k is 2:\n",
        "            loss_ = wmsd[k][0].max_kernel_gsw_weighted(X.to(device), Y.to(device),beta[k].to(device))\n",
        "        elif k is 3:\n",
        "            loss_ = wmsd[k][0].max_sliced_bures_weighted(X.to(device), Y.to(device),beta[k].to(device))\n",
        "        else:\n",
        "            loss_ = wmsd[k][0].max_sliced_kernel_bures_weighted(X.to(device), Y.to(device),beta[k].to(device))\n",
        "\n",
        "        # Optimization step\n",
        "        loss.append(loss_)\n",
        "        optimizer[k].zero_grad()\n",
        "        loss[k].backward()\n",
        "        optimizer[k].step()\n",
        "        # Should projection to simplex for nu instead beta \n",
        "\n",
        "        \n",
        "        nu = beta[k].detach().cpu().numpy()\n",
        "        nu = np.maximum(0,nu).astype(np.float64)\n",
        "        nu = nu/np.sum(nu)\n",
        "        # Compute the 2-Wasserstein distance to compare the distributions\n",
        "        w2_dist[i, k] = w2_weighted(X.detach().cpu().numpy(), Y.detach().cpu().numpy(), nu)\n",
        "        \n",
        "        nu = nu/np.max(nu)*15\n",
        "        # Plot samples from the target and the current solution\n",
        "        row = 0\n",
        "        col = k\n",
        "        if k>=3:\n",
        "            col = k-3\n",
        "            row = 1\n",
        "        f_axs[row,col].scatter(X[:, 0], X[:, 1], c='b')\n",
        "        f_axs[row,col].scatter(Y[:, 0], Y[:, 1], s=nu, c='r')\n",
        "        f_axs[row,col].set_title(titles[k], fontsize=22)\n",
        "\n",
        "    # Plot the 2-Wasserstein distance\n",
        "    for p, color in enumerate(colors):\n",
        "        axbig.plot(np.log10(w2_dist[:,p]), color = color)\n",
        "\n",
        "    axbig.legend(titles, fontsize=22, bbox_to_anchor=(.1,-.55), loc=\"lower left\",\n",
        "                 ncol=2, fancybox=True, shadow=True)\n",
        "    camera.snap()\n",
        "\n",
        "plt.close()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "U7hWanL11SCW"
      },
      "source": [
        "animation_full = camera.animate()\n",
        "animation_full.save('animation_full.mp4')\n",
        "animation_reduced = camera.animate(blit=False, interval=10)\n",
        "animation_reduced.save('animation_reduced.mp4')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rQgu-R0dBso6"
      },
      "source": [
        "display.HTML(animation_full.to_html5_video())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hwifsuNHzLJg"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}