"""
Utilities for working with JAX.
Some of these functions are taken from
https://github.com/deepmind/ferminet/tree/jax/ferminet
"""

import functools
import importlib
import logging
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any, Self, TypeVar, cast, overload

import jax
import jax.numpy as jnp
from flax import serialization
from flax.serialization import from_bytes, to_bytes, to_state_dict
from flax.struct import PyTreeNode
from jax import shard_map
from jax.sharding import NamedSharding, PartitionSpec
from jaxtyping import Array, PyTree

_BATCH_AXIS = 'qmc_batch'


MESH = jax.make_mesh((jax.device_count(),), (_BATCH_AXIS,))
BATCH_SPEC = PartitionSpec(_BATCH_AXIS)
BATCH_SHARDING = NamedSharding(MESH, BATCH_SPEC)
REPLICATE_SPEC = PartitionSpec()
REPLICATE_SHARDING = NamedSharding(MESH, REPLICATE_SPEC)


def distribute_keys(key: jax.Array) -> jax.Array:
    return jax.random.split(key, jax.device_count())[pidx()]


pmean = functools.partial(jax.lax.pmean, axis_name=_BATCH_AXIS)
psum = functools.partial(jax.lax.psum, axis_name=_BATCH_AXIS)
pmax = functools.partial(jax.lax.pmax, axis_name=_BATCH_AXIS)
pmin = functools.partial(jax.lax.pmin, axis_name=_BATCH_AXIS)
pgather = functools.partial(jax.lax.all_gather, axis_name=_BATCH_AXIS)
pall_to_all = functools.partial(jax.lax.all_to_all, axis_name=_BATCH_AXIS)
pidx = functools.partial(jax.lax.axis_index, axis_name=_BATCH_AXIS)


def pvary(x):
    if hasattr(jax.lax, 'pvary'):
        return jax.lax.pvary(x, axis_name=_BATCH_AXIS)
    return x


def wrap_if_pmap[C: Callable](p_func: C) -> C:
    @functools.wraps(p_func)
    def p_func_if_pmap[T](obj: T, *args, **kwargs) -> T:
        try:
            jax.lax.axis_index(_BATCH_AXIS)
            return p_func(obj, *args, **kwargs)
        except NameError:
            return obj

    return p_func_if_pmap  # type: ignore


pmean_if_pmap = wrap_if_pmap(pmean)
psum_if_pmap = wrap_if_pmap(psum)
pmax_if_pmap = wrap_if_pmap(pmax)
pmin_if_pmap = wrap_if_pmap(pmin)
pgather_if_pmap = wrap_if_pmap(pgather)


C = TypeVar('C', bound=Callable)


@overload
def jit[C: Callable](fun: None = None, *jit_args, **jit_kwargs) -> Callable[[C], C]: ...


@overload
def jit[C: Callable](fun: C, *jit_args, **jit_kwargs) -> C: ...


@functools.wraps(jax.jit)
def jit[C: Callable](
    fun: C | None = None,
    *jit_args,
    **jit_kwargs,
) -> C | Callable[[C], C]:
    def inner_jit(fun: C) -> C:
        jitted = jax.jit(fun, *jit_args, **jit_kwargs)

        @functools.wraps(fun)
        def wrapper(*args, **kwargs):
            return jitted(*args, **kwargs)

        return wrapper  # type: ignore

    if fun is None:
        return inner_jit

    return inner_jit(fun)


@overload
def vectorize(fun: None = None, *vec_args, **vec_kwargs) -> Callable[[C], C]: ...


@overload
def vectorize[C: Callable](fun: C, *vec_args, **vec_kwargs) -> C: ...


def vectorize[C: Callable](
    fun: C | None = None,
    *vec_args,
    **vec_kwargs,
) -> C | Callable[[C], C]:
    def inner_jit(fun: C) -> C:
        vectorized = jnp.vectorize(fun, *vec_args, **vec_kwargs)

        @functools.wraps(fun)
        def wrapper(*args, **kwargs):
            return vectorized(*args, **kwargs)

        return wrapper  # type: ignore

    if fun is None:
        return inner_jit

    return inner_jit(fun)


@functools.wraps(shard_map)
def shmap[C: Callable](
    fun: C | None = None,
    *shmap_args,
    **shmap_kwargs,
) -> C | Callable[[C], C]:
    def inner_shmap(fun: C) -> C:
        return shard_map(fun, *shmap_args, mesh=MESH, **shmap_kwargs)  # type: ignore

    if fun is None:
        return inner_shmap

    return inner_shmap(fun)


@overload
def vmap(fun: None = None, *vmap_args, **vmap_kwargs) -> Callable[[C], C]: ...


