"""Common layers to use in different models """

import math
import jax
import numpy as np
import einops
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.dtypes import promote_dtype
from .xPos import XPos

from jax.scipy.linalg import block_diag
from .ssm.ssm_init import make_DPLR_HiPPO
from .ssm.ssm import init_S5SSM


class RMSNorm(nn.Module):
    """
    RMSNorm layer: a_i/(rms(a))g_i; rms(a) = \sqrt(\frac{1}{n}\sum_{i=1}^n a_i)
    Args:
        width: int - number dimensions in the input and output
        eps: float - small constant added in the normalization to avoid zeros
        dtype: jnp.dtype - dtype used for computation (useful for mixed precission)
        param_dtype: jnp.dtype - type of the scale parameter (used in mixed precision training)
    """

    width: int
    eps: float = 1e-8
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, a) -> jax.Array:
        return nn.RMSNorm(dtype=self.dtype)(a)  # return nn.LayerNorm()(a) #
        # scale = self.param("g", nn.initializers.ones_init(), (self.width,))
        # a, scale = promote_dtype(a, scale, dtype=self.dtype)
        # var = jnp.mean(jnp.square(a), axis=-1, keepdims=True)
        # rms_x = jax.lax.sqrt(var * self.width)
        # x_normed = a / (rms_x + self.eps)
        # return scale * x_normed


class Conv1D(nn.Module):
    # TODO Not sure why this uses more memory and gives OOO
    """A 1D temporal convolution layer.

    Attributes:
      nchannels: The number of dimensions of the input and output.
      kernel_size: The size of the temporal receptive field of the convolution.
      w_init_variance_scale: A parameter that scales the variance of the
        initialization of the weights.
      dtype: dtype used for computation.
    """

    nchannels: int
    out_channels: int = None
    kernel_size: int = 3
    w_init_variance_scale: float = 0.01
    dtype: jnp.dtype = jnp.float32

    @property
    def kernel_init(self) -> nn.initializers.Initializer:
        """Initializer for the kernel of the Conv1D."""
        return nn.initializers.variance_scaling(
            scale=self.w_init_variance_scale,
            mode="fan_in",
            distribution="normal",
        )

    @nn.compact
    def __call__(self, x):
        # Parameters.
        # default: input: 'NCHW', parameters: 'OIHW', output: 'NCHW'
        out_channels = (
            self.nchannels if self.out_channels is None else self.out_channels
        )
        conv_shape = (self.kernel_size, out_channels, self.nchannels)
        inp_shape = x.shape
        dimension_numbers = jax.lax.conv_dimension_numbers(
            inp_shape, conv_shape, ("NWC", "WOI", "NWC")
        )
        conv1d_w = self.param(
            "conv1d_w",
            self.kernel_init,
            conv_shape,
        )
        conv1d_b = self.param(
            "conv1d_b",
            nn.initializers.zeros_init(),
            (out_channels,),
        )
        x, b, w = promote_dtype(x, conv1d_b, conv1d_w, dtype=self.dtype)
        # feature_group_count=self.nchannels,
        # padding = ((self.kernel_size - 1, 0),)
        res = jax.lax.conv_general_dilated(
            x,
            w,
            padding=((self.kernel_size - 1, 0),),
            window_strides=(1,),
            dimension_numbers=dimension_numbers,
        )
        b = b.reshape((1,) * (res.ndim - b.ndim) + b.shape)
        res += b
        return res


