import distrax
import jax.nn
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence, Callable, List
from flax.linen.normalization import _compute_stats, _normalize, _canonicalize_axes
from typing import Any, Callable, Optional, Sequence, Tuple, Union

PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
Dtype = Any  # this could be a real type?
Axes = Union[int, Sequence[int]]


class BatchRenorm(nn.Module):
    """BatchRenorm Module, implemented based on the Batch Renormalization paper (https://arxiv.org/abs/1702.03275).
    and adapted from Flax's BatchNorm implementation:
    https://github.com/google/flax/blob/ce8a3c74d8d1f4a7d8f14b9fb84b2cc76d7f8dbf/flax/linen/normalization.py#L228


    Attributes:
      use_running_average: if True, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of the batch
        statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the result (default: infer from input and params).
      param_dtype: the dtype passed to parameter initializers (default: float32).
      use_bias:  if True, bias (beta) is added.
      use_scale: if True, multiply by scale (gamma). When the next layer is linear
        (also e.g. nn.relu), this can be disabled since the scaling will be done
        by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
        example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
        examples on the first two and last two devices. See `jax.lax.psum` for
        more details.
      use_fast_variance: If true, use a faster, but less numerically stable,
        calculation for the variance.
    """

    use_running_average: Optional[bool] = None
    axis: int = -1
    momentum: float = 0.99
    epsilon: float = 0.001
    warmup_steps: int = 100_000
    dtype: Optional[Dtype] = None
    param_dtype: Dtype = jnp.float32
    use_bias: bool = True
    use_scale: bool = True
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros
    scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.ones
    axis_name: Optional[str] = None
    axis_index_groups: Any = None

    @nn.compact
    def __call__(self, x, use_running_average: Optional[bool] = None):
        """Normalizes the input using batch statistics.

        NOTE:
        During initialization (when `self.is_initializing()` is `True`) the running
        average of the batch statistics will not be updated. Therefore, the inputs
        fed during initialization don't need to match that of the actual input
        distribution and the reduction axis (set with `axis_name`) does not have
        to exist.

        Args:
          x: the input to be normalized.
          use_running_average: if true, the statistics stored in batch_stats will be
            used instead of computing the batch statistics on the input.

        Returns:
          Normalized inputs (the same shape as inputs).
        """

        use_running_average = nn.merge_param("use_running_average", self.use_running_average, use_running_average)
        feature_axes = _canonicalize_axes(x.ndim, self.axis)
        reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
        feature_shape = [x.shape[ax] for ax in feature_axes]

        ra_mean = self.variable(
            "batch_stats",
            "mean",
            lambda s: jnp.zeros(s, jnp.float32),
            feature_shape,
        )
        ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape)

        r_max = self.variable(
            "batch_stats",
            "r_max",
            lambda s: s,
            3,
        )
        d_max = self.variable(
            "batch_stats",
            "d_max",
            lambda s: s,
            5,
        )
        steps = self.variable(
            "batch_stats",
            "steps",
            lambda s: s,
            0,
        )

        if use_running_average:
            custom_mean = ra_mean.value
            custom_var = ra_var.value
        else:
            batch_mean, batch_var = _compute_stats(
                x,
                reduction_axes,
                dtype=self.dtype,
                axis_name=self.axis_name if not self.is_initializing() else None,
                axis_index_groups=self.axis_index_groups,
                # use_fast_variance=self.use_fast_variance,
            )
            if self.is_initializing():
                custom_mean = batch_mean
                custom_var = batch_var
            else:
                std = jnp.sqrt(batch_var + self.epsilon)
                ra_std = jnp.sqrt(ra_var.value + self.epsilon)
                # scale
                r = jax.lax.stop_gradient(std / ra_std)
                r = jnp.clip(r, 1 / r_max.value, r_max.value)
                # bias
                d = jax.lax.stop_gradient((batch_mean - ra_mean.value) / ra_std)
                d = jnp.clip(d, -d_max.value, d_max.value)

                # BatchNorm normalization, using minibatch stats and running average stats
                # Because we use _normalize, this is equivalent to
                # ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma
                # where sigma = sqrt(var)
                affine_mean = batch_mean - d * jnp.sqrt(batch_var) / r
                affine_var = batch_var / (r ** 2)

                # Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
                # the constraints are linearly relaxed to r_max/d_max over 40k steps
                # Here we only have a warmup phase
                is_warmed_up = jnp.greater_equal(steps.value, self.warmup_steps).astype(jnp.float32)
                custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean
                custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var

                ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * batch_mean
                ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * batch_var
                steps.value += 1

        return _normalize(
            self,
            x,
            custom_mean,
            custom_var,
            reduction_axes,
            feature_axes,
            self.dtype,
            self.param_dtype,
            self.epsilon,
            self.use_bias,
            self.use_scale,
            self.bias_init,
            self.scale_init,
        )


