{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sklearn.svm import SVR\n",
    "from sklearn.gaussian_process import GaussianProcessRegressor\n",
    "\n",
    "from aiau.data.data_manager import DataManager\n",
    "from aiau.oracles import Oracle, NoisyBenchmarkOracle, aleotoric_and_epistemic_noise_function, CorrelatedNoiseOracle\n",
    "from aiau.data.synthetic_datasets import generate_2d_synthetic_data, generate_non_regular_2d_synthetic_data, generate_freq_power_2d_synthetic_data\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem 1: Regular Periodic field\n",
    "\n",
    "$$y = \\sin(1.5x_1) * \\sin(1.5x_2)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dataset\n",
    "X, y = generate_2d_synthetic_data(x1_size= 50, x2_size=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the dataset\n",
    "plt.scatter(X[:,0], X[:,1], c=y)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = np.arange(X.shape[0])\n",
    "initially_labelled_indices = np.random.choice(indices, 25, replace=False)\n",
    "\n",
    "# Create a DataManager object, an oracle and initialise the datamanager\n",
    "dm = DataManager(indices = indices,\n",
    "                 observations = X,\n",
    "                 targets = y,\n",
    "                 initially_labelled_indices=initially_labelled_indices)\n",
    "\n",
    "# Plot the data in the DataManager\n",
    "plt.scatter(dm.full_X[:,0], dm.full_X[:,1], c=dm.full_y)\n",
    "plt.scatter(dm.full_X[[dm.initially_labelled_indices],0], dm.full_X[[dm.initially_labelled_indices],1], c='red', marker='x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aiau.data.index_initialisation_utils import random_indices, lower_left_corner_indices_2d\n",
    "indices = np.arange(X.shape[0])\n",
    "initially_labelled_indices = lower_left_corner_indices_2d(X, 100, lower_bound=0, upper_bound=np.pi)\n",
    "\n",
    "# Create a DataManager object, an oracle and initialise the datamanager\n",
    "dm = DataManager(indices = indices,\n",
    "                 observations = X,\n",
    "                 targets = y,\n",
    "                 initially_labelled_indices=initially_labelled_indices)\n",
    "\n",
    "# Plot the data in the DataManager\n",
    "plt.scatter(dm.full_X[:,0], dm.full_X[:,1], c=dm.full_y)\n",
    "plt.scatter(dm.full_X[[dm.initially_labelled_indices],0], dm.full_X[[dm.initially_labelled_indices],1], c='red', marker='x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a 3D plot of the data\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "ax.scatter(dm.full_X[:,0], dm.full_X[:,1], dm.full_y, c=dm.full_y)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem 2: Proportionally modulated 2D field\n",
    "\n",
    "$$y = \\sin(m_1 * \\sqrt{2\\pi - x_1}) * \\sin (m_2 * \\sqrt{2\\pi - x_2})$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X1_MOD = 2\n",
    "X2_MOD = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = generate_non_regular_2d_synthetic_data(x1_mod=X1_MOD, x2_mod=X2_MOD, x1_size=50, x2_size=50)\n",
    "\n",
    "indices = np.arange(X.shape[0])\n",
    "initially_labelled_indices = lower_left_corner_indices_2d(X, 100, lower_bound=0, upper_bound=np.pi)\n",
    "\n",
    "# Create a DataManager object, an oracle and initialise the datamanager\n",
    "dm = DataManager(indices = indices,\n",
    "                 observations = X,\n",
    "                 targets = y,\n",
    "                 initially_labelled_indices=initially_labelled_indices)\n",
    "\n",
    "# Plot the data in the DataManager\n",
    "plt.scatter(dm.full_X[:,0], dm.full_X[:,1], c=dm.full_y)\n",
    "plt.scatter(dm.full_X[[dm.initially_labelled_indices],0], dm.full_X[[dm.initially_labelled_indices],1], c='red', marker='x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a 3D plot of the data\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "ax.scatter(dm.full_X[:,0], dm.full_X[:,1], dm.full_y, c=dm.full_y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note the issue around $x = 2\\pi$ because of the division of 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Frequency and Power modulated field\n",
    "\n",
    "    \n",
    "$$y = cos( f * \\frac{x_{1}^{p}}{(2\\pi)^{p-1}} ) * cos( f * \\frac{(x_{2}^{p})} {(2\\pi)^{p-1}} )$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "POWER = 4\n",
    "FREQ = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = generate_freq_power_2d_synthetic_data(freq=FREQ, power=POWER, x1_size=50, x2_size=50)\n",
    "\n",
    "indices = np.arange(X.shape[0])\n",
    "initially_labelled_indices = lower_left_corner_indices_2d(X, 100, lower_bound=0, upper_bound=np.pi)\n",
    "\n",
    "# Create a DataManager object, an oracle and initialise the datamanager\n",
    "dm = DataManager(indices = indices,\n",
    "                 observations = X,\n",
    "                 targets = y,\n",
    "                 initially_labelled_indices=initially_labelled_indices)\n",
    "\n",
    "# Plot the data in the DataManager\n",
    "plt.scatter(dm.full_X[:,0], dm.full_X[:,1], c=dm.full_y)\n",
    "plt.scatter(dm.full_X[[dm.initially_labelled_indices],0], dm.full_X[[dm.initially_labelled_indices],1], c='red', marker='x')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a 3D plot of the data\n",
    "fig = plt.figure()\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "ax.scatter(dm.full_X[:,0], dm.full_X[:,1], dm.full_y, c=dm.full_y)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "noisyal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
