{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax \n",
    "import jax.numpy as jnp \n",
    "from jaxtyping import Float, Key, Array \n",
    "import numpy as np \n",
    "import pandas as pd \n",
    "import plotly.express as px \n",
    "from plotly import graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "from scipy.special import sph_harm "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rotate(sph):\n",
    "    \"\"\"\n",
    "    Apply a roll rotation of pi / 2 to the points in spherical coordinates (colatitude, longitude) in [0, pi] x [0, 2pi].\n",
    "    \n",
    "    Parameters:\n",
    "    colatitude (float or jnp.ndarray): Colatitude values in the range [0, pi].\n",
    "    longitude (float or jnp.ndarray): Longitude values in the range [0, 2pi].\n",
    "\n",
    "    Returns:\n",
    "    jnp.ndarray: Array of transformed colatitude and longitude values.\n",
    "    \"\"\"\n",
    "    colatitude, longitude = sph[..., 0], sph[..., 1]\n",
    "    \n",
    "    # Convert back to spherical coordinates\n",
    "    new_colatitude = jnp.arccos(jnp.sin(colatitude) * jnp.sin(longitude))\n",
    "    new_longitude = jnp.arctan2(-jnp.cos(colatitude), jnp.sin(colatitude) * jnp.cos(longitude))\n",
    "    \n",
    "    return jnp.stack([new_colatitude, new_longitude], axis=-1)\n",
    "\n",
    "\n",
    "\n",
    "def reversed_spherical_harmonic(sph, m: int, n: int):\n",
    "    colat, lon = sph[..., 0], sph[..., 1]\n",
    "    return jnp.asarray(sph_harm(m, n, np.asarray(colat), np.asarray(lon)).real)\n",
    "\n",
    "\n",
    "def target_f__reversed_spherical_harmonic(sph: Float) -> Float:\n",
    "    return reversed_spherical_harmonic(sph, m=1, n=2) + reversed_spherical_harmonic(rotate(sph), m=1, n=1)\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def car_to_sph(car):\n",
    "    x, y, z = car[..., 0], car[..., 1], car[..., 2]\n",
    "    colat = jnp.arccos(z)\n",
    "    lon = jnp.arctan2(y, x)\n",
    "    return jnp.stack([colat, lon], axis=-1)\n",
    "\n",
    "\n",
    "def add_noise(f: Float[Array, \" N\"], noise_std: float = 0.01, *, key: Key) -> Float:\n",
    "    return f + jax.random.normal(key=key, shape=f.shape) * noise_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.key(0)\n",
    "\n",
    "\n",
    "x = jnp.asarray(pd.read_csv('../std_inputs.csv', header=None, names=['x', 'y', 'z']).values)\n",
    "f = target_f__reversed_spherical_harmonic(car_to_sph(x))\n",
    "y = add_noise(f, noise_std=0.01, key=key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure() \n",
    "fig.add_trace(\n",
    "    go.Scatter3d(\n",
    "        x=x[:, 0], \n",
    "        y=x[:, 1], \n",
    "        z=x[:, 2], \n",
    "        mode='markers', \n",
    "        marker=dict(\n",
    "            color=-y, \n",
    "            size=3,\n",
    "            colorscale='magma',\n",
    "        ),\n",
    "    ), \n",
    ")\n",
    "\n",
    "fig.update_layout(\n",
    "    scene=dict(\n",
    "        xaxis=dict(visible=False),\n",
    "        yaxis=dict(visible=False),\n",
    "        zaxis=dict(visible=False),\n",
    "    ),\n",
    "    width=1300,\n",
    "    height=600,\n",
    "    showlegend=False,\n",
    ")\n",
    "fig.write_image(\"synthetic_target_function.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the data as csv using the names of the variables\n",
    "data = [\n",
    "    x, \n",
    "    -y, \n",
    "]\n",
    "\n",
    "names = [\n",
    "    'synthetic_y-inputs',\n",
    "    'synthetic_y-outputs', \n",
    "]\n",
    "\n",
    "\n",
    "for datum, name in zip(data, names):\n",
    "    pd.DataFrame(datum).to_csv(f\"{name}.csv\", header=False, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdgp-jax2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
