{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c18b0eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import sys\n",
    "import os\n",
    "import jax\n",
    "\n",
    "# jax.config.update(\"jax_enable_x64\", True)\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '3'\n",
    "os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b0bfc0e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from jax import random, vmap\n",
    "from jax.experimental import sparse as jsparse\n",
    "from jax.scipy.sparse.linalg import cg as jcg\n",
    "\n",
    "import numpy as np\n",
    "from scipy.sparse.linalg import LinearOperator, cg\n",
    "from scipy.sparse import tril, triu, eye\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from functools import partial\n",
    "from time import perf_counter\n",
    "import ilupp\n",
    "\n",
    "from linsolve.precond import llt_prec_trig_solve, llt_inv_prec\n",
    "from utils import factorsILUp, batchedILUp, ILU2, jILU2, iter_per_residual, jBCOO_to_scipyCSR\n",
    "from linsolve.cg import ConjGrad\n",
    "from data.qtt import poisson, div_k_grad, scipy_validation, solve_precChol, solve_precLU"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50213e7a",
   "metadata": {},
   "source": [
    "## Poisson dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8ba9ffa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '..'\n",
    "try:\n",
    "    os.mkdir(os.path.join(path, 'paper_datasets'))\n",
    "except:\n",
    "    pass\n",
    "\n",
    "save_dir = os.path.join(path, 'paper_datasets')\n",
    "grid_ls = [32, 64, 128]\n",
    "N_train = 1000\n",
    "N_test = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c553f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Poisson\n",
    "\n",
    "for g in [512, 1024]:#grid_ls:\n",
    "    A_train, b_train, x_train = poisson(N_train, g)\n",
    "    A_test, b_test, x_test = poisson(N_test, g)\n",
    "    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_train.npz'), Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train)\n",
    "    jnp.savez(os.path.join(save_dir, 'poisson'+str(g)+'_test.npz'), Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1d9f0875",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Div-k-grad\n",
    "boundaries_ls = [\n",
    "    [[5, 11], [35, 180], [180, 700]],\n",
    "    [[5, 12], [45, 200], [200, 750]],\n",
    "    [[6, 14], [50, 300], [300, 800]],\n",
    "]\n",
    "var_ls = [0.1, 0.5, 0.7]\n",
    "\n",
    "for g, bound_ls in zip(grid_ls, boundaries_ls):\n",
    "    for var_i, b_ in zip(var_ls, bound_ls):\n",
    "        A_train, b_train, x_train, k_stats_train = div_k_grad(N_train, g, bounds=b_, cov_model='gaussian', var=var_i)\n",
    "        A_test, b_test, x_test, k_stats_test = div_k_grad(N_test, g, bounds=b_, cov_model='gaussian', var=var_i)\n",
    "        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_train.npz'),\n",
    "                  Aval=A_train.data, Aind=A_train.indices, b=b_train, x=x_train,\n",
    "                  kmin=k_stats_train['min'], kmax=k_stats_train['max'], kmean=k_stats_train['mean'])\n",
    "        jnp.savez(os.path.join(save_dir, 'div_k_grad'+str(g)+'_Gauss'+str(var_i)+'_test.npz'),\n",
    "                  Aval=A_test.data, Aind=A_test.indices, b=b_test, x=x_test,\n",
    "                  kmin=k_stats_test['min'], kmax=k_stats_test['max'], kmean=k_stats_test['mean'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc892091",
   "metadata": {},
   "outputs": [],
   "source": [
    "ls32, ls64, ls128, ls256 = [], [], [], []\n",
    "for f in os.listdir(save_dir):\n",
    "    if os.path.isdir(f):\n",
    "        continue\n",
    "    if 'div_k_grad32' in f:\n",
    "        ls32.append(os.path.join(save_dir, f))\n",
    "    elif 'div_k_grad64' in f:\n",
    "        ls64.append(os.path.join(save_dir, f))\n",
    "    elif 'div_k_grad128' in f:\n",
    "        ls128.append(os.path.join(save_dir, f))\n",
    "    elif 'div_k_grad256' in f:\n",
    "        ls256.append(os.path.join(save_dir, f))\n",
    "        \n",
    "ls32.sort()\n",
    "ls64.sort()\n",
    "ls128.sort()\n",
    "ls256.sort()\n",
    "\n",
    "for d in ls32:\n",
    "    a = jnp.load(os.path.join(save_dir, d))\n",
    "    if 'train' in d:\n",
    "        print(f\"Grid 32, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")\n",
    "    elif 'test' in d:\n",
    "        print(f\"Grid 32, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")\n",
    "        \n",
    "print()\n",
    "for d in ls64:\n",
    "    a = jnp.load(os.path.join(save_dir, d))\n",
    "    if 'train' in d:\n",
    "        print(f\"Grid 64, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")\n",
    "    elif 'test' in d:\n",
    "        print(f\"Grid 64, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")\n",
    "\n",
    "print()\n",
    "for d in ls128:\n",
    "    a = jnp.load(os.path.join(save_dir, d))\n",
    "    if 'train' in d:\n",
    "        print(f\"Grid 128, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")\n",
    "    elif 'test' in d:\n",
    "        print(f\"Grid 128, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")\n",
    "        \n",
    "print()\n",
    "for d in ls256:\n",
    "    a = jnp.load(os.path.join(save_dir, d))\n",
    "    if 'train' in d:\n",
    "        print(f\"Grid 256, train, var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")\n",
    "    elif 'test' in d:\n",
    "        print(f\"Grid 256, test,  var {d.split('Gauss')[1][:3]}. Min k: {a['kmin']:.0f}, max k: {a['kmax']:.0f}, mean k: {a['kmean']:.0f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87ecfc27",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
