import time
from typing import Dict, List, Any
import numpy as np
import gym

# from augment.rl.augmentation_functions import validate_augmentation
from augment.rl.augmentation_functions.augmentation_function import AugmentationFunction


class AntReflect(AugmentationFunction):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.obs_permute = np.arange(27)
        # joint angles
        self.obs_permute[5] = 7
        self.obs_permute[6] = 8
        self.obs_permute[7] = 5
        self.obs_permute[8] = 6
        self.obs_permute[9] = 11
        self.obs_permute[10] = 12
        self.obs_permute[11] = 9
        self.obs_permute[12] = 10
        # joint vels
        self.obs_permute[19] = 21
        self.obs_permute[20] = 22
        self.obs_permute[21] = 19
        self.obs_permute[22] = 20
        self.obs_permute[23] = 25
        self.obs_permute[24] = 26
        self.obs_permute[25] = 23
        self.obs_permute[26] = 24

        self.obs_reflect = np.zeros(27, dtype=bool)
        self.obs_reflect[5:12+1] = True
        self.obs_reflect[13] = True
        self.obs_reflect[17] = True
        self.obs_reflect[18] = True
        self.obs_reflect[19:] = True

        self.action_permute = np.arange(8)
        self.action_permute[0] = 6
        self.action_permute[2] = 4
        self.action_permute[4] = 2
        self.action_permute[6] = 0

        self.action_permute[1] = 7 #-
        self.action_permute[3] = 5 #-
        self.action_permute[5] = 3 #-
        self.action_permute[7] = 1 #-

    def _swap_action_left_right(self, action):
        action[:, :] = action[:, self.action_permute]
        action[:, :] *= -1

    def _reflect_orientation(self, obs):
        obs[:, 3] *= -1
        obs[:, 4] *= -1

    def _augment(self,
                obs: np.ndarray,
                next_obs: np.ndarray,
                action: np.ndarray,
                reward: np.ndarray,
                done: np.ndarray,
                infos: List[Dict[str, Any]],
                delta = None,
                p=None
                ):

        obs[:, :] = obs[:, self.obs_permute]
        next_obs[:, :] = next_obs[:, self.obs_permute]

        obs[:, self.obs_reflect] *= -1
        next_obs[:, self.obs_reflect] *= -1
        self._reflect_orientation(obs)
        self._reflect_orientation(next_obs)

        self._swap_action_left_right(action)
        reward_forward = infos[0][0]['reward_forward']
        reward[:] += -2*reward_forward

        return obs, next_obs, action, reward, done, infos



class AntRotate(AugmentationFunction):

    def __init__(self, noise_scale=np.pi/4, **kwargs):
        super().__init__(**kwargs)
        self.noise_scale = noise_scale

    def quat_mul(self, quat0, quat1):
        assert quat0.shape == quat1.shape
        assert quat0.shape[-1] == 4

        # mujoco stores quats as (qw, qx, qy, qz)
        w0 = quat0[..., 3]
        x0 = quat0[..., 0]
        y0 = quat0[..., 1]
        z0 = quat0[..., 2]

        w1 = quat1[..., 3]
        x1 = quat1[..., 0]
        y1 = quat1[..., 1]
        z1 = quat1[..., 2]

        w = w0 * w1 - x0 * x1 - y0 * y1 - z0 * z1
        x = w0 * x1 + x0 * w1 + y0 * z1 - z0 * y1
        y = w0 * y1 + y0 * w1 + z0 * x1 - x0 * z1
        z = w0 * z1 + z0 * w1 + x0 * y1 - y0 * x1
        quat = np.stack([x, y, z, w], axis=-1)

        assert quat.shape == quat0.shape
        return quat

    def _rotate_torso(self, obs, quat_rotate_by):
        quat_curr = obs[0, 1:4+1]
        quat_result = self.quat_mul(quat0=quat_curr, quat1=quat_rotate_by)
        # quat already normalized
        obs[0, 1:4+1] = quat_result

    def _rotate_vel(self, obs, sin, cos):
        x = obs[:, 13].copy()
        y = obs[:, 14].copy()
        obs[:, 13] = x * cos - y * sin
        obs[:, 14] = x * sin + y * cos

    def _augment(self,
                obs: np.ndarray,
                next_obs: np.ndarray,
                action: np.ndarray,
                reward: np.ndarray,
                done: np.ndarray,
                infos: List[Dict[str, Any]],
                delta = None,
                p=None
                ):

        assert obs.shape[0] == 1 # for now.
        alpha = np.random.uniform(low=-self.noise_scale, high=+self.noise_scale)
        sin = np.sin(alpha/2)
        cos = np.cos(alpha/2)

        # mujoco stores quats as (qw, qx, qy, qz)
        quat_rotate_by = np.array([sin, 0, 0, cos])

        self._rotate_torso(obs, quat_rotate_by)
        self._rotate_torso(next_obs, quat_rotate_by)

        # Not sure why we need -alpha here...
        sin = np.sin(-alpha)
        cos = np.cos(-alpha)
        self._rotate_vel(obs, sin, cos)
        self._rotate_vel(next_obs, sin, cos)

        vx = infos[0][0]['x_velocity']
        vy = infos[0][0]['y_velocity']
        reward_forward = infos[0][0]['reward_forward']

        reward[:] -= reward_forward
        reward[:] += vx*cos - vy*sin

        return obs, next_obs, action, reward, done, infos


ANT_AUG_FUNCTIONS = {
    'reflect': AntReflect,
    'rotate': AntRotate,
}