@overload
def vmap[C: Callable](fun: C, *vmap_args, **vmap_kwargs) -> C: ...


@functools.wraps(jax.vmap)
def vmap[C: Callable](
    fun: C | None = None,
    *vmap_args,
    **vmap_kwargs,
) -> C | Callable[[C], C]:
    def inner_vmap(fun: C) -> C:
        vmapped = jax.vmap(fun, *vmap_args, **vmap_kwargs)

        @functools.wraps(fun)
        def wrapper(*args, **kwargs):
            return vmapped(*args, **kwargs)

        return wrapper  # type: ignore

    if fun is None:
        return inner_vmap

    return inner_vmap(fun)


Axis = int | tuple[int, ...]


def pad_along_axis(
    array: jax.Array,
    pad_width: int | tuple[int, int],
    axis: Axis = -1,
    mode: str | Callable[..., Any] = 'constant',
    **kwargs,
) -> jax.Array:
    """
    Pads an array along a specified axis.
    This is a convenience wrapper around `jax.numpy.pad` that simplifies padding.
    Instead of specifying padding along all axes, you can specify the axes
    along which to pad.

    Args:
        array: The input array to pad.
        pad_width: The number of elements to pad on both sides of the specified axis.
            If `pad_width` is an integer, it will pad all axes on both sides.
            If `pad_width` is a tuple, it should contain two integers specifying the
            number of elements to pad on both sides of the specified axes.
        axis: The axes along which to pad the array.
            Defaults to padding the last axis -1.

    Returns:
        A new array with the specified padding applied.
    """
    if isinstance(pad_width, int):
        pad_width = (pad_width, pad_width)
    full_pad_width = [(0, 0)] * array.ndim

    if isinstance(axis, int):
        full_pad_width[axis] = pad_width
    elif isinstance(axis, Iterable):
        for ax in axis:
            full_pad_width[ax] = pad_width
    return jnp.pad(array, full_pad_width, mode=mode, **kwargs)


class SerializeablePyTree(PyTreeNode):
    serialize = to_bytes
    deserialize = from_bytes

    def to_file(self, path: str | Path):
        with Path(path).open('wb') as f:
            f.write(self.serialize())

    def from_file(self, path: str | Path) -> Self:
        return cast('Self', self.deserialize(Path(path).read_bytes()))

    @property
    def partition_spec(self) -> PartitionSpec | Self:
        return REPLICATE_SPEC

    @property
    def sharding(self):
        def to_sharding(x: PartitionSpec):
            return NamedSharding(MESH, x)

        return jax.tree.map(
            to_sharding,
            self.partition_spec,
            is_leaf=lambda x: isinstance(x, PartitionSpec),
        )

    @property
    def sharded(self):
        return jax.device_put(self, self.sharding)


try:
    import kfac_jax  # pyright: ignore[reportMissingImports]

    def serialize_kfac_state(instance: kfac_jax.utils.State) -> PyTree[Array]:
        """Used instead of `kfac_jax.utils.serialize_state_tree` to avoid
        serializing tuples."""
        state_dict: dict[str, dict[str, Any]] = {
            name: to_state_dict(getattr(instance, name))
            for name in instance.field_names()
        }
        return state_dict

    def deserialize_kfac_state(target: kfac_jax.utils.State, state: dict[str, Any]):
        """Used instead of `kfac_jax.utils.deserialize_state_tree` to avoid
        serializing tuples."""
        state = state.copy()
        kwargs = {}
        for name in target.field_names():
            if name not in state:
                raise ValueError(
                    f'Missing field {name} in state dict while restoring'
                    f' an instance of {type(target).__name__}',
                    f' at path {serialization.current_path()}',
                )
            value = getattr(target, name)
            value_state = state.pop(name)
            kwargs[name] = serialization.from_state_dict(value, value_state, name=name)
        if state:
            names = ','.join(state.keys())
            raise ValueError(
                f'Unexpected fields {names} in state dict while restoring'
                f' an instance of {type(target).__name__}',
                f' at path {serialization.current_path()}',
            )
        return target.__class__(**kwargs)

    def register_kfac_state_with_flax() -> None:
        """Registers a KFAC jax state with flax and recurses all children states."""
        try:
            mod = importlib.import_module('kfac_jax._src.utils.misc')
            kfac_states = mod.STATE_CLASSES_SERIALIZATION_DICT
        except (AttributeError, ImportError):
            logging.warning(
                'KFAC state serialization dictionary not found.'
                ' Maybe KFAC Jax internals have changed?',
            )
            kfac_states = {}

        for ty in kfac_states.values():
            serialization.register_serialization_state(
                ty,
                serialize_kfac_state,
                deserialize_kfac_state,
            )
except ImportError:
    pass
