#!/usr/bin/env python
#
# MIT License
#
# Copyright (c) 2020-2021 NVIDIA CORPORATION.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.#
from abc import ABC, abstractmethod
import copy
import numpy as np
import torch
import torch.autograd.profiler as profiler
[docs]class Controller(ABC):
"""Base class for sampling based controllers."""
def __init__(self,
d_action,
action_lows,
action_highs,
horizon,
gamma,
n_iters,
rollout_fn=None,
sample_mode='mean',
hotstart=True,
seed=0,
tensor_args={'device':torch.device('cpu'), 'dtype':torch.float32}):
"""
Defines an abstract base class for
sampling based MPC algorithms.
Implements the optimize method that is called to
generate an action sequence for a given state and
is common across sampling based controllers
Attributes:
d_action : int
size of action space
action_lows : torch.Tensor
lower limits for each action dim
action_highs : torch.Tensor
upper limits for each action dim
horizon : int
horizon of rollouts
gamma : float
discount factor
n_iters : int
number of optimization iterations per MPC call
rollout_fn : function handle
rollout policy (or actions) in simulator
and return states and costs for updating MPC
distribution
sample_mode : {'mean', 'sample'}
how to choose action to be executed
'mean' plays the first mean action and
'sample' samples from the distribution
hotstart : bool
If true, the solution from previous step
is used to warm start current step
seed : int
seed value
device: torch.device
controller can run on both cpu and gpu
float_dtype: torch.dtype
floating point precision for calculations
"""
self.tensor_args = tensor_args
self.d_action = d_action
self.action_lows = action_lows.to(**self.tensor_args)
self.action_highs = action_highs.to(**self.tensor_args)
self.horizon = horizon
self.gamma = gamma
self.n_iters = n_iters
self.gamma_seq = torch.cumprod(torch.tensor([1.0] + [self.gamma] * (horizon - 1)),dim=0).reshape(1, horizon)
self.gamma_seq = self.gamma_seq.to(**self.tensor_args)
self._rollout_fn = rollout_fn
self.sample_mode = sample_mode
self.num_steps = 0
self.hotstart = hotstart
self.seed_val = seed
self.trajectories = None
[docs] @abstractmethod
def _get_action_seq(self, mode='mean'):
"""
Get action sequence to execute on the system based
on current control distribution
Args:
mode : {'mean', 'sample'}
how to choose action to be executed
'mean' plays mean action and
'sample' samples from the distribution
"""
pass
[docs] def sample_actions(self):
"""
Sample actions from current control distribution
"""
raise NotImplementedError('sample_actions funtion not implemented')
[docs] @abstractmethod
def _update_distribution(self, trajectories):
"""
Update current control distribution using
rollout trajectories
Args:
trajectories : dict
Rollout trajectories. Contains the following fields
observations : torch.tensor
observations along rollouts
actions : torch.tensor
actions sampled from control distribution along rollouts
costs : torch.tensor
step costs along rollouts
"""
pass
[docs] @abstractmethod
def _shift(self):
"""
Shift the current control distribution
to hotstart the next timestep
"""
pass
[docs] @abstractmethod
def reset_distribution(self):
pass
[docs] def reset(self):
"""
Reset the controller
"""
self.num_steps = 0
self.reset_distribution()
[docs] @abstractmethod
def _calc_val(self, cost_seq, act_seq):
"""
Calculate value of state given
rollouts from a policy
"""
pass
[docs] def check_convergence(self):
"""
Checks if controller has converged
Returns False by default
"""
return False
# @property
# def set_sim_state_fn(self):
# return self._set_sim_state_fn
# @set_sim_state_fn.setter
# def set_sim_state_fn(self, fn):
# """
# Set function that sets the simulation
# environment to a particular state
# """
# self._set_sim_state_fn = fn
@property
def rollout_fn(self):
return self._rollout_fn
@rollout_fn.setter
def rollout_fn(self, fn):
"""
Set the rollout function from
input function pointer
"""
self._rollout_fn = fn
[docs] @abstractmethod
def generate_rollouts(self, state):
pass
[docs] def optimize(self, state, calc_val=False, shift_steps=1, n_iters=None):
"""
Optimize for best action at current state
Parameters
----------
state : torch.Tensor
state to calculate optimal action from
calc_val : bool
If true, calculate the optimal value estimate
of the state along with action
Returns
-------
action : torch.Tensor
next action to execute
value: float
optimal value estimate (default: 0.)
info: dict
dictionary with side-information
"""
n_iters = n_iters if n_iters is not None else self.n_iters
# get input device:
inp_device = state.device
inp_dtype = state.dtype
state.to(**self.tensor_args)
info = dict(rollout_time=0.0, entropy=[])
# shift distribution to hotstart from previous timestep
if self.hotstart:
self._shift(shift_steps)
else:
self.reset_distribution()
with torch.cuda.amp.autocast(enabled=True):
with torch.no_grad():
for _ in range(n_iters):
# generate random simulated trajectories
trajectory = self.generate_rollouts(state)
# update distribution parameters
with profiler.record_function("mppi_update"):
self._update_distribution(trajectory)
info['rollout_time'] += trajectory['rollout_time']
# check if converged
if self.check_convergence():
break
self.trajectories = trajectory
#calculate best action
# curr_action = self._get_next_action(state, mode=self.sample_mode)
curr_action_seq = self._get_action_seq(mode=self.sample_mode)
#calculate optimal value estimate if required
value = 0.0
if calc_val:
trajectories = self.generate_rollouts(state)
value = self._calc_val(trajectories)
# # shift distribution to hotstart next timestep
# if self.hotstart:
# self._shift()
# else:
# self.reset_distribution()
info['entropy'].append(self.entropy)
self.num_steps += 1
return curr_action_seq.to(inp_device, dtype=inp_dtype), value, info
[docs] def get_optimal_value(self, state):
"""
Calculate optimal value of a state, i.e
value under optimal policy.
Parameters
----------
state : torch.Tensor
state to calculate optimal value estimate for
Returns
-------
value : float
optimal value estimate of the state
"""
self.reset() #reset the control distribution
_, value = self.optimize(state, calc_val=True, shift_steps=0)
return value
# def seed(self, seed=None):
# self.np_random, seed = seeding.np_random(seed)
# return seed