import numpy as np

from Policy.Reward.reward import RewardTerminateTruncate
from collections import deque

class GoalDistanceReward(RewardTerminateTruncate):
    """
    returns -1 reward unless a goal is reached, then returns term=True and 0 reward
    timeouts according to lower policy time
    """

    def __init__(self, **kwargs):
        # initialize hyperparameters
        super().__init__(**kwargs)
        config = kwargs["config"]
        self.goal_epsilon = config.policy.reward.target_goal_epsilon
        self.goal_shaping = config.policy.reward.target_goal_shaping
        self.goal_negative_constant = config.policy.reward.reached_goal_negative_constant
        self.reached_graph_indices = config.policy.reward.reached_graph_indices
        self.adaptive_radius_rate = config.policy.reward.adaptive_radius_rate
        self.terminate_on_goal = config.policy.reward.terminate_on_goal
        self.horizon = config.policy.reward.timeout
        self.goal_scale = config.policy.reward.goal_scale
        self.average_normalized_rewards_queue = deque(maxlen=30)
        self.average_normalized_rewards = self.horizon * self.goal_negative_constant
        self.passive_graph = np.concatenate([np.zeros((config.num_factors, 1)), np.eye(config.num_factors)], axis=-1)

    def _reached_goal(self, batch):
        # if self.goal_shaping > 0:
        #     dist = np.linalg.norm(batch.obs.desired_goal - batch.obs.achieved_goal, axis=-1)
        #     inside = dist < self.goal_epsilon
        #     return np.exp(-dist / self.goal_shaping) * inside.astype(float)
        goal_epsilon = self.goal_epsilon
        # print(batch.obs.desired_goal[0], batch.obs.achieved_goal[0], np.linalg.norm(batch.obs.desired_goal - batch.obs.achieved_goal, axis=-1), goal_epsilon)
        if self.adaptive_radius_rate > 0:
            goal_epsilon = self.goal_epsilon * (np.mean(self.average_normalized_rewards) / (self.goal_negative_constant * self.horizon) + self.adaptive_radius_rate)
        return np.linalg.norm(batch.obs.desired_goal - batch.obs.achieved_goal, axis=-1) < goal_epsilon

    def _reached_nontrivial_graph(self, batch): 
        nontrivial_graphs = batch.true_graph - np.expand_dims(self.passive_graph, 0)
        check_graphs = nontrivial_graphs[:,self.reached_graph_indices] # assumes batch x object being effected, objects
        check_graphs = check_graphs.sum(axis=-1).sum(axis=-1).astype(bool).astype(float)
        return check_graphs

    def update_state_counts(self, batch):
        pass # TODO: write a state cound manager

    def update_statistics(self, results):
        if "rew" in results:
            self.average_normalized_rewards_queue.append(results.rew)
            self.average_normalized_rewards = np.mean(self.average_normalized_rewards_queue)

    def check_reached(self, batch):
        return self._reached_nontrivial_graph(batch), self._reached_goal(batch)

    def rew(self, batch):
        reached = self._reached_goal(batch)
            
        # return euclidean distance between the two points
        achieved, desired = batch.obs.achieved_goal, batch.obs.desired_goal
        if len(batch.obs.achieved_goal.shape) == 1:
            achieved, desired = np.expand_dims(achieved, axis=0), np.expand_dims(desired, axis=0)
        dist = np.square(np.linalg.norm(achieved - desired, axis=-1) / 2)

        # bonus = 10 if self.task_env.current_timestep > self.task_env.falling_time else 0 # this prevents the falling initiliazwed puck from triggering a success
        reward = - dist * (1-reached).astype(float) +  self.goal_scale * (reached).astype(float)
        return reward


        # print(reached,batch.obs.desired_goal, batch.obs.achieved_goal, self.goal_epsilon, self.goal_negative_constant)
        return reached.astype(float) + self.goal_negative_constant

    def term(self, batch):
        if self.terminate_on_goal:
            # print("TERMINATING", self._reached_goal(batch), np.linalg.norm(batch.obs.desired_goal - batch.obs.achieved_goal, axis=-1), self.goal_epsilon, batch.obs.desired_goal, batch.obs.achieved_goal)
            return self._reached_goal(batch)
        else:
            return False

    def trunc(self, batch):
        return batch.time >= self.timeout