class DepthConv1D(nn.Module):
    """A 1D temporal convolution layer.

    Attributes:
      nchannels: The number of dimensions of the input and output.
      kernel_size: The size of the temporal receptive field of the convolution.
      w_init_variance_scale: A parameter that scales the variance of the
        initialization of the weights.
      dtype: dtype used for computation.
    """

    nchannels: int
    out_channels: int = None
    kernel_size: int = 3
    w_init_variance_scale: float = 0.01
    dtype: jnp.dtype = jnp.float32

    @property
    def kernel_init(self) -> nn.initializers.Initializer:
        """Initializer for the kernel of the Conv1D."""
        return nn.initializers.variance_scaling(
            scale=self.w_init_variance_scale,
            mode="fan_in",
            distribution="normal",
        )

    def setup(self):
        out_channels = (
            self.nchannels if self.out_channels is None else self.out_channels
        )
        self.conv_shape = (self.kernel_size, out_channels, 1)
        self.depthwise_w = self.param(
            "depthwise_w", self.kernel_init, self.conv_shape, jnp.float32
        )
        self.depthwise_b = self.param(
            "depthwise_b", nn.initializers.zeros_init(), (out_channels,), jnp.float32
        )

    @nn.checkpoint
    def __call__(self, x):
        # Parameters.
        # default: input: 'NCHW', parameters: 'OIHW', output: 'NCHW'

        inp_shape = x.shape
        dimension_numbers = jax.lax.conv_dimension_numbers(
            inp_shape, self.conv_shape, ("NWC", "WOI", "NWC")
        )

        x, b, w = nn.dtypes.promote_dtype(
            x, self.depthwise_b, self.depthwise_w, dtype=self.dtype
        )
        # feature_group_count=self.nchannels,
        # padding = ((self.kernel_size - 1, 0),)
        res = jax.lax.conv_general_dilated(
            x,
            w,
            padding=((self.kernel_size - 1, 0),),
            window_strides=(1,),
            dimension_numbers=dimension_numbers,
            feature_group_count=inp_shape[-1],
            preferred_element_type=self.dtype,
            precision="default",
        )
        b = b.reshape((1,) * (res.ndim - b.ndim) + b.shape)
        res += b
        return res


# class DepthConv1D(nn.Module):
#     """A 1D temporal convolution layer.

#     Attributes:
#       nchannels: The number of dimensions of the input and output.
#       kernel_size: The size of the temporal receptive field of the convolution.
#       w_init_variance_scale: A parameter that scales the variance of the
#         initialization of the weights.
#       dtype: dtype used for computation.
#     """

#     nchannels: int
#     out_channels: int = None
#     kernel_size: int = 3
#     w_init_variance_scale: float = 0.01
#     dtype: jnp.dtype = jnp.float32

#     @property
#     def kernel_init(self) -> nn.initializers.Initializer:
#         """Initializer for the kernel of the Conv1D."""
#         return nn.initializers.variance_scaling(
#             scale=self.w_init_variance_scale,
#             mode="fan_in",
#             distribution="normal",
#         )

#     def setup(self):
#         out_channels = (
#             self.nchannels if self.out_channels is None else self.out_channels
#         )
#         self.conv_shape = (self.kernel_size, out_channels, 1)
#         self.w = self.param(
#             "deeptwise_w", self.kernel_init, self.conv_shape, jnp.float32
#         )
#         self.b = self.param(
#             "deeptwise_b", nn.initializers.zeros_init(), (out_channels,), jnp.float32
#         )

#     @nn.checkpoint
#     def __call__(self, x):
#         # Parameters.
#         # default: input: 'NCHW', parameters: 'OIHW', output: 'NCHW'

#         inp_shape = x.shape
#         dimension_numbers = jax.lax.conv_dimension_numbers(
#             inp_shape, self.conv_shape, ("NWC", "WOI", "NWC")
#         )

#         x, b, w = nn.dtypes.promote_dtype(x, self.b, self.w, dtype=self.dtype)
#         # feature_group_count=self.nchannels,
#         # padding = ((self.kernel_size - 1, 0),)
#         res = jax.lax.conv_general_dilated(
#             x,
#             w,
#             padding=((self.kernel_size - 1, 0),),
#             window_strides=(1,),
#             dimension_numbers=dimension_numbers,
#             feature_group_count=inp_shape[-1],
#         )
#         b = b.reshape((1,) * (res.ndim - b.ndim) + b.shape)
#         res += b
#         return res


