import functools
import logging
from collections.abc import Callable, Sequence
from typing import TypeVar

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, ArrayLike, Float

from neural_pfaffian.kfac_blocks import register_repeated_dense
from neural_pfaffian.utils.jax_utils import vectorize


def residual(x: Array, y: Array):
    if x.shape == y.shape:
        return (x + y) / jnp.sqrt(2)
    return y


def log1p_rescale(x: Float[Array, '... 3+1']):
    return x / x[..., -1:] * jnp.log1p(x[..., -1:])


T = TypeVar('T', bound=Array)

ActivationOrName = Callable[[T], T]


class Activation(nn.Module):
    activation: ActivationOrName

    def __call__(self, x: Float[Array, ' ...']) -> Float[Array, ' ...']:
        if callable(self.activation):
            return self.activation(x)
        return getattr(nn, self.activation)(x)


class Dense(nn.Module):
    dim: int
    use_bias: bool = True
    out_std: ArrayLike = 1.0

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        kernel = self.param(
            'kernel',
            jax.nn.initializers.variance_scaling(
                self.out_std**2,
                'fan_in',
                'truncated_normal',
                in_axis=-2,
                out_axis=-1,
                batch_axis=(),
                dtype=jnp.float32,
            ),
            (x.shape[-1], self.dim),
            jnp.float32,
        )
        bias = (
            self.param('bias', jax.nn.initializers.zeros, (self.dim,), jnp.float32)
            if self.use_bias
            else None
        )
        y = x @ kernel
        if bias is not None:
            y += bias
        y = register_repeated_dense(y, x, kernel, bias)
        return y


