import copy
from typing import Any, List, Dict

import numpy as np

from augment.rl.augmentation_functions.panda.common import PANDA_AUG_FUNCTIONS, PandaGoalAugmentationFunction


class TranslateGoalProximal(PandaGoalAugmentationFunction):

    def __init__(self, env, p=0.5,  **kwargs):
        super().__init__(env=env, **kwargs)
        self.p = p

    def quaternion_multiply(self, q0, q1):
        w0, x0, y0, z0 = q0
        w1, x1, y1, z1 = q1
        return np.array([-x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0,
                         x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0,
                         -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0,
                         x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0], dtype=np.float64)

    def _sample_goals(self, next_obs, **kwargs):
        ep_length = next_obs.shape[0]
        achieved_goal = next_obs[:, self.env.achieved_idx]
        if np.random.random() < self.p:
            a = np.arccos(achieved_goal[:, 0])
            theta = np.random.uniform(-0.927, +0.927, size=(ep_length,))  # arccos(0.6) ~= +/-0.927
            q_rotation = np.array([
                np.cos(theta / 2),
                a * np.sin(theta / 2),
                a * np.sin(theta / 2),
                a * np.sin(theta / 2),
            ]).T
            new_goal = self.quaternion_multiply(achieved_goal.T, q_rotation.T).T
        else:
            new_goal = self.env.task._sample_n_goals(ep_length)
            at_goal = self.env.task.is_success(achieved_goal, new_goal).astype(bool)

            # resample if success (rejection sampling)
            while np.any(at_goal):
                new_goal[at_goal] = self.env.task._sample_n_goals(ep_length)[at_goal]
                at_goal = self.env.task.is_success(achieved_goal, new_goal).astype(bool)
        return new_goal


class TranslateGoalProximal0(TranslateGoalProximal):
    def __init__(self, env, **kwargs):
        super().__init__(env=env, p=0, **kwargs)


PANDA_FLIP_AUG_FUNCTIONS = copy.deepcopy(PANDA_AUG_FUNCTIONS)
PANDA_FLIP_AUG_FUNCTIONS.update(
    {
        'translate_goal_proximal': TranslateGoalProximal,
        'translate_goal_proximal_0': TranslateGoalProximal0,
    })
