'''
Run environment taken from DM Control Suite Half-Cheetah
It's that simple!
'''
from dm_control import mjcf
import numpy as np
from .base import MorphologyEnv
import torch

class ArmReachPhysics(mjcf.Physics):
    pass

class ArmReach(MorphologyEnv):
    PHYSICS_CLS = ArmReachPhysics
    DEFAULT_TIME_LIMIT = 4
    
    def __init__(self, morphology, action_penalty=0.01, **kwargs):
        self.action_penalty = action_penalty
        self.goal_low = (0.5, -0.8)
        self.goal_high  = (1.5, 0.8)
        self.goal = np.random.uniform(low=self.goal_low, high=self.goal_high)
        super(ArmReach, self).__init__(morphology, **kwargs)
        
    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        cur_end_pos = self._physics.data.xpos[-1, :2].copy()
        dist = np.linalg.norm(cur_end_pos - self.goal)
        if dist < 0.1:
            reward = 10
            done = True
        else:
            reward = -1*dist
            reward -= self.action_penalty * np.sum(np.square(action))
            done = False
        return obs, reward, done, {}

    def _get_task_obs(self):
        return self._physics.data.xpos[-1, :2].copy() - self.goal

    def _reset(self, noise=True):
        # Add reset variance.
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
        self.goal = np.random.uniform(low=self.goal_low, high=self.goal_high)
        self._physics.named.model.geom_pos['target', 'x'] = self.goal[0]
        self._physics.named.model.geom_pos['target', 'y'] = self.goal[1]


class ArmReachEnd(MorphologyEnv):
    PHYSICS_CLS = ArmReachPhysics
    DEFAULT_TIME_LIMIT = 4
    
    def __init__(self, morphology, action_penalty=0.01, **kwargs):
        self.action_penalty = action_penalty
        self.goal_low = (0.5, -0.8)
        self.goal_high  = (1.5, 0.8)
        self.goal = np.random.uniform(low=self.goal_low, high=self.goal_high)
        super(ArmReachEnd, self).__init__(morphology, **kwargs)
        
    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        cur_end_pos = self._physics.data.site_xpos[-1, :2].copy()
        dist = np.linalg.norm(cur_end_pos - self.goal)
        if dist < 0.1:
            reward = 100
            done = True
        else:
            reward = -1*dist
            reward -= self.action_penalty * np.sum(np.square(action))
            done = False
        return obs, reward, done, {'success' : float(done)}

    def _get_task_obs(self):
        # return self._physics.data.xpos[-1, :2].copy() - self.goal
        return np.expand_dims(self._physics.data.xpos[-1, :2].copy() - self.goal, axis=0)

    def _reset(self, noise=True):
        # Add reset variance.
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
        self.goal = np.random.uniform(low=self.goal_low, high=self.goal_high)
        self._physics.named.model.geom_pos['target', 'x'] = self.goal[0]
        self._physics.named.model.geom_pos['target', 'y'] = self.goal[1]

class ArmPush1(MorphologyEnv):
    PHYSICS_CLS = ArmReachPhysics
    DEFAULT_TIME_LIMIT = 4

    GOAL = None

    def __init__(self, morphology, action_penalty=0.01, **kwargs):
        self.action_penalty = action_penalty
        super(ArmPush1, self).__init__(morphology, **kwargs)

    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        cur_box_pos = self._physics.data.xpos[1, :2].copy()
        dist = np.linalg.norm(cur_box_pos - self.GOAL)
        arm_dist = np.linalg.norm(cur_box_pos - self._physics.data.site_xpos[-1, :2].copy())
        if dist < 0.05:
            reward = 100
            done = True
        else:
            reward = -1 * dist
            reward -= 0.5*arm_dist
            reward -= self.action_penalty * np.sum(np.square(action))
            done = False

        return obs, reward, done, {}

    def _get_task_obs(self):
        return np.concatenate((
            self._physics.data.xpos[1, :2].copy() - self.GOAL,
            self._physics.data.xpos[1, :2].copy(),
            self._physics.data.xpos[1, :2].copy() - self._physics.data.site_xpos[-1, :2].copy()
        ))

    def _reset(self, noise=True):
        # Add reset variance.
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001

