from typing import Optional, Sequence, Tuple, Union

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np


class WNConvND(hk.ConvND):
    """General N-dimensional convolutional with normalized kernels."""

    # TODO: extend to any data_format (now only working with NCHW)
    def __init__(
        self,
        num_spatial_dims: int,
        output_channels: int,
        kernel_shape: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        rate: Union[int, Sequence[int]] = 1,
        padding: Union[
            str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]
        ] = "SAME",
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        data_format: str = "channels_last",
        mask: Optional[jnp.ndarray] = None,
        feature_group_count: int = 1,
        name: Optional[str] = None,
    ):
        """Initializes the module.
        Args:
          num_spatial_dims: The number of spatial dimensions of the input.
          output_channels: Number of output channels.
          kernel_shape: The shape of the kernel. Either an integer or a sequence of
            length ``num_spatial_dims``.
          stride: Optional stride for the kernel. Either an integer or a sequence of
            length ``num_spatial_dims``. Defaults to 1.
          rate: Optional kernel dilation rate. Either an integer or a sequence of
            length ``num_spatial_dims``. 1 corresponds to standard ND convolution,
            ``rate > 1`` corresponds to dilated convolution. Defaults to 1.
          padding: Optional padding algorithm. Either ``VALID`` or ``SAME`` or a
            sequence of n ``(low, high)`` integer pairs that give the padding to
            apply before and after each spatial dimension. or a callable or sequence
            of callables of size ``num_spatial_dims``. Any callables must take a
            single integer argument equal to the effective kernel size and return a
            sequence of two integers representing the padding before and after. See
            ``haiku.pad.*`` for more details and example functions. Defaults to
            ``SAME``. See:
            https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
          with_bias: Whether to add a bias. By default, true.
          w_init: Optional weight initialization. By default, truncated normal.
          b_init: Optional bias initialization. By default, zeros.
          data_format: The data format of the input.  Can be either
            ``channels_first``, ``channels_last``, ``N...C`` or ``NC...``. By
            default, ``channels_last``. See :func:`get_channel_index`.
          mask: Optional mask of the weights.
          feature_group_count: Optional number of groups in group convolution.
            Default value of 1 corresponds to normal dense convolution. If a higher
            value is used, convolutions are applied separately to that many groups,
            then stacked together. This reduces the number of parameters
            and possibly the compute for a given ``output_channels``. See:
            https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
          name: The name of the module.
        """
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            output_channels=output_channels,
            kernel_shape=kernel_shape,
            stride=stride,
            rate=rate,
            padding=padding,
            with_bias=with_bias,
            w_init=w_init,
            b_init=b_init,
            data_format=data_format,
            mask=mask,
            feature_group_count=feature_group_count,
            name=name,
        )

    def __call__(
        self,
        inputs: jnp.ndarray,
        *,
        precision: Optional[jax.lax.Precision] = None,
    ) -> jnp.ndarray:
        """Connects ``ConvND`` layer.
        Args:
          inputs: An array of shape ``[spatial_dims, C]`` and rank-N+1 if unbatched,
            or an array of shape ``[N, spatial_dims, C]`` and rank-N+2 if batched.
          precision: Optional :class:`jax.lax.Precision` to pass to
            :func:`jax.lax.conv_general_dilated`.
        Returns:
          An array of shape ``[spatial_dims, output_channels]`` and rank-N+1 if
            unbatched, or an array of shape ``[N, spatial_dims, output_channels]``
            and rank-N+2 if batched.
        """
        unbatched_rank = self.num_spatial_dims + 1
        allowed_ranks = [unbatched_rank, unbatched_rank + 1]
        if inputs.ndim not in allowed_ranks:
            raise ValueError(
                f"Input to ConvND needs to have rank in {allowed_ranks},"
                f" but input has shape {inputs.shape}."
            )

        unbatched = inputs.ndim == unbatched_rank
        if unbatched:
            inputs = jnp.expand_dims(inputs, axis=0)

        if inputs.shape[self.channel_index] % self.feature_group_count != 0:
            raise ValueError(
                f"Inputs channels {inputs.shape[self.channel_index]} "
                f"should be a multiple of feature_group_count "
                f"{self.feature_group_count}"
            )
        w_shape = self.kernel_shape + (
            inputs.shape[self.channel_index] // self.feature_group_count,
            self.output_channels,
        )

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError(
                "Mask needs to have the same shape as weights. "
                f"Shapes are: {self.mask.shape}, {w_shape}"
            )

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = np.prod(w_shape[:-1])
            stddev = 1.0 / np.sqrt(fan_in_shape)
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)

        def g_init(shape: Tuple[int, ...], dtype: jnp.dtype) -> jnp.ndarray:
            w0 = w_init(w_shape, inputs.dtype)

            return jnp.sqrt((w0 ** 2).sum(axis=(0, 1, 3)))

        g = hk.get_parameter(
            "g", (inputs.shape[self.channel_index],), jnp.float32, init=g_init
        )
        v = hk.get_parameter("v", w_shape, inputs.dtype, init=w_init)
        norm = (v ** 2).sum(axis=(0, 1, 3), keepdims=True)
        w = g[:, None] * (v / norm)
        # w = hk.get_parameter("w", w_shape, inputs.dtype, init=w_init)

        if self.mask is not None:
            # print(w.shape)
            w *= self.mask
            # print("mask:\n", self.mask[:, :, 0, 0])
        out = jax.lax.conv_general_dilated(
            inputs,
            w,
            window_strides=self.stride,
            padding=self.padding,
            lhs_dilation=self.lhs_dilation,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=self.dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=precision,
        )

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels,)
            else:
                bias_shape = (self.output_channels,) + (1,) * self.num_spatial_dims
            b = hk.get_parameter("b", bias_shape, inputs.dtype, init=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        if unbatched:
            out = jnp.squeeze(out, axis=0)
        return out


class WNConv2D(WNConvND):
    """Two dimensional convolution with normalized kernels."""

    def __init__(
        self,
        output_channels: int,
        kernel_shape: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        rate: Union[int, Sequence[int]] = 1,
        padding: Union[
            str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]
        ] = "SAME",
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        data_format: str = "NHWC",
        mask: Optional[jnp.ndarray] = None,
        feature_group_count: int = 1,
        name: Optional[str] = None,
    ):
        """Initializes the module.
        Args:
          output_channels: Number of output channels.
          kernel_shape: The shape of the kernel. Either an integer or a sequence of
            length 2.
          stride: Optional stride for the kernel. Either an integer or a sequence of
            length 2. Defaults to 1.
          rate: Optional kernel dilation rate. Either an integer or a sequence of
            length 2. 1 corresponds to standard ND convolution,
            ``rate > 1`` corresponds to dilated convolution. Defaults to 1.
          padding: Optional padding algorithm. Either ``VALID`` or ``SAME`` or
            a callable or sequence of callables of length 2. Any callables must
            take a single integer argument equal to the effective kernel size and
            return a list of two integers representing the padding before and after.
            See haiku.pad.* for more details and example functions.
            Defaults to ``SAME``. See:
            https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
          with_bias: Whether to add a bias. By default, true.
          w_init: Optional weight initialization. By default, truncated normal.
          b_init: Optional bias initialization. By default, zeros.
          data_format: The data format of the input. Either ``NHWC`` or ``NCHW``. By
            default, ``NHWC``.
          mask: Optional mask of the weights.
          feature_group_count: Optional number of groups in group convolution.
            Default value of 1 corresponds to normal dense convolution. If a higher
            value is used, convolutions are applied separately to that many groups,
            then stacked together. This reduces the number of parameters
            and possibly the compute for a given ``output_channels``. See:
            https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
          name: The name of the module.
        """
        super().__init__(
            num_spatial_dims=2,
            output_channels=output_channels,
            kernel_shape=kernel_shape,
            stride=stride,
            rate=rate,
            padding=padding,
            with_bias=with_bias,
            w_init=w_init,
            b_init=b_init,
            data_format=data_format,
            mask=mask,
            feature_group_count=feature_group_count,
            name=name,
        )


