{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6c93d7f8-5a2d-4ae1-b3e2-c781a306b8ae",
   "metadata": {},
   "source": [
    "### The following code can be used to reconstruct figure 4 (top row) in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "461d90a5-e78a-4761-bd01-3cb568dc36aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys,os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c5d8aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "from jax import numpy as jnp\n",
    "import numpy as np\n",
    "import distrax\n",
    "import haiku as hk\n",
    "from ima.residual import TriangularResidual, ConstantScaling\n",
    "from ima.utils import get_config\n",
    "\n",
    "from ima.plotting import cart2pol \n",
    "\n",
    "from ima.mixing_functions import build_moebius_transform, build_automorphism"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f90410d-e23f-4064-a8fd-defc935265d5",
   "metadata": {},
   "source": [
    "### Get train/test data and parameters of the Möbius transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c72873a",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_darmois = get_config('config/darmois_2d_moeb.yaml')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7deba11b",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_train = jnp.array(jnp.load('data/figure_4_top_sources_train.npy'))\n",
    "S_test = jnp.array(jnp.load('data/figure_4_top_sources_test.npy'))\n",
    "X_train = jnp.array(jnp.load('data/figure_4_top_observation_train.npy'))\n",
    "X_test = jnp.array(jnp.load('data/figure_4_top_observation_test.npy'))\n",
    "mean_std = jnp.load('data/figure_4_top_observation_mean_std.npy', allow_pickle=True).item()\n",
    "mean_train, std_train = mean_std['mean'], mean_std['std']\n",
    "moeb_params = jnp.load('data/figure_4_top_moebius_transform_params.npy', allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c12887ff-4e3a-468f-ad15-4becef0c6ec3",
   "metadata": {},
   "source": [
    "### Specify how many datapoints to use for the plots and re-define the train/test splits accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a905f688-5f6e-4d7c-b58b-cda876735a02",
   "metadata": {},
   "outputs": [],
   "source": [
    "howmany = 5000\n",
    "\n",
    "S_train = S_train[:howmany]\n",
    "S_test = S_test[:howmany]\n",
    "X_train = X_train[:howmany]\n",
    "X_test = X_test[:howmany]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6af0ce5-43c7-49ef-bb46-3d4a7b65acb4",
   "metadata": {},
   "source": [
    "### Define the true mixing and unmixing, based on the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a329bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 1.0\n",
    "A = jnp.array(moeb_params['A'])\n",
    "a = jnp.array(moeb_params['a'])\n",
    "b = jnp.zeros(2) \n",
    "\n",
    "mixing_moebius, mixing_moebius_inv = build_moebius_transform(alpha, A, a, b, epsilon=2)\n",
    "mixing_batched = jax.vmap(mixing_moebius)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e3f24d-8e0f-42f9-8efc-739102a10e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def scatterplot_variables(X, title, colors='None', cmap='hsv'):\n",
    "    '''\n",
    "    Scatterplot of 2d variables, can be used both for the mixing and the unmixing\n",
    "    X : (N,D) array -- N samples, D dimensions (D=2).ss\n",
    "    '''\n",
    "    if colors=='None':\n",
    "        plt.scatter(X[:,0], X[:,1], color='r', s=30)\n",
    "    else:\n",
    "        plt.scatter(X[:,0], X[:,1], c=colors, s=30, alpha=0.75, cmap=cmap)\n",
    "\n",
    "    plt.axis('off')\n",
    "    plt.gca().set_aspect('equal', adjustable='box')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1522058e-f7e9-4be9-ad3b-1b1b16fa6b87",
   "metadata": {},
   "source": [
    "### Path where figures should be saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be0d2b0c-254e-41c0-98f1-d6331e55fbc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "figure_path = \"Plots_IMA\"#/perceptually_uniform\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a440489-e20b-4d65-a55b-ca1b69ebb029",
   "metadata": {},
   "source": [
    "### Colormap for the plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3035866-9b9b-430e-9c10-4eeca5ec5d83",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = 'hsv'\n",
    "\n",
    "# Uncomment for perceptually uniform colormaps\n",
    "# import cmocean\n",
    "# cmap = cmocean.cm.phase \n",
    "\n",
    "_, colors_train = cart2pol(S_train[:, 0], S_train[:, 1])\n",
    "_, colors_test = cart2pol(S_test[:, 0], S_test[:, 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80b5049d-761d-4dcf-bde2-e4422d1eb417",
   "metadata": {},
   "source": [
    "### 1. Plot ground truth sources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9444b002-e530-4b22-a292-c02bf92c57eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatterplot_variables(S_test, 'Sources (test)',\n",
    "                      colors=colors_test, cmap=cmap)\n",
    "\n",
    "\n",
    "plt.title('Ground truth', fontsize=19)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# fname = os.path.join(figure_path, \"true_sources.pdf\")\n",
    "# plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches='tight', pad_inches=0.0,\n",
    "#             frameon=None, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25f5c6a2-3f0b-43c0-bcee-b0052001ebba",
   "metadata": {},
   "source": [
    "### 2. Plot the observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3346dc69-c229-4e64-b7e6-2fb0ca2b985b",
   "metadata": {},
   "outputs": [],
   "source": [
    "scatterplot_variables(X_test, 'Observations (test)',\n",
    "                      colors=colors_test, cmap=cmap)\n",
    "\n",
    "plt.title('Observations', fontsize=19)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# fname = os.path.join(figure_path, \"observations.pdf\")\n",
    "# plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches='tight', pad_inches=0.0,\n",
    "#             frameon=None, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26c91a0a-e7c3-4450-96f4-a56b225acece",
   "metadata": {},
   "source": [
    "### Load parameters of the trained Darmois construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c2eb05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_darmois = hk.data_structures.to_immutable_dict(jnp.load('models/figure_4_top_darmois.npy', allow_pickle=True).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c298edd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup model\n",
    "n_layers_darmois = config_darmois['model']['flow_layers']\n",
    "hidden_units_darmois = config_darmois['model']['nn_layers'] * [config_darmois['model']['nn_hidden_units']]\n",
    "\n",
    "def inv_map_fn_darmois(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_darmois + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_darmois)] + [ConstantScaling(std_train)])\n",
    "    return flows.inverse(x)\n",
    "\n",
    "def fw_map_fn_darmois(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_darmois + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_darmois)] + [ConstantScaling(std_train)])\n",
    "    return flows.forward(x)\n",
    "\n",
    "fw_map_darmois = hk.transform(fw_map_fn_darmois)\n",
    "inv_map_darmois = hk.transform(inv_map_fn_darmois)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6911052-f3eb-46a7-b652-7fa0aa0329ed",
   "metadata": {},
   "source": [
    "### 3. Plot Darmois construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a51f39",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_rec_darmois = inv_map_darmois.apply(params_darmois, None, X_test)\n",
    "S_rec_uni_darmois = jnp.column_stack([jax.scipy.stats.norm.cdf(S_rec_darmois[:, 0]),\n",
    "                              jax.scipy.stats.norm.cdf(S_rec_darmois[:, 1])])\n",
    "S_rec_uni_darmois -= 0.5\n",
    "\n",
    "scatterplot_variables(S_rec_uni_darmois, 'Reconstructed sources (test)',\n",
    "                      colors=colors_test, cmap=cmap)\n",
    "\n",
    "plt.title('Darmois', fontsize=19)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# fname = os.path.join(figure_path, \"darmois.pdf\")\n",
    "# plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches='tight', pad_inches=0.0,\n",
    "#             frameon=None, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36223c51-8b04-466e-97eb-473928418f07",
   "metadata": {},
   "source": [
    "### 4. Plot effects of a \"Rotated-Gaussian\" measure-preserving automorphism, with $\\pi/4$ rotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "392a1276",
   "metadata": {},
   "outputs": [],
   "source": [
    "theta = np.radians(45)\n",
    "c, s = np.cos(theta), np.sin(theta)\n",
    "R = np.array([[c, -s], [s, c]])\n",
    "\n",
    "measure_preserving, measure_preserving_inv = build_automorphism(R)\n",
    "measure_preserving_batched = jax.vmap(measure_preserving)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e0bc0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_ = measure_preserving_batched(S_test + 0.5)\n",
    "\n",
    "scatterplot_variables(S_, 'Mapped sources (test)',\n",
    "                      colors=colors_test, cmap=cmap)\n",
    "\n",
    "\n",
    "plt.title('MPA $\\pi/4$', fontsize=19)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# fname = os.path.join(figure_path, \"mpa.pdf\")\n",
    "# plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches='tight', pad_inches=0.0,\n",
    "#             frameon=None, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66d0b314-1baa-4c1a-ae97-0efcd03571e1",
   "metadata": {},
   "source": [
    "### 5. Plot effects of the composition of the Darmois construction with a \"Rotated-Gaussian\" measure-preserving automorphism, with $\\pi/4$ rotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe6d603b",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_rec_uni_ = measure_preserving_batched(S_rec_uni_darmois + 0.5)\n",
    "\n",
    "scatterplot_variables(S_rec_uni_, 'Reconstructed sources (test)',\n",
    "                      colors=colors_test, cmap=cmap)\n",
    "\n",
    "plt.title('Darmois + MPA $\\pi/4$', fontsize=19)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# fname = os.path.join(figure_path, \"mpa_composed_darmois.pdf\")\n",
    "# plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches='tight', pad_inches=0.0,\n",
    "#             frameon=None, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09aa8a25-6a54-4d25-9091-0f6a47cc3c29",
   "metadata": {},
   "source": [
    "### 6. Plotting the maximum likelihood learned model ($\\lambda = 0$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65a50d6a-e9dc-4d0f-bb49-0253de4d5203",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_mle = hk.data_structures.to_immutable_dict(jnp.load('models/figure_4_top_mle.npy', allow_pickle=True).item())\n",
    "\n",
    "# Setup model\n",
    "n_layers_mle = config_darmois['model']['flow_layers']\n",
    "hidden_units_mle = config_darmois['model']['nn_layers'] * [config_darmois['model']['nn_hidden_units']]\n",
    "\n",
    "def inv_map_fn_mle(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_mle + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_mle)] + [ConstantScaling(std_train)])\n",
    "    return flows.inverse(x)\n",
    "\n",
    "def fw_map_fn_mle(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_mle + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_mle)] + [ConstantScaling(std_train)])\n",
    "    return flows.forward(x)\n",
    "\n",
    "fw_map_mle = hk.transform(fw_map_fn_mle)\n",
    "inv_map_mle = hk.transform(inv_map_fn_mle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c34afab4-18a6-41df-abb5-495c3d3ea401",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_rec_mle = inv_map_mle.apply(params_mle, None, X_test)\n",
    "S_rec_uni_mle = jnp.column_stack([jax.scipy.stats.logistic.cdf(S_rec_mle[:, 0]),\n",
    "                              jax.scipy.stats.logistic.cdf(S_rec_mle[:, 1])])\n",
    "S_rec_uni_mle -= 0.5\n",
    "\n",
    "scatterplot_variables(S_rec_uni_mle, 'Reconstructed sources (test)',\n",
    "                      colors=colors_test, cmap=cmap)\n",
    "\n",
    "plt.title('MLE, $\\lambda=0$', fontsize=19)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# fname = os.path.join(figure_path, \"mle.pdf\")\n",
    "# plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches='tight', pad_inches=0.0,\n",
    "#             frameon=None, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1014dace-8ed5-47e8-bc51-a965172e46dd",
   "metadata": {},
   "source": [
    "### 7. Plot reconstructed sources with $C_{\\operatorname{IMA}}$ regularized model ($\\lambda=1$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b08abb9-6054-44c9-92b6-8eae76f9f47f",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_cima = hk.data_structures.to_immutable_dict(jnp.load('models/figure_4_top_cima_obj.npy', allow_pickle=True).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ffdd4e3-2464-4cd4-a7de-80e9012dc16e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup model\n",
    "n_layers_cima = config_darmois['model']['flow_layers']\n",
    "hidden_units_cima = config_darmois['model']['nn_layers'] * [config_darmois['model']['nn_hidden_units']]\n",
    "\n",
    "def inv_map_fn_cima(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_cima + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_cima)] + [ConstantScaling(std_train)])\n",
    "    return flows.inverse(x)\n",
    "\n",
    "def fw_map_fn_cima(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_cima + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_cima)] + [ConstantScaling(std_train)])\n",
    "    return flows.forward(x)\n",
    "\n",
    "fw_map_cima = hk.transform(fw_map_fn_cima)\n",
    "inv_map_cima = hk.transform(inv_map_fn_cima)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e843ba53-c8ef-46c2-aadc-b3ec6c1dbf38",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_rec_cima = inv_map_cima.apply(params_cima, None, X_test)\n",
    "S_rec_uni_cima = jnp.column_stack([jax.scipy.stats.logistic.cdf(S_rec_cima[:, 0]),\n",
    "                              jax.scipy.stats.logistic.cdf(S_rec_cima[:, 1])])\n",
    "S_rec_uni_cima -= 0.5\n",
    "\n",
    "scatterplot_variables(S_rec_uni_cima, 'Reconstructed sources (test)',\n",
    "                      colors=colors_test, cmap=cmap)\n",
    "\n",
    "\n",
    "plt.title('$C_{\\operatorname{IMA}}, \\lambda=1$', fontsize=19)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# fname = os.path.join(figure_path, \"cima_model.pdf\")\n",
    "# plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches='tight', pad_inches=0.0,\n",
    "#             frameon=None, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}