class BatchNormBlock(nn.Module):
    output_size: int
    activation_fn: Callable = jax.nn.silu
    residual: bool = True

    @nn.compact
    def __call__(self, inputs, is_training: bool = True):
        y = BatchRenorm(use_running_average=not is_training)(inputs)
        y = nn.Dense(self.output_size, use_bias=False)(y)
        y = self.activation_fn(y)
        return y


class MLP(nn.Module):
    output_size: int
    hidden_sizes: Sequence[int] = (128, 128, 128, 128)
    activation_fn: Callable = nn.PReLU()
    layer_norm: bool = True
    d2rl: bool = False

    @nn.compact
    def __call__(self, inputs, is_training: bool = True):
        y = inputs
        for h_size in self.hidden_sizes:
            y = LinearBlock(h_size, activation_fn=self.activation_fn, layer_norm=self.layer_norm)(y)
            if self.d2rl:
                y = jnp.concatenate([inputs, y], axis=-1)
        return nn.Dense(self.output_size, use_bias=True)(y)


class PlainMLP(nn.Module):
    output_size: int
    hidden_sizes: Sequence[int] = (256, 256, 256)
    activation_fn: Callable = nn.PReLU()

    @nn.compact
    def __call__(self, inputs, is_training: bool = True):
        y = inputs
        for h_size in self.hidden_sizes:
            y = LinearBlock(h_size, activation_fn=self.activation_fn, layer_norm=True)(y)
        return nn.Dense(self.output_size, use_bias=True)(y)


class LinearBlock(nn.Module):
    output_size: int
    activation_fn: Callable = jax.nn.silu
    layer_norm: bool = True

    def setup(self):
        blocks = [nn.Dense(self.output_size,
                           use_bias=not self.layer_norm)]
        if self.layer_norm:
            blocks.append(nn.LayerNorm())
        blocks.append(self.activation_fn)
        self.block = nn.Sequential(blocks)

    def __call__(self, x):
        return self.block(x)


class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""
    embed_dim: int
    scale: float = 30

    def setup(self) -> None:
        self.w = self.param('w', nn.initializers.normal(stddev=self.scale),
                            [self.embed_dim // 2])

    def __call__(self, x):
        x_proj = x[..., None] * self.w[None, :] * 2 * np.pi
        return jax.lax.stop_gradient(jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1))


class GaussianFourierWithGrad(GaussianFourierProjection):
    """Gaussian random features for encoding time steps."""
    embed_dim: int
    scale: float = 0.01

    def __call__(self, x):
        x_proj = x[..., None] * self.w[None, :] * 2 * np.pi
        return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)


class LFF(nn.Module):
    n_dim: int = 64
    scale: float = 3e-4

    def setup(self):
        self.kernel_initializer = nn.initializers.variance_scaling(self.scale,
                                                                   mode='fan_in', distribution='normal')
        self.bias_initializer = nn.initializers.uniform(1)
        self.dense = nn.Dense(self.n_dim,
                              kernel_init=self.kernel_initializer,
                              bias_init=self.bias_initializer)
        self.proj = nn.Dense(self.n_dim)

    def __call__(self, x):
        return jnp.sin(jnp.pi * self.dense(x))


