from typing import Tuple, List
from collections import namedtuple
import logging
from pathlib import Path
import sys

PROJECT_ROOT_DIR = Path(__file__).parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from terminations.termination_base import TerminationBase


class ContinuouselyMoveAwayTermination(TerminationBase):
    def __init__(self,
        time_window: float=2.,
        ignore_mu_error: float=1.,
        ignore_chi_error: float=1.,
        termination_reward: float=-1.,
        is_termination_reward_based_on_steps_left: bool=False,
        env_config: dict=None,
        my_logger: logging.Logger=None
    ) -> None:
        super().__init__(
            termination_reward=termination_reward, 
            is_termination_reward_based_on_steps_left=is_termination_reward_based_on_steps_left,
            env_config=env_config,
            my_logger=my_logger
        )

        self.time_window = time_window
        self.ignore_mu_error = ignore_mu_error
        self.ignore_chi_error = ignore_chi_error

        self.prev_mu_errors = -360.
        self.mu_continuously_increasing_num = -1
        self.prev_chi_errors = -360.
        self.chi_continuously_increasing_num = -1

    
    def _get_termination(self, state: namedtuple, goal_v: float, goal_mu: float, goal_chi: float):
        cur_mu_error = abs(goal_mu - state.mu)

        tmp_chi_error = abs(goal_chi - state.chi)
        cur_chi_error = min(tmp_chi_error, 360. - tmp_chi_error)

        if cur_mu_error <= self.ignore_mu_error:
            self.mu_continuously_increasing_num = 0
        else:
            self.mu_continuously_increasing_num = 0 if cur_mu_error <= self.prev_mu_errors else self.mu_continuously_increasing_num + 1
        
        if cur_chi_error <= self.ignore_chi_error:
            self.chi_continuously_increasing_num = 0
        else:
            self.chi_continuously_increasing_num = 0 if cur_chi_error <= self.prev_chi_errors else self.chi_continuously_increasing_num + 1
        
        self.prev_mu_errors, self.prev_chi_errors = cur_mu_error, cur_chi_error

        if self.mu_continuously_increasing_num >= self.time_window_step_length and \
            self.chi_continuously_increasing_num >= self.time_window_step_length:

            if self.logger is not None:
                self.logger.info(f"mu and chi both change larger continously for {self.time_window_step_length} steps!!!")
            
            terminated, truncated = True, False
        else:
            terminated, truncated = False, False

        return terminated, truncated
    
    def get_termination(self, state: namedtuple, **kwargs) -> Tuple[bool, bool]:
        assert "goal_v" in kwargs, "args must include goal_v"
        assert "goal_mu" in kwargs, "args must include goal_mu"
        assert "goal_chi" in kwargs, "args must include goal_chi"

        return self._get_termination(
            state=state, 
            goal_v=kwargs["goal_v"], 
            goal_mu=kwargs["goal_mu"], 
            goal_chi=kwargs["goal_chi"]
        )
    
    def get_termination_and_reward(self, state: namedtuple, **kwargs) -> Tuple[bool, bool, float]:
        assert "goal_v" in kwargs, "args must include goal_v"
        assert "goal_mu" in kwargs, "args must include goal_mu"
        assert "goal_chi" in kwargs, "args must include goal_chi"
        assert "step_cnt" in kwargs, "args must include step_cnt"

        terminated, truncated = self._get_termination(
            state=state,
            goal_v=kwargs["goal_v"], 
            goal_mu=kwargs["goal_mu"], 
            goal_chi=kwargs["goal_chi"]
        )
        # reward = self.termination_reward if terminated else 0.

        return terminated, truncated, self.get_termination_penalty(terminated=terminated, steps_cnt=kwargs["step_cnt"])

    def reset(self):
        self.prev_mu_errors = -360.
        self.mu_continuously_increasing_num = -1
        self.prev_chi_errors = -360.
        self.chi_continuously_increasing_num = -1

    @property
    def time_window_step_length(self) -> int:
        return round(self.time_window * self.step_frequence)

    def __str__(self) -> str:
        return "continuousely_move_away_termination_based_on_mu_error_and_chi_error"