{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "14756656",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import sys\n",
    "import os\n",
    "import jax\n",
    "\n",
    "# os.environ['JAX_PLATFORMS'] = 'cpu'\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '2'\n",
    "os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'\n",
    "jax.config.update(\"jax_enable_x64\", True)\n",
    "warnings.filterwarnings('ignore')\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f6354ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.experimental.sparse as jsparse\n",
    "import jax.numpy as jnp\n",
    "from jax import device_put, random\n",
    "from jax.lax import scan\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from functools import partial\n",
    "from time import perf_counter\n",
    "import ilupp\n",
    "\n",
    "from data.dataset import dataset_qtt\n",
    "from linsolve.scipy_linsolve import batched_cg_scipy, make_Chol_prec\n",
    "from utils import iter_per_residual, jBCOO_to_scipyCSR\n",
    "\n",
    "plt.rcParams['figure.figsize'] = (11, 7)\n",
    "plt.rcParams['font.size'] = 20\n",
    "plt.rcParams[\"lines.linewidth\"] = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0431b5a",
   "metadata": {},
   "source": [
    "# Benchmark dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35a65435",
   "metadata": {},
   "outputs": [],
   "source": [
    "def scipy_benchmark_linsys(A, b, cg_iter, pcg_iter, atol):\n",
    "    L0, Lt1, Lt5 = [], [], []\n",
    "    L0_time, Lt1_time, Lt5_time = [], [], []\n",
    "    Annz, L0nnz, Lt1nnz, Lt5nnz = [], [], [], []\n",
    "    \n",
    "    for i in range(A.shape[0]):\n",
    "        A_i = jBCOO_to_scipyCSR(A[i, ...])\n",
    "        Annz.append((A_i.nnz*100) / (A_i.shape[-1] ** 2))\n",
    "        \n",
    "        # IC(0)\n",
    "        s = perf_counter()\n",
    "        L0_i = ilupp.ichol0(A_i)\n",
    "        L0_time.append(perf_counter() - s)\n",
    "        L0.append(L0_i)\n",
    "        L0nnz.append((L0_i.nnz*100) / (L0_i.shape[-1] ** 2))\n",
    "\n",
    "        # ICt(1)\n",
    "        s = perf_counter()\n",
    "        Lt1_i = ilupp.icholt(A_i, add_fill_in=1, threshold=1e-4)\n",
    "        Lt1_time.append(perf_counter() - s)\n",
    "        Lt1.append(Lt1_i)\n",
    "        Lt1nnz.append((Lt1_i.nnz*100) / (Lt1_i.shape[-1] ** 2))\n",
    "        \n",
    "        # ICt(5)\n",
    "        s = perf_counter()\n",
    "        Lt5_i = ilupp.icholt(A_i, add_fill_in=5, threshold=1e-4)\n",
    "        Lt5_time.append(perf_counter() - s)\n",
    "        Lt5.append(Lt5_i)\n",
    "        Lt5nnz.append((Lt5_i.nnz*100) / (Lt5_i.shape[-1] ** 2))\n",
    "    \n",
    "    ## Save precs props\n",
    "    # Average time\n",
    "    comb_L0_time = [np.mean(L0_time), np.std(L0_time)]\n",
    "    comb_Lt1_time = [np.mean(Lt1_time), np.std(Lt1_time)]\n",
    "    comb_Lt5_time  = [np.mean(Lt5_time), np.std(Lt5_time)]\n",
    "    \n",
    "    # Average nnz\n",
    "    Annz = [np.mean(Annz), np.std(Annz)]\n",
    "    L0nnz = [np.mean(L0nnz), np.std(L0nnz)]\n",
    "    Lt1nnz = [np.mean(Lt1nnz), np.std(Lt1nnz)]\n",
    "    Lt5nnz = [np.mean(Lt5nnz), np.std(Lt5nnz)]\n",
    "    \n",
    "    ## Assemble precs\n",
    "    P_L0 = make_Chol_prec(L0)\n",
    "    P_Lt1 = make_Chol_prec(Lt1)\n",
    "    P_Lt5 = make_Chol_prec(Lt5)\n",
    "    print('  Precs are combined')\n",
    "        \n",
    "    ## Run PCG\n",
    "    # I\n",
    "    if isinstance(cg_iter, int):\n",
    "        _, I_iters_mean, I_iters_std, I_time_mean, I_time_std = batched_cg_scipy(A, b, P=None, atol=atol, maxiter=cg_iter)\n",
    "        print('  I is done')\n",
    "    else:\n",
    "        I_iters_mean, I_iters_std, I_time_mean, I_time_std = [-1]*4, [-1]*4, [-1]*4, [-1]*4\n",
    "        print('  I is skipped')\n",
    "\n",
    "    # L0\n",
    "    _, L0_iters_mean, L0_iters_std, L0_time_mean, L0_time_std = batched_cg_scipy(A, b, P=P_L0, atol=atol, maxiter=pcg_iter)\n",
    "    print('  L0 is done')\n",
    "\n",
    "    # Lt1\n",
    "    _, Lt1_iters_mean, Lt1_iters_std, Lt1_time_mean, Lt1_time_std = batched_cg_scipy(A, b, P=P_Lt1, atol=atol, maxiter=pcg_iter)\n",
    "    print('  Lt1 is done')\n",
    "\n",
    "    # Lt5\n",
    "    _, Lt5_iters_mean, Lt5_iters_std, Lt5_time_mean, Lt5_time_std = batched_cg_scipy(A, b, P=P_Lt5, atol=atol, maxiter=pcg_iter)\n",
    "    print('  Lt5 is done')\n",
    "    \n",
    "    out = (\n",
    "        (Annz, I_iters_mean, I_iters_std, I_time_mean, I_time_std, [-1, -1]),\n",
    "        (L0nnz, L0_iters_mean, L0_iters_std, L0_time_mean, L0_time_std, comb_L0_time),\n",
    "        (Lt1nnz, Lt1_iters_mean, Lt1_iters_std, Lt1_time_mean, Lt1_time_std, comb_Lt1_time),\n",
    "        (Lt5nnz, Lt5_iters_mean, Lt5_iters_std, Lt5_time_mean, Lt5_time_std, comb_Lt5_time),        \n",
    "    )\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6741e9e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_results(prec_name, results, tol_ls=[1e-3, 1e-6, 1e-9, 1e-12]):\n",
    "    nnz, res_mean, res_std = results[0], results[1], results[2]\n",
    "    cg_time_mean, cg_time_std, prec_time = results[3], results[4], results[5]\n",
    "    \n",
    "    print(f'{prec_name}, nnz = {nnz[0]:.3f}±{nnz[1]:.3f}, prec time = {prec_time[0]:6f}±{prec_time[1]:6f}')\n",
    "    print('Residuals:   ', end='')\n",
    "    for i, t in enumerate(tol_ls):\n",
    "        mean_t, std_t = res_mean[i], res_std[i]\n",
    "        print(f'{t}: {mean_t:.0f}±{std_t:.2f}', end='; ')\n",
    "    print()\n",
    "    \n",
    "    print('Time:        ', end='')\n",
    "    for i, t in enumerate(tol_ls):\n",
    "        mean_t, std_t = cg_time_mean[i], cg_time_std[i]\n",
    "        print(f'{t}: {mean_t:.4f}±{std_t:.5f}', end='; ')\n",
    "    print(end='\\n\\n')\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c048a06e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def scipy_benchmark_all(params_grid, N, cg_iter, pcg_iter, atol):\n",
    "    for p in params_grid:\n",
    "        print(p)\n",
    "        A, _, b, _, _ = dataset_qtt(pde=p[0], grid=p[1], variance=p[2], lhs_type='fd', return_train=False, N_samples=N, precision='f64')\n",
    "        out = scipy_benchmark_linsys(A, b, cg_iter=cg_iter, pcg_iter=pcg_iter, atol=atol)\n",
    "        for i, name in enumerate(['I', 'L0', 'Lt1', 'Lt5']):\n",
    "            print_results(name, out[i])\n",
    "        print('------------------------------------------------------------------------------------------------\\n')\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2083616d",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_grid = [\n",
    "    ['poisson',    32,  .1],\n",
    "    ['poisson',    64,  .1],\n",
    "    ['poisson',    128, .1],\n",
    "# \n",
    "    ['div_k_grad', 32,  .1],\n",
    "    ['div_k_grad', 64,  .1],\n",
    "    ['div_k_grad', 128, .1],\n",
    "#     \n",
    "    ['div_k_grad', 32,  .5],\n",
    "    ['div_k_grad', 64,  .5],\n",
    "    ['div_k_grad', 128, .5],\n",
    "#\n",
    "    ['div_k_grad', 32,  .7],\n",
    "    ['div_k_grad', 64,  .7],\n",
    "    ['div_k_grad', 128, .7]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62811049",
   "metadata": {},
   "outputs": [],
   "source": [
    "scipy_benchmark_all(params_grid, N=200, cg_iter=None, pcg_iter=350, atol=0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
