from typing import NamedTuple

from flax import nnx

from offline.modules.base import TargetModel
from offline.modules.critic import VCritic
from offline.modules.mlp import MLP
from offline.types import BoolArray, FloatArray, IntArray


class HLQLearningBatch(NamedTuple):
    assignments: IntArray
    dones: BoolArray
    next_observations: FloatArray
    observations: FloatArray
    rewards: FloatArray


class HighLevelTrainState(NamedTuple):
    optimizer_qcritic: nnx.Optimizer
    optimizer_vcritic: nnx.Optimizer
    qcritic: MLP
    target_qcritic: TargetModel[MLP]
    vcritic: VCritic
