from absl import logging
from time import time
from typing import Union, Sequence, Callable, Any

import flax.typing
import jax
import jax.numpy as jnp
import numpy as np
import scipy.interpolate
import scipy.fft

Array = Union[jnp.ndarray, np.ndarray]
ScalarArray = Union[jnp.ndarray, np.ndarray]
Tree = Any

class disable_logging:
  """Context manager for disabling the logging."""

  def __init__(self, level: int = -1):
    self.level_context = level
    self.level_init = None

  def __enter__(self):
    self.level_init = logging.get_verbosity()
    logging.set_verbosity(self.level_context)

  def __exit__(self, exc_type, exc_value, traceback):
    logging.set_verbosity(self.level_init)

def is_multiple(b, a):
  return abs(int(b / a) * a - b) < 1e-08

def profile(f: Callable, kwargs: dict, repeats: int = 1, block_until_ready: bool = False):
  t_0 = time()
  for _ in range(repeats):
    u = f(**kwargs)
  if block_until_ready:
    u = u.block_until_ready()
  return ((time() - t_0) / repeats)

def shuffle_arrays(rngkey: flax.typing.PRNGKey, arrays: Sequence[Array], axis: int = 0) -> Sequence[Array]:
  """Shuffles a set of arrays with the same random permutation along the given axis."""

  # Move the desired axis to the leading axis
  arrays = jax.tree.map(lambda v: jnp.moveaxis(v, axis, 0), arrays)

  # Get permutation
  length = arrays[0].shape[0]
  assert all(jax.tree.map(lambda v: v.shape[0] == length, arrays))
  permutation = jax.random.permutation(rngkey, length)

  # Permute along the leading axis
  arrays = jax.tree.map(lambda v: v[permutation], arrays)
  # Move back the leading axis to its place
  arrays = jax.tree.map(lambda v: jnp.moveaxis(v, 0, axis), arrays)

  return arrays

def split_arrays(arrays: Sequence[Array], size: int) -> Sequence[Array]:

  length = arrays[0].shape[0]
  assert all([arr.shape[0] == length for arr in arrays])

  return [jnp.stack(jnp.split(arr, length // size)) for arr in arrays]

def normalize(arr: Array, shift: Array, scale: Array):
  scale = jnp.where(scale == 0., 1., scale)
  arr = (arr - shift) / scale
  return arr

def unnormalize(arr: Array, mean: Array, std: Array):
  arr = std * arr + mean
  return arr

def segment_mean(arr: Array, chunks: Sequence, axis: int = 0):
  """Uses np.add.reduceat to perform a segment mean operation at the given axis"""

  arr = np.array(arr)
  chunks = np.array(chunks)
  assert len(chunks.shape) == 1
  assert chunks.shape[0] == arr.shape[axis]

  # Swap axes and sort in increasing order
  arr = arr.swapaxes(0, axis)
  argsort = np.argsort(chunks)
  chunks = chunks[argsort]
  arr = arr[argsort]

  # Get the appropriate indices for reduceat based on the chunks
  steps = np.where(chunks[1:] - chunks[:-1])[0] + 1
  steps = np.concatenate([[0], steps])

  # Get the reduced sum and the size of each chunk
  reduced = np.add.reduceat(arr, indices=steps, axis=0)
  sizes = np.add.reduceat(np.ones_like(arr), indices=steps, axis=0)
  # Calculate the mean per chunk
  out = reduced / sizes

  # Back to the original order of the axes
  out = out.swapaxes(0, axis)

  return out

def interpolate_on_grid(coordinates, values, bbox, dim):
  """Interpolated from unstructured coordinates to a uniform grid."""

  function = scipy.interpolate.RBFInterpolator(coordinates, values)
  x = np.stack(np.meshgrid(np.linspace(bbox[0][0], bbox[1][0], dim), np.linspace(bbox[0][1], bbox[1][1], dim)), axis=-1)
  f = function(x.reshape(-1, 2)).reshape(dim, dim)

  return x, f

def compute_amplitude_spectrum(coordinates, values, bbox, dim):
  """
  Compute the Fourier amplitude spectrum based on interpolated values in a bounding box.
  The returned amplitudes are shifted so that the zero frequency is centered.
  """

  _, f = interpolate_on_grid(coordinates, values, bbox, dim)
  amplitude = np.abs(scipy.fft.fft2(f))
  amplitude = scipy.fft.fftshift(amplitude)

  return amplitude