class WNLinear(hk.Module):
    """Linear module with weight norm."""

    def __init__(
        self,
        output_size: int,
        with_bias: bool = True,
        w_init: Optional[hk.initializers.Initializer] = None,
        b_init: Optional[hk.initializers.Initializer] = None,
        name: Optional[str] = None,
    ):
        """Constructs the Linear module.
        Args:
          output_size: Output dimensionality.
          with_bias: Whether to add a bias to the output.
          w_init: Optional initializer for weights. By default, uses random values
            from truncated normal, with stddev ``1 / sqrt(fan_in)``. See
            https://arxiv.org/abs/1502.03167v3.
          b_init: Optional initializer for bias. By default, zero.
          name: Name of the module.
        """
        super().__init__(name=name)
        self.input_size = None
        self.output_size = output_size
        self.with_bias = with_bias
        self.w_init = w_init
        self.b_init = b_init or jnp.zeros

    def __call__(
        self,
        inputs: jnp.ndarray,
        *,
        precision: Optional[jax.lax.Precision] = None,
    ) -> jnp.ndarray:
        """Computes a linear transform of the input."""
        if not inputs.shape:
            raise ValueError("Input must not be scalar.")

        input_size = self.input_size = inputs.shape[-1]
        output_size = self.output_size
        dtype = inputs.dtype

        w_init = self.w_init
        if w_init is None:
            stddev = 1.0 / np.sqrt(self.input_size)
            w_init = hk.initializers.TruncatedNormal(stddev=stddev)

        def g_init(shape: Tuple[int, ...], dtype: jnp.dtype) -> jnp.ndarray:
            w0 = w_init([input_size, output_size], dtype)  # type: ignore
            return jnp.linalg.norm(w0)[None]

        g = hk.get_parameter("g", (1,), jnp.float32, init=g_init)
        v = hk.get_parameter("v", [input_size, output_size], dtype, init=w_init)

        w = g * v / jnp.linalg.norm(v)

        # w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)

        out = jnp.dot(inputs, w, precision=precision)

        if self.with_bias:
            b = hk.get_parameter("b", [self.output_size], dtype, init=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out