class S5Layer(nn.Module):
    ssm_size: int = 128
    hidden_dim: int = 128
    blocks: int = 1
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.block_size = int(self.ssm_size / self.blocks)
        # Initialize state matrix A using approximation to HiPPO-LegS matrix
        Lambda, _, B, V, B_orig = make_DPLR_HiPPO(self.block_size)
        Lambda = Lambda[: self.block_size]
        V = V[:, : self.block_size]
        Vc = V.conj().T
        Lambda = (Lambda * np.ones((self.blocks, self.block_size))).ravel()

        V = block_diag(*([V] * self.blocks))
        Vinv = block_diag(*([Vc] * self.blocks))

        ssm_init_fn = init_S5SSM(
            H=self.hidden_dim,
            P=self.ssm_size,
            Lambda_re_init=Lambda.real,
            Lambda_im_init=Lambda.imag,
            V=V,
            Vinv=Vinv,
            C_init="trunc_standard_normal",
            discretization="zoh",
            dt_min=0.001,
            dt_max=0.1,
            conj_sym=False,
            clip_eigs=False,
            bidirectional=False,
        )
        ssm = ssm_init_fn  # (step_rescale=1)
        ssm = nn.vmap(
            ssm,
            in_axes=0,
            out_axes=0,
            variable_axes={"params": None},
            split_rngs={"params": False},
        )
        self.ssm = ssm(step_rescale=1)

    def __call__(self, X):
        return self.ssm(X).astype(self.dtype)


