import dataclasses
from typing import Any, Mapping, Optional

@dataclasses.dataclass
class State:
  """State that gets passed to the loss function."""

  # static state that gets enclosed in the loss and predict functions
  static_state: Mapping[str, Any] = dataclasses.field(default_factory=dict)

  # dynamic state that gets passed as an argument to the JIT-ed function
  dynamic_state: Mapping[str, Any] = dataclasses.field(default_factory=dict)

  # takes static_args + dynamic_args + model itself, returns new dynamic_args
  dynamic_update_fn: Optional[Any] = None

  dynamic_update_freq: int = 1

  jit_update: bool = True

  