import os

from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import common
from dm_control.suite import hopper
from dm_control.utils import rewards
from dm_control.utils import io as resources
import numpy as np

_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')

_CONTROL_TIMESTEP = .02  # (Seconds)

# Default duration of an episode, in seconds.
_DEFAULT_TIME_LIMIT = 20

# Minimal height of torso over foot above which stand reward is 1.
_STAND_HEIGHT = 0.6

# Hopping speed above which hop reward is 1.
_HOP_SPEED = 2

# Angular momentum above which reward is 1.
_SPIN_SPEED = 5


def get_model_and_assets():
	"""Returns a tuple containing the model XML string and a dict of assets."""
	return resources.GetResource(os.path.join(_TASKS_DIR, 'hopper.xml')), common.ASSETS


@hopper.SUITE.add('custom')
def hop_backwards(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
	"""Returns the Hop Backwards task."""
	physics = Physics.from_xml_string(*get_model_and_assets())
	task = CustomHopper(goal='hop-backwards', random=random)
	environment_kwargs = environment_kwargs or {}
	return control.Environment(
		physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
		**environment_kwargs)


@hopper.SUITE.add('custom')
def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
	"""Returns the Flip task."""
	physics = Physics.from_xml_string(*get_model_and_assets())
	task = CustomHopper(goal='flip', random=random)
	environment_kwargs = environment_kwargs or {}
	return control.Environment(
		physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
		**environment_kwargs)


@hopper.SUITE.add('custom')
def flip_backwards(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
	"""Returns the Flip Backwards task."""
	physics = Physics.from_xml_string(*get_model_and_assets())
	task = CustomHopper(goal='flip-backwards', random=random)
	environment_kwargs = environment_kwargs or {}
	return control.Environment(
		physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
		**environment_kwargs)


class Physics(hopper.Physics):

	def angmomentum(self):
		"""Returns the angular momentum of torso of the Cheetah about Y axis."""
		return self.named.data.subtree_angmom['torso'][1]


class CustomHopper(hopper.Hopper):
	"""Custom Hopper tasks."""

	def __init__(self, goal='hop-backwards', random=None):
		super().__init__(None, random)
		self._goal = goal
	
	def _hop_backwards_reward(self, physics):
		standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
		hopping = rewards.tolerance(physics.speed(),
									bounds=(-float('inf'), -_HOP_SPEED/2),
									margin=_HOP_SPEED/4,
									value_at_margin=0.5,
									sigmoid='linear')
		return standing * hopping
	
	def _flip_reward(self, physics, forward=True):
		reward = rewards.tolerance((1. if forward else -1.) * physics.angmomentum(),
								   bounds=(_SPIN_SPEED, float('inf')),
								   margin=_SPIN_SPEED/2,
								   value_at_margin=0,
								   sigmoid='linear')
		return reward


	def get_reward(self, physics):
		if self._goal == 'hop-backwards':
			return self._hop_backwards_reward(physics)
		elif self._goal == 'flip':
			return self._flip_reward(physics, forward=True)
		elif self._goal == 'flip-backwards':
			return self._flip_reward(physics, forward=False)
		else:
			raise NotImplementedError(f'Goal {self._goal} is not implemented.')


if __name__ == '__main__':
	env = hop_backwards()
	obs = env.reset()
	import numpy as np
	next_obs, reward, done, info = env.step(np.zeros(2))
	print(reward)
