{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Approximate the true solution (supervised setting) using linear regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "sys.path.append('../../src/')\n",
    "from swimpde import Domain\n",
    "from swimpde import BasicAnsatz\n",
    "from swimpde import Diffusion_Solver\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import matplotlib.cm as cm\n",
    "import time\n",
    "cmap = cm.jet\n",
    "from examples.utils import *\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "#from exautils import *\n",
    "\n",
    "# Set seeds\n",
    "np.random.seed(2)\n",
    "rng = np.random.default_rng(seed=123)\n",
    "print(sys.path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and test boundary points\n",
    "d = 100  # Dimensions\n",
    "n_b_train = 1000  # Number of samples\n",
    "n_b_test = 2000  # Number of samples\n",
    "n_int_train = 1000  # Number of samples\n",
    "n_int_test = 8000  # 16000 Number of samples\n",
    "\n",
    "# Train and test boundary points\n",
    "X_b_train = sample_boundary_lhs_ball(d, n_b_train)\n",
    "X_b_test = sample_boundary_lhs_ball(d, n_b_test)\n",
    "\n",
    "# Train and test interior points\n",
    "X_int_train = sample_interior_lhs_ball(d, n_int_train)\n",
    "X_int_test = sample_interior_lhs_ball(d, n_int_test)\n",
    "\n",
    "X_test = np.vstack((X_int_test, X_b_test))\n",
    "\n",
    "# Uncomment to plot interior points\n",
    "plot_interior_points(X_int_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "# initial condition\n",
    "def u0(x):\n",
    "    space_dim = np.shape(x)[1]\n",
    "    return np.sum(x**2, axis=1, keepdims=True)/(2. * space_dim)\n",
    "\n",
    "# forcing\n",
    "def forcing(x, t):    \n",
    "    return 0.\n",
    "\n",
    "# boundary condition\n",
    "boundary_condition = \"dirichlet\"\n",
    "\n",
    "# Analytical solution\n",
    "def analytical_sol(x, t):\n",
    "    space_dim = np.shape(x)[1]\n",
    "    return t + np.sum(x**2, axis=1, keepdims=True)/(2. * space_dim)\n",
    "    \n",
    "# Test data\n",
    "t_eval = np.linspace(0, 1, 100).reshape(-1, 1, 1) # time domain\n",
    "x_train = X_int_train # space domain\n",
    "\n",
    "u_true =  analytical_sol(X_test, t_eval)\n",
    "u_true = np.reshape(u_true, (np.shape(u_true)[0], np.shape(u_true)[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "# 1. True function values (norm of points)\n",
    "y_train = analytical_sol(X_int_train, t=0)\n",
    "y_test = analytical_sol(X_test, t=0)\n",
    "\n",
    "# 2. Fit linear model (first-order polynomial)\n",
    "model = LinearRegression()\n",
    "model.fit(X_int_train, y_train)\n",
    "\n",
    "# 3. Predict and evaluate\n",
    "y_pred = model.predict(X_test)\n",
    "rel_rmse = np.sqrt(mean_squared_error(y_test, y_pred))/np.sqrt(mean_squared_error(y_test, np.zeros_like(y_test)))\n",
    "print(f\"Test MSE: {rel_rmse}\")\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
