storm_kit.mpc.control.control_base module

class Controller(d_action, action_lows, action_highs, horizon, gamma, n_iters, rollout_fn=None, sample_mode='mean', hotstart=True, seed=0, tensor_args={'device': device(type='cpu'), 'dtype': torch.float32})[source]

Bases: abc.ABC

Base class for sampling based controllers.

Defines an abstract base class for sampling based MPC algorithms.

Implements the optimize method that is called to generate an action sequence for a given state and is common across sampling based controllers

Attributes:

d_actionint

size of action space

action_lowstorch.Tensor

lower limits for each action dim

action_highstorch.Tensor

upper limits for each action dim

horizonint

horizon of rollouts

gammafloat

discount factor

n_itersint

number of optimization iterations per MPC call

rollout_fnfunction handle

rollout policy (or actions) in simulator and return states and costs for updating MPC distribution

sample_mode{‘mean’, ‘sample’}

how to choose action to be executed ‘mean’ plays the first mean action and ‘sample’ samples from the distribution

hotstartbool

If true, the solution from previous step is used to warm start current step

seedint

seed value

device: torch.device

controller can run on both cpu and gpu

float_dtype: torch.dtype

floating point precision for calculations

_abc_impl = <_abc_data object>
abstract _calc_val(cost_seq, act_seq)[source]

Calculate value of state given rollouts from a policy

abstract _get_action_seq(mode='mean')[source]

Get action sequence to execute on the system based on current control distribution

Parameters

mode – {‘mean’, ‘sample’} how to choose action to be executed ‘mean’ plays mean action and ‘sample’ samples from the distribution

abstract _shift()[source]

Shift the current control distribution to hotstart the next timestep

abstract _update_distribution(trajectories)[source]

Update current control distribution using rollout trajectories

Parameters

trajectories

dict Rollout trajectories. Contains the following fields observations : torch.tensor

observations along rollouts

actionstorch.tensor

actions sampled from control distribution along rollouts

coststorch.tensor

step costs along rollouts

check_convergence()[source]

Checks if controller has converged Returns False by default

abstract generate_rollouts(state)[source]
get_optimal_value(state)[source]

Calculate optimal value of a state, i.e value under optimal policy.

Parameters

state (torch.Tensor) – state to calculate optimal value estimate for

Returns

value – optimal value estimate of the state

Return type

float

optimize(state, calc_val=False, shift_steps=1, n_iters=None)[source]

Optimize for best action at current state

Parameters
  • state (torch.Tensor) – state to calculate optimal action from

  • calc_val (bool) – If true, calculate the optimal value estimate of the state along with action

Returns

  • action (torch.Tensor) – next action to execute

  • value (float) – optimal value estimate (default: 0.)

  • info (dict) – dictionary with side-information

reset()[source]

Reset the controller

abstract reset_distribution()[source]
property rollout_fn
sample_actions()[source]

Sample actions from current control distribution