import jax
from jax import lax
from jax.nn import initializers
import jax.numpy as jnp
import flax
from flax.linen.module import merge_param
import flax.linen as nn
from typing import Callable, Iterable, Optional, Tuple, Union, Any
import functools
import pickle
from . import utils

PRNGKey = Any
Array = Any
Shape = Tuple[int]
Dtype = Any


class InceptionV3(nn.Module):
    """
    InceptionV3 network.
    Reference: https://arxiv.org/abs/1512.00567
    Ported mostly from: https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py

    Attributes:
        include_head (bool): If True, include classifier head.
        num_classes (int): Number of classes.
        pretrained (bool): If True, use pretrained weights. 
        transform_input (bool): If True, preprocesses the input according to the method with which it
                                was trained on ImageNet.
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
        dtype (str): Data type.
    """
    include_head: bool=False
    num_classes: int=1000
    pretrained: bool=False
    transform_input: bool=False
    aux_logits: bool=False
    ckpt_path: str='https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1'
    dtype: str='float32'

    def setup(self):
        if self.pretrained:
            ckpt_file = utils.download(self.ckpt_path)
            self.params_dict = pickle.load(open(ckpt_file, 'rb'))
            self.num_classes_ = 1000
        else:
            self.params_dict = None
            self.num_classes_ = self.num_classes

    @nn.compact
    def __call__(self, x, train=True, rng=jax.random.PRNGKey(0)):
        """
        Args:
            x (tensor): Input image, shape [B, H, W, C].
            train (bool): If True, training mode.
            rng (jax.random.PRNGKey): Random seed.
        """
        x = self._transform_input(x)
        x = BasicConv2d(out_channels=32,
                        kernel_size=(3, 3),
                        strides=(2, 2),
                        params_dict=utils.get(self.params_dict, 'Conv2d_1a_3x3'),
                        dtype=self.dtype)(x, train)
        x = BasicConv2d(out_channels=32,
                        kernel_size=(3, 3),
                        params_dict=utils.get(self.params_dict, 'Conv2d_2a_3x3'),
                        dtype=self.dtype)(x, train)
        x = BasicConv2d(out_channels=64,
                        kernel_size=(3, 3),
                        padding=((1, 1), (1, 1)),
                        params_dict=utils.get(self.params_dict, 'Conv2d_2b_3x3'),
                        dtype=self.dtype)(x, train)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = BasicConv2d(out_channels=80,
                        kernel_size=(1, 1),
                        params_dict=utils.get(self.params_dict, 'Conv2d_3b_1x1'),
                        dtype=self.dtype)(x, train)
        x = BasicConv2d(out_channels=192,
                        kernel_size=(3, 3),
                        params_dict=utils.get(self.params_dict, 'Conv2d_4a_3x3'),
                        dtype=self.dtype)(x, train)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = InceptionA(pool_features=32,
                       params_dict=utils.get(self.params_dict, 'Mixed_5b'),
                       dtype=self.dtype)(x, train)
        x = InceptionA(pool_features=64,
                       params_dict=utils.get(self.params_dict, 'Mixed_5c'),
                       dtype=self.dtype)(x, train)
        x = InceptionA(pool_features=64,
                       params_dict=utils.get(self.params_dict, 'Mixed_5d'),
                       dtype=self.dtype)(x, train)
        x = InceptionB(params_dict=utils.get(self.params_dict, 'Mixed_6a'),
                       dtype=self.dtype)(x, train)
        x = InceptionC(channels_7x7=128,
                       params_dict=utils.get(self.params_dict, 'Mixed_6b'),
                       dtype=self.dtype)(x, train)
        x = InceptionC(channels_7x7=160,
                       params_dict=utils.get(self.params_dict, 'Mixed_6c'),
                       dtype=self.dtype)(x, train)
        x = InceptionC(channels_7x7=160,
                       params_dict=utils.get(self.params_dict, 'Mixed_6d'),
                       dtype=self.dtype)(x, train)
        x = InceptionC(channels_7x7=192,
                       params_dict=utils.get(self.params_dict, 'Mixed_6e'),
                       dtype=self.dtype)(x, train)
        aux = None
        if self.aux_logits and train:
            aux = InceptionAux(num_classes=self.num_classes_,
                               params_dict=utils.get(self.params_dict, 'AuxLogits'),
                               dtype=self.dtype)(x, train)
        x = InceptionD(params_dict=utils.get(self.params_dict, 'Mixed_7a'),
                       dtype=self.dtype)(x, train)
        x = InceptionE(avg_pool, params_dict=utils.get(self.params_dict, 'Mixed_7b'),
                       dtype=self.dtype)(x, train)
        # Following the implementation by @mseitzer, we use max pooling instead
        # of average pooling here.
        # See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320
        x = InceptionE(nn.max_pool, params_dict=utils.get(self.params_dict, 'Mixed_7c'),
                       dtype=self.dtype)(x, train)
        x = jnp.mean(x, axis=(1, 2), keepdims=True)
        if not self.include_head:
            return x
        x = nn.Dropout(rate=0.5)(x, deterministic=not train, rng=rng)
        x = jnp.reshape(x, newshape=(x.shape[0], -1))
        x = Dense(features=self.num_classes_,
                  params_dict=utils.get(self.params_dict, 'fc'),
                  dtype=self.dtype)(x)
        if self.aux_logits:
            return x, aux
        return x

    def _transform_input(self, x):
        if self.transform_input:
            x_ch0 = jnp.expand_dims(x[..., 0], axis=-1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = jnp.expand_dims(x[..., 1], axis=-1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = jnp.expand_dims(x[..., 2], axis=-1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = jnp.concatenate((x_ch0, x_ch1, x_ch2), axis=-1)
        return x


class Dense(nn.Module):
    features: int
    kernel_init: functools.partial=nn.initializers.lecun_normal()
    bias_init: functools.partial=nn.initializers.zeros
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.features,
                     kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['kernel']),
                     bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['bias']))(x)
        return x


