storm_kit.mpc.control.control_base module¶
-
class
Controller
(d_action, action_lows, action_highs, horizon, gamma, n_iters, rollout_fn=None, sample_mode='mean', hotstart=True, seed=0, tensor_args={'device': device(type='cpu'), 'dtype': torch.float32})[source]¶ Bases:
abc.ABC
Base class for sampling based controllers.
Defines an abstract base class for sampling based MPC algorithms.
Implements the optimize method that is called to generate an action sequence for a given state and is common across sampling based controllers
Attributes:
- d_actionint
size of action space
- action_lowstorch.Tensor
lower limits for each action dim
- action_highstorch.Tensor
upper limits for each action dim
- horizonint
horizon of rollouts
- gammafloat
discount factor
- n_itersint
number of optimization iterations per MPC call
- rollout_fnfunction handle
rollout policy (or actions) in simulator and return states and costs for updating MPC distribution
- sample_mode{‘mean’, ‘sample’}
how to choose action to be executed ‘mean’ plays the first mean action and ‘sample’ samples from the distribution
- hotstartbool
If true, the solution from previous step is used to warm start current step
- seedint
seed value
- device: torch.device
controller can run on both cpu and gpu
- float_dtype: torch.dtype
floating point precision for calculations
-
_abc_impl
= <_abc_data object>¶
-
abstract
_calc_val
(cost_seq, act_seq)[source]¶ Calculate value of state given rollouts from a policy
-
abstract
_get_action_seq
(mode='mean')[source]¶ Get action sequence to execute on the system based on current control distribution
- Parameters
mode – {‘mean’, ‘sample’} how to choose action to be executed ‘mean’ plays mean action and ‘sample’ samples from the distribution
-
abstract
_update_distribution
(trajectories)[source]¶ Update current control distribution using rollout trajectories
- Parameters
trajectories –
dict Rollout trajectories. Contains the following fields observations : torch.tensor
observations along rollouts
- actionstorch.tensor
actions sampled from control distribution along rollouts
- coststorch.tensor
step costs along rollouts
-
get_optimal_value
(state)[source]¶ Calculate optimal value of a state, i.e value under optimal policy.
- Parameters
state (torch.Tensor) – state to calculate optimal value estimate for
- Returns
value – optimal value estimate of the state
- Return type
float
-
optimize
(state, calc_val=False, shift_steps=1, n_iters=None)[source]¶ Optimize for best action at current state
- Parameters
state (torch.Tensor) – state to calculate optimal action from
calc_val (bool) – If true, calculate the optimal value estimate of the state along with action
- Returns
action (torch.Tensor) – next action to execute
value (float) – optimal value estimate (default: 0.)
info (dict) – dictionary with side-information
-
property
rollout_fn
¶