class GatedLinearUnit(nn.Module):
    dim: int | tuple[int, ...]
    activation: ActivationOrName
    hidden_dim: int | None = None
    normalize: bool = True
    out_std: ArrayLike = 1.0
    chunk_axis: int | Sequence[int] | None = None
    max_chunk_size: int | None = None

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        if self.chunk_axis is not None and isinstance(self.dim, int):
            raise ValueError(f'Cannot chunk over axis {self.chunk_axis} with 1D output.')

        if self.hidden_dim is None:
            assert isinstance(self.dim, int)
            hidden_dim = self.dim
        else:
            hidden_dim = self.hidden_dim
        if self.normalize:
            x = nn.LayerNorm()(x)
        hidden = Activation(self.activation)(
            Dense(hidden_dim, use_bias=False)(x),
        ) * Dense(hidden_dim, use_bias=False)(x)

        OutDense = functools.partial(
            Dense,
            use_bias=False,
            out_std=self.out_std,
        )

        if isinstance(self.dim, int):
            return OutDense(self.dim)(hidden)

        output_shape = tuple(self.dim)

        if self.chunk_axis is None:
            return OutDense(int(np.prod(output_shape)))(hidden).reshape(
                *hidden.shape[:-1],
                *output_shape,
            )

        chunk_axes = self.chunk_axis
        if isinstance(chunk_axes, int):
            chunk_axes = (chunk_axes,)
        elif isinstance(chunk_axes, Sequence):
            chunk_axes = tuple(chunk_axes)
        else:
            raise TypeError(
                f'chunk_axis must be an int, sequence of ints, or None, got {type(chunk_axes)}.',
            )

        # Convert negative axes to positive (and check for duplicates).
        ndim = len(output_shape)
        normalized_axes: list[int] = []
        for axis in chunk_axes:
            axis = int(axis)
            axis %= ndim
            if axis in normalized_axes:
                raise ValueError(f'Duplicate chunk axis {axis}.')
            normalized_axes.append(axis)
        chunk_axes = tuple(normalized_axes)

        # In case chunk_axes is empty, we revert back to the unchunked version.
        if not chunk_axes:
            return OutDense(int(np.prod(output_shape)))(hidden).reshape(
                *hidden.shape[:-1],
                *output_shape,
            )

        unchunked_axes = tuple(ax for ax in range(ndim) if ax not in chunk_axes)
        unchunked_shape = tuple(output_shape[ax] for ax in unchunked_axes)
        chunk_shape = tuple(output_shape[ax] for ax in chunk_axes)
        unchunk_prod = int(np.prod(unchunked_shape)) if unchunked_shape else 1
        num_chunks = int(np.prod(chunk_shape)) if chunk_shape else 1

        # Determine how many chunks we can fit in a single kernel based on max_chunk_size.
        max_chunk_size = self.max_chunk_size
        if max_chunk_size is not None:
            max_chunk_size = int(max_chunk_size)
            if max_chunk_size <= 0:
                raise ValueError('max_chunk_size must be positive when provided.')
            if max_chunk_size < unchunk_prod:
                logging.warning(
                    'max_chunk_size is less than the unchunked output size.'
                    ' Ignoring the max_chunk_size; consider adding a chunk_axis to param'
                    f'{self.path}.',
                )
                chunks_per_kernel = 1
            chunks_per_kernel = max(1, max_chunk_size // unchunk_prod)
        else:
            chunks_per_kernel = 1

        kernel_init = jax.nn.initializers.variance_scaling(
            self.out_std**2,
            'fan_in',
            'truncated_normal',
            in_axis=-2,
            out_axis=-1,
            batch_axis=(),
            dtype=jnp.float32,
        )

        # Spread the chunks across the kernels
        remaining = num_chunks
        group_idx = 0
        outs = []
        while remaining > 0:
            group_size = min(remaining, chunks_per_kernel)
            kernel_2d = self.param(
                f'kernel_{group_idx}',
                kernel_init,
                (hidden.shape[-1], unchunk_prod * group_size),
                jnp.float32,
            )
            kernel_3d = kernel_2d.reshape(hidden.shape[-1], unchunk_prod, group_size)
            outs.append(jnp.einsum('...i,ijc->...jc', hidden, kernel_3d))
            remaining -= group_size
            group_idx += 1
        out = jnp.stack(outs, axis=-1)
        out = out.reshape(
            *hidden.shape[:-1],
            *unchunked_shape,
            *chunk_shape,
        )
        # move all axes back to where they belong
        batch_ndim = hidden.ndim - 1
        # maps each original index to the index in `out`
        unchunk_map = {axis: idx for idx, axis in enumerate(unchunked_axes)}
        chunk_map = {axis: idx for idx, axis in enumerate(chunk_axes)}
        perm = list(range(batch_ndim))  # batch_dims are unchanged and stay in front
        for axis in range(ndim):
            if axis in unchunk_map:
                perm.append(batch_ndim + unchunk_map[axis])
            else:
                perm.append(
                    batch_ndim + len(unchunked_shape) + chunk_map[axis],
                )
        out = jnp.transpose(out, perm)

        return out


@vectorize(signature='(a,b)->(c,d)', excluded={1, 2, 3})
def pad_block_constant(
    x: Array,
    top_right: ArrayLike,
    bottom_left: ArrayLike,
    bottom_right: ArrayLike,
):
    n, m = x.shape
    dtype = x.dtype
    tr = jnp.full((n, 1), top_right, dtype=dtype)
    bl = jnp.full((1, m), bottom_left, dtype=dtype)
    br = jnp.full((1, 1), bottom_right, dtype=dtype)
    return jnp.block([[x, tr], [bl, br]])


@vectorize(signature='(a,b),(a),(b),()->(c,d)')
def block(x: Array, top_right: Array, bottom_left: Array, bottom_right: Array):
    return jnp.block(
        [
            [x, top_right[:, None]],
            [bottom_left[None], bottom_right[None, None]],
        ],
    )


class MLP(nn.Module):
    dims: Sequence[int]
    activation: ActivationOrName

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        activation = Activation(self.activation)
        for dim in self.dims[:-1]:
            x = activation(Dense(dim)(x))
        x = Dense(self.dims[-1])(x)
        return x


def normal_init(mean: Float[ArrayLike, ''], std: Float[ArrayLike, '']):
    def init(key: jax.Array, shape: Sequence[int], dtype=jnp.float32):
        return jax.random.normal(key, shape, dtype=dtype) * jnp.array(
            std,
            dtype=dtype,
        ) + jnp.array(mean, dtype=dtype)

    return init
