import numpy as np

from augment.rl.augmentation_functions.augmentation_function import GoalAugmentationFunction


#######################################################################################################################
#######################################################################################################################

class PandaGoalAugmentationFunction(GoalAugmentationFunction):
    def __init__(self, env, **kwargs):
        super().__init__(env=env, **kwargs)
        self.achieved_goal_mask = self.env.achieved_idx
        self.desired_goal_mask = self.env.goal_idx

    def _is_at_goal(self, achieved_goal, desired_goal, **kwargs):
        return self.env.task.is_success(achieved_goal, desired_goal).astype(bool)

    def _compute_reward(self, achieved_goal, desired_goal, infos=None, **kwargs):
        return self.env.task.compute_reward(achieved_goal, desired_goal, infos)

class TranslateGoal(PandaGoalAugmentationFunction):
    def __init__(self, env, **kwargs):
        super().__init__(env=env, **kwargs)

    def _sample_goals(self, next_obs, **kwargs):
        n = next_obs.shape[0]
        return self.env.task._sample_n_goals(n)

class TranslateGoalProximal(PandaGoalAugmentationFunction):
    def __init__(self, env, p=0.5, **kwargs):
        super().__init__(env=env, **kwargs)
        self.p = p

    def _sample_goal_noise(self, n, **kwargs):
        r = np.random.uniform(0, self.env.task.distance_threshold, size=n)
        theta = np.random.uniform(-np.pi, np.pi, size=n)
        phi = np.random.uniform(-np.pi / 2, np.pi / 2, size=n)
        dx = r * np.sin(phi) * np.cos(theta)
        dy = r * np.sin(phi) * np.sin(theta)
        dz = r * np.cos(phi)
        dz[:] = 0
        noise = np.array([dx, dy, dz]).T
        return noise

    def _sample_goals(self, next_obs, **kwargs):
        n = next_obs.shape[0]
        if np.random.random() < self.p:
            noise = self._sample_goal_noise(n)
            new_goal = next_obs[:, self.env.goal_idx] + noise
        else:
            new_goal = self.env.task._sample_n_goals(n)
            achieved_goal = next_obs[:, self.env.achieved_idx]
            at_goal = self.env.task.is_success(achieved_goal, new_goal).astype(bool)

            # resample if success (rejection sampling)
            while np.any(at_goal):
                new_goal[at_goal] = self.env.task._sample_n_goals(n)[at_goal]
                at_goal = self.env.task.is_success(achieved_goal, new_goal).astype(bool)
        return new_goal


class TranslateGoalProximal0(TranslateGoalProximal):
    def __init__(self, env, **kwargs):
        super().__init__(env=env, p=0, **kwargs)


PANDA_AUG_FUNCTIONS = {
    'translate_goal': TranslateGoal,
    'translate_goal_proximal': TranslateGoalProximal,
    'translate_goal_proximal_0': TranslateGoalProximal0,
}

