{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import Tensor\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as  mpatches\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from labproject.metrics.wasserstein_sinkhorn import sinkhorn_loss,sinkhorn_algorithm\n",
    "from labproject.metrics.wasserstein_kuhn import kuhn_transport\n",
    "from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance\n",
    "from labproject.metrics.MMD_torch import compute_rbf_mmd,median_heuristic\n",
    "from labproject.data import get_distribution\n",
    "from labproject.utils import set_seed\n",
    "from dataclasses import dataclass\n",
    "from torch.distributions import MultivariateNormal, Categorical\n",
    "set_seed(0)\n",
    "\n",
    "plt.style.use(\"../../../matplotlibrc\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Create a Mixture of 2 Gaussians, that we will use as ground truth\n",
    "\"\"\"\n",
    "\n",
    "class MO2G:\n",
    "    def __init__(self):\n",
    "        self.means = torch.tensor(\n",
    "            [\n",
    "                [-3.0,1],\n",
    "                [3, -1],\n",
    "\n",
    "            ]\n",
    "        )\n",
    "        self.covariances = torch.tensor(\n",
    "            [\n",
    "                [[1.0, 0], [0, 1.0]],\n",
    "                [[1.0, 0], [0, 1.0]]\n",
    "\n",
    "            ]\n",
    "        )\n",
    "        self.weights = torch.tensor([0.2,0.8])\n",
    "\n",
    "        self.gaussians = [\n",
    "            MultivariateNormal(mean, covariance)\n",
    "            for mean, covariance in zip(self.means, self.covariances)\n",
    "        ]\n",
    "\n",
    "    def sample(self, sample_shape):\n",
    "        if isinstance(sample_shape, int):\n",
    "            sample_shape = (sample_shape,)\n",
    "        categorical = Categorical(self.weights)\n",
    "        sample_indices = categorical.sample(sample_shape)\n",
    "        return torch.stack([self.gaussians[i].sample() for i in sample_indices])\n",
    "\n",
    "    def log_prob(self, input):\n",
    "        probs = torch.stack([g.log_prob(input).exp() for g in self.gaussians])\n",
    "        probs = probs.T * self.weights\n",
    "        return torch.sum(probs, dim=1).log()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot some samples\n",
    "\n",
    "MOG = MO2G()\n",
    "num_samples = 5000\n",
    "mixture_samples =MOG.sample(num_samples)\n",
    "plt.scatter(mixture_samples[:, 0], mixture_samples[:, 1], label=\"Mixture of Gaussians\")\n",
    "covar = torch.cov(mixture_samples.T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Gaussians distrubution that can be optimised\n",
    "\n",
    "class Gauss(nn.Module):\n",
    "    def __init__(self,dim):\n",
    "        super(Gauss, self).__init__()\n",
    "        self.dim = dim\n",
    "        self.mean = nn.Parameter(torch.randn(dim))\n",
    "        self.scale_tril = nn.Parameter(torch.eye(dim))\n",
    "        self.G = torch.distributions.MultivariateNormal(self.mean, scale_tril=self.scale_tril)\n",
    "\n",
    "    def sample(self,size):\n",
    "        with torch.no_grad():\n",
    "            self.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G._unbroadcasted_scale_tril))\n",
    "        return self.G.rsample((size,))\n",
    "\n",
    "    def cov(self):\n",
    "        with torch.no_grad():\n",
    "            self.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G._unbroadcasted_scale_tril))\n",
    "        return self.G.covariance_matrix.detach()\n",
    "\n",
    "    def log_prob(self, value):\n",
    "        with torch.no_grad():\n",
    "            self.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G._unbroadcasted_scale_tril))\n",
    "        return self.G.log_prob(value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def c2st_target(samples1, samples2, density1, density2):\n",
    "    r\"\"\"Computes optimal C2ST and resulting classification cross-entropy loss\n",
    "    for optimization\n",
    "    \"\"\"\n",
    "    density_ratios1 = torch.stack([density1.log_prob(samples1), density2.log_prob(samples1)], dim=-1)\n",
    "    density_ratios2 = torch.stack([density1.log_prob(samples2), density2.log_prob(samples2)], dim=-1)\n",
    "    probs = torch.cat([density_ratios1, density_ratios2], dim=0)\n",
    "    labels = torch.cat([torch.zeros(len(samples1)), torch.ones(len(samples2))], dim=0).long()\n",
    "    loss = -nn.functional.cross_entropy(probs, labels)\n",
    "    with torch.no_grad():\n",
    "        c2st = (sum(density_ratios1[:, 0] >= density_ratios1[:, 1]) + sum(density_ratios2[:, 0] < density_ratios2[:, 1])) / (len(samples1) + len(samples2))\n",
    "    return loss, c2st"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gauss_model_C2ST = Gauss(2)\n",
    "\n",
    "model_toy_opt = torch.optim.Adam(gauss_model_C2ST.parameters(), lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_iters =2500\n",
    "num_samples = 10000\n",
    "\n",
    "for epoch in range(n_iters):\n",
    "    model_toy_opt.zero_grad()\n",
    "\n",
    "    model_samples = gauss_model_C2ST.sample(num_samples)\n",
    "    toy_samples = MOG.sample(num_samples)\n",
    "\n",
    "    cent_loss, true_c2st = c2st_target(model_samples, toy_samples, gauss_model_C2ST, MOG)\n",
    "    # c2st_loss = torch.mean(torch.square(preds - 0.5))\n",
    "    if epoch % 5 == 0:\n",
    "        print(\"Iter: {}     loss: {}     c2st: {}\".format(epoch, cent_loss.item(), true_c2st.item()))\n",
    "    cent_loss.backward()\n",
    "    model_toy_opt.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MMD\n",
    "gauss_model_MMD = Gauss(dim=2)\n",
    "with torch.no_grad():\n",
    "    samples = gauss_model_MMD.sample(num_samples)\n",
    "    uniform_samples = MOG.sample(num_samples)\n",
    "    bandwidth = median_heuristic(mixture_samples, uniform_samples)\n",
    "    print(\"bandwidth: \", bandwidth)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bandwidth = 1\n",
    "optimizer = torch.optim.Adam(gauss_model_MMD.parameters(), lr=0.01)\n",
    "gauss_model_MMD.train()\n",
    "\n",
    "for epoch in range(n_iters):\n",
    "    gauss_model_MMD.zero_grad()\n",
    "    samples = gauss_model_MMD.sample(num_samples)\n",
    "    uniform_samples = MOG.sample(num_samples)    #print(samples)\n",
    "    loss = compute_rbf_mmd(samples, uniform_samples, bandwidth=bandwidth)\n",
    "    print(\"Iter: {}     loss: {}\".format(epoch, loss.item()))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "gauss_model_MMD.eval()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# WS\n",
    "\n",
    "gauss_model_WS = Gauss(dim=2)\n",
    "optimizer = torch.optim.Adam(gauss_model_WS.parameters(), lr=0.01)\n",
    "gauss_model_WS.train()\n",
    "for epoch in range(n_iters):\n",
    "    gauss_model_WS.zero_grad()\n",
    "    samples = gauss_model_WS.sample(num_samples)\n",
    "    uniform_samples = MOG.sample(num_samples)    #print(samples)\n",
    "    loss = sliced_wasserstein_distance(samples, uniform_samples)#, bandwidth=bandwidth)\n",
    "    print(\"Iter: {}     loss: {}\".format(epoch, loss.item()))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "gauss_model_WS.eval()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function for plotting elipses according to covariance matrix\n",
    "\n",
    "\n",
    "#FROM: https://github.com/joferkington/oost_paper_code/blob/master/error_ellipse.py\n",
    "from matplotlib.patches import Ellipse\n",
    "\n",
    "def plot_cov_ellipse(cov, pos, nstd=[1,2], ax=None, **kwargs):\n",
    "    \"\"\"\n",
    "    Plots an `nstd` sigma error ellipse based on the specified covariance\n",
    "    matrix (`cov`). Additional keyword arguments are passed on to the \n",
    "    ellipse patch artist.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "        cov : The 2x2 covariance matrix to base the ellipse on\n",
    "        pos : The location of the center of the ellipse. Expects a 2-element\n",
    "            sequence of [x0, y0].\n",
    "        nstd : The radius of the ellipse in numbers of standard deviations.\n",
    "            Defaults to 2 standard deviations.\n",
    "        ax : The axis that the ellipse will be plotted on. Defaults to the \n",
    "            current axis.\n",
    "        Additional keyword arguments are pass on to the ellipse patch.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "        A matplotlib ellipse artist\n",
    "    \"\"\"\n",
    "    def eigsorted(cov):\n",
    "        vals, vecs = np.linalg.eigh(cov)\n",
    "        order = vals.argsort()[::-1]\n",
    "        return vals[order], vecs[:,order]\n",
    "\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "\n",
    "    vals, vecs = eigsorted(cov)\n",
    "    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))\n",
    "    for std in nstd:\n",
    "    # Width and height are \"full\" widths, not radius\n",
    "        width, height = 2 * std * np.sqrt(vals)\n",
    "        print(pos,width, height,theta)\n",
    "        ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, **kwargs)\n",
    "\n",
    "        ax.add_artist(ellip)\n",
    "    return ellip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a grey colormap\n",
    "\n",
    "# FROM: https://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib\n",
    "import matplotlib.colors as colors\n",
    "\n",
    "\n",
    "def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):\n",
    "    new_cmap = colors.LinearSegmentedColormap.from_list(\n",
    "        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),\n",
    "        cmap(np.linspace(minval, maxval, n)))\n",
    "    return new_cmap\n",
    "\n",
    "arr = np.linspace(0, 50, 100).reshape((10, 10))\n",
    "fig, ax = plt.subplots(ncols=2)\n",
    "\n",
    "cmap = plt.get_cmap('gist_yarg')\n",
    "new_cmap = truncate_colormap(cmap, 0, 0.75)\n",
    "ax[0].imshow(arr, interpolation='nearest', cmap=cmap)\n",
    "ax[1].imshow(arr, interpolation='nearest', cmap=new_cmap)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = 10000\n",
    "# Create a dataframe for the mixture samples\n",
    "unif_df = pd.DataFrame({'x': mixture_samples[:, 0], 'y': mixture_samples[:, 1]})\n",
    "\n",
    "\n",
    "\n",
    "# Plot the probability contours\n",
    "fig,axs = plt.subplots(1,4, figsize = (6.5,3))\n",
    "sns.kdeplot(ax=axs[0],data=unif_df, x='x', y='y',cmap=\"Blues\", fill=True,levels=12)#, thresh=.1,)#, color=\"black\",alpha=0.75, levels=5)\n",
    "\n",
    "alpha_bg=1\n",
    "sns.kdeplot(ax=axs[1],data=unif_df, x='x', y='y', fill=True,cmap=new_cmap,alpha=alpha_bg, levels=12)\n",
    "sns.kdeplot(ax=axs[2],data=unif_df, x='x', y='y', fill=True,cmap=new_cmap,alpha=alpha_bg, levels=12)\n",
    "sns.kdeplot(ax=axs[3],data=unif_df, x='x', y='y', fill=True,cmap=new_cmap,alpha=alpha_bg, levels=12)\n",
    "\n",
    "std_plot = [np.sqrt(.25),np.sqrt(.5),np.sqrt(1),np.sqrt(2)]\n",
    "std_plot=[.25,.75,1.5,2.5]\n",
    "plot_cov_ellipse(gauss_model_WS.cov().detach().numpy(),gauss_model_WS.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[1],edgecolor='#cc241d', lw=1.5, facecolor='none')\n",
    "#axs[1].scatter(gauss_model_WS.mean.detach().numpy()[0],gauss_model_WS.mean.detach().numpy()[1],\n",
    "#               color='#cc241d',s=1.5)\n",
    "plot_cov_ellipse(gauss_model_MMD.cov().detach().numpy(),gauss_model_MMD.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[3],edgecolor='#eebd35', lw=1.5, facecolor='none')\n",
    "plot_cov_ellipse(gauss_model_C2ST.cov().detach().numpy(),gauss_model_C2ST.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[2],edgecolor='#458588', lw=1.5, facecolor='none')\n",
    "\n",
    "for ax in axs:\n",
    "    ax.spines[['left', 'bottom']].set_visible(False)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.set_xlim(-8,5)\n",
    "    ax.set_xlim(-10,10)\n",
    "\n",
    "    ax.set_ylim(-7,4)\n",
    "    #make square subplots \n",
    "    ax.set_box_aspect(1)\n",
    "axs[0].set_title(r\"$p_{true}$\")\n",
    "axs[1].set_title(\"SW\",color ='#cc241d')\n",
    "axs[3].set_title(r\"$MMD_1$\",color ='#eebd35')\n",
    "axs[2].set_title(\"C2ST\",color =\"#458588\")\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"mode.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labproject",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
