{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6223a87-fbe4-46ef-b0ee-4fc2ac7f4adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get two figures for the radial map (original and transformed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8184ec35-cd76-48b3-826d-7c4551582f77",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.collections import LineCollection\n",
    "import numpy as onp\n",
    "\n",
    "def cart2pol(x, y):\n",
    "    '''\n",
    "    From cartesian to polar coordinates\n",
    "    '''\n",
    "    rho = jnp.sqrt(x**2 + y**2)\n",
    "    phi = jnp.arctan2(y, x)\n",
    "    return(rho, phi)\n",
    "\n",
    "def pol2cart(rho, phi):\n",
    "    '''\n",
    "    From polar to cartesian coordinates\n",
    "    '''\n",
    "    x = rho * jnp.cos(phi)\n",
    "    y = rho * jnp.sin(phi)\n",
    "    return(x, y)\n",
    "\n",
    "def identity(rho, phi):\n",
    "    '''\n",
    "    Identity\n",
    "    '''\n",
    "    return(rho, phi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64a83d0f-c56c-4ea1-8b23-4766b82e9948",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2b2ea4f-6a39-4197-89f8-27cede036257",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_grid_x(x,y, ax=None, **kwargs):\n",
    "# def plot_grid(x,y, ax=None):\n",
    "    ax = ax or plt.gca()\n",
    "    segs1 = jnp.stack((x,y), axis=2)\n",
    "    segs2 = segs1.transpose(1,0,2)\n",
    "#     ax.add_collection(LineCollection(segs1, edgecolors=\"black\"))\n",
    "    ax.add_collection(LineCollection(segs1, **kwargs))\n",
    "#     ax.add_collection(LineCollection(segs2, edgecolors=\"black\"))\n",
    "#     ax.add_collection(LineCollection(segs2, **kwargs))\n",
    "    ax.autoscale()\n",
    "    \n",
    "def plot_grid_y(x,y, ax=None, **kwargs):\n",
    "# def plot_grid(x,y, ax=None):\n",
    "    ax = ax or plt.gca()\n",
    "    segs1 = jnp.stack((x,y), axis=2)\n",
    "    segs2 = segs1.transpose(1,0,2)\n",
    "#     ax.add_collection(LineCollection(segs1, edgecolors=\"black\"))\n",
    "#     ax.add_collection(LineCollection(segs1, **kwargs))\n",
    "#     ax.add_collection(LineCollection(segs2, edgecolors=\"black\"))\n",
    "    ax.add_collection(LineCollection(segs2, **kwargs))\n",
    "    ax.autoscale()\n",
    "    \n",
    "def show_grid_plot(f, multi_argument=False, extremes_x=(0,1), extremes_y=(0,1),savefig=False, fname=\"grplot\"):\n",
    "    '''\n",
    "    Plots how a regularly spaced grid in a 2d space is distorted under the action of the function f\n",
    "    \n",
    "    f: A mixing function\n",
    "    multi_argument: A Boolean variable; checks whether f takes a (N,2) array as input, or two (N,) arrays.\n",
    "                    In the latter case, internally builds a version of f which takes two (N,) arrays as input.\n",
    "    '''\n",
    "    \n",
    "    if multi_argument==False:\n",
    "        def f_grid(x, y):\n",
    "            z = jnp.array([x, y])\n",
    "            z_ = f(z)\n",
    "            return z_[0], z_[1]\n",
    "    else:\n",
    "        f_grid = f\n",
    "\n",
    "    bottom_x, top_x = extremes_x\n",
    "    bottom_y, top_y = extremes_y\n",
    "    \n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    grid_x,grid_y = jnp.meshgrid(jnp.linspace(bottom_x,top_x,15),jnp.linspace(bottom_y,top_y,15))\n",
    "#     plot_grid(grid_x,grid_y, ax=ax,  color=\"lightgrey\")\n",
    "\n",
    "#1A85FF\n",
    "#D41159\n",
    "\n",
    "    distx, disty = f_grid(grid_x,grid_y)\n",
    "#     plot_grid(distx, disty, ax=ax)\n",
    "    plot_grid_x(distx, disty, ax=ax, color=\"#1A85FF\")\n",
    "    plot_grid_y(distx, disty, ax=ax, color=\"#D41159\")\n",
    " \n",
    "#     plt.gca().set_aspect('equal', adjustable='box')\n",
    "#     if savefig==True:\n",
    "#         plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "#             orientation='portrait', papertype=None, format=None,\n",
    "#             transparent=False, bbox_inches=None, pad_inches=0.1,\n",
    "#             frameon=None, metadata=None)\n",
    "#     plt.show()\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b5474b8-07e9-4c61-8e75-99cb00cca167",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdaf7096-7361-483f-b28d-e7aa64327de3",
   "metadata": {},
   "outputs": [],
   "source": [
    "figure_path = \"/Users/luigigresele/Documents/Plots_IMA\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3a49c51-7fa6-4fc7-9954-cef3a72fea5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import MaxNLocator, MultipleLocator, Locator, FixedLocator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6185260-4dfc-4baa-b06a-8bf37c8efd8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize_math = 14\n",
    "fontsize_ticks = 13"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd1dcdf-2fd1-46c3-a3b6-825ee1596dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "extremes_x=(0.7,1.5)\n",
    "extremes_y=(-0.4,.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fec9fc9c-28ab-4fec-862b-1e5d5211c0fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "extremes_x=(0.5,1.5)\n",
    "extremes_y=(-0.5,.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70fd6245-0b47-4b87-8c2c-6ab9aedd792f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Roughly square observations plot\n",
    "# extremes_x=(0.2,0.4)\n",
    "# extremes_y=(-jnp.pi/12,jnp.pi/12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e711ef1b-7cdb-40ca-8b17-ab59444f00c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = os.path.join(figure_path,\"source_grid.pdf\")\n",
    "fig, ax = show_grid_plot(identity, extremes_x=extremes_x, extremes_y=extremes_y, multi_argument=True)\n",
    "# plt.gca().set_aspect('equal', adjustable='box')\n",
    "plt.gca().set_aspect(1.0/ax.get_data_ratio(), adjustable='box')\n",
    "plt.xlabel(r\"$r$\", fontsize=fontsize_math)\n",
    "plt.ylabel(r\"$\\theta$\", fontsize=fontsize_math)\n",
    "# ax.xaxis.set_major_locator(FixedLocator([0.5, 1.5]))\n",
    "ax.yaxis.set_major_locator(FixedLocator([-jnp.pi/8, 0, jnp.pi/8]))\n",
    "# ax.set_yticks(jnp.arange(-jnp.pi/8, jnp.pi/8))\n",
    "labels = [r'$\\pi/8$', r'$0$', r'$\\pi/8$']\n",
    "ax.set_yticklabels(labels)\n",
    "ax.tick_params(labelsize = fontsize_ticks)\n",
    "plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "            orientation='portrait', papertype=None, format=None,\n",
    "            transparent=False, bbox_inches='tight', pad_inches=0.1,\n",
    "            frameon=None, metadata=None)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af141cf-d40a-4e9b-92b4-1cba781a07dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "fname = os.path.join(figure_path,\"observations_grid.pdf\")\n",
    "fig, ax = show_grid_plot(pol2cart, extremes_x=extremes_x, extremes_y=extremes_y, multi_argument=True)\n",
    "plt.gca().set_aspect('equal', adjustable='box')\n",
    "# plt.gca().set_aspect(1.0/ax.get_data_ratio(), adjustable='box')\n",
    "plt.xlabel(r\"$x$\", fontsize=fontsize_math)\n",
    "plt.ylabel(r\"$y$\", fontsize=fontsize_math)\n",
    "ax.tick_params(labelsize = fontsize_ticks)\n",
    "plt.savefig(fname, dpi=None, facecolor='w', edgecolor='w',\n",
    "            orientation='portrait', papertype=None, format=None,\n",
    "            transparent=False, bbox_inches='tight', pad_inches=0.1,\n",
    "            frameon=None, metadata=None)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17d690b9-5675-4520-82c6-cfd28b968ae4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0765004-f7a6-4864-b9dd-2ee458a7dae3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
