import flax
from typing import NamedTuple, Union
import jax.numpy as jnp

from jax_rl.networks.common import Model, PRNGKey
from jax_rl.networks.pe_net import RunningMeanStd


@flax.struct.dataclass
class ActorCriticTemp:
    actor: Model
    critic: Model
    target_critic: Model
    temp: Model
    rng: PRNGKey


@flax.struct.dataclass
class Gater:
    gater_q: Model
    gater_q_tgt: Model
    gater_p: Model
    pA: Model
    pB: Model
    gater_stats: RunningMeanStd
    
class Pre(NamedTuple):
    lam_p: jnp.ndarray
    lam_q: jnp.ndarray
    lam_tgt: jnp.ndarray
    next_actions: jnp.ndarray
    next_log_probs: jnp.ndarray