import functools
import math

import equinox as eqx
import jax
import jax.numpy as jnp


class Ravel(eqx.Module):
    """Unravel and concatenates all arrays in all the inputs"""

    def __call__(self, *xs, key=None):
        return jnp.hstack(
            [jnp.ravel(x) for x in eqx.filter(jax.tree.leaves(xs), eqx.is_array) if x is not None]
        )


class MLP(eqx.Module):
    mlp: eqx.nn.MLP

    """A smarter replacement for equinox.nn.MLP

    Flattens its inputs into a 1D array before passing to an MLP.  The
    `in_size` argument must be the inputs' total size.
    """

    def __init__(self, in_size, *args, **kwargs):
        self.mlp = eqx.nn.MLP(in_size, *args, **kwargs)

    def __call__(self, *xs, **kwargs):
        return self.mlp(Ravel()(*xs), **kwargs)


class MultiMLP(eqx.Module):
    """An MLP that accepts multiple inputs

    It processes each input with an MLP then combines them into a single
    vector and processes them further with another MLP.
    """

    pre_mlp: tuple[MLP, ...]
    combine: eqx.Module
    post_mlp: MLP

    def __init__(self, in_sizes, out_size, width_size, depths, *, combine='cat', **kwargs):
        in_sizes = tuple(math.prod(i) if hasattr(i, '__iter__') else i for i in in_sizes)
        if hasattr(depths, '__iter__'):
            pre_depth, post_depth = depths
        else:
            pre_depth, post_depth = 0, depths

        # Setup pre-combine MLPs
        pre_kwargs = dict(kwargs)
        if 'use_bias' in kwargs:
            pre_kwargs['use_final_bias'] = kwargs['use_bias']
        elif 'use_final_bias' in pre_kwargs:
            del pre_kwargs['use_final_bias']

        if 'activation' in kwargs:
            pre_kwargs['final_activation'] = kwargs['activation']
        elif 'final_activation' in pre_kwargs:
            del pre_kwargs['final_activation']

        if pre_depth > 0:
            self.pre_mlp = tuple(
                MLP(in_size, width_size, width_size, pre_depth, **pre_kwargs)
                for in_size in in_sizes
            )
            combine_sizes = [width_size] * len(in_sizes)
        else:
            self.pre_mlp = (Ravel(),) * len(in_sizes)
            combine_sizes = in_sizes

        # Setup combine function
        post_in_size = 0

        if combine == 'cat':
            self.combine = eqx.nn.Lambda(lambda xs: jnp.hstack(xs))
            post_in_size = sum(combine_sizes)
        elif combine == 'mul':
            assert min(combine_sizes) == max(combine_sizes), (
                "Cannot combine with 'mul' with unequal sizes"
            )
            self.combine = eqx.nn.Lambda(lambda xs: math.prod(xs))
            post_in_size = combine_sizes[0]
        elif combine == 'prod':
            self.combine = eqx.nn.Lambda(lambda xs: functools.reduce(jnp.kron, xs))
            post_in_size = math.prod(combine_sizes)
        else:
            raise ValueError(f'Unknown combine method: {combine}')

        # Setup post-combine MLP
        self.post_mlp = eqx.nn.MLP(post_in_size, out_size, width_size, post_depth, **kwargs)

    def __call__(self, *xs, **kwargs):
        xs = tuple(mlp(x) for x, mlp in zip(xs, self.pre_mlp, strict=False))
        x = self.combine(xs)
        x = self.post_mlp(x)

        return x


class MultiLayerLSTM(eqx.Module):
    cells: tuple[eqx.Module, ...]
    hidden_size: int

    def __init__(self, in_size: int, nlayers: int, hidden_size: int, *, key: jax.Array):
        keys = jax.random.split(key, nlayers)

        cells = []
        for i in range(nlayers):
            in_size = in_size if i == 0 else hidden_size
            cells.append(eqx.nn.LSTMCell(in_size, hidden_size, key=keys[i]))

        self.cells = tuple(cells)
        self.hidden_size = hidden_size

    def __call__(self, x: jax.Array, hidden: jax.Array):
        new_h = []

        for i, cell in enumerate(self.cells):
            h, c = hidden[i, : self.hidden_size], hidden[i, self.hidden_size :]
            next_h_i, next_c_i = cell(x, (h, c))

            new_h.append(jnp.concatenate([next_h_i, next_c_i], axis=0))
            x = next_h_i

        return x, jnp.array(new_h)

    @property
    def num_layers(self):
        return len(self.cells)
