from typing import Union, NamedTuple, Mapping, Sequence, Callable

import jax.tree
import jax.numpy as jnp
from flax import linen as nn

from ol.utils import Array


class Inputs(NamedTuple):
  s: Array  # Contains the geometry-related domain features
  a: Array  # Contains the full-domain functions
  q: Mapping[str, Array]  # Contains the segment function values, accomponied by a mask
  m: Mapping[str, Array]  # Binary function masks, make sure to not normalize at all
  x_inp: Array
  x_out: Array
  t: Union[Array, float, None] = None
  tau: Union[Array, float, None] = None

class AbstractOperator(nn.Module):
  def setup(self):
    raise NotImplementedError

  def _check_coordinates(self, x: Array) -> None:
    assert x is not None
    assert x.ndim == 2
    assert x.shape[1] <= 3
    assert x.min() >= -1
    assert x.max() <= +1

  def _check_function(self, u: Array, x: Array) -> None:
    assert u is not None
    assert u.ndim == 4
    assert u.shape[1] == 1
    assert u.shape[2] == x.shape[2], f'u: {u.shape}, x: {x.shape}'

  def __call__(self, inputs: Inputs, **kwargs) -> Array:
    # Check input functions
    self._check_function(inputs.s, x=inputs.x_inp)
    self._check_function(inputs.a, x=inputs.x_inp)
    jax.tree.map(lambda f: self._check_function(f, x=inputs.x_inp), inputs.q)
    jax.tree.map(lambda f: self._check_function(f[..., None], x=inputs.x_inp), inputs.m)
    assert jax.tree.all(jax.tree.map(lambda m: m.ndim == 3, inputs.m))
    assert jax.tree.all(jax.tree.map(lambda m: m.dtype == jnp.dtype(bool), inputs.m))

    return self.call(inputs, **kwargs)

  def call(self, inputs: Inputs) -> Array:
    raise NotImplementedError

  @property
  def configs(self):
    configs = {
      attr: self.__getattr__(attr)
      for attr in self.__annotations__.keys() if attr != 'parent'
    }
    return configs

class LeadTimeConditionedNorm(nn.Module):
  """
  Learned correction layer is designed to be applied after a normalization layer.
  Based on an input (e.g., lead time), it shifts and scales the distribution of its input.
  correction_size must either be 1 or the same as one of the input dimensions (broadcastable).
  """

  latent_size: Sequence[int]
  correction_size: int = 1

  def setup(self):
    self.mlp_scale = nn.Sequential(layers=[
      nn.Dense(self.latent_size, kernel_init=nn.initializers.normal(stddev=.01)),
      nn.sigmoid,
      nn.Dense(self.correction_size, kernel_init=nn.initializers.normal(stddev=.01)),
    ])
    self.mlp_bias = nn.Sequential(layers=[
      nn.Dense(self.latent_size, kernel_init=nn.initializers.normal(stddev=.01)),
      nn.sigmoid,
      nn.Dense(self.correction_size, kernel_init=nn.initializers.normal(stddev=.01)),
    ])

  def __call__(self, c, x):
    scale = 1 + c * self.mlp_scale(c)
    bias = c * self.mlp_bias(c)
    shape = x.shape
    x = x.reshape(shape[0], -1, shape[-1])
    scale = jnp.expand_dims(scale, axis=1)
    bias = jnp.expand_dims(bias, axis=1)
    x = x * scale + bias
    x = x.reshape(*shape)
    return x

