"""A set of tools to generate and solve linear ODE systems."""
from typing import Tuple, Union, Text, Sequence, Dict

import pathlib
import concurrent.futures

import h5py
import jax
from jax import random as jrandom
import jax.numpy as jnp
import numpy as np
from jax.experimental.ode import odeint
from tqdm import tqdm


Key = jrandom.PRNGKeyArray
Path = Union[Text, pathlib.Path]


def generate_sparse_system_classes(
    key: Key,
    dim: int,
    frac_zeros: float,
    method: str = "independent",
) -> Union[jax.Array, Tuple[jax.Array, jax.Array]]:
  """Generate a sparse square matrix with a minimum abs value for all entries.
  """
  assert 0.0 <= frac_zeros <= 1.0

  if method == "independent":
    bernoulli_mask = jrandom.bernoulli(key, p=1 - frac_zeros, shape=(dim, dim))
    gaussian_vals = sample_gaussian(key, (dim, dim))
    return bernoulli_mask * gaussian_vals



def generate_iv(
    key: Key,
    dim: int,
    unit_norm: bool = False,
    sigma: float = 5.0
) -> jax.Array:
  """Generate an initial condition for the ODE with specified variance.
  """
  iv = jrandom.normal(key, (dim,)) * sigma
  if unit_norm:
    return iv / jnp.linalg.norm(iv)
  return iv

def candidate_problems(
    key: Key,
    dim: int,
    frac_zeros: float,
    num_systems: int,
    num_iv_per_system: int,
    class_A: str = 'independent',
    unit_norm: bool = True
) -> Tuple[jax.Array, jax.Array]:
  """Generate a set of problems (systems and initial conditions).
  """
  key, *subkeys = jrandom.split(key, num_systems + 1)
  A = jnp.stack([
    generate_sparse_system_classes(subkeys[_], dim, frac_zeros,class_A)
    for _ in range(num_systems)
  ])
  subkeys = jrandom.split(key, num_systems * num_iv_per_system)
  x0 = jnp.stack([
    generate_iv(subkeys[_], dim, unit_norm)
    for _ in range(num_systems * num_iv_per_system)
  ])
  return A, x0


def solve_system(
    A: jax.Array,
    x0: jax.Array,
    steps: int = 128,
    interval: Tuple[float, float] = (0.0, 1.0)
) -> Tuple[jax.Array, jax.Array]:
  """Solve liner ordinary differential equations from initial conditions.
  """
  if jnp.ndim(A) == 2:
    A = jnp.expand_dims(A, axis=0)
  if jnp.ndim(x0) == 1:
    x0 = jnp.expand_dims(x0, axis=0)
  n_systems = A.shape[0]
  n_ivs = x0.shape[0]
  assert n_ivs % n_systems == 0, f'{n_ivs} IVs for {n_systems} systems'
  num_iv_per_system = n_ivs // n_systems
  tt = jnp.linspace(*interval, steps)

  @jax.jit
  def _solve_system(i):
    @jax.jit
    def f(_x, t):  # pylint: disable=unused-argument
      return jnp.dot(A[i], _x)

    res = jnp.stack([
      odeint(f, x0[i * num_iv_per_system + j], tt).T
      for j in range(num_iv_per_system)
    ])
    return res

  xs = jax.vmap(_solve_system)(jnp.arange(n_systems))

  return jnp.array(xs), tt


def store_systems_and_iv(path: Path, A: jax.Array, x0: jax.Array) -> None:
  """Store the system matrices and initial conditions to disk in hdf5 format.
  """
  path.parent.mkdir(parents=True, exist_ok=True)
  num_systems, dim = A.shape[0], A.shape[1]
  num_iv_per_sys = x0.shape[0] // num_systems
  with h5py.File(path, "a") as f:
    dsetA = f.create_dataset("A", data=np.array(A))
    dsetA.attrs["comment"] = f"{num_systems} systems of dimension {dim}"
    dsetx0 = f.create_dataset("x0", data=np.array(x0))
    dsetx0.attrs["comment"] = f"{num_iv_per_sys} IVs for {num_systems} systems"


def store_solutions(path: Path, xs: jax.Array, tt: jax.Array) -> None:
  """Store the solutions to disk in hdf5 format.
  """
  path.parent.mkdir(parents=True, exist_ok=True)
  num_systems, num_iv_per_system, dim, steps = xs.shape
  with h5py.File(path, "a") as f:
    dsetxs = f.create_dataset("xs", data=np.array(xs))
    dsetxs.attrs["comment"] = f"{num_systems} systems, {num_iv_per_system} IVs \
                              per system, in {dim} dimensions, solved for \
                              {steps} steps"
    dsettt = f.create_dataset("tt", data=np.array(tt))
    dsettt.attrs["comment"] = f"the grid of {steps} time steps"


def load_data(path: Path) -> Dict[Text, np.ndarray]:
  """Load systems, solution, and metrics from disk."""
  results = {}
  with h5py.File(path, 'r') as f:
    keys = list(f.keys())
    for key in keys:
      results[key] = np.array(f[key])
  return results


if __name__ == '__main__':
  jax.config.update('jax_enable_x64', True)
  outpath = pathlib.Path.cwd() / 'data.h5'
  A_test, x0_test = candidate_problems(
    jrandom.PRNGKey(0),
    dim=20,
    frac_zeros=0.5,
    num_systems=5,
    num_iv_per_system=2,
    unit_norm=True)
  store_systems_and_iv(outpath, A_test, x0_test)
  xs_test, tt_test = solve_system(A_test, x0_test)
  store_solutions(outpath, xs_test, tt_test)
