"""Control-affine systems discrete-time simulation."""
from dataclasses import dataclass
from typing import Callable
import torch


VectorField = Callable[[torch.Tensor], torch.Tensor]


@dataclass
class ControlAffineSystem:
    drift_field: VectorField
    control_fields: list[VectorField]


def step(
        x: torch.Tensor,
        u: list[torch.Tensor],
        system: ControlAffineSystem,
        dt: float,
        ) -> torch.Tensor:
    """Integrate the control affine system, returning the next state.

    - `x`: current state
    - `u`: list of control inputs
    - `system`: control affine system with as
              many controllable vector fields
              as inputs in `u`
    - `dt`: timestep size
    """
    # Unpack system components
    h0 = system.drift_field
    h = system.control_fields

    # Compute derivative
    xp = h0(x) + sum(torch.matmul(hi(x), ui) for hi, ui in zip(h, u))

    # Integrate
    x_next = x + xp*dt

    return x_next