class BasicConv2d(nn.Module):
    out_channels: int
    kernel_size: Union[int, Iterable[int]]=(3, 3)
    strides: Optional[Iterable[int]]=(1, 1)
    padding: Union[str, Iterable[Tuple[int, int]]]='valid'
    use_bias: bool=False
    kernel_init: functools.partial=nn.initializers.lecun_normal()
    bias_init: functools.partial=nn.initializers.zeros
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, train=True):
        x = nn.Conv(features=self.out_channels,
                    kernel_size=self.kernel_size,
                    strides=self.strides,
                    padding=self.padding,
                    use_bias=self.use_bias,
                    kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['kernel']),
                    bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['bias']),
                    dtype=self.dtype)(x)
        if self.params_dict is None:
            x = BatchNorm(epsilon=0.001,
                          momentum=0.1,
                          use_running_average=not train,
                          dtype=self.dtype)(x)
        else:
            x = BatchNorm(epsilon=0.001,
                          momentum=0.1,
                          bias_init=lambda *_ : jnp.array(self.params_dict['bn']['bias']),
                          scale_init=lambda *_ : jnp.array(self.params_dict['bn']['scale']),
                          mean_init=lambda *_ : jnp.array(self.params_dict['bn']['mean']),
                          var_init=lambda *_ : jnp.array(self.params_dict['bn']['var']),
                          use_running_average=not train,
                          dtype=self.dtype)(x)
        x = jax.nn.relu(x)
        return x