class FeedForward(nn.Module):
  """
  Multi-layer perceptron with optional layer norm and learned correction on the last layer.
  Activation is applied on all layers except the last one.
  Multiple inputs are concatenated before being fed to the MLP.
  """

  layer_sizes: Sequence[int]
  activation: Callable
  use_layer_norm: bool = False
  use_conditional_norm: bool = False
  cond_norm_hidden_size: int = 4
  concatenate_axis: int = -1
  conv: bool = False
  dropout: float = None

  def setup(self):
    # Set up layers
    if not self.conv:
      self.layers = [nn.Dense(features) for features in self.layer_sizes]
    else:
      self.layers = [nn.Conv(features, kernel_size=1) for features in self.layer_sizes]
    # Set dropout layers
    if self.dropout is not None:
      self.dropouts = [nn.Dropout(self.dropout) for _ in range(len(self.layers))]

    # Set up normalization layer
    self.layernorm = nn.LayerNorm(
      reduction_axes=-1,
      feature_axes=-1,
      use_scale=True,
      use_bias=True,
    ) if self.use_layer_norm else None

    # Set conditional normalization layer
    self.correction = None
    if self.use_conditional_norm:
      self.correction = LeadTimeConditionedNorm(
        latent_size=self.cond_norm_hidden_size,
        correction_size=self.layer_sizes[-1],
      )

  def __call__(self, *args, c: Array = None, deterministic: bool = False, **kwargs):
    x = concatenate_args(args=args, kwargs=kwargs, axis=self.concatenate_axis)
    for i, layer in enumerate(self.layers[:-1]):
      x = layer(x)
      x = self.activation(x)
      if self.dropout is not None:
        x = self.dropouts[i](x, deterministic=deterministic)
    x = self.layers[-1](x)
    if self.dropout is not None:
      x = self.dropouts[-1](x, deterministic=deterministic)
    if self.layernorm:
      x = self.layernorm(x)
    if self.correction:
      assert c is not None
      x = self.correction(c=c, x=x)
    return x

def concatenate_args(args, kwargs, axis: int = -1):
  combined_args = jax.tree.flatten(args)[0] + jax.tree.flatten(kwargs)[0]
  concat_args = jnp.concatenate(combined_args, axis=axis)
  return concat_args

def segment_mean(arr, idx, num_segments: int, indices_are_sorted: bool = False):
  """Compute segment-wise mean along the leading axis."""

  sums = jax.ops.segment_sum(arr, idx, num_segments=num_segments, indices_are_sorted=indices_are_sorted)
  counts = jax.ops.segment_sum(jnp.ones_like(arr), idx, num_segments=num_segments, indices_are_sorted=indices_are_sorted)
  counts = jnp.maximum(counts, 1.0)
  return sums / counts

def segment_softmax(arr, segment_ids, num_segments: int, indices_are_sorted: bool = False):
  """
  Performs the softmax operation segment-wise.
  arr shape: [N, *]
  segment_ids shape: [N, *]
  """

  max_per_segment = jax.ops.segment_max(arr, segment_ids, num_segments=num_segments, indices_are_sorted=indices_are_sorted)  # [num_segments, *]
  max_expanded = max_per_segment[segment_ids]  # [N, *]
  arr_exp = jnp.exp(arr - max_expanded)  # [N, *]
  sum_per_segment = jax.ops.segment_sum(arr_exp, segment_ids, num_segments=num_segments, indices_are_sorted=indices_are_sorted)  # [num_segments, *]
  sum_expanded = sum_per_segment[segment_ids]  # [N, *]
  return arr_exp / sum_expanded

def segment_attention(q, k, v, segment_ids, num_segments: int = None, sorted: bool = False):
  """
  q: query array of shape [N, H, D]
  k: key array of shape [N, H, D]
  v: value array of shape [N, H, D]
  segment_ids: of shape [N]
  """

  head_dim = q.shape[-1]
  # Compute the similarity scores
  sim = jnp.sum(q * k, axis=-1) * (head_dim ** -0.5)  # [N, H]
  # Apply segment-wise softmax on the raw scores (sum for each rnode is one)
  attn = segment_softmax(arr=sim, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=sorted)  # [N, H]
  # Segment-wise aggregatation of values
  v_scaled = attn[..., None] * v  # [N, H, D]
  out = jax.ops.segment_sum(v_scaled, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=sorted)  # [num_segments, H, D]

  return out
