import numpy as np

def hopper(state, action, next_state):
	assert len(state.shape) == len(next_state.shape) == len(action.shape) == 2

	height = next_state[:, 0]
	angle = next_state[:, 1]
	not_done = np.isfinite(next_state).all(axis = -1) \
				* (np.abs(next_state[:, 1:]) < 100).all(axis = -1) \
				* (height > 0.7) \
				* (np.abs(angle) < 0.2)

	done = ~not_done
	done = done[:, None]
	return done

def halfcheetah(state, action, next_state):
	assert len(state.shape) == len(next_state.shape) == len(action.shape) == 2

	return np.zeros((len(state), 1), dtype = np.bool)

def walker2d(state, action, next_state):
	assert len(state.shape) == len(next_state.shape) == len(action.shape) == 2

	height = next_state[:, 0]
	angle = next_state[:, 1]
	not_done = (height > 0.8) \
				* (height < 2.0) \
				* (angle > -1.0) \
				* (angle < 1.0)

	done = ~not_done
	done = done[:, None]
	return done


environment = {
	"hopper": hopper,
	"halfcheetah": halfcheetah,
	"walker2d": walker2d
}

def termination_function(state, action, next_state, env_name):
	return environment[env_name](state, action, next_state)