{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as  mpatches\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance\n",
    "from labproject.metrics.MMD_torch import compute_rbf_mmd,median_heuristic\n",
    "from labproject.data import get_distribution\n",
    "from labproject.utils import set_seed\n",
    "from torch.distributions import MultivariateNormal, Categorical\n",
    "set_seed(0)\n",
    "\n",
    "plt.style.use(\"../../../matplotlibrc\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Create a Mixture of 2 Gaussians, that we will use as ground truth\n",
    "\"\"\"\n",
    "\n",
    "class MO2G:\n",
    "    def __init__(self):\n",
    "        self.means = torch.tensor(\n",
    "            [\n",
    "                [-3.0,1],\n",
    "                [3, -1],\n",
    "\n",
    "            ]\n",
    "        )\n",
    "        self.covariances = torch.tensor(\n",
    "            [\n",
    "                [[1.0, 0], [0, 1.0]],\n",
    "                [[1.0, 0], [0, 1.0]]\n",
    "\n",
    "            ]\n",
    "        )\n",
    "        self.weights = torch.tensor([0.2,0.8])\n",
    "\n",
    "        self.gaussians = [\n",
    "            MultivariateNormal(mean, covariance)\n",
    "            for mean, covariance in zip(self.means, self.covariances)\n",
    "        ]\n",
    "\n",
    "    def sample(self, sample_shape):\n",
    "        if isinstance(sample_shape, int):\n",
    "            sample_shape = (sample_shape,)\n",
    "        categorical = Categorical(self.weights)\n",
    "        sample_indices = categorical.sample(sample_shape)\n",
    "        return torch.stack([self.gaussians[i].sample() for i in sample_indices])\n",
    "\n",
    "    def log_prob(self, input):\n",
    "        probs = torch.stack([g.log_prob(input).exp() for g in self.gaussians])\n",
    "        probs = probs.T * self.weights\n",
    "        return torch.sum(probs, dim=1).log()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot some samples\n",
    "\n",
    "MOG = MO2G()\n",
    "num_samples = 10000\n",
    "mixture_samples =MOG.sample(num_samples)\n",
    "plt.scatter(mixture_samples[:, 0], mixture_samples[:, 1], label=\"Mixture of Gaussians\")\n",
    "covar = torch.cov(mixture_samples.T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Gaussians distrubution that can be used as component for our to optimise mixture\n",
    "\n",
    "class Gauss(nn.Module):\n",
    "    def __init__(self,dim):\n",
    "        super(Gauss, self).__init__()\n",
    "        self.dim = dim\n",
    "        self.mean = nn.Parameter(torch.randn(dim))\n",
    "        self.scale_tril = nn.Parameter(torch.eye(dim))\n",
    "        self.G = torch.distributions.MultivariateNormal(self.mean, scale_tril=self.scale_tril)\n",
    "\n",
    "    def sample(self,size):\n",
    "        with torch.no_grad():\n",
    "            self.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G._unbroadcasted_scale_tril))\n",
    "        return self.G.rsample((size,))\n",
    "\n",
    "    def cov(self):\n",
    "        with torch.no_grad():\n",
    "            self.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G._unbroadcasted_scale_tril))\n",
    "        return self.G.covariance_matrix.detach()\n",
    "\n",
    "    def log_prob(self, value):\n",
    "        with torch.no_grad():\n",
    "            self.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G._unbroadcasted_scale_tril))\n",
    "        return self.G.log_prob(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a mixture of Gaussians that we will fit to the ground truth distribution\n",
    "\n",
    "class MOGOpt(nn.Module):\n",
    "    \"\"\"The model to optimise\"\"\"\n",
    "    def __init__(self,dim):\n",
    "        super(MOGOpt, self).__init__()\n",
    "        self.G1 = Gauss(dim)\n",
    "        self.G2 = Gauss(dim)\n",
    "        self.weights = nn.Parameter(torch.tensor([0.5, 0.5]), requires_grad=False)\n",
    "        self.categorical = Categorical(self.weights)\n",
    "        self.dim = dim\n",
    "\n",
    "    def sample(self,size):\n",
    "        with torch.no_grad():\n",
    "            self.G1.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G1.G._unbroadcasted_scale_tril))\n",
    "            self.G2.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G2.G._unbroadcasted_scale_tril))\n",
    "        samples = torch.zeros(size,self.dim)\n",
    "        inds= self.categorical.sample((size,))\n",
    "        s1=self.G1.G.rsample((sum(inds==0),))\n",
    "        s2=self.G2.G.rsample((sum(inds==1),))\n",
    "        samples[inds==0]=s1\n",
    "        samples[inds==1]=s2\n",
    "        return samples\n",
    "\n",
    "    def cov(self):\n",
    "        with torch.no_grad():\n",
    "            self.G1.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G1.G._unbroadcasted_scale_tril))\n",
    "            self.G2.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G2.G._unbroadcasted_scale_tril))\n",
    "        return [self.G1.covariance_matrix.detach(),self.G2.covariance_matrix.detach()]\n",
    "\n",
    "    def log_prob(self, input, labels=None):\n",
    "        with torch.no_grad():\n",
    "            self.G1.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G1.G._unbroadcasted_scale_tril))\n",
    "            self.G2.G._unbroadcasted_scale_tril.data.copy_(torch.tril(self.G2.G._unbroadcasted_scale_tril))\n",
    "        probs = torch.stack([self.G1.log_prob(input).exp(),self.G2.log_prob(input).exp()])\n",
    "        probs = probs.T * self.weights\n",
    "        return torch.sum(probs, dim=1).log()\n",
    "\n",
    "    @torch.no_grad\n",
    "    def e_step(self, input, update_weights=False):\n",
    "        prob_1 = self.G1.log_prob(input).exp() * self.weights[0]\n",
    "        prob_2 = self.G2.log_prob(input).exp() * self.weights[1]\n",
    "        probs = torch.stack([prob_1, prob_2], dim=-1)\n",
    "        probs /= probs.sum(dim=-1, keepdim=True)\n",
    "        randfloat = torch.rand(probs.shape[0])\n",
    "        clusters = randfloat > probs[:, 0]\n",
    "        if update_weights:\n",
    "\n",
    "            weights = torch.mean(probs, dim=0)\n",
    "            self.weights.data.copy_(weights)\n",
    "        return clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def c2st_target(samples1, samples2, density1, density2):\n",
    "    r\"\"\"Computes optimal C2ST and resulting classification cross-entropy loss\n",
    "    for optimization\n",
    "    \"\"\"\n",
    "    density_ratios1 = torch.stack([density1.log_prob(samples1), density2.log_prob(samples1)], dim=-1)\n",
    "    density_ratios2 = torch.stack([density1.log_prob(samples2), density2.log_prob(samples2)], dim=-1)\n",
    "    probs = torch.cat([density_ratios1, density_ratios2], dim=0)\n",
    "    labels = torch.cat([torch.zeros(len(samples1)), torch.ones(len(samples2))], dim=0).long()\n",
    "    loss = -nn.functional.cross_entropy(probs, labels)\n",
    "    with torch.no_grad():\n",
    "        c2st = (sum(density_ratios1[:, 0] >= density_ratios1[:, 1]) + sum(density_ratios2[:, 0] < density_ratios2[:, 1])) / (len(samples1) + len(samples2))\n",
    "    return loss, c2st"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MOG_model_C2ST = MOGOpt(2)\n",
    "optimizer= torch.optim.Adam(MOG_model_C2ST.parameters(), lr=1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C2ST_samples = MOG_model_C2ST.sample(num_samples).detach().numpy()\n",
    "plt.scatter(C2ST_samples[:, 0], C2ST_samples[:, 1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_iters = 2000\n",
    "num_samples = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for epoch in range(n_iters):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    toy_samples = MOG.sample(num_samples)\n",
    "    toy_clusters = MOG_model_C2ST.e_step(toy_samples, update_weights=True)\n",
    "\n",
    "    loss = 0.\n",
    "    c2st = 0.\n",
    "    for cluster_idx in range(2):\n",
    "        toy_mask = toy_clusters == cluster_idx\n",
    "        toy_cluster_samples = toy_samples[toy_mask]\n",
    "        toy_n_samples = toy_mask.sum()\n",
    "        if cluster_idx == 0:\n",
    "            model = MOG_model_C2ST.G1\n",
    "        elif cluster_idx == 1:\n",
    "            model = MOG_model_C2ST.G2\n",
    "        model_cluster_samples = model.sample(toy_n_samples)\n",
    "        cent_loss, true_c2st = c2st_target(model_cluster_samples, toy_cluster_samples, model, MOG.gaussians[cluster_idx])\n",
    "        loss += cent_loss# * (toy_n_samples / toy_clusters.numel())\n",
    "        c2st += true_c2st# * (toy_n_samples / toy_clusters.numel())\n",
    "\n",
    "    # c2st_loss = torch.mean(torch.square(preds - 0.5))\n",
    "    if epoch % 5 == 0:\n",
    "        print(\"Iter: {}     loss: {}     c2st: {}\".format(epoch, cent_loss.item(), true_c2st.item()))\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MMD\n",
    "MOG_model_MMD = MOGOpt(dim=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bandwidth = 1\n",
    "optimizer = torch.optim.Adam(MOG_model_MMD.parameters(), lr=0.01)\n",
    "MOG_model_MMD.train()\n",
    "\n",
    "for epoch in range(n_iters):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    toy_samples = MOG.sample(num_samples)\n",
    "    toy_clusters = MOG_model_MMD.e_step(toy_samples, update_weights=True)\n",
    "\n",
    "    loss = 0.\n",
    "    c2st = 0.\n",
    "    for cluster_idx in range(2):\n",
    "        toy_mask = toy_clusters == cluster_idx\n",
    "        toy_cluster_samples = toy_samples[toy_mask]\n",
    "        toy_n_samples = toy_mask.sum()\n",
    "        if cluster_idx == 0:\n",
    "            model = MOG_model_MMD.G1\n",
    "        elif cluster_idx == 1:\n",
    "            model = MOG_model_MMD.G2\n",
    "        model_cluster_samples = model.sample(toy_n_samples)\n",
    "        cent_loss= compute_rbf_mmd(model_cluster_samples, toy_cluster_samples, bandwidth)\n",
    "        loss += cent_loss# * (toy_n_samples / toy_clusters.numel())\n",
    "\n",
    "    # c2st_loss = torch.mean(torch.square(preds - 0.5))\n",
    "    if epoch % 1== 0:\n",
    "        print(\"Iter: {}     loss: {}     \".format(epoch, cent_loss.item()))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "MOG_model_MMD.eval()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# WS\n",
    "\n",
    "MOG_model_WS = MOGOpt(dim=2)\n",
    "optimizer = torch.optim.Adam(MOG_model_WS.parameters(), lr=0.01)\n",
    "MOG_model_WS.train()\n",
    "\n",
    "for epoch in range(n_iters):\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    toy_samples = MOG.sample(num_samples)\n",
    "    toy_clusters = MOG_model_WS.e_step(toy_samples, update_weights=True)\n",
    "\n",
    "    loss = 0.\n",
    "    c2st = 0.\n",
    "    for cluster_idx in range(2):\n",
    "        toy_mask = toy_clusters == cluster_idx\n",
    "        toy_cluster_samples = toy_samples[toy_mask]\n",
    "        toy_n_samples = toy_mask.sum()\n",
    "        if cluster_idx == 0:\n",
    "            model = MOG_model_WS.G1\n",
    "        elif cluster_idx == 1:\n",
    "            model = MOG_model_WS.G2\n",
    "        model_cluster_samples = model.sample(toy_n_samples)\n",
    "        cent_loss= sliced_wasserstein_distance(model_cluster_samples,toy_cluster_samples)#, bandwidth=bandwidth)\n",
    "        loss += cent_loss# * (toy_n_samples / toy_clusters.numel())\n",
    "\n",
    "    # c2st_loss = torch.mean(torch.square(preds - 0.5))\n",
    "    if epoch % 1== 0:\n",
    "        print(\"Iter: {}     loss: {}     \".format(epoch, cent_loss.item()))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "MOG_model_WS.eval()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function for plotting elipses according to covariance matrix\n",
    "\n",
    "#FROM: https://github.com/joferkington/oost_paper_code/blob/master/error_ellipse.py\n",
    "from matplotlib.patches import Ellipse\n",
    "\n",
    "def plot_cov_ellipse(cov, pos, nstd=[1,2], ax=None, **kwargs):\n",
    "    \"\"\"\n",
    "    Plots an `nstd` sigma error ellipse based on the specified covariance\n",
    "    matrix (`cov`). Additional keyword arguments are passed on to the \n",
    "    ellipse patch artist.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "        cov : The 2x2 covariance matrix to base the ellipse on\n",
    "        pos : The location of the center of the ellipse. Expects a 2-element\n",
    "            sequence of [x0, y0].\n",
    "        nstd : The radius of the ellipse in numbers of standard deviations.\n",
    "            Defaults to 2 standard deviations.\n",
    "        ax : The axis that the ellipse will be plotted on. Defaults to the \n",
    "            current axis.\n",
    "        Additional keyword arguments are pass on to the ellipse patch.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "        A matplotlib ellipse artist\n",
    "    \"\"\"\n",
    "    def eigsorted(cov):\n",
    "        vals, vecs = np.linalg.eigh(cov)\n",
    "        order = vals.argsort()[::-1]\n",
    "        return vals[order], vecs[:,order]\n",
    "\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "\n",
    "    vals, vecs = eigsorted(cov)\n",
    "    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))\n",
    "    for std in nstd:\n",
    "    # Width and height are \"full\" widths, not radius\n",
    "        width, height = 2 * std * np.sqrt(vals)\n",
    "        print(pos,width, height,theta)\n",
    "        ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, **kwargs)\n",
    "\n",
    "        ax.add_artist(ellip)\n",
    "    return ellip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a grey colormap\n",
    "\n",
    "# FROM: https://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib\n",
    "import matplotlib.colors as colors\n",
    "\n",
    "\n",
    "def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):\n",
    "    new_cmap = colors.LinearSegmentedColormap.from_list(\n",
    "        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),\n",
    "        cmap(np.linspace(minval, maxval, n)))\n",
    "    return new_cmap\n",
    "\n",
    "arr = np.linspace(0, 50, 100).reshape((10, 10))\n",
    "fig, ax = plt.subplots(ncols=2)\n",
    "\n",
    "cmap = plt.get_cmap('gist_yarg')\n",
    "new_cmap = truncate_colormap(cmap, 0, 0.75)\n",
    "ax[0].imshow(arr, interpolation='nearest', cmap=cmap)\n",
    "ax[1].imshow(arr, interpolation='nearest', cmap=new_cmap)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = 10000\n",
    "\n",
    "# Create a dataframe for the mixture samples\n",
    "unif_df = pd.DataFrame({'x': mixture_samples[:, 0], 'y': mixture_samples[:, 1]})\n",
    "\n",
    "# Plot the probability contours\n",
    "fig,axs = plt.subplots(1,4, figsize = (6.5,3))\n",
    "sns.kdeplot(ax=axs[0],data=unif_df, x='x', y='y',cmap=\"Blues\", fill=True,levels=12)#, thresh=.1,)#, color=\"black\",alpha=0.75, levels=5)\n",
    "\n",
    "alpha_bg=1\n",
    "sns.kdeplot(ax=axs[1],data=unif_df, x='x', y='y', fill=True,cmap=new_cmap,alpha=alpha_bg, levels=12)\n",
    "sns.kdeplot(ax=axs[2],data=unif_df, x='x', y='y', fill=True,cmap=new_cmap,alpha=alpha_bg, levels=12)\n",
    "sns.kdeplot(ax=axs[3],data=unif_df, x='x', y='y', fill=True,cmap=new_cmap,alpha=alpha_bg, levels=12)\n",
    "\n",
    "std_plot = [np.sqrt(.25),np.sqrt(.5),np.sqrt(1),np.sqrt(2)]\n",
    "std_plot=[.25,.75,1.5,2.5]\n",
    "\n",
    "plot_cov_ellipse(MOG_model_WS.weights[0].detach().numpy()*MOG_model_WS.G1.cov().detach().numpy(),MOG_model_WS.G1.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[1],edgecolor='#cc241d', lw=1.5, facecolor='none')\n",
    "plot_cov_ellipse(MOG_model_MMD.weights[0]*MOG_model_MMD.G1.cov().detach().numpy(),MOG_model_MMD.G1.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[3],edgecolor='#eebd35', lw=1.5, facecolor='none')\n",
    "plot_cov_ellipse(MOG_model_C2ST.weights[0]*MOG_model_C2ST.G1.cov().detach().numpy(),MOG_model_C2ST.G1.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[2],edgecolor='#458588', lw=1.5, facecolor='none')\n",
    "\n",
    "\n",
    "\n",
    "plot_cov_ellipse(MOG_model_WS.weights[1].detach().numpy()*MOG_model_WS.G2.cov().detach().numpy(),MOG_model_WS.G2.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[1],edgecolor='#cc241d', lw=1.5, facecolor='none')\n",
    "plot_cov_ellipse(MOG_model_MMD.weights[1].detach().numpy()*MOG_model_MMD.G2.cov().detach().numpy(),MOG_model_MMD.G2.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[3],edgecolor='#eebd35', lw=1.5, facecolor='none')\n",
    "plot_cov_ellipse(MOG_model_C2ST.weights[1].detach().numpy()*MOG_model_C2ST.G2.cov().detach().numpy(),MOG_model_C2ST.G2.mean.detach().numpy(),\n",
    "                 nstd = std_plot ,ax=axs[2],edgecolor='#458588', lw=1.5, facecolor='none')\n",
    "\n",
    "\n",
    "for ax in axs:\n",
    "    ax.spines[['left', 'bottom']].set_visible(False)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.set_xlim(-8,5)\n",
    "    ax.set_xlim(-10,10)\n",
    "\n",
    "    ax.set_ylim(-7,4)\n",
    "    \n",
    "    #make square subplots \n",
    "    ax.set_box_aspect(1)\n",
    "axs[0].set_title(r\"$p_{true}$\")\n",
    "axs[1].set_title(\"SW\",color ='#cc241d')\n",
    "axs[3].set_title(r\"$MMD_1$\",color ='#eebd35')\n",
    "axs[2].set_title(\"C2ST\",color =\"#458588\")\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"mode.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labproject",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