class InceptionA(nn.Module):
    pool_features: int
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, train=True):
        branch1x1 = BasicConv2d(out_channels=64,
                                kernel_size=(1, 1),
                                params_dict=utils.get(self.params_dict, 'branch1x1'),
                                dtype=self.dtype)(x, train)
        branch5x5 = BasicConv2d(out_channels=48,
                                kernel_size=(1, 1),
                                params_dict=utils.get(self.params_dict, 'branch5x5_1'),
                                dtype=self.dtype)(x, train)
        branch5x5 = BasicConv2d(out_channels=64,
                                kernel_size=(5, 5),
                                padding=((2, 2), (2, 2)),
                                params_dict=utils.get(self.params_dict, 'branch5x5_2'),
                                dtype=self.dtype)(branch5x5, train)

        branch3x3dbl = BasicConv2d(out_channels=64,
                                   kernel_size=(1, 1),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
                                   dtype=self.dtype)(x, train)
        branch3x3dbl = BasicConv2d(out_channels=96,
                                   kernel_size=(3, 3),
                                   padding=((1, 1), (1, 1)),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
                                   dtype=self.dtype)(branch3x3dbl, train)
        branch3x3dbl = BasicConv2d(out_channels=96,
                                   kernel_size=(3, 3),
                                   padding=((1, 1), (1, 1)),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
                                   dtype=self.dtype)(branch3x3dbl, train)

        branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
        branch_pool = BasicConv2d(out_channels=self.pool_features,
                                  kernel_size=(1, 1),
                                  params_dict=utils.get(self.params_dict, 'branch_pool'),
                                  dtype=self.dtype)(branch_pool, train)
        
        output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1)
        return output


class InceptionB(nn.Module):
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, train=True):
        branch3x3 = BasicConv2d(out_channels=384,
                                kernel_size=(3, 3),
                                strides=(2, 2),
                                params_dict=utils.get(self.params_dict, 'branch3x3'),
                                dtype=self.dtype)(x, train)

        branch3x3dbl = BasicConv2d(out_channels=64,
                                   kernel_size=(1, 1),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
                                   dtype=self.dtype)(x, train)
        branch3x3dbl = BasicConv2d(out_channels=96,
                                   kernel_size=(3, 3),
                                   padding=((1, 1), (1, 1)),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
                                   dtype=self.dtype)(branch3x3dbl, train)
        branch3x3dbl = BasicConv2d(out_channels=96,
                                   kernel_size=(3, 3),
                                   strides=(2, 2),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
                                   dtype=self.dtype)(branch3x3dbl, train)

        branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))

        output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1)
        return output


class InceptionC(nn.Module):
    channels_7x7: int
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, train=True):
        branch1x1 = BasicConv2d(out_channels=192,
                                kernel_size=(1, 1),
                                params_dict=utils.get(self.params_dict, 'branch1x1'),
                                dtype=self.dtype)(x, train)
            
        branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
                                kernel_size=(1, 1),
                                params_dict=utils.get(self.params_dict, 'branch7x7_1'),
                                dtype=self.dtype)(x, train)
        branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
                                kernel_size=(1, 7),
                                padding=((0, 0), (3, 3)),
                                params_dict=utils.get(self.params_dict, 'branch7x7_2'),
                                dtype=self.dtype)(branch7x7, train)
        branch7x7 = BasicConv2d(out_channels=192,
                                kernel_size=(7, 1),
                                padding=((3, 3), (0, 0)),
                                params_dict=utils.get(self.params_dict, 'branch7x7_3'),
                                dtype=self.dtype)(branch7x7, train)

        branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
                                   kernel_size=(1, 1),
                                   params_dict=utils.get(self.params_dict, 'branch7x7dbl_1'),
                                   dtype=self.dtype)(x, train)
        branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
                                   kernel_size=(7, 1),
                                   padding=((3, 3), (0, 0)),
                                   params_dict=utils.get(self.params_dict, 'branch7x7dbl_2'),
                                   dtype=self.dtype)(branch7x7dbl, train)
        branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
                                   kernel_size=(1, 7),
                                   padding=((0, 0), (3, 3)),
                                   params_dict=utils.get(self.params_dict, 'branch7x7dbl_3'),
                                   dtype=self.dtype)(branch7x7dbl, train)
        branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
                                   kernel_size=(7, 1),
                                   padding=((3, 3), (0, 0)),
                                   params_dict=utils.get(self.params_dict, 'branch7x7dbl_4'),
                                   dtype=self.dtype)(branch7x7dbl, train)
        branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
                                   kernel_size=(1, 7),
                                   padding=((0, 0), (3, 3)),
                                   params_dict=utils.get(self.params_dict, 'branch7x7dbl_5'),
                                   dtype=self.dtype)(branch7x7dbl, train)

        branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
        branch_pool = BasicConv2d(out_channels=192,
                                  kernel_size=(1, 1),
                                  params_dict=utils.get(self.params_dict, 'branch_pool'),
                                  dtype=self.dtype)(branch_pool, train)
        
        output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1)
        return output


