import dataclasses
import functools
from typing import TYPE_CHECKING

import folx
import jax
import jax.numpy as jnp

if TYPE_CHECKING:
    import kfac_jax  # pyright: ignore[reportMissingImports]
    from kfac_jax._src.layers_and_loss_tags import (  # pyright: ignore[reportMissingImports]
        layer_tag,
    )
else:
    try:
        import kfac_jax
        from kfac_jax._src.layers_and_loss_tags import (
            layer_tag,
        )

    except ImportError:
        kfac_jax = None
        layer_tag = None

GRAPH_PATTERNS = ()


def register_repeated_dense(y, x, w, b, **kwargs):
    if kfac_jax is not None:
        kfac_jax.register_dense(y, x, w, b, variant='repeated_dense', **kwargs)
    return y


if kfac_jax is not None:
    type Array = kfac_jax.utils.Array
    type Scalar = kfac_jax.utils.Scalar
    type Numeric = kfac_jax.utils.Numeric

    vmap_psd_inv = jax.vmap(kfac_jax.utils.psd_inv, (0, None), 0)
    vmap_matmul = jax.vmap(jnp.matmul, in_axes=(0, 0), out_axes=0)

    _dense = functools.partial(
        kfac_jax.tag_graph_matcher._dense,
        axes=1,
        with_reshape=False,
    )
    _repeated_dense_parameter_extractor = functools.partial(
        kfac_jax.tag_graph_matcher._dense_parameter_extractor,
        variant='repeated_dense',
    )

    class LayerNormBlock(kfac_jax.ScaleAndShiftDiagonal):
        def fixed_scale(self) -> Numeric:
            (x_shape,) = self.inputs_shapes
            return x_shape[-2]

        def update_curvature_matrix_estimate(
            self,
            state: kfac_jax.Diagonal.State,
            estimation_data: kfac_jax.LayerVjpData[Array],
            ema_old: Numeric,
            ema_new: Numeric,
            identity_weight: Numeric,
            batch_size: Numeric,
        ) -> kfac_jax.Diagonal.State:
            [x] = estimation_data.primals.inputs
            [dy] = estimation_data.tangents.outputs

            estimation_data = dataclasses.replace(
                estimation_data,
                primals=dataclasses.replace(
                    estimation_data.primals,
                    inputs=(x.reshape([-1, x.shape[-1]]),),
                ),
                tangents=dataclasses.replace(
                    estimation_data.tangents,
                    outputs=(dy.reshape([-1, dy.shape[-1]]),),
                ),
            )

            batch_size = x.size // x.shape[-1]
            return super().update_curvature_matrix_estimate(
                state=state,
                estimation_data=estimation_data,
                ema_old=ema_old,
                ema_new=ema_new,
                identity_weight=identity_weight,
                batch_size=batch_size,
            )

    class RepeatedDenseBlock(kfac_jax.DenseTwoKroneckerFactored):
        """Dense block that is repeatedly applied to multiple inputs (e.g. vmap)."""

        def fixed_scale(self) -> Numeric:
            (x_shape,) = self.inputs_shapes
            return x_shape[-2]

        def update_curvature_matrix_estimate(
            self,
            state: kfac_jax.KroneckerFactored.State,
            estimation_data: kfac_jax.LayerVjpData[Array],
            ema_old: Numeric,
            ema_new: Numeric,
            identity_weight: Numeric,
            batch_size: int,
        ) -> kfac_jax.KroneckerFactored.State:
            [x] = estimation_data.primals.inputs
            [dy] = estimation_data.tangents.outputs

            estimation_data = dataclasses.replace(
                estimation_data,
                primals=dataclasses.replace(
                    estimation_data.primals,
                    inputs=(x.reshape([-1, x.shape[-1]]),),
                ),
                tangents=dataclasses.replace(
                    estimation_data.tangents,
                    outputs=(dy.reshape([-1, dy.shape[-1]]),),
                ),
            )

            batch_size = x.size // x.shape[-1]
            return super().update_curvature_matrix_estimate(
                state=state,
                estimation_data=estimation_data,
                ema_old=ema_old,
                ema_new=ema_new,
                identity_weight=identity_weight,
                batch_size=batch_size,
            )

    # repeating a dense layer once
    _repeated_dense1 = jax.vmap(_dense, in_axes=[0, [None, None]])
    _repeated_dense2 = jax.vmap(_repeated_dense1, in_axes=[0, [None, None]])
    _repeated_dense1_no_b = jax.vmap(_dense, in_axes=[0, [None]])
    _repeated_dense2_no_b = jax.vmap(_repeated_dense1_no_b, in_axes=[0, [None]])

    # Computation for repeated dense layer
    repeated_dense1_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
        name='repeated_dense1_with_bias',
        tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
        compute_func=_repeated_dense1,
        parameters_extractor_func=_repeated_dense_parameter_extractor,
        example_args=[jnp.zeros([9, 11, 13]), [jnp.zeros([13, 7]), jnp.zeros([7])]],
    )

    repeated_dense1_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
        name='repeated_dense1_no_bias',
        tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
        compute_func=_repeated_dense1_no_b,
        parameters_extractor_func=_repeated_dense_parameter_extractor,
        example_args=[jnp.zeros([9, 11, 13]), [jnp.zeros([13, 7])]],
    )

    repeated_dense2_with_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
        name='repeated_dense2_with_bias',
        tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
        compute_func=_repeated_dense2,
        parameters_extractor_func=_repeated_dense_parameter_extractor,
        example_args=[jnp.zeros([8, 9, 11, 13]), [jnp.zeros([13, 7]), jnp.zeros([7])]],
    )

    repeated_dense2_no_bias_pattern = kfac_jax.tag_graph_matcher.GraphPattern(
        name='repeated_dense2_no_bias',
        tag_primitive=kfac_jax.layers_and_loss_tags.layer_tag,
        compute_func=_repeated_dense2_no_b,
        parameters_extractor_func=_repeated_dense_parameter_extractor,
        example_args=[jnp.zeros([8, 9, 11, 13]), [jnp.zeros([13, 7])]],
    )

    GRAPH_PATTERNS = (
        repeated_dense1_with_bias_pattern,
        repeated_dense2_with_bias_pattern,
        repeated_dense1_no_bias_pattern,
        repeated_dense2_no_bias_pattern,
        *kfac_jax.tag_graph_matcher.DEFAULT_GRAPH_PATTERNS,
    )

    kfac_jax.set_default_tag_to_block_ctor('repeated_dense', RepeatedDenseBlock)
    kfac_jax.set_default_tag_to_block_ctor('dense', RepeatedDenseBlock)
    kfac_jax.set_default_tag_to_block_ctor('scale_and_shift', LayerNormBlock)

    # This assumes that every tag has a single output
    folx.register_function(layer_tag, lambda args, kwargs, sparsity_threshold: args[0])