class DiffusionResidualBlock(nn.Module):
    output_dim: int
    activation_fn: Callable = jax.nn.silu

    def setup(self):
        super().__init__()
        self.time_mlp = nn.Sequential([self.activation_fn, nn.Dense(self.output_dim)])
        self.dense1 = nn.Sequential([nn.Dense(self.output_dim), self.activation_fn])
        self.dense2 = nn.Sequential([nn.Dense(self.output_dim), self.activation_fn])
        self.shortcut = nn.Dense(self.output_dim)

    def __call__(self, x, t):
        h1 = self.dense1(x) + self.time_mlp(t)
        h2 = self.dense2(h1)
        res = self.shortcut(x)
        return h2 + res


class DiffusionMLP(nn.Module):
    output_dim: int
    time_emd_dim: int = 32
    activation_fn: Callable = jax.nn.silu

    def setup(self):
        self.embed = nn.Sequential([GaussianFourierProjection(embed_dim=self.time_emd_dim),
                                    nn.Dense(self.time_emd_dim)])
        self.pre_sort_condition = nn.Sequential([nn.Dense(self.time_emd_dim), self.activation_fn])
        self.sort_t = nn.Sequential([nn.Dense(128), self.activation_fn, nn.Dense(128)])
        self.down_block1 = DiffusionResidualBlock(512, activation_fn=self.activation_fn)
        self.down_block2 = DiffusionResidualBlock(256, activation_fn=self.activation_fn)
        self.down_block3 = DiffusionResidualBlock(128, activation_fn=self.activation_fn)
        self.middle1 = DiffusionResidualBlock(128, activation_fn=self.activation_fn)
        self.up_block3 = DiffusionResidualBlock(256, activation_fn=self.activation_fn)
        self.up_block2 = DiffusionResidualBlock(512, activation_fn=self.activation_fn)
        self.last = nn.Dense(self.output_dim)

    def __call__(self, x, t, condition):
        embed = self.embed(t)
        embed = jnp.concatenate([self.pre_sort_condition(condition), embed], axis=-1)
        embed = self.sort_t(embed)
        d1 = self.down_block1(x, embed)
        d2 = self.down_block2(d1, embed)
        d3 = self.down_block3(d2, embed)
        u3 = self.middle1(d3, embed)
        u2 = self.up_block3(jnp.concatenate([d3, u3], axis=-1), embed)
        u1 = self.up_block2(jnp.concatenate([d2, u2], axis=-1), embed)
        u0 = jnp.concatenate([d1, u1], axis=-1)
        h = self.last(u0)
        return h


class PointWisePReLU(nn.PReLU):
    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        negative_slope = self.param(
            'negative_slope',
            lambda k: jnp.asarray(self.negative_slope_init * jnp.ones(shape=(inputs.shape[-1, ])), self.param_dtype),
        )
        negative_slope = jnp.reshape(negative_slope, (1,) * (inputs.ndim - 1) + (-1,))

        return jnp.where(
            inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs
        )


class VAE(nn.Module):
    output_dim: int
    latent_dim: int
    layer_norm: bool = True
    residual: bool = True
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.PReLU()
    monotone: bool = False

    def setup(self) -> None:
        self.encoder = MLP(self.latent_dim * 2, layer_norm=self.layer_norm,
                           activation_fn=self.activation_fn, residual=self.residual)
        self.decoder = MLP(self.latent_dim * 2, layer_norm=self.layer_norm,
                           activation_fn=self.activation_fn, residual=self.residual)
        self.make_rng("rng_stream")


class LagrangianCoefficient(nn.Module):

    @nn.compact
    def __call__(self):
        log_lam = self.param('log_lambd', nn.initializers.constant(0.), shape=())
        return log_lam