class InceptionD(nn.Module):
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, train=True):
        branch3x3 = BasicConv2d(out_channels=192,
                                kernel_size=(1, 1),
                                params_dict=utils.get(self.params_dict, 'branch3x3_1'),
                                dtype=self.dtype)(x, train)
        branch3x3 = BasicConv2d(out_channels=320,
                                kernel_size=(3, 3),
                                strides=(2, 2),
                                params_dict=utils.get(self.params_dict, 'branch3x3_2'),
                                dtype=self.dtype)(branch3x3, train)
            
        branch7x7x3 = BasicConv2d(out_channels=192,
                                  kernel_size=(1, 1),
                                  params_dict=utils.get(self.params_dict, 'branch7x7x3_1'),
                                  dtype=self.dtype)(x, train)
        branch7x7x3 = BasicConv2d(out_channels=192,
                                  kernel_size=(1, 7),
                                  padding=((0, 0), (3, 3)),
                                  params_dict=utils.get(self.params_dict, 'branch7x7x3_2'),
                                  dtype=self.dtype)(branch7x7x3, train)
        branch7x7x3 = BasicConv2d(out_channels=192,
                                  kernel_size=(7, 1),
                                  padding=((3, 3), (0, 0)),
                                  params_dict=utils.get(self.params_dict, 'branch7x7x3_3'),
                                  dtype=self.dtype)(branch7x7x3, train)
        branch7x7x3 = BasicConv2d(out_channels=192,
                                  kernel_size=(3, 3),
                                  strides=(2, 2),
                                  params_dict=utils.get(self.params_dict, 'branch7x7x3_4'),
                                  dtype=self.dtype)(branch7x7x3, train)

        branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        
        output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1)
        return output


class InceptionE(nn.Module):
    pooling: Callable
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, train=True):
        branch1x1 = BasicConv2d(out_channels=320,
                                kernel_size=(1, 1),
                                params_dict=utils.get(self.params_dict, 'branch1x1'),
                                dtype=self.dtype)(x, train)
          
        branch3x3 = BasicConv2d(out_channels=384,
                                kernel_size=(1, 1),
                                params_dict=utils.get(self.params_dict, 'branch3x3_1'),
                                dtype=self.dtype)(x, train)
        branch3x3_a = BasicConv2d(out_channels=384,
                                  kernel_size=(1, 3),
                                  padding=((0, 0), (1, 1)),
                                  params_dict=utils.get(self.params_dict, 'branch3x3_2a'),
                                  dtype=self.dtype)(branch3x3, train)
        branch3x3_b = BasicConv2d(out_channels=384,
                                  kernel_size=(3, 1),
                                  padding=((1, 1), (0, 0)),
                                  params_dict=utils.get(self.params_dict, 'branch3x3_2b'),
                                  dtype=self.dtype)(branch3x3, train)
        branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1)

        branch3x3dbl = BasicConv2d(out_channels=448,
                                   kernel_size=(1, 1),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
                                   dtype=self.dtype)(x, train)
        branch3x3dbl = BasicConv2d(out_channels=384,
                                   kernel_size=(3, 3),
                                   padding=((1, 1), (1, 1)),
                                   params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
                                   dtype=self.dtype)(branch3x3dbl, train)
        branch3x3dbl_a = BasicConv2d(out_channels=384,
                                     kernel_size=(1, 3),
                                     padding=((0, 0), (1, 1)),
                                     params_dict=utils.get(self.params_dict, 'branch3x3dbl_3a'),
                                     dtype=self.dtype)(branch3x3dbl, train)
        branch3x3dbl_b = BasicConv2d(out_channels=384,
                                     kernel_size=(3, 1),
                                     padding=((1, 1), (0, 0)),
                                     params_dict=utils.get(self.params_dict, 'branch3x3dbl_3b'),
                                     dtype=self.dtype)(branch3x3dbl, train)
        branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1)

        branch_pool = self.pooling(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
        branch_pool = BasicConv2d(out_channels=192,
                                  kernel_size=(1, 1),
                                  params_dict=utils.get(self.params_dict, 'branch_pool'),
                                  dtype=self.dtype)(branch_pool, train)
        
        output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1)
        return output