class ArmPush2(MorphologyEnv):
    PHYSICS_CLS = ArmReachPhysics
    DEFAULT_TIME_LIMIT = 4

    def __init__(self, morphology, action_penalty=0.01, **kwargs):
        self.action_penalty = action_penalty
        super(ArmPush2, self).__init__(morphology, **kwargs)

    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        cur_box_pos = self._physics.data.xpos[2, :2].copy()
        dist = np.linalg.norm(cur_box_pos - self.GOAL)
        arm_dist = np.linalg.norm(cur_box_pos - self._physics.data.site_xpos[-1, :2].copy())
        if dist < 0.05:
            reward = 100
            done = True
        else:
            reward = -1 * dist
            reward -= 0.5*arm_dist
            reward -= self.action_penalty * np.sum(np.square(action))
            done = False

        return obs, reward, done, {}

    def _get_task_obs(self):
        return np.concatenate((
            self._physics.data.xpos[2, :2].copy() - self.GOAL,
            self._physics.data.xpos[2, :2].copy(),
            self._physics.data.xpos[2, :2].copy() - self._physics.data.site_xpos[-1, :2].copy()
        ))

    def _reset(self, noise=True):
        # Add reset variance.
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001

class ArmPush1G1(ArmPush1):
    GOAL = np.array([0.0, 0.65])

class ArmPush1G2(ArmPush1):
    GOAL = np.array([0.9, 1.4])

class ArmPush1G3(ArmPush1):
    GOAL = np.array([0.15, 1.25])

class ArmPush2G1(ArmPush2):
    GOAL = np.array([0.0, -0.65])

class ArmPush2G2(ArmPush2):
    GOAL = np.array([0.9, -1.4])

class ArmPush2G3(ArmPush2):
    GOAL = np.array([0.15, -1.25])

class ArmPushAll(MorphologyEnv):
    PHYSICS_CLS = ArmReachPhysics
    DEFAULT_TIME_LIMIT = 4

    GOALS = np.array([
                        [0.0, 0.65],
                        [0.9, 1.4],
                        [0.15, 1.25],
                        [0.0, -0.65],
                        [0.9, -1.4],
                        [0.15, -1.25]
                    ])

    def __init__(self, morphology, action_penalty=0.01, **kwargs):
        self.action_penalty = action_penalty
        self.cur_task = 0
        super(ArmPushAll, self).__init__(morphology, **kwargs)

    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        if self.cur_task < 3:
            cur_box_pos = self._physics.data.xpos[1, :2].copy()
        else:
            cur_box_pos = self._physics.data.xpos[2, :2].copy()
        dist = np.linalg.norm(cur_box_pos - self.GOALS[self.cur_task])
        arm_dist = np.linalg.norm(cur_box_pos - self._physics.data.site_xpos[-1, :2].copy())
        if dist < 0.05:
            reward = 100
            done = True
        else:
            reward = -1 * dist
            reward -= 0.5*arm_dist
            reward -= self.action_penalty * np.sum(np.square(action))
            done = False

        return obs, reward, done, {}

    def _get_task_obs(self):
        if self.cur_task < 3:
            obs = np.concatenate((
                self._physics.data.xpos[1, :2].copy() - self.GOALS[self.cur_task],
                self._physics.data.xpos[1, :2].copy(),
                self._physics.data.xpos[1, :2].copy() - self._physics.data.site_xpos[-1, :2].copy()
            ))
        else:
            obs = np.concatenate((
                self._physics.data.xpos[2, :2].copy() - self.GOALS[self.cur_task],
                self._physics.data.xpos[2, :2].copy(),
                self._physics.data.xpos[2, :2].copy() - self._physics.data.site_xpos[-1, :2].copy()
            ))
        return np.expand_dims(obs, axis=0)

    def _reset(self, noise=True):
        # Add reset variance.
        self.cur_task = (self.cur_task + 1) % 6
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
