'''
Run environment taken from DM Control Suite Half-Cheetah
It's that simple!
'''
from dm_control import mjcf
import numpy as np
from .base import MorphologyEnv
import torch
from torch_scatter import scatter_sum
import torch_geometric

class XVelPhysics(mjcf.Physics):
    
    def speedx(self):
        return self.named.data.sensordata['velocity'][0]

    def speedy(self):
        return self.named.data.sensordata['velocity'][1]

    def speedz(self):
        return self.named.data.sensordata['velocity'][2]

class XVel(MorphologyEnv):
    PHYSICS_CLS = XVelPhysics
    DEFAULT_TIME_LIMIT = 10
    
    def __init__(self, morphology, action_penalty=0.1, **kwargs):
        self.action_penalty = action_penalty
        super(XVel, self).__init__(morphology, **kwargs)

    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        reward = self._get_reward(obs, action)
        return obs, reward, False, {}

    def _get_reward(self, obs, action):
        reward = np.clip(self._physics.speedx(), -10, 10)
        reward -= self.action_penalty * np.sum(np.square(action))
        return reward

    def get_reward_graph(self, ob, ac):
        # First determine the run reward
        assert isinstance(ob, torch_geometric.data.Batch), "get_reward_graph only supports batch graphs."
        # The reward is the global attribute. It could be stored in x as the last item in the root node.
        # It could also just be stored in the global attribute u.
        if 'u' in ob.keys:
            reward_run = ob.u.squeeze(-1)
        # Else, assume its stored in the nodes.
        else:
            # Root node locations
            # This was done in really sus way last time.
            _, counts = torch.unique_consecutive(ob.batch, return_counts=True)
            locations = torch.cumsum(counts, 0) - counts[0]
            reward_run = ob.x[locations, 3] # 3 xpos then 3 xvel. We thus want component 3.

        # Next, determine the action penalty
        if isinstance(ac, torch.Tensor):
            assert ac.shape[0] == ob.num_graphs
            if len(ac.shape) == 1:
                reward_ctrl = -self.action_penalty * torch.square(ac).sum()
            else:
                reward_ctrl = -self.action_penalty * torch.square(ac).sum(dim=1)
        elif isinstance(ac, torch_geometric.data.Data):
            # First try edges
            if 'edge_attr' in ac.keys and not ac.edge_attr is None:
                # The actions are the edges
                assert ac.edge_attr.shape[1] == 1 # Must only have one attribute
                # Edge attributes are duplicated, so we need to halve the value we get here. Thus, we multiply by 0.5
                # NOTE: We assume all edges have actions
                reward_ctrl = -0.5 * self.action_penalty * scatter_sum(torch.square(ac.edge_attr), ac.edge_attr_batch, dim=0).squeeze(-1)
            # Next try nodes
            else:
                assert ac.x.shape[1] == 1
                reward_ctrl = -0.1 * scatter_sum(torch.square(ac.x) * ac.pad_mask.unsqueeze(-1), ac.batch, dim=0).sum(dim=1)
        else:
            raise ValueError("Invalid Action object passed into graph reward")
        
        reward = reward_run + reward_ctrl
        return reward

    def _reset(self, noise=True):
        # Add reset variance.
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
        

class XVel_Global(XVel):
        
    def _get_reward(self, obs, action):
        return np.clip(self._physics.speedx(), -10, 10)

    def _get_task_obs(self):
        return np.array([[self._physics.speedx()]])

    def get_reward_graph(self, obs, ac):
        # TODO: Replace
        pass

class NegXVel(XVel):

    def _get_reward(self, obs, action):
        reward = np.clip(-1*self._physics.speedx(), -10, 10)
        reward -= self.action_penalty * np.sum(np.square(action))
        return reward

class YVel(XVel):

    def _get_reward(self, obs, action):
        reward = np.clip(self._physics.speedy(), -10, 10)
        reward -= self.action_penalty * np.sum(np.square(action))
        return reward


class NegYVel(XVel):

    def _get_reward(self, obs, action):
        reward = np.clip(-1*self._physics.speedy(), -10, 10)
        reward -= self.action_penalty * np.sum(np.square(action))
        return reward

class ZVel(XVel):

    def _get_reward(self, obs, action):
        reward = np.clip(self._physics.speedz(), -2, 10)
        reward -= self.action_penalty * np.sum(np.square(action))
        return reward

class Directions2D(MorphologyEnv):
    PHYSICS_CLS = XVelPhysics
    DEFAULT_TIME_LIMIT = 10
    
    def __init__(self, morphology, action_penalty=0.1, **kwargs):
        self.action_penalty = action_penalty
        self.cur_task = 0
        super(Directions2D, self).__init__(morphology, **kwargs)

    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        reward = self._get_reward(obs, action)
        return obs, reward, False, {}

    def _get_reward(self, obs, action):
        if self.cur_task == 0:
            reward = np.clip(self._physics.speedx(), -10, 10)
        elif self.cur_task == 1:
            reward = np.clip(-1*self._physics.speedx(), -10, 10)
        elif self.cur_task == 2:
            reward = np.clip(self._physics.speedz(), -2, 10)
        else:
            raise ValueError("Invalid Task ID")
        reward -= self.action_penalty * np.sum(np.square(action))
        return reward

    def _get_task_obs(self):
        one_hot = np.zeros(3)
        one_hot[self.cur_task] = 1
        return np.expand_dims(one_hot, axis=0)

    def _reset(self, noise=True):
        # Add reset variance.
        self.cur_task = (self.cur_task + 1) % 3
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001

class Directions3D(MorphologyEnv):
    PHYSICS_CLS = XVelPhysics
    DEFAULT_TIME_LIMIT = 10
    
    def __init__(self, morphology, action_penalty=0.1, **kwargs):
        self.action_penalty = action_penalty
        self.cur_task = 0
        super(Directions3D, self).__init__(morphology, **kwargs)

    def _step(self, action):
        try:
            for _ in range(self._n_sub_steps):
                self._physics.step()
        except:
            if self.allow_exceptions:
                pass
            else:
                raise
        obs = self._get_obs()
        reward = self._get_reward(obs, action)
        return obs, reward, False, {}

    def _get_reward(self, obs, action):
        if self.cur_task == 0:
            reward = np.clip(self._physics.speedx(), -10, 10)
        elif self.cur_task == 1:
            reward = np.clip(-1*self._physics.speedx(), -10, 10)
        elif self.cur_task == 2:
            reward = np.clip(self._physics.speedy(), -10, 10)
        elif self.cur_task == 3:
            reward = np.clip(-1*self._physics.speedy(), -10, 10)
        else:
            raise ValueError("Invalid Task ID")
        reward -= self.action_penalty * np.sum(np.square(action))
        return reward

    def _get_task_obs(self):
        one_hot = np.zeros(4)
        one_hot[self.cur_task] = 1
        return np.expand_dims(one_hot, axis=0)

    def _reset(self, noise=True):
        # Add reset variance.
        self.cur_task = (self.cur_task + 1) % 4
        if True:
            self._physics.data.qpos[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001
            self._physics.data.qvel[-self._morphology.num_joints:] = self.np_random.randn(self._morphology.num_joints) * 0.001

