from typing import Union
import numpy as np
from collections import namedtuple
from pathlib import Path
import logging
import sys

PROJECT_ROOT_DIR = Path(__file__).parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from rewards.reward_base import RewardBase
from utils.geometry_utils import angle_of_2_3d_vectors


class DenseReward(RewardBase):
    def __init__(self, b: float = 1., log_history_reward: bool = True, my_logger: logging.Logger = None) -> None:
        self.b = b
        super().__init__(is_potential=False, log_history_reward=log_history_reward, my_logger=my_logger)
    
    def get_reward(self, state: Union[namedtuple, np.ndarray], **kwargs) -> float:
        assert "goal_v" in kwargs, "args must include goal_v"
        assert "goal_mu" in kwargs, "args must include goal_mu"
        assert "goal_chi" in kwargs, "args must include goal_chi"
        assert "next_state" in kwargs, "args must incldue next_state"
        
        next_state = kwargs["next_state"]

        plane_current_velocity_vector = [
            next_state.v * np.cos(np.deg2rad(next_state.mu)) * np.sin(np.deg2rad(next_state.chi)), 
            next_state.v * np.cos(np.deg2rad(next_state.mu)) * np.cos(np.deg2rad(next_state.chi)), 
            next_state.v * np.sin(np.deg2rad(next_state.mu))
        ]
        target_velocity_vector = [
            kwargs["goal_v"] * np.cos(np.deg2rad(kwargs["goal_mu"])) * np.sin(np.deg2rad(kwargs["goal_chi"])), 
            kwargs["goal_v"] * np.cos(np.deg2rad(kwargs["goal_mu"])) * np.cos(np.deg2rad(kwargs["goal_chi"])),
            kwargs["goal_v"] * np.sin(np.deg2rad(kwargs["goal_mu"])),
        ]

        angle = angle_of_2_3d_vectors(plane_current_velocity_vector, target_velocity_vector)

        return -np.power(angle / 180., self.b)

    def reset(self):
        super().reset()