{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6c93d7f8-5a2d-4ae1-b3e2-c781a306b8ae",
   "metadata": {},
   "source": [
    "### The following code can be used to reconstruct figure 4 **(c)**, (bottom row) in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e22e0e2d-f5bd-4691-960c-b72cfb8f3e6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys,os\n",
    "sys.path.append('/Users/luigigresele/git/projects/ica_and_icm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c5d8aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "from jax import numpy as jnp\n",
    "import numpy as np\n",
    "import distrax\n",
    "import haiku as hk\n",
    "from residual import TriangularResidual, ConstantScaling\n",
    "from utils import get_config\n",
    "\n",
    "from plotting import cart2pol \n",
    "\n",
    "from mixing_functions import build_moebius_transform, build_automorphism"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f90410d-e23f-4064-a8fd-defc935265d5",
   "metadata": {},
   "source": [
    "### Get train/test data and parameters of the Möbius transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c72873a",
   "metadata": {},
   "outputs": [],
   "source": [
    "number_darmois = '0308'\n",
    "model_root_darmois = '/Users/luigigresele/Desktop/ICA and ICM/Experiments_Vincent/'+ number_darmois +'/projects/ica-flows/experiments/triresflow/2d/'+ number_darmois +'/'\n",
    "\n",
    "config_darmois = get_config(model_root_darmois + 'config/config.yaml')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7deba11b",
   "metadata": {},
   "outputs": [],
   "source": [
    "S_train = jnp.array(jnp.load(model_root_darmois + 'data/sources_train.npy'))\n",
    "S_test = jnp.array(jnp.load(model_root_darmois + 'data/sources_test.npy'))\n",
    "X_train = jnp.array(jnp.load(model_root_darmois + 'data/observation_train.npy'))\n",
    "X_test = jnp.array(jnp.load(model_root_darmois + 'data/observation_test.npy'))\n",
    "mean_std = jnp.load(model_root_darmois + 'data/observation_mean_std.npy', allow_pickle=True).item()\n",
    "mean_train, std_train = mean_std['mean'], mean_std['std']\n",
    "moeb_params = jnp.load(model_root_darmois + 'data/moebius_transform_params.npy', allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c12887ff-4e3a-468f-ad15-4becef0c6ec3",
   "metadata": {},
   "source": [
    "### Specify how many datapoints to use for the plots and re-define the train/test splits accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a905f688-5f6e-4d7c-b58b-cda876735a02",
   "metadata": {},
   "outputs": [],
   "source": [
    "howmany = 5000\n",
    "\n",
    "S_train = S_train[:howmany]\n",
    "S_test = S_test[:howmany]\n",
    "X_train = X_train[:howmany]\n",
    "X_test = X_test[:howmany]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6af0ce5-43c7-49ef-bb46-3d4a7b65acb4",
   "metadata": {},
   "source": [
    "### Define the true mixing and unmixing, based on the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a329bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 1.0\n",
    "A = jnp.array(moeb_params['A'])\n",
    "a = jnp.array(moeb_params['a'])\n",
    "b = jnp.zeros(2) \n",
    "\n",
    "mixing_moebius, mixing_moebius_inv = build_moebius_transform(alpha, A, a, b, epsilon=2)\n",
    "mixing_batched = jax.vmap(mixing_moebius)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e3f24d-8e0f-42f9-8efc-739102a10e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be0d2b0c-254e-41c0-98f1-d6331e55fbc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "figure_path = \"/Users/luigigresele/Documents/Plots_IMA\"#/perceptually_uniform\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ecc5d47-e63c-4255-887b-fff788747ba4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d4c4d89-849e-4464-83b5-4bf17066f303",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import vmap, jacfwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ea93937-7b9c-48c1-95e8-d6ac768ab10f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from metrics import cima"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "946ec68d-1350-4a7a-ae4f-650a3442df06",
   "metadata": {},
   "source": [
    "### Plot $C_{\\operatorname{IMA}}$ of the measure preserving automorphism composed on the right with the true unmixing, for different values of the rotation angle $\\theta$ (takes roughly one minute on my laptop)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "149b0f5e-72ad-4687-9207-1cd30e5e55a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reconstruct observations by mixing the sources, since the observations are saved de-meaned.\n",
    "S_test_mixed = vmap(mixing_moebius)(S_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf8c1951-6044-4651-a3ac-e2ef2f4e5292",
   "metadata": {},
   "outputs": [],
   "source": [
    "angles = jnp.linspace(0, 360, num=720)\n",
    "cimas_mpa = np.copy(angles)\n",
    "\n",
    "for i, angle in enumerate(angles):\n",
    "    # Build a rotation matrix\n",
    "    theta = jnp.radians(angle)\n",
    "    c, s = jnp.cos(theta), jnp.sin(theta)\n",
    "    R = jnp.array([[c, -s], [s, c]])\n",
    "\n",
    "    # Build measure pres. automorphism\n",
    "    measure_preserving, measure_preserving_inv = build_automorphism(R)\n",
    "\n",
    "    # Map the sources through **direct** mpa\n",
    "    measure_preserving_batched = vmap(measure_preserving)\n",
    "    Y = measure_preserving_batched(S_test)\n",
    "    \n",
    "    # Compose **its inverse** it with the mixing\n",
    "    def composed_inverse_transformation(x):\n",
    "        y = mixing_moebius_inv(x) + 0.5\n",
    "        return measure_preserving(y)\n",
    "    comp_inv_batched = jax.vmap(composed_inverse_transformation)\n",
    "        \n",
    "    # Compute the Jacobian\n",
    "    Jcomposed = jacfwd(composed_inverse_transformation)\n",
    "    Jcomposed_batched = vmap(Jcomposed)\n",
    "    # Compute aDM true\n",
    "    composed_cima = jnp.mean(cima(S_test_mixed, Jcomposed_batched))\n",
    "    cimas_mpa[i] = composed_cima    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "640d221f-d7ca-48c4-bfdc-24f11486c1a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(angles, cimas_mpa)\n",
    "plt.xlabel(r'$\\theta$')\n",
    "plt.ylabel(r'$C_{IMA}$')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4105e05-0987-4f3b-8ff2-44eb47ca82da",
   "metadata": {},
   "source": [
    "### Plot $C_{\\operatorname{IMA}}$ of the measure preserving automorphism composed on the right with the **learnt Darmois construction**, for different values of the rotation angle $\\theta$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26c91a0a-e7c3-4450-96f4-a56b225acece",
   "metadata": {},
   "source": [
    "#### Load parameters of the trained Darmois construction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c2eb05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_darmois = hk.data_structures.to_immutable_dict(jnp.load(model_root_darmois + 'checkpoints/model_100000.npy', allow_pickle=True).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c298edd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup model\n",
    "n_layers_darmois = config_darmois['model']['flow_layers']\n",
    "hidden_units_darmois = config_darmois['model']['nn_layers'] * [config_darmois['model']['nn_hidden_units']]\n",
    "\n",
    "def inv_map_fn_darmois(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_darmois + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_darmois)] + [ConstantScaling(std_train)])\n",
    "    return flows.inverse(x)\n",
    "\n",
    "def fw_map_fn_darmois(x):\n",
    "    flows = distrax.Chain([TriangularResidual(hidden_units_darmois + [2], name='residual_' + str(i))\n",
    "                           for i in range(n_layers_darmois)] + [ConstantScaling(std_train)])\n",
    "    return flows.forward(x)\n",
    "\n",
    "fw_map_darmois = hk.transform(fw_map_fn_darmois)\n",
    "inv_map_darmois = hk.transform(inv_map_fn_darmois)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c6886cd-e3e8-41b1-b89d-a44c8c8d17b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "inv_map_darmois_apply = lambda y: inv_map_darmois.apply(params_darmois, None, y)\n",
    "jac_invmap = jax.vmap(jax.jacfwd(inv_map_darmois_apply))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddd5892c-4820-4998-8283-3690e58011dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "time_0 = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deee249f-96f2-4995-811f-18fceeb256a9",
   "metadata": {},
   "source": [
    "#### Set the number of angles to test. \n",
    "#### Figure in the paper is made with 720, which took $\\sim 1$ hour on my laptop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a989bdf4-c8a4-4df4-9da3-a4edf517ffc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "num = 13"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f11bb84e-a93c-4e00-88b0-a82166c23a52",
   "metadata": {},
   "outputs": [],
   "source": [
    "angles = jnp.linspace(0, 360, num=num)\n",
    "cimas = np.copy(angles)\n",
    "\n",
    "start = np.copy(time_0)\n",
    "\n",
    "for i, angle in enumerate(angles):\n",
    "    # Build a rotation matrix\n",
    "    theta = jnp.radians(angle)\n",
    "    c, s = jnp.cos(theta), np.sin(theta)\n",
    "    R = jnp.array([[c, -s], [s, c]])\n",
    "\n",
    "    # Build measure pres. automorphism\n",
    "    measure_preserving, measure_preserving_inv = build_automorphism(R)\n",
    "\n",
    "# #     Map the sources through **direct** mpa\n",
    "#     measure_preserving_batched = vmap(measure_preserving)\n",
    "#     Y = measure_preserving_batched(S_rec_uni)\n",
    "\n",
    "    \n",
    "    # Compose **its inverse** it with the mixing\n",
    "    def composed_inverse_transformation(x):\n",
    "        y = inv_map_darmois_apply(x)\n",
    "        y= jax.scipy.stats.norm.cdf(y)\n",
    "        return measure_preserving(y)\n",
    "    comp_inv_batched = vmap(composed_inverse_transformation)\n",
    "    \n",
    "    Y_1 = comp_inv_batched(X_test)\n",
    "    \n",
    "    # Compute the Jacobian\n",
    "    Jcomposed = jacfwd(composed_inverse_transformation)\n",
    "    Jcomposed_batched = vmap(Jcomposed)\n",
    "    # Compute aDM true\n",
    "    c_cima_ = cima(X_test, Jcomposed_batched)\n",
    "    composed_cima = jnp.mean(c_cima_)\n",
    "    cimas[i] = composed_cima\n",
    "    if i%5==0:\n",
    "        stop = time.time()\n",
    "        print(i, \"Duration:\", stop - start, \"; Total: \", stop - time_0)\n",
    "        start = time.time()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce760ce4-5e96-43cd-81c1-125180b006e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(angles, cimas)\n",
    "plt.xlabel(r'$\\theta$')\n",
    "plt.ylabel(r'$C_{IMA}$')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02937b5e-f41a-46a7-9176-816cea86b9d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.save('cima_darmois_mpa_0308', cimas)#, fmt='%d')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32efa4b3-b26f-4b22-b557-1f413e58e274",
   "metadata": {},
   "source": [
    "#### Load the values here if already pre-computed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1626593-430f-4df4-828c-de12ad4291bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_angles = jnp.linspace(0, 360, num=720)\n",
    "loaded_cimas = np.load('/Users/luigigresele/git/projects/ica_and_icm/cima_darmois_mpa_0308.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "938e793b-0251-450a-9050-edaa072f7892",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(loaded_angles, loaded_cimas)\n",
    "plt.xlabel(r'$\\theta$')\n",
    "plt.ylabel(r'$C_{IMA}$')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80fbc228-b270-46a6-9bb0-60b784998b14",
   "metadata": {},
   "source": [
    "### Plot both together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c054771-3d4b-401b-89eb-b4414396c327",
   "metadata": {},
   "outputs": [],
   "source": [
    "angles_rad = np.radians(loaded_angles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb7d9877-6271-4f95-abbc-7d9f88b4e2e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a81605d-ecc7-46eb-b826-130f3dd06a7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "figure_path = \"/Users/luigigresele/Documents/Plots_IMA\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed815e5e-e38a-4d4b-b1af-1469cf6a86ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import FuncFormatter, MultipleLocator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c766000d-3df3-4a97-abe6-12c803d899e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "fig,ax = plt.subplots()\n",
    "ax.plot(angles_rad, cimas_mpa, label='MPA')\n",
    "ax.plot(angles_rad, loaded_cimas, label='Darmois+MPA')\n",
    "plt.xlabel(r'$\\theta$')\n",
    "plt.ylabel(r'$C_{IMA}$')\n",
    "\n",
    "ax.set_xticks(np.arange(0, 2*np.pi+0.01, np.pi/2))\n",
    "labels = ['$0$', r'$\\pi/2$', r'$\\pi$',\n",
    "            r'$3\\pi/2$', r'$2\\pi$']\n",
    "ax.set_xticklabels(labels)\n",
    "\n",
    "# Uncomment if you want to save the figure\n",
    "\n",
    "# plt.savefig(os.path.join(figure_path, 'mpa_darmois.pdf'), \n",
    "#             dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', format=None,\n",
    "#             transparent=True, bbox_inches='tight', pad_inches=0.1, metadata=None)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