class SplineCurve(nn.Module):
    num_range: int

    @nn.compact
    def __call__(self, feature, taus):
        feature = jnp.arcsinh(feature)
        taus = (taus + 1e-3) * (1 - 1e-3)
        taus = jax.scipy.stats.norm.ppf(taus)
        taus = jax.lax.stop_gradient(taus)

        spline = nn.Dense(13)(feature)
        mu_sigma = nn.Dense(2)(feature)

        mu = mu_sigma[..., [0]]
        sigma = jnp.exp(mu_sigma[..., [1]].clip(-2, ))
        spline_map = distrax.RationalQuadraticSpline(spline, -4, 3,
                                                     boundary_slopes='identity')
        splined = spline_map.forward(taus.swapaxes(0, 1)).swapaxes(0, 1)
        taus = jax.nn.sigmoid(taus)
        taus = (taus + 1e-3) * (1 - 1e-3)

        taus = jax.scipy.stats.norm.ppf(taus, loc=mu, scale=sigma) + splined

        return taus


class GaussianKernel(nn.Module):
    @nn.compact
    def __call__(self, feature, taus_embedding):
        gauss_kernel = jnp.exp((-(feature[..., None, :] - taus_embedding) ** 2))
        return gauss_kernel


class BiLinear(nn.Dense):

    @nn.compact
    def __call__(self, inputs1, inputs2):
        kernel = self.param(
            'kernel',
            self.kernel_init,
            (jnp.shape(inputs1)[-1], jnp.shape(inputs2)[-1], self.features),
            self.param_dtype,
        )
        return jnp.einsum('...i, ijk, ...j-> ...k', inputs1, kernel, inputs2,
                          precision=self.param_dtype, )


class SirenNet(nn.Module):
    out_features: int
    omega_0: float = 30.

    @nn.compact
    def __call__(self, x):
        x = (x - 0.5) * 2
        y = nn.Dense(64, kernel_init=nn.initializers.uniform(scale=1, ), dtype=jnp.float32)(x)
        y = jnp.sin(30 * y)
        y = nn.Dense(64, kernel_init=nn.initializers.variance_scaling(scale=2,
                                                                      mode='fan_in', distribution='uniform'))(y)
        y = jnp.sin(y)
        y = nn.Dense(64, kernel_init=nn.initializers.variance_scaling(scale=2,
                                                                      mode='fan_in', distribution='uniform'))(y)
        y = jnp.sin(y)
        return nn.Dense(self.out_features, kernel_init=nn.initializers.variance_scaling(scale=2,
                                                                                        mode='fan_in',
                                                                                        distribution='uniform'))(y)


class CosineQuantileHead(nn.Module):
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    embedding_size: int = 64
    n_cosine: int = 64
    smooth: bool = True
    w_kernel: Callable = nn.relu

    def setup(self):

        if self.smooth:
            self.units = nn.Dense(self.n_cosine, use_bias=False,
                                  kernel_init=nn.initializers.normal(stddev=0.03))

        else:
            self.units = jnp.arange(1, self.n_cosine + 1, dtype=jnp.float32) * jnp.pi
            w_kernel = self.w_kernel
        self.w = nn.Sequential([nn.Dense(self.embedding_size, ), self.w_kernel])
        self.out = MLP(1, hidden_sizes=(256,),
                       activation_fn=self.activation_fn, layer_norm=True)

    def __call__(self, feature, taus):
        if self.smooth:
            taus_embedding = 2 * jnp.pi * self.units(taus[..., None])
            taus_embedding = jnp.concatenate([jnp.sin(taus_embedding), jnp.cos(taus_embedding), taus[..., None]],
                                             axis=-1)
        else:
            taus_embedding = taus[..., None] * self.units
            taus_embedding = jnp.cos(taus_embedding)
        taus_embedding = self.w(taus_embedding)
        feature = feature[..., None, :] * taus_embedding
        qfs = (self.out(feature)).squeeze(axis=-1)
        return qfs


if __name__ == '__main__':
    layer = SplineCurve(8)
    params = layer.init(jax.random.PRNGKey(42), jnp.ones((1, 32)), jnp.linspace(0, 1, 32)[None])
    outs = layer.apply(params, jax.random.normal(jax.random.PRNGKey(42),
                                                 (64, 32)),
                       jnp.repeat(jnp.linspace(0, 1, 32)[None], axis=0, repeats=64))
    import matplotlib.pyplot as plt

    plt.plot(np.linspace(0, 1, 32), outs[0])
    plt.show()
