import functools
import operator as op
from munch import Munch

from jax import lax
from jax import random
import jax.numpy as jnp

from jax.nn import (
    relu,
    log_softmax,
    softmax,
    softplus,
    sigmoid,
    elu,
    leaky_relu,
    selu,
    gelu,
    celu,
    normalize,
)
from jax.nn.initializers import ones, zeros
from initializers import glorot_normal, normal, bernoulli

# Following the convention used in Keras and tf.layers, we use CamelCase for the
# names of layer constructors, like Conv and Relu, while using snake_case for
# other functions, like lax.conv and relu.

# Each layer constructor function returns an dictionary. Common fields are:
#   init: takes an rng key and an input shape and returns an
#     (output_shape, params) pair,
#   logp(Optional): takes params, return log probability. If absent,
#     assume there is no randomness
#   apply: takes params, inputs, and an rng key and applies the layer.
#   init_01(Optional): Similar to init, but assume  bernoulli distribution
#     when possible. If absent, fallback to init.
#   param_num(Optional): If absent, assume there is no parameter.


def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
    """Layer constructor function for a dense (fully-connected) layer."""

    def init(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W, b = W_init[0](k1, (input_shape[-1], out_dim)), b_init[0](k2, (out_dim,))
        return output_shape, (W, b)

    def logp(params):
        W, b = params
        return W_init[1](W) + b_init[1](b)

    def apply(params, inputs, **kwargs):
        W, b = params
        return jnp.dot(inputs, W) + b

    def init_01(rng, input_shape):
        W_init = bernoulli()
        b_init = bernoulli()

        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W, b = W_init[0](k1, (input_shape[-1], out_dim)), b_init[0](k2, (out_dim,))
        return output_shape, (W, b)

    def init_any(rng, input_shape, dist):
        W_init = dist
        b_init = dist

        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W, b = W_init[0](k1, (input_shape[-1], out_dim)), b_init[0](k2, (out_dim,))
        return output_shape, (W, b)

    def param_num(params):
        W, b = params
        return W.size + b.size

    return Munch(
        init=init,
        logp=logp,
        apply=apply,
        init_01=init_01,
        init_any=init_any,
        param_num=param_num,
        out_dim=out_dim,
    )


def ScaleDense(out_dim, input_norm=None, b_var_scale=1e-2, init_scale=1.0):
    if b_var_scale == 0:
        return ScaleDenseNoBias(out_dim, input_norm=input_norm, init_scale=init_scale)

    def get_c(leninput):
        """c*real parameters = effective parameters"""
        inn = input_norm or jnp.sqrt(leninput)
        return jnp.sqrt((1 - b_var_scale)) / inn, jnp.sqrt(b_var_scale)

    def init(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        c = get_c(input_shape[-1])

        W_init = normal(init_scale)
        b_init = normal(init_scale)
        k1, k2 = random.split(rng)
        W, b = W_init[0](k1, (input_shape[-1], out_dim)), b_init[0](k2, (out_dim,))
        return output_shape, (W, b)

    def logp(params):
        W, b = params
        W_init = normal(init_scale)
        b_init = normal(init_scale)

        return W_init[1](W) + b_init[1](b)

    def apply(params, inputs, **kwargs):
        W, b = params
        c = get_c(W.shape[0])
        return c[0] * jnp.dot(inputs, W) + c[1] * b

    def init_01(rng, input_shape):
        W_init = bernoulli()
        b_init = bernoulli()

        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W, b = W_init[0](k1, (input_shape[-1], out_dim)), b_init[0](k2, (out_dim,))
        return output_shape, (W, b)

    def init_any(rng, input_shape, dist):
        W_init = dist
        b_init = dist

        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W, b = W_init[0](k1, (input_shape[-1], out_dim)), b_init[0](k2, (out_dim,))
        return output_shape, (W, b)

    def param_num(params):
        W, b = params
        return W.size + b.size

    return Munch(
        init=init,
        logp=logp,
        apply=apply,
        init_01=init_01,
        init_any=init_any,
        param_num=param_num,
        out_dim=out_dim,
    )


def ScaleDenseNoBias(out_dim, input_norm=None, init_scale=1.0):
    def get_c(leninput):
        """c*real parameters = effective parameters"""
        inn = input_norm or jnp.sqrt(leninput)
        return 1 / inn

    def init(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        c = get_c(input_shape[-1])

        W_init = normal(init_scale)
        W = W_init[0](rng, (input_shape[-1], out_dim))
        return output_shape, W

    def logp(params):
        W = params
        W_init = normal(init_scale)

        return W_init[1](W)

    def apply(params, inputs, **kwargs):
        W = params
        c = get_c(W.shape[0])
        return c * jnp.dot(inputs, W)

    def init_01(rng, input_shape):
        W_init = bernoulli()

        output_shape = input_shape[:-1] + (out_dim,)
        W = W_init[0](rng, (input_shape[-1], out_dim))
        return output_shape, W

    def init_any(rng, input_shape, dist):
        W_init = dist

        output_shape = input_shape[:-1] + (out_dim,)
        W = W_init[0](rng, (input_shape[-1], out_dim))
        return output_shape, W

    def param_num(params):
        W = params
        return W.size

    return Munch(
        init=init,
        logp=logp,
        apply=apply,
        init_01=init_01,
        init_any=init_any,
        param_num=param_num,
        out_dim=out_dim,
        get_c=get_c,
    )


def elementwise(fun, **fun_kwargs):
    """Layer that applies a scalar function elementwise on its inputs."""
    init = lambda rng, input_shape: (input_shape, ())
    apply = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
    return Munch(init=init, apply=apply)


Tanh = elementwise(jnp.tanh)
Relu = elementwise(relu)
Exp = elementwise(jnp.exp)
LogSoftmax = elementwise(log_softmax, axis=-1)
Softmax = elementwise(softmax, axis=-1)
Softplus = elementwise(softplus)
Sigmoid = elementwise(sigmoid)
Elu = elementwise(elu)
LeakyRelu = elementwise(leaky_relu)
Selu = elementwise(selu)
Gelu = elementwise(gelu)
Celu = elementwise(celu)


def _pooling_layer(reducer, init_val, rescaler=None):
    def PoolingLayer(window_shape, strides=None, padding="VALID", spec=None):
        """Layer construction function for a pooling layer."""
        strides = strides or (1,) * len(window_shape)
        rescale = rescaler(window_shape, strides, padding) if rescaler else None

        if spec is None:
            non_spatial_axes = 0, len(window_shape) + 1
        else:
            non_spatial_axes = spec.index("N"), spec.index("C")

        for i in sorted(non_spatial_axes):
            window_shape = window_shape[:i] + (1,) + window_shape[i:]
            strides = strides[:i] + (1,) + strides[i:]

        def init(rng, input_shape):
            padding_vals = lax.padtype_to_pads(
                input_shape, window_shape, strides, padding
            )
            ones = (1,) * len(window_shape)
            out_shape = lax.reduce_window_shape_tuple(
                input_shape, window_shape, strides, padding_vals, ones, ones
            )
            return out_shape, ()

        def apply(params, inputs, **kwargs):
            out = lax.reduce_window(
                inputs, init_val, reducer, window_shape, strides, padding
            )
            return rescale(out, inputs, spec) if rescale else out

        return Munch(init=init, apply=apply)

    return PoolingLayer


MaxPool = _pooling_layer(lax.max, -jnp.inf)
SumPool = _pooling_layer(lax.add, 0.0)


def _normalize_by_window_size(dims, strides, padding):
    def rescale(outputs, inputs, spec):
        if spec is None:
            non_spatial_axes = 0, inputs.ndim - 1
        else:
            non_spatial_axes = spec.index("N"), spec.index("C")

        spatial_shape = tuple(
            inputs.shape[i] for i in range(inputs.ndim) if i not in non_spatial_axes
        )
        one = jnp.ones(spatial_shape, dtype=inputs.dtype)
        window_sizes = lax.reduce_window(one, 0.0, lax.add, dims, strides, padding)
        for i in sorted(non_spatial_axes):
            window_sizes = jnp.expand_dims(window_sizes, i)

        return outputs / window_sizes

    return rescale


AvgPool = _pooling_layer(lax.add, 0.0, _normalize_by_window_size)


def Flatten():
    """Layer construction function for flattening all but the leading dim."""

    def init(rng, input_shape):
        output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
        return output_shape, ()

    def apply(params, inputs, **kwargs):
        return jnp.reshape(inputs, (inputs.shape[0], -1))

    return Munch(init=init, apply=apply)


Flatten = Flatten()


def Identity():
    """Layer construction function for an identity layer."""
    init = lambda rng, input_shape: (input_shape, ())
    apply = lambda params, inputs, **kwargs: inputs
    return Munch(init=init, apply=apply)


Identity = Identity()


def FanOut(num):
    """Layer construction function for a fan-out layer."""
    init = lambda rng, input_shape: ([input_shape] * num, ())
    apply = lambda params, inputs, **kwargs: [inputs] * num
    return Munch(init=init, apply=apply)


def FanInSum():
    """Layer construction function for a fan-in sum layer."""
    init = lambda rng, input_shape: (input_shape[0], ())
    apply = lambda params, inputs, **kwargs: sum(inputs)
    return Munch(init=init, apply=apply)


FanInSum = FanInSum()


def FanInConcat(axis=-1):
    """Layer construction function for a fan-in concatenation layer."""

    def init(rng, input_shape):
        ax = axis % len(input_shape[0])
        concat_size = sum(shape[ax] for shape in input_shape)
        out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax + 1 :]
        return out_shape, ()

    def apply(params, inputs, **kwargs):
        return jnp.concatenate(inputs, axis)

    return Munch(init=init, apply=apply)


def Dropout(rate, mode="train"):
    """Layer construction function for a dropout layer with given rate."""

    def init(rng, input_shape):
        return input_shape, ()

    def apply(params, inputs, **kwargs):
        rng = kwargs.get("rng", None)
        if rng is None:
            msg = (
                "Dropout layer requires apply_fun to be called with a PRNG key "
                "argument. That is, instead of `apply_fun(params, inputs)`, call "
                "it like `apply_fun(params, inputs, rng)` where `rng` is a "
                "jax.random.PRNGKey value."
            )
            raise ValueError(msg)
        if mode == "train":
            keep = random.bernoulli(rng, rate, inputs.shape)
            return jnp.where(keep, inputs / rate, 0)
        else:
            return inputs

    return Munch(init=init, apply=apply)


# Composing layers via combinators


def serial(*layers):
    """Combinator for composing layers in serial.

    Args:
      *layers: a sequence of layers, each an (init_fun, apply_fun) pair.

    Returns:
      A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
      composition of the given sequence of layers.
    """
    nlayers = len(layers)
    net = Munch()

    def init(rng, input_shape):
        params = []
        inits = [layer["init"] for layer in layers]
        for init in inits:
            rng, layer_rng = random.split(rng)
            input_shape, param = init(layer_rng, input_shape)
            params.append(param)
        return input_shape, params

    net["init"] = init

    def apply(params, inputs, **kwargs):
        rng = kwargs.pop("rng", None)
        rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
        applys = [layer["apply"] for layer in layers]
        for fun, param, rng in zip(applys, params, rngs):
            inputs = fun(param, inputs, rng=rng, **kwargs)
        return inputs

    net["apply"] = apply

    def apply_with_middle(params, inputs, **kwargs):
        rng = kwargs.pop("rng", None)
        rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
        applys = [layer["apply"] for layer in layers]
        middle = [inputs]
        for fun, param, rng in zip(applys, params, rngs):
            inputs = fun(param, inputs, rng=rng, **kwargs)
            middle.append(inputs)
        return inputs, middle

    net["apply_with_middle"] = apply_with_middle

    def logp(params):
        logps = [
            layer["logp"](param)
            for layer, param in zip(layers, params)
            if "logp" in layer
        ]
        return jnp.stack(logps).sum()

    if [layer for layer in layers if "logp" in layer]:
        net["logp"] = logp

    def init_01(rng, input_shape):
        params = []
        inits = [layer.get("init_01", layer["init"]) for layer in layers]
        for init in inits:
            rng, layer_rng = random.split(rng)
            input_shape, param = init(layer_rng, input_shape)
            params.append(param)
        return input_shape, params

    net["init_01"] = init_01

    def init_any(rng, input_shape, dist):
        params = []
        inits = [layer.get("init_any", layer["init"]) for layer in layers]
        is_init_anys = ["init_any" in layer for layer in layers]
        for init, is_init_any in zip(inits, is_init_anys):
            rng, layer_rng = random.split(rng)
            if is_init_any:
                input_shape, param = init(layer_rng, input_shape, dist)
            else:
                input_shape, param = init(layer_rng, input_shape)
            params.append(param)
        return input_shape, params

    net["init_any"] = init_any

    def param_num(params):
        funs = [layer.get("param_num", lambda p: 0) for layer in layers]
        return sum([fun(param) for fun, param in zip(funs, params)])

    net["param_num"] = param_num

    def serial_layers():
        return layers

    net["serial_layers"] = serial_layers

    return net


def parallel(*layers):
    """Combinator for composing layers in parallel.

    The layer resulting from this combinator is often used with the FanOut and
    FanInSum layers.

    Args:
      *layers: a sequence of layers, each an (init_fun, apply_fun) pair.

    Returns:
      A new layer, meaning an (init_fun, apply_fun) pair, representing the
      parallel composition of the given sequence of layers. In particular, the
      returned layer takes a sequence of inputs and returns a sequence of outputs
      with the same length as the argument `layers`.
    """
    nlayers = len(layers)
    net = Munch()

    def init(rng, input_shape):
        rngs = random.split(rng, layers)
        inits = [layer["init"] for layer in nlayers]
        return zip(
            *[init(rng, shape) for init, rng, shape in zip(inits, rngs, input_shape)]
        )

    net["init"] = init

    def apply(params, inputs, **kwargs):
        rng = kwargs.pop("rng", None)
        rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
        applys = [layer["apply"] for layer in layers]
        return [
            f(p, x, rng=r, **kwargs) for f, p, x, r in zip(applys, params, inputs, rngs)
        ]

    net["apply"] = apply

    def logp(params):
        logps = [
            layer["logp"](param)
            for layer, param in zip(layers, params)
            if "logp" in layer
        ]
        return jnp.stack(logps).sum()

    if [layer for layer in layers if "logp" in layer]:
        net["logp"] = logp

    def init_01(rng, input_shape):
        rngs = random.split(rng, layers)
        inits = [layer.get("init_01", layer["init"]) for layer in layers]
        return zip(
            *[init(rng, shape) for init, rng, shape in zip(inits, rngs, input_shape)]
        )

    net["init_01"] = init_01

    def init_any(rng, input_shape, dist):
        rngs = random.split(rng, layers)
        inits = [layer.get("init_any", layer["init"]) for layer in layers]
        is_init_anys = ["init_any" in layer for layer in layers]
        return zip(
            *[
                init(rng, shape, dist) if is_init_any else init(rng, shape)
                for init, rng, shape, is_init_any in zip(
                    inits, rngs, input_shape, is_init_anys
                )
            ]
        )

    net["init_any"] = init_any

    def param_num(params):
        funs = [layer.get("param_num", lambda p: 0) for layer in layers]
        return sum([fun(param) for fun, param in zip(funs, params)])

    net["param_num"] = param_num

    return net
