import os

from dm_control.rl import control
from dm_control.suite import common
from dm_control.suite import cheetah
from dm_control.utils import rewards
from dm_control.utils import io as resources

_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')

_CHEETAH_JUMP_HEIGHT = 1.2
_CHEETAH_LIE_HEIGHT = 0.25
_CHEETAH_SPIN_SPEED = 8


def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return resources.GetResource(os.path.join(_TASKS_DIR, 'cheetah.xml')), common.ASSETS


@cheetah.SUITE.add('custom')
def run_backwards(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Run Backwards task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='run-backwards', move_speed=cheetah._RUN_SPEED*0.8, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def stand_front(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Stand Front task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='stand-front', move_speed=0.5, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def stand_back(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Stand Back task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='stand-back', move_speed=0.5, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def jump(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Jump task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='jump', move_speed=0.5, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def run_front(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Run Front task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='run-front', move_speed=cheetah._RUN_SPEED*0.6, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def run_back(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Run Back task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='run-back', move_speed=cheetah._RUN_SPEED*0.6, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def lie_down(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Lie Down task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='lie-down', random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def legs_up(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Legs Up task."""
    physics = cheetah.Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='legs-up', random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def flip(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Flip task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='flip', move_speed=cheetah._RUN_SPEED, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


@cheetah.SUITE.add('custom')
def flip_backwards(time_limit=cheetah._DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Flip Backwards task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = CustomCheetah(goal='flip-backwards', move_speed=cheetah._RUN_SPEED*0.8, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               **environment_kwargs)


class Physics(cheetah.Physics):
    """Physics simulation with additional features for the Cheetah domain."""

    def angmomentum(self):
        """Returns the angular momentum of torso of the Cheetah about Y axis."""
        return self.named.data.subtree_angmom['torso'][1]


class CustomCheetah(cheetah.Cheetah):
    """Custom Cheetah tasks."""
    
    def __init__(self, goal='run-backwards', move_speed=0, random=None):
        super().__init__(random)
        self._goal = goal
        self._move_speed = move_speed

    def _run_backwards_reward(self, physics):
        return rewards.tolerance(physics.speed(),
                            bounds=(-float('inf'), -self._move_speed),
                            margin=self._move_speed,
                            value_at_margin=0,
                            sigmoid='linear')
       
    def _stand_one_foot_reward(self, physics, foot):
        """Note: `foot` is the foot that is *not* on the ground."""
        torso_height = physics.named.data.xpos['torso', 'z']
        foot_height = physics.named.data.xpos[foot, 'z']
        height_reward = rewards.tolerance((torso_height + foot_height)/2,
                            bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
                            margin=_CHEETAH_JUMP_HEIGHT/2)
        horizontal_speed_reward = rewards.tolerance(physics.speed(),
                            bounds=(-self._move_speed, self._move_speed),
                            margin=self._move_speed,
                            value_at_margin=0,
                            sigmoid='linear')
        stand_reward = (5*height_reward + horizontal_speed_reward) / 6
        return stand_reward

    def _stand_front_reward(self, physics):
        return self._stand_one_foot_reward(physics, 'bfoot')
    
    def _stand_back_reward(self, physics):
        return self._stand_one_foot_reward(physics, 'ffoot')
    
    def _jump_reward(self, physics):
        front_reward = self._stand_front_reward(physics)
        back_reward = self._stand_back_reward(physics)
        jump_reward = (front_reward + back_reward) / 2
        return jump_reward

    def _run_one_foot_reward(self, physics, foot):
        """Note: `foot` is the foot that is *not* on the ground."""
        torso_height = physics.named.data.xpos['torso', 'z']
        foot_height = physics.named.data.xpos[foot, 'z']
        torso_up = rewards.tolerance(torso_height,
                            bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
                            margin=_CHEETAH_JUMP_HEIGHT/2)
        foot_up = rewards.tolerance(foot_height,
                            bounds=(_CHEETAH_JUMP_HEIGHT, float('inf')),
                            margin=_CHEETAH_JUMP_HEIGHT/2)
        up_reward = (3*foot_up + 2*torso_up) / 5
        if self._move_speed == 0:
            return up_reward
        horizontal_speed_reward = rewards.tolerance(physics.speed(),
                            bounds=(self._move_speed, float('inf')),
                            margin=self._move_speed,
                            value_at_margin=0,
                            sigmoid='linear')
        return up_reward * (5*horizontal_speed_reward + 1) / 6

    def _run_front_reward(self, physics):
        return self._run_one_foot_reward(physics, 'bfoot')
    
    def _run_back_reward(self, physics):
        return self._run_one_foot_reward(physics, 'ffoot')

    def _lie_down_reward(self, physics):
        torso_height = physics.named.data.xpos['torso', 'z']
        feet_height = (physics.named.data.xpos['ffoot', 'z'] + physics.named.data.xpos['bfoot', 'z']) / 2
        torso_down = rewards.tolerance(torso_height,
                            bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
                            margin=_CHEETAH_LIE_HEIGHT,
                            value_at_margin=0,
                            sigmoid='linear')
        feet_down = rewards.tolerance(feet_height,
                            bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
                            margin=_CHEETAH_LIE_HEIGHT,
                            value_at_margin=0,
                            sigmoid='linear')
        lie_down_reward = (3*torso_down + feet_down) / 4
        return lie_down_reward

    def _legs_up_reward(self, physics):
        torso_height = physics.named.data.xpos['torso', 'z']
        torso_down = rewards.tolerance(torso_height,
                            bounds=(-float('inf'), _CHEETAH_LIE_HEIGHT),
                            margin=_CHEETAH_LIE_HEIGHT/2)
        get_up = self._run_one_foot_reward(physics, 'bfoot')
        legs_up_reward = (5*torso_down + get_up) / 6
        return legs_up_reward
    
    def _flip_reward(self, physics, forward=True):
        spin_reward = rewards.tolerance(
                            (1. if forward else -1.) * physics.angmomentum(),
                            bounds=(_CHEETAH_SPIN_SPEED, float('inf')),
                            margin=_CHEETAH_SPIN_SPEED,
                            value_at_margin=0,
                            sigmoid='linear')
        horizontal_speed_reward = rewards.tolerance(
                            (1. if forward else -1.) * physics.speed(),
                            bounds=(self._move_speed, float('inf')),
                            margin=self._move_speed,
                            value_at_margin=0,
                            sigmoid='linear')
        flip_reward = (2*spin_reward + horizontal_speed_reward) / 3
        return flip_reward

    def get_reward(self, physics):
        if self._goal == 'run-backwards':
            return self._run_backwards_reward(physics)
        elif self._goal == 'stand-front':
            return self._stand_front_reward(physics)
        elif self._goal == 'stand-back':
            return self._stand_back_reward(physics)
        elif self._goal == 'jump':
            return self._jump_reward(physics)
        elif self._goal == 'run-front':
            return self._run_front_reward(physics)
        elif self._goal == 'run-back':
            return self._run_back_reward(physics)
        elif self._goal == 'lie-down':
            return self._lie_down_reward(physics)
        elif self._goal == 'legs-up':
            return self._legs_up_reward(physics)
        elif self._goal == 'flip':
            return self._flip_reward(physics, forward=True)
        elif self._goal == 'flip-backwards':
            return self._flip_reward(physics, forward=False)
        else:
            raise NotImplementedError(f'Goal {self._goal} is not implemented.')


if __name__ == '__main__':
    env = jump()
    obs = env.reset()
    import numpy as np
    next_obs, reward, done, info = env.step(np.zeros(6))
    print(reward)