def create_rope(n_pos, d_model):
    """
    Args:
        n_pos: int = max number of positional embeddings
        d_model: int = hiddend dim of the embedding/model
    """
    sub_space = jnp.arange(start=0, stop=d_model, step=2)  # 2i
    denom = jnp.exp(sub_space * (-math.log(10000.0) / d_model))
    positions = jnp.arange(n_pos)
    positions = jnp.tile(positions[None, :], (d_model // 2, 1)).transpose(1, 0)
    out = jnp.zeros((n_pos, d_model))
    out = out.at[:, 0 : (d_model // 2)].set(jnp.sin(positions * denom))
    out = out.at[:, (d_model // 2) :].set(jnp.cos(positions * denom))
    return out


class RopeEmbeds(nn.Module):
    n_pos: int
    d_model: int
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.rel_pos = create_rope(self.n_pos, self.d_model).astype(self.dtype)

    def __call__(self, mat):
        rel_pos = self.rel_pos[: mat.shape[2], :]
        sin, cos = jnp.split(rel_pos, indices_or_sections=2, axis=-1)
        sin = jnp.tile(sin, reps=(1, 2))
        cos = jnp.tile(cos, reps=(1, 2))
        rotate_half_mat = jnp.concatenate([-mat[..., 1::2], mat[..., 0::2]], axis=-1)
        # rotated = cos * mat + sin * rotate_half_mat
        rotated = jnp.einsum("BHTD,TD->BHTD", mat, cos) + jnp.einsum(
            "BHTD,TD->BHTD", rotate_half_mat, sin
        )
        return rotated

    def apply_vapor(self, mat, neg=False):
        # TODO - clean repeated code
        rel_pos = self.rel_pos[: mat.shape[0], :]
        sin, cos = jnp.split(rel_pos, indices_or_sections=2, axis=-1)
        sin = jnp.tile(sin, reps=(1, 2))
        cos = jnp.tile(cos, reps=(1, 2))
        if neg:
            rotate_half_mat = jnp.concatenate(
                [mat[..., 1::2], -mat[..., 0::2]], axis=-1
            )
        else:
            rotate_half_mat = jnp.concatenate(
                [-mat[..., 1::2], mat[..., 0::2]], axis=-1
            )
        # rotated = cos * mat + sin * rotate_half_mat
        rotated = jnp.einsum("TBHD,TD->TBHD", mat, cos) + jnp.einsum(
            "TBHD,TD->TBHD", rotate_half_mat, sin
        )

        return rotated


class PositionalEncoding(nn.Module):
    d_model: int
    max_len: int = 2048
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        max_len = self.max_len
        d_model = self.d_model
        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = jnp.zeros((max_len, d_model))
        position = jnp.arange(0, max_len, dtype=jnp.float32)[..., None]
        div_term = jnp.exp(
            jnp.arange(0, d_model, 2, dtype=jnp.float32)
            * (-math.log(10000.0) / d_model)
        )
        pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
        self.pe = pe[None, ...].astype(self.dtype)

    def __call__(
        self,
        embeds: jnp.array,
        do_inference: bool = False,
        time_pos: int = None,
    ) -> jnp.array:
        """
        Args:
            embeds: jnp.array(BTD) - embedding vectors
            do_inference: bool - special treatment for inference
            time_pos: int - position in the sentence for sequential inference
        """
        if do_inference:
            return self.pe[:, time_pos] + embeds
        return self.pe[:, : embeds.shape[1]] + embeds


class Embedder(nn.Module):
    """Embedder module.

    Attributes:
      vocab_size: The size of the token vocabulary.
      embed_dim: The dimensionality of each token embedding.
      scale_by_sqrt_dim: Whether to scale the output of the block by
        `sqrt(elf.embed_dim)`.
      dtype: dtype used for computation.
      param_dtype: dtype used for initializing parameters.
    """

    vocab_size: int
    embed_dim: int
    scale_by_sqrt_dim: bool
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # Parameters.
        self.input_embedding_table = self.param(
            "embedding",
            nn.initializers.variance_scaling(
                scale=1.0,
                mode="fan_in",
                distribution="normal",
                in_axis=1,
                out_axis=0,
            ),
            (self.vocab_size, self.embed_dim),
        )

    def encode(self, x):
        """Encodes an input sequence of tokens."""
        x = self.input_embedding_table[(x,)]
        [x] = promote_dtype(x, dtype=self.dtype)

        if self.scale_by_sqrt_dim:
            # Cast to bfloat16 to match training.
            x = x * jnp.sqrt(self.embed_dim).astype(jnp.bfloat16)
        return x

    def decode(self, x):
        """Decodes an input sequence of activations."""
        x, embedding_table = promote_dtype(
            x,
            self.input_embedding_table,
            dtype=self.dtype,
        )
        return x @ embedding_table.T


class GatedMLP(nn.Module):
    hidden_dim: int
    intermediate_dim: int
    initializer_range: float = 0.02
    final_w_init_variance_scale: float = 1.0
    dtype: jnp.dtype = jnp.float32

    @property
    def out_kernel_init(self) -> nn.initializers.Initializer:
        """Initialization of the kernel for the last layer of the block."""
        return nn.initializers.variance_scaling(
            scale=self.final_w_init_variance_scale,
            mode="fan_in",
            distribution="normal",
        )

    @nn.compact
    def __call__(self, x):
        intermediate_size = self.intermediate_dim
        gate_proj = nn.Dense(
            features=intermediate_size,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(stddev=self.initializer_range),
            dtype=self.dtype,
        )
        up_proj = nn.Dense(
            features=intermediate_size,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(stddev=self.initializer_range),
            dtype=self.dtype,
        )
        down_proj = nn.Dense(
            features=self.hidden_dim,
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(stddev=self.initializer_range),
            dtype=self.dtype,
        )
        act_fn = jax.nn.silu
        out = down_proj(act_fn(gate_proj(x)) * up_proj(x))
        return out


class SlidingWindowAtt:
    """Not a proper flax module. Just apply sliding window attention given Q,K,V
    Sliding window attention in (OT2W). Kernel implementation nedeed for O(TW)
    Pass `exact_windowsize=true` to mask entires on the far left since algorithm uses 2W.
    """

    def __init__(self, window_size: int, exact_windowsize: bool = True, causal=True):
        self.exact_windowsize = exact_windowsize
        self.causal = causal
        if causal:
            self.window_size = window_size
        else:  # bidirectional case
            self.window_size = window_size // 2

    @staticmethod
    def look_around(x, backward=1, forward=0, pad_value=-1, dim=2):
        t = x.shape[1]
        pad_width = len(x.shape) * [(0, 0)]
        pad_width[1] = (backward, forward)
        padded_x = jnp.pad(x, pad_width=pad_width, constant_values=pad_value)
        tensors = [
            padded_x[:, ind : (ind + t), ...] for ind in range(backward + forward + 1)
        ]
        return jnp.concatenate(tensors, axis=dim)

    @staticmethod
    def pad_to_multiple(x, multiple, dim=-1, value=0):
        seqlen = x.shape[dim]
        m = seqlen / multiple
        if m.is_integer():
            return False, x
        remainder = math.ceil(m) * multiple - seqlen
        # pad_offset = (0,) * (-1 - dim) * 2
        pad_width = len(x.shape) * [(0, 0)]
        pad_width[dim] = (0, remainder)
        return True, jnp.pad(x, pad_width=pad_width, constant_values=value)

    def __call__(self, Q, K, V, input_mask, attn_dropout, rot_embeds=None):
        """
        Args:
            Q,K,V: jax.Array(B,H,T,D)
            input_mask: jax.Array(BT) - useful only for bidirectional case. 1 = attend, 0 ignore
        Returns:
            jax.Array(B,H,T,D)

        """
        B, H, T, D = Q.shape
        pad_value = -1

        if self.causal:
            forward = 0
        else:
            forward = 1

        # merge batch and heads for ease
        (Q, packed_shape), (K, _), (V, _) = map(
            lambda t: einops.pack([t], "* n d"), (Q, K, V)
        )
        # autopadding to make sure seq length divisible by window size - discard before returning
        orig_seq_len = Q.shape[1]
        (needed_pad, Q), (_, K), (_, V) = map(
            lambda t: self.pad_to_multiple(
                t, self.window_size, dim=-2, value=pad_value
            ),
            (Q, K, V),
        )
        B, T, dim_head = Q.shape
        assert (
            T % self.window_size
        ) == 0, f"sequence length {T} must be divisible by window size {self.window_size} for local attention"

        windows = T // self.window_size
        bq, bk, bv = map(
            lambda t: einops.rearrange(t, "b (w n) d -> b w n d", w=windows), (Q, K, V)
        )
        bq = bq * (D**-0.5)  # attention scale sqrt(1/dim_head)
        # concatenate one window ahead to make sure first token had w context length
        bk = self.look_around(bk, backward=1, forward=forward, pad_value=pad_value)
        bv = self.look_around(bv, backward=1, forward=forward, pad_value=pad_value)

        # apply rotary positions to bq, bk
        if rot_embeds is not None:
            if isinstance(rot_embeds, XPos):
                bk = rot_embeds(bk, offset=0, downscale=True)
                bq = rot_embeds(bq, offset=0, downscale=False)
            else:
                bk = rot_embeds(bk)
                bq = rot_embeds(bq)
        sim = einops.einsum(bq, bk, "b h i e, b h j e -> b h i j")

        # handle padding
        seq = jnp.arange(T)
        b_t = einops.rearrange(seq, "(w n) -> 1 w n", w=windows, n=self.window_size)
        bq_t = b_t
        bq_k = self.look_around(b_t, backward=1, forward=forward, pad_value=pad_value)

        bq_t = einops.rearrange(bq_t, "... i -> ... i 1")
        bq_k = einops.rearrange(bq_k, "... j -> ... 1 j")
        pad_mask = bq_k == pad_value

        if self.causal:
            causal_mask = bq_t < bq_k
            if self.exact_windowsize:
                causal_mask = causal_mask | (bq_t > (bq_k + self.window_size))
            sim = jnp.where(causal_mask, -9e15, sim)

        # bidirectional case
        if not self.causal and self.exact_windowsize:
            window_mask = ((bq_k - self.window_size) > bq_t) | (
                bq_t > (bq_k + self.window_size)
            )
            sim = jnp.where(window_mask, -9e15, sim)
        # everything has a pad mask
        sim = jnp.where(pad_mask, -9e15, sim)

        if input_mask is not None:
            assert (B % input_mask.shape[0]) == 0
            h = B // input_mask.shape[0]
            _, input_mask = self.pad_to_multiple(
                input_mask, self.window_size, dim=-1, value=False
            )
            input_mask = einops.rearrange(
                input_mask, "... (w n) -> (...) w n", w=windows, n=self.window_size
            )
            input_mask = self.look_around(
                input_mask, backward=1, forward=forward, pad_value=False
            )
            input_mask = einops.rearrange(input_mask, "... j -> ... 1 j")
            input_mask = einops.repeat(input_mask, "b ... -> (b h) ...", h=h)
            sim = jnp.where(input_mask, sim, -9e15)

        attn = jax.nn.softmax(sim, axis=-1)
        attn = attn_dropout(attn)
        out = einops.einsum(attn, bv, "b h i j, b h j e -> b h i e")
        out = einops.rearrange(out, "b w n d -> b (w n) d")

        out = out[:, :orig_seq_len, :]
        out, *_ = einops.unpack(out, packed_shape, "* n d")
        return out
