from flax.training.train_state import TrainState as BaseTrainState
from flax.core.frozen_dict import FrozenDict
from flax import struct
from typing import Any


class TrainState(BaseTrainState):
    target_params: FrozenDict[str, Any] = struct.field(pytree_node=True)
