import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import inspect
import math
import functools
import dataclasses
from typing import List, Tuple

import smarter_jax as sj


def Reshape(*args):
  """Constructs a module to reshape inputs
  """
  shape = args[0] if len(args) == 1 else args
  return eqx.nn.Lambda(lambda x, key=None: x.reshape(shape))


class Ravel(eqx.Module):
  """Unravel and concatenates all arrays in all the inputs
  """
  def __call__(self, *xs, key=None):
    return jnp.hstack([
      jnp.ravel(x)
      for x in eqx.filter(jax.tree_util.tree_leaves(xs), eqx.is_array)
      if x is not None
    ])


class MLP(eqx.Module):
  mlp: eqx.nn.MLP

  """A smarter replacement for equinox.nn.MLP

  Flattens its inputs into a 1D array before passing to an MLP.  The
  `in_size` argument must be the inputs' total size.
  """
  def __init__(self, in_size, *args, **kwargs):
    self.mlp = eqx.nn.MLP(in_size, *args, **kwargs)

  def __call__(self, *xs, **kwargs):
    return self.mlp(Ravel()(*xs), **kwargs)


class MultiMLP(eqx.Module):
  """An MLP that accepts multiple inputs

  It processes each input with an MLP then combines them into a single
  vector and processes them further with another MLP.
  """
  pre_mlp: tuple[MLP, ...]
  combine: eqx.Module
  post_mlp: MLP

  def __init__(self, in_sizes, out_size, width_size, depths,
               *, combine='cat', **kwargs):
    in_sizes = tuple(math.prod(i) if hasattr(i, '__iter__') else i
                     for i in in_sizes)
    if hasattr(depths, '__iter__'):
      pre_depth, post_depth = depths
    else:
      pre_depth, post_depth = 0, depths

    # Setup pre-combine MLPs
    pre_kwargs = dict(kwargs)
    if 'use_bias' in kwargs:
      pre_kwargs['use_final_bias'] = kwargs['use_bias']
    elif 'use_final_bias' in pre_kwargs:
      del pre_kwargs['use_final_bias']

    if 'activation' in kwargs:
      pre_kwargs['final_activation'] = kwargs['activation']
    elif 'final_activation' in pre_kwargs:
      del pre_kwargs['final_activation']

    if pre_depth > 0:
      self.pre_mlp = tuple(
        MLP(in_size, width_size, width_size, pre_depth, **pre_kwargs)
        for in_size in in_sizes)
      combine_sizes = [ width_size ] * len(in_sizes)
    else:
      self.pre_mlp = (Ravel(),) * len(in_sizes)
      combine_sizes = in_sizes

    # Setup combine function
    post_in_size = 0

    if combine == 'cat':
      self.combine = eqx.nn.Lambda(
        lambda xs: jnp.hstack(xs))
      post_in_size = sum(combine_sizes)
    elif combine == 'mul':
      assert min(combine_sizes) == max(combine_sizes), \
        "Cannot combine with 'mul' with unequal sizes"
      self.combine = eqx.nn.Lambda(
        lambda xs: math.prod(xs))
      post_in_size = combine_sizes[0]
    elif combine == 'prod':
      self.combine = eqx.nn.Lambda(
        lambda xs: functools.reduce(jnp.kron, xs))
      post_in_size = math.prod(combine_sizes)
    else:
      assert False

    # Setup post-combine MLP
    self.post_mlp = eqx.nn.MLP(
      post_in_size, out_size, width_size, post_depth, **kwargs)

  @sj.with_subkeys
  def __call__(self, *xs, **kwargs):
    xs = tuple(mlp(x) for x, mlp in zip(xs, self.pre_mlp))
    x = self.combine(xs)
    x = self.post_mlp(x)
    return x


class MultiLayerLSTM(eqx.Module):
  cells: Tuple[eqx.Module, ...]
  hidden_size: int

  def __init__(self, in_size: int, nlayers: int, hidden_size: int, *, key: jax.random.PRNGKey):
    keys = jax.random.split(key, nlayers)
    self.hidden_size = hidden_size
    cells = []
    for i in range(nlayers):
      in_size = in_size if i == 0 else hidden_size
      cells.append(eqx.nn.LSTMCell(in_size, hidden_size, key=keys[i]))
    self.cells = tuple(cells)

  def __call__(self, x: jax.Array, hidden: jax.Array):
    new_h = []
    for i, cell in enumerate(self.cells):
      h, c = hidden[i, :self.hidden_size], hidden[i, self.hidden_size:]
      next_h_i, next_c_i = cell(x, (h, c))
      new_h.append(jnp.concatenate([next_h_i, next_c_i], axis=0))
      x = next_h_i
    return x, jnp.array(new_h)

  @property
  def num_layers(self):
    return len(self.cells)
