from abc import ABC, abstractmethod
import numpy as np
from collections import namedtuple
from typing import Union, Callable
import logging
from pathlib import Path
import sys

PROJECT_ROOT_DIR = Path(__file__).parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from rewards.reward_base import RewardBase
from utils.geometry_utils import angle_of_2_3d_vectors


class PonentialRewardBasedOnAngle(RewardBase, ABC):
    def __init__(self, b: float=1., gamma: float=0.99, log_history_reward: bool = True, my_logger: logging.Logger = None) -> None:
        self.b = b
        self.gamma = gamma
        super().__init__(is_potential=False, log_history_reward=log_history_reward, my_logger=my_logger)
    
    def get_reward(self, state: Union[namedtuple, np.ndarray], **kwargs) -> float:
        assert "goal_v" in kwargs, "args must include goal_v"
        assert "goal_mu" in kwargs, "args must include goal_mu"
        assert "goal_chi" in kwargs, "args must include goal_chi"
        assert "next_state" in kwargs, "args must incldue next_state"
        assert "done" in kwargs, "args must include done"

        next_state = kwargs["next_state"]
        done = kwargs["done"]

        reward = self.gamma * (0. if done else self.phi(next_state, kwargs["goal_v"], kwargs["goal_mu"], kwargs["goal_chi"])) - self.phi(state, kwargs["goal_v"], kwargs["goal_mu"], kwargs["goal_chi"])
        
        return self._process(new_reward=reward)

    def phi(self, state: namedtuple, goal_v: float, goal_mu: float, goal_chi: float):
        plane_current_velocity_vector = [
            state.v * np.cos(np.deg2rad(state.mu)) * np.sin(np.deg2rad(state.chi)), 
            state.v * np.cos(np.deg2rad(state.mu)) * np.cos(np.deg2rad(state.chi)),
            state.v * np.sin(np.deg2rad(state.mu)),
        ]

        target_velocity_vector = [
            goal_v * np.cos(np.deg2rad(goal_mu)) * np.sin(np.deg2rad(goal_chi)), 
            goal_v * np.cos(np.deg2rad(goal_mu)) * np.cos(np.deg2rad(goal_chi)),
            goal_v * np.sin(np.deg2rad(goal_mu)),
        ]

        angle = angle_of_2_3d_vectors(plane_current_velocity_vector, target_velocity_vector)

        return -np.power(angle / 180., self.b)

    def reset(self):
        super().reset()