import jax
import jax.numpy as jnp
from jaxlib.xla_extension import Device
from optax import OptState


# Turns off type checking
from typing import Union, Tuple, Dict, Any, Collection, Callable, Annotated

beartype = lambda f: f  # Make `@beartype` a no-op

# -----------------------
# General Parameter Types
# -----------------------

Array = jnp.ndarray
ScalarArray = Any #Annotated[jax.Array, Is[lambda x: hasattr(x, "shape") and x.shape == ()]]
LossTupleType = Tuple[ScalarArray, ScalarArray]

ParamsTypeGen = Dict[Union[str, int], Any]


#DeviceID = Union[int, Collection[int]]
DeviceLike = Union[int, Device]
DevicesArg = Union[DeviceLike, Collection[DeviceLike]]


# -----------------------
# Coeff Types
# -----------------------
def is_array_like_with_valid_shape(x):
    try:
        return hasattr(x, "ndim") and x.ndim in (2, 3)
    except:
        return False
def is_2d_array(x):
    try:
        return hasattr(x, "ndim") and x.ndim == 2
    except:
        return False
def is_3d_array(x):
    try:
        return hasattr(x, "ndim") and x.ndim == 3
    except:
        return False
CoeffType = Any #Annotated[jax.Array, Is[is_array_like_with_valid_shape]]
CoarseCoeffType = Any #Annotated[jax.Array, Is[is_2d_array]]
FineCoeffType = Any #Annotated[jax.Array, Is[is_3d_array]]
# -----------------------
# Posterior Types
# -----------------------
def is_valid_posterior(x):
    try:
        return hasattr(x, "ndim") and x.ndim in (2, 3)
    except:
        return False
def is_posterior_2d(x):
    try:
        return hasattr(x, "ndim") and x.ndim == 2
    except:
        return False
def is_posterior_3d(x):
    try:
        return hasattr(x, "ndim") and x.ndim == 3
    except:
        return False
PosteriorType = Any #Annotated[jax.Array, Is[is_valid_posterior]]
CoarsePosteriorType = Any #Annotated[jax.Array, Is[is_posterior_2d]]
FinePosteriorType   = Any #Annotated[jax.Array, Is[is_posterior_3d]]

# -----------------------
# Layer Param Types
# -----------------------
# Layer weights (typically weight and bias)
LayerWeightsType = Tuple[Array, Array]

# A single layer's parameters
LayerParamsType = Dict[str, LayerWeightsType]

# -----------------------
# Gating Network Types
# -----------------------
# Flax typically uses {'params': {layer_name: layer_params}} structure
GatingParamsInner = Dict[str, LayerParamsType]
GatingParamsOuter = Dict[str, GatingParamsInner]

# Coarse level has one gating network
GatingCParams = GatingParamsOuter  # {'params': {layer_name: {'layer_weights': (w,b)}}}

# Fine level has multiple gating networks
GatingFParams = Tuple[GatingParamsOuter, ...]  # One per coarse partition

# -----------------------
# Basis Network Types
# -----------------------
# For MLP basis networks
BasisMLPParams = GatingParamsOuter  # Same structure as gating network

# For classical basis functions (empty dict since no parameters)
BasisClassicalParams = Dict[str, Any]  # Typically empty for classical bases

# Combined basis type (MLP or classical)
BasisSingleParams = Dict[str, Union[BasisMLPParams, BasisClassicalParams]]

# Coarse level basis params (one per partition)
BasisCParams = Tuple[BasisSingleParams, ...]

# Fine level basis params (nested: coarse × fine partitions)
BasisFParams = Tuple[Tuple[BasisSingleParams, ...], ...]

# -----------------------
# Combined Parameter Types
# -----------------------
ParamsCType = Tuple[GatingCParams, BasisCParams]
ParamsFType = Tuple[GatingFParams, BasisFParams]

# Optimizer state
OptParamsType = Tuple[Any, OptState]