class InceptionAux(nn.Module):
    num_classes: int
    kernel_init: functools.partial=nn.initializers.lecun_normal()
    bias_init: functools.partial=nn.initializers.zeros
    params_dict: dict=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, train=True):
        x = avg_pool(x, window_shape=(5, 5), strides=(3, 3))
        x = BasicConv2d(out_channels=128,
                        kernel_size=(1, 1),
                        params_dict=utils.get(self.params_dict, 'conv0'),
                        dtype=self.dtype)(x, train)
        x = BasicConv2d(out_channels=768,
                        kernel_size=(5, 5),
                        params_dict=utils.get(self.params_dict, 'conv1'),
                        dtype=self.dtype)(x, train)
        x = jnp.mean(x, axis=(1, 2))
        x = jnp.reshape(x, newshape=(x.shape[0], -1))
        x = Dense(features=self.num_classes,
                  params_dict=utils.get(self.params_dict, 'fc'),
                  dtype=self.dtype)(x)
        return x
    
def _absolute_dims(rank, dims):
    return tuple([rank + dim if dim < 0 else dim for dim in dims])


class BatchNorm(nn.Module):
    """BatchNorm Module.
    Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py
    Attributes:
        use_running_average: if True, the statistics stored in batch_stats
                             will be used instead of computing the batch statistics on the input.
    axis: the feature or non-batch axis of the input.
    momentum: decay rate for the exponential moving average of the batch statistics.
    epsilon: a small float added to variance to avoid dividing by zero.
    dtype: the dtype of the computation (default: float32).
    use_bias:  if True, bias (beta) is added.
    use_scale: if True, multiply by scale (gamma).
               When the next layer is linear (also e.g. nn.relu), this can be disabled
               since the scaling will be done by the next layer.
    bias_init: initializer for bias, by default, zero.
    scale_init: initializer for scale, by default, one.
    axis_name: the axis name used to combine batch statistics from multiple
               devices. See `jax.pmap` for a description of axis names (default: None).
    axis_index_groups: groups of axis indices within that named axis
                       representing subsets of devices to reduce over (default: None). For
                       example, `[[0, 1], [2, 3]]` would independently batch-normalize over
                       the examples on the first two and last two devices. See `jax.lax.psum`
                       for more details.
    """
    use_running_average: Optional[bool] = None
    axis: int = -1
    momentum: float = 0.99
    epsilon: float = 1e-5
    dtype: Dtype = jnp.float32
    use_bias: bool = True
    use_scale: bool = True
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
    scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
    mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32)
    var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32)
    axis_name: Optional[str] = None
    axis_index_groups: Any = None

    @nn.compact
    def __call__(self, x, use_running_average: Optional[bool] = None):
        """Normalizes the input using batch statistics.
        
        NOTE:
        During initialization (when parameters are mutable) the running average
        of the batch statistics will not be updated. Therefore, the inputs
        fed during initialization don't need to match that of the actual input
        distribution and the reduction axis (set with `axis_name`) does not have
        to exist.
        Args:
            x: the input to be normalized.
            use_running_average: if true, the statistics stored in batch_stats
                                 will be used instead of computing the batch statistics on the input.
        Returns:
            Normalized inputs (the same shape as inputs).
        """
        use_running_average = merge_param(
            'use_running_average', self.use_running_average, use_running_average)
        x = jnp.asarray(x, jnp.float32)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # see NOTE above on initialization behavior
        initializing = self.is_mutable_collection('params')

        ra_mean = self.variable('batch_stats', 'mean',
                                self.mean_init,
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               self.var_init,
                               reduced_feature_shape)

        if use_running_average:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(
                        concatenated_mean,
                        axis_name=self.axis_name,
                        axis_index_groups=self.axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var

        y = x - mean.reshape(feature_shape)
        mul = lax.rsqrt(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale',
                               self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias',
                              self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            y = y + bias
        return jnp.asarray(y, self.dtype)


def pool(inputs, init, reduce_fn, window_shape, strides, padding):
    """
    Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py

    Helper function to define pooling functions.
    Pooling functions are implemented using the ReduceWindow XLA op.
    NOTE: Be aware that pooling is not generally differentiable.
    That means providing a reduce_fn that is differentiable does not imply
    that pool is differentiable.
    Args:
      inputs: input data with dimensions (batch, window dims..., features).
      init: the initial value for the reduction
      reduce_fn: a reduce function of the form `(T, T) -> T`.
      window_shape: a shape tuple defining the window to reduce over.
      strides: a sequence of `n` integers, representing the inter-window
          strides.
      padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
        of `n` `(low, high)` integer pairs that give the padding to apply before
        and after each spatial dimension.
    Returns:
      The output of the reduction for each window slice.
    """
    strides = strides or (1,) * len(window_shape)
    assert len(window_shape) == len(strides), (
        f"len({window_shape}) == len({strides})")
    strides = (1,) + strides + (1,)
    dims = (1,) + window_shape + (1,)

    is_single_input = False
    if inputs.ndim == len(dims) - 1:
      # add singleton batch dimension because lax.reduce_window always
      # needs a batch dimension.
      inputs = inputs[None]
      is_single_input = True

    assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
    if not isinstance(padding, str):
      padding = tuple(map(tuple, padding))
      assert(len(padding) == len(window_shape)), (
        f"padding {padding} must specify pads for same number of dims as "
        f"window_shape {window_shape}")
      assert(all([len(x) == 2 for x in padding])), (
        f"each entry in padding {padding} must be length 2")
      padding = ((0,0),) + padding + ((0,0),)
    y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
    if is_single_input:
      y = jnp.squeeze(y, axis=0)
    return y


def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
    """
    Pools the input by taking the average over a window.

    In comparison to flax.linen.avg_pool, this pooling operation does not
    consider the padded zero's for the average computation.

    Args:
      inputs: input data with dimensions (batch, window dims..., features).
      window_shape: a shape tuple defining the window to reduce over.
      strides: a sequence of `n` integers, representing the inter-window
          strides (default: `(1, ..., 1)`).
      padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
        of `n` `(low, high)` integer pairs that give the padding to apply before
        and after each spatial dimension (default: `'VALID'`).
    Returns:
      The average for each window slice.
    """
    assert inputs.ndim == 4
    assert len(window_shape) == 2

    y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
    ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype)
    counts = jax.lax.conv_general_dilated(ones,
                                          jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)),
                                          window_strides=(1, 1),
                                          padding=((1, 1), (1, 1)),
                                          dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape),
                                          feature_group_count=1)
    y = y / counts 
    return y
