from typing import Tuple, List
from collections import namedtuple
import logging
from pathlib import Path
import sys
import numpy as np

PROJECT_ROOT_DIR = Path(__file__).parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from terminations.termination_base import TerminationBase
from utils.geometry_utils import angle_of_2_3d_vectors


class ContinuouselyMoveAwayTermination2(TerminationBase):
    def __init__(self,
        time_window: float=2.,
        ignore_velocity_vector_angle_error: float=1.,
        termination_reward: float=-1.,
        is_termination_reward_based_on_steps_left: bool=False,
        env_config: dict=None,
        my_logger: logging.Logger=None
    ) -> None:
        super().__init__(
            termination_reward=termination_reward, 
            is_termination_reward_based_on_steps_left=is_termination_reward_based_on_steps_left,
            env_config=env_config,
            my_logger=my_logger
        )

        self.time_window = time_window
        self.ignore_velocity_vector_angle_error = ignore_velocity_vector_angle_error

        self.prev_velocity_vector_errors = 180.
        self.velocity_vector_continuously_increasing_num = -1

    
    def _get_termination(self, state: namedtuple, goal_v: float, goal_mu: float, goal_chi: float, ve: float, vn: float, vh: float):
        
        plane_current_velocity_vector = [
            ve, 
            vn, 
            vh
        ]

        target_velocity_vector = [
            goal_v * np.cos(np.deg2rad(goal_mu)) * np.sin(np.deg2rad(goal_chi)), 
            goal_v * np.cos(np.deg2rad(goal_mu)) * np.cos(np.deg2rad(goal_chi)),
            goal_v * np.sin(np.deg2rad(goal_mu)),
        ]

        cur_velocity_vector_error = angle_of_2_3d_vectors(plane_current_velocity_vector, target_velocity_vector)
        
        if cur_velocity_vector_error <= self.ignore_velocity_vector_angle_error:
            self.velocity_vector_continuously_increasing_num = 0
        else:
            self.velocity_vector_continuously_increasing_num = 0 if cur_velocity_vector_error <= self.prev_velocity_vector_errors else self.velocity_vector_continuously_increasing_num + 1

        self.prev_velocity_vector_errors = cur_velocity_vector_error

        if self.velocity_vector_continuously_increasing_num >= self.time_window_step_length:

            if self.logger is not None:
                self.logger.info(f"the error continuous change larger for {self.time_window_step_length} steps!!!")
            
            terminated, truncated = True, False
        else:
            terminated, truncated = False, False

        return terminated, truncated
    
    def get_termination(self, state: namedtuple, **kwargs) -> Tuple[bool, bool]:
        assert "goal_v" in kwargs, "args must include goal_v"
        assert "goal_mu" in kwargs, "args must include goal_mu"
        assert "goal_chi" in kwargs, "args must include goal_chi"

        return self._get_termination(
            state=state, 
            goal_v=kwargs["goal_v"],
            goal_mu=kwargs["goal_mu"],
            goal_chi=kwargs["goal_chi"],
            ve=state.v * np.cos(np.deg2rad(state.mu)) * np.sin(np.deg2rad(state.chi)), 
            vn=state.v * np.cos(np.deg2rad(state.mu)) * np.cos(np.deg2rad(state.chi)), 
            vh=state.v * np.sin(np.deg2rad(state.mu))
        )
    
    def get_termination_and_reward(self, state: namedtuple, **kwargs) -> Tuple[bool, bool, float]:
        assert "goal_v" in kwargs, "args must include goal_v"
        assert "goal_mu" in kwargs, "args must include goal_mu"
        assert "goal_chi" in kwargs, "args must include goal_chi"
        assert "step_cnt" in kwargs, "args must include step_cnt"

        terminated, truncated = self._get_termination(
            state=state, 
            goal_v=kwargs["goal_v"],
            goal_mu=kwargs["goal_mu"],
            goal_chi=kwargs["goal_chi"],
            ve=state.v * np.cos(np.deg2rad(state.mu)) * np.sin(np.deg2rad(state.chi)), 
            vn=state.v * np.cos(np.deg2rad(state.mu)) * np.cos(np.deg2rad(state.chi)), 
            vh=state.v * np.sin(np.deg2rad(state.mu))
        )
        # reward = self.termination_reward if terminated else 0.

        return terminated, truncated, self.get_termination_penalty(terminated=terminated, steps_cnt=kwargs["step_cnt"])

    def reset(self):
        self.prev_velocity_vector_errors = 180.
        self.velocity_vector_continuously_increasing_num = -1

    @property
    def time_window_step_length(self) -> int:
        return round(self.time_window * self.step_frequence)

    def __str__(self) -> str:
        return "continuousely_move_away_termination_based_on_velocity_vector_error"