{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac22e2a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "import sys\n",
    "import os\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '3'\n",
    "os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "941bf314",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "from jax import random, vmap, clear_caches, jit\n",
    "import numpy as np\n",
    "\n",
    "import optax\n",
    "from equinox.nn import Conv1d\n",
    "import matplotlib.pyplot as plt\n",
    "from functools import partial\n",
    "from time import perf_counter\n",
    "import cloudpickle\n",
    "\n",
    "from data.dataset import dataset_qtt\n",
    "from linsolve.cg import ConjGrad\n",
    "from linsolve.precond import llt_prec_trig_solve, llt_inv_prec\n",
    "from model import MessagePassing, FullyConnectedNet, PrecNet, ConstantConv1d, MessagePassingWithDot, CorrectionNet\n",
    "\n",
    "from utils import params_count, asses_cond, iter_per_residual, batch_indices\n",
    "from data.graph_utils import direc_graph_from_linear_system_sparse\n",
    "from train import train\n",
    "\n",
    "plt.rcParams['figure.figsize'] = (11, 7)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a71ccae",
   "metadata": {},
   "source": [
    "# Train/retrain/overwrite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "139424da",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = '..'\n",
    "model_name = '..'\n",
    "train_from_scratch = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ec2d85a",
   "metadata": {},
   "source": [
    "# Setup experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "70273935",
   "metadata": {},
   "outputs": [],
   "source": [
    "pde = 'poisson'   \n",
    "grid = 128    \n",
    "variance = .5 \n",
    "lhs_type = 'fd'      \n",
    "N_train = 1000\n",
    "N_test = 200\n",
    "precision = 'f64'\n",
    "\n",
    "fill_factor = 1     # int\n",
    "threshold = 1e-4     # float\n",
    "power = 2            # int\n",
    "N_valid_CG = 300     # Number of CG iterations for validation in the very end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "65c61b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "with_cond = False\n",
    "layer_ = Conv1d\n",
    "# layer_ = ConstantConv1d         # 'ConstantConv1d' to make a \"zero\" NN initialization; 'Conv1d' to make a random initialization\n",
    "alpha = jnp.array([0.])\n",
    "\n",
    "loss_type = 'llt'               "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "22f31860",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 8\n",
    "epoch_num = 1000\n",
    "lr = 1e-3\n",
    "schedule_params = None #[1700, 2001, 300, 1e-1]    # [start, stop, step, decay_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "c64831c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "if schedule_params != None:\n",
    "    assert len(schedule_params) == 4\n",
    "    \n",
    "    start, stop, step, decay_size = schedule_params\n",
    "    steps_per_batch = N_train // batch_size\n",
    "    start, stop, step = start*steps_per_batch, stop*steps_per_batch, step*steps_per_batch\n",
    "    lr = optax.piecewise_constant_schedule(\n",
    "        lr,\n",
    "        {k: v for k, v in zip(np.arange(start, stop, step), [decay_size, ] * len(jnp.arange(start, stop, step)))}\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "f3edae9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = {\n",
    "    'node_enc': {\n",
    "        'features': [1, 16, 16],\n",
    "        'N_layers': 2,\n",
    "        'layer_': layer_\n",
    "    },\n",
    "    'edge_enc': {\n",
    "        'features': [1, 16, 16],\n",
    "        'N_layers': 2,\n",
    "        'layer_': layer_\n",
    "    },\n",
    "    'edge_dec': {\n",
    "        'features': [16, 16, 1],\n",
    "        'N_layers': 2,\n",
    "        'layer_': layer_\n",
    "    },\n",
    "    'mp': {\n",
    "        'edge_upd': {\n",
    "            'features': [48, 16, 16],\n",
    "            'N_layers': 2,\n",
    "            'layer_': layer_\n",
    "        },\n",
    "        'node_upd': {\n",
    "            'features': [32, 16, 16],\n",
    "            'N_layers': 2,\n",
    "            'layer_': layer_\n",
    "        },\n",
    "        'mp_rounds': 5\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "519031e8",
   "metadata": {},
   "source": [
    "# Make dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a1c444c",
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = perf_counter()\n",
    "A_train, A_pad_train, b_train, u_exact_train, bi_edges_train = dataset_qtt(pde, grid, variance, lhs_type,\n",
    "                                                                           return_train=True, N_samples=N_train,\n",
    "                                                                           fill_factor=fill_factor, threshold=threshold,\n",
    "                                                                           power=power, precision=precision)\n",
    "A_test, A_pad_test, b_test, u_exact_test, bi_edges_test = dataset_qtt(pde, grid, variance, lhs_type,\n",
    "                                                                      return_train=False, N_samples=N_test,\n",
    "                                                                      fill_factor=fill_factor, threshold=threshold,\n",
    "                                                                      power=power, precision=precision)\n",
    "print(perf_counter() - s1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f315d8a6",
   "metadata": {},
   "source": [
    "# Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a0b6c6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "NodeEncoder = FullyConnectedNet(**model_config['node_enc'], key=random.PRNGKey(seed))\n",
    "EdgeEncoder = FullyConnectedNet(**model_config['edge_enc'], key=random.PRNGKey(seed))\n",
    "EdgeDecoder = FullyConnectedNet(**model_config['edge_dec'], key=random.PRNGKey(seed))\n",
    "\n",
    "mp_rounds = 5\n",
    "MessagePass = MessagePassing(\n",
    "    update_edge_fn = FullyConnectedNet(**model_config['mp']['edge_upd'], key=random.PRNGKey(seed)),    \n",
    "    update_node_fn = FullyConnectedNet(**model_config['mp']['node_upd'], key=random.PRNGKey(seed)),\n",
    "    mp_rounds=model_config['mp']['mp_rounds']\n",
    ")\n",
    "\n",
    "model = PrecNet(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, \n",
    "                EdgeDecoder=EdgeDecoder, MessagePass=MessagePass)\n",
    "\n",
    "# model = CorrectionNet(NodeEncoder=NodeEncoder, EdgeEncoder=EdgeEncoder, \n",
    "#                 EdgeDecoder=EdgeDecoder, MessagePass=MessagePass, alpha=alpha)\n",
    "\n",
    "print(f'Parameter number: {params_count(model)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "394d3287",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data = (X_train, X_test, y_train, y_test)\n",
    "data = (\n",
    "    [A_train, A_pad_train, b_train, bi_edges_train, u_exact_train],\n",
    "    [A_test, A_pad_test, b_test, bi_edges_test, u_exact_test],\n",
    "    jnp.array([1]), jnp.array([1])\n",
    ")\n",
    "train_config = {\n",
    "    'optimizer': optax.adam,\n",
    "    'lr': lr,\n",
    "    'optim_params': {},#{'weight_decay': 1e-8}, \n",
    "    'epoch_num': epoch_num,\n",
    "    'batch_size': batch_size,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "63e23084",
   "metadata": {},
   "outputs": [],
   "source": [
    "if train_from_scratch:\n",
    "    s = perf_counter()\n",
    "    model, losses = train(model, data, train_config, loss_name=loss_type, repeat_step=1, with_cond=with_cond)\n",
    "    dt = perf_counter() - s\n",
    "    \n",
    "    with open(save_path + model_name + '.pkl', 'wb') as f:\n",
    "        cloudpickle.dump(model, f)\n",
    "else:\n",
    "    with open(save_path + model_name + '.pkl', 'rb') as f:\n",
    "        model = cloudpickle.load(f)\n",
    "    losses, dt = [np.nan], np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb27c1eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "dt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8c8c1aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('alpha:', end='')\n",
    "model.alpha"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7a4ee1a",
   "metadata": {},
   "source": [
    "## Forward call time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "72d2b4c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_forward(model, A_i, b_i, bi_edges_i, num_rounds=200):\n",
    "    nodes_i, edges_i, receivers_i, senders_i, _ = direc_graph_from_linear_system_sparse(A_i[None, ...], b_i[None, ...])\n",
    "    jit_model = jit(lambda t1, t2, t3, t4, t5: model((t1, t2, t3, t4), t5))\n",
    "    L = jit_model(nodes_i[0, ...], edges_i[0, ...], receivers_i[0, ...], senders_i[0, ...], bi_edges_i)\n",
    "    \n",
    "    t_ls = []\n",
    "    for _ in range(num_rounds):\n",
    "        s = perf_counter()\n",
    "        L = jit_model(nodes_i[0, ...], edges_i[0, ...], receivers_i[0, ...], senders_i[0, ...], bi_edges_i)\n",
    "        t_ls.append(perf_counter() - s)\n",
    "    return np.mean(t_ls), np.std(t_ls)\n",
    "\n",
    "def calc_forward_PrecNet(model, A_i, A_lhs_i, b_i, bi_edges_i, num_rounds=200):\n",
    "    nodes_i, edges_i, receivers_i, senders_i, _ = direc_graph_from_linear_system_sparse(A_i[None, ...], b_i[None, ...])\n",
    "    lhs_nodes_i, lhs_edges_i, lhs_receivers_i, lhs_senders_i, _ = direc_graph_from_linear_system_sparse(A_lhs_i[None, ...], b_i[None, ...])\n",
    "    \n",
    "    jit_model = jit(lambda t1, t2, t3, t4, t5, t6, t7, t8, t9: model((t1, t2, t3, t4), t5, (t6, t7, t8, t9)))\n",
    "    L = jit_model(nodes_i[0, ...], edges_i[0, ...], receivers_i[0, ...], senders_i[0, ...], bi_edges_i,\n",
    "                  lhs_nodes_i[0, ...], lhs_edges_i[0, ...], lhs_receivers_i[0, ...], lhs_senders_i[0, ...])\n",
    "    \n",
    "    t_ls = []\n",
    "    for _ in range(num_rounds):\n",
    "        s = perf_counter()\n",
    "        L = jit_model(nodes_i[0, ...], edges_i[0, ...], receivers_i[0, ...], senders_i[0, ...], bi_edges_i,\n",
    "                      lhs_nodes_i[0, ...], lhs_edges_i[0, ...], lhs_receivers_i[0, ...], lhs_senders_i[0, ...])\n",
    "        t_ls.append(perf_counter() - s)\n",
    "    return np.mean(t_ls), np.std(t_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "3a0210ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "# mean_, std_ = calc_forward(model, A_pad_test[i, ...], b_test[i, ...], bi_edges_test[i, ...], num_rounds=200)\n",
    "mean_, std_ = calc_forward_PrecNet(model, A_pad_test[i, ...], A_test[i, ...], b_test[i, ...], bi_edges_test[i, ...], num_rounds=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a728bb25",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Grid = {grid}, lhs_type = {lhs_type}. Average over 200 forward calls {mean_:.6f}±{std_:.6f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7c02c19",
   "metadata": {},
   "source": [
    "## Make precs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "23575eb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes, edges, receivers, senders, _ = direc_graph_from_linear_system_sparse(A_pad_test, b_test)\n",
    "# lhs_nodes, lhs_edges, lhs_receivers, lhs_senders, _ = direc_graph_from_linear_system_sparse(A_test, b_test)\n",
    "\n",
    "L = vmap(model, in_axes=((0, 0, 0, 0), 0), out_axes=(0))((nodes, edges, receivers, senders), bi_edges_test)#, (lhs_nodes, lhs_edges, lhs_receivers, lhs_senders))\n",
    "# del model, data, A_train, A_pad_train, b_train, u_exact_train, bi_edges_train, bi_edges_test\n",
    "# clear_caches()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34acf438",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, axes = plt.subplots(1, 2, figsize=(15, 5))\n",
    "\n",
    "axes[0].plot(range(len(losses[0])), losses[1], label='Test')\n",
    "axes[0].plot(range(len(losses[0])), losses[0], label='Train')\n",
    "axes[0].legend()\n",
    "axes[0].set_yscale('log')\n",
    "axes[0].set_xlabel('Epoch')\n",
    "axes[0].set_ylabel('Loss');\n",
    "axes[0].grid();\n",
    "\n",
    "axes[1].plot(range(len(losses[0])), losses[2], label='Test')\n",
    "axes[1].legend()\n",
    "axes[1].set_yscale('log')\n",
    "axes[1].set_xlabel('Epoch')\n",
    "axes[1].set_ylabel('Cond of $P^{-1}A$')\n",
    "axes[1].grid();\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "print(f'Final values\\n  train loss: {losses[0][-1]:.4f}\\n   test loss: {losses[1][-1]:.4f}\\n    LLT cond: {losses[2][-1]:.0f}')\n",
    "print(f'\\nMinimim test loss `{jnp.min(losses[1]).item():.4f}` at epoch `{jnp.argmin(losses[1]).item():.0f}`')\n",
    "print(f'\\nMinimim test P^(-1)A cond `{jnp.min(losses[2]).item():.0f}` at epoch `{jnp.argmin(losses[2]).item():.0f}`')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbc5a1a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "losses[1][500], losses[1][-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c57af87d",
   "metadata": {},
   "source": [
    "# Apply model to CG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "3f889929",
   "metadata": {},
   "outputs": [],
   "source": [
    "from linsolve.scipy_linsolve import batched_cg_scipy, make_Chol_prec_from_bcoo, cg_scipy\n",
    "from utils import jBCOO_to_scipyCSR\n",
    "\n",
    "import scipy.sparse.linalg as sci_sp_linalg\n",
    "import scipy.linalg as sci_linalg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "09051e58",
   "metadata": {},
   "outputs": [],
   "source": [
    "P = make_Chol_prec_from_bcoo(L)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7f06b64b",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, iters_mean, iters_std, time_mean, time_std = batched_cg_scipy(A_test, b_test, P=P, atol=1e-12, maxiter=N_valid_CG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "918e8e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "iters_mean"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7933a29c",
   "metadata": {},
   "source": [
    "# Spectrum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8accb81a",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "i = 0\n",
    "\n",
    "print(i)\n",
    "A_i = jBCOO_to_scipyCSR(A_test[i, ...]).todense(order='C')\n",
    "L_i = jBCOO_to_scipyCSR(L[i, ...])\n",
    "P_i = (L_i @ L_i.T).todense(order='C')\n",
    "print(' A, L and P are combined')\n",
    "\n",
    "Pinv_i = sci_linalg.inv(P_i)\n",
    "Ainv_i = sci_linalg.inv(A_i)\n",
    "print(' A and P are inverted')\n",
    "\n",
    "Pinv_A_i = Pinv_i @ A_i\n",
    "\n",
    "eigen_i = np.abs(sci_linalg.eigvals(Pinv_A_i))\n",
    "eigen_min_i = np.min(eigen_i)\n",
    "eigen_max_i = np.max(eigen_i)\n",
    "cond_i = eigen_max_i / eigen_min_i\n",
    "print(' Eigenvalues are calculated')\n",
    "\n",
    "\n",
    "P_Ainv_i = P_i @ Ainv_i\n",
    "sqrt_loss = np.linalg.norm(P_Ainv_i - np.eye(P_i.shape[0]), ord='fro')\n",
    "loss_i = np.square(sqrt_loss)\n",
    "min_bound_i = 1 / (np.linalg.norm(P_Ainv_i, ord='fro'))\n",
    "min_bound_long_i = 1 / (sqrt_loss + 1)\n",
    "max_bound_i = np.linalg.norm(A_i - P_i, ord=2) * np.linalg.norm(Pinv_i, ord=2) + 1\n",
    "print(' Norms are calculated')\n",
    "\n",
    "print(' All done')\n",
    "print(f'  cond = {cond_i:.5f}, lambda_min = {eigen_min_i:.5f}, lambda_max = {eigen_max_i:.5f}')\n",
    "print(f'  loss = {loss_i:.5f}, lambda_max_bound = {max_bound_i:.5f}')\n",
    "print(f'  lambda_min_bound = {min_bound_i:.5f}, lambda_min_long_bound = {min_bound_long_i:.5f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
