from typing import Any, NamedTuple
import jax.numpy as jnp
import jax

class Transition(NamedTuple):
    done:     jnp.ndarray
    action:   jnp.ndarray
    value:    jnp.ndarray
    reward:   jnp.ndarray
    log_prob: jnp.ndarray
    obs:      jnp.ndarray
    info:     Any

def load_feat_extractor_params(actor_critic_params):
    """Load the feature extractor parameters from actor-critic params."""
    return {
        "params": {
            "Dense_0": actor_critic_params["Dense_0"],
            "Dense_1": actor_critic_params["Dense_1"],
        }
    }


def extract_submodel(params: dict, model_idx: int = 0):
    """
    Return a copy of `params` where every leaf with leading dimension
    equal to the ensemble size has been sliced at `model_idx`.
    Assumes *all* ensemble leaves share the same size (= params['Dense_0']['kernel'].shape[0]).
    """
    # find the ensemble size once
    ensemble_dim = None
    for leaf in jax.tree_leaves(params):
        if isinstance(leaf, jnp.ndarray) and leaf.ndim > 0:
            ensemble_dim = leaf.shape[0]
            break
    if ensemble_dim is None:
        raise ValueError("Could not infer ensemble dimension")

    def maybe_slice(leaf):
        # slice if first axis matches ensemble size
        if isinstance(leaf, jnp.ndarray) and leaf.shape and leaf.shape[0] == ensemble_dim:
            return leaf[model_idx]          # drop leading axis
        return leaf                         # leave untouched

    return jax.tree_map(maybe_slice, params)
