from abc import ABC, abstractmethod
from typing import Tuple
from collections import namedtuple
import logging
import numpy as np


class TerminationBase(ABC):

    def __init__(self, 
        termination_reward: float=-1.,
        is_termination_reward_based_on_steps_left: bool=False,
        env_config: dict=None,
        my_logger: logging.Logger=None
    ) -> None:
        self.termination_reward = termination_reward
        self.is_termination_reward_based_on_steps_left = is_termination_reward_based_on_steps_left
        self.env_config = env_config
        self.logger = my_logger

    @abstractmethod
    def get_termination(self, state: namedtuple, **kwargs) -> Tuple[bool, bool]:
        raise NotImplementedError
    
    @abstractmethod
    def get_termination_and_reward(self, state: namedtuple, **kwargs) -> Tuple[bool, bool, float]:
        raise NotImplementedError

    @abstractmethod
    def reset(self):
        raise NotImplementedError

    def get_penalty_base_on_steps_left(self, steps_cnt: int=1):
        return - (1 -  np.power(self.rl_gamma, self.max_episode_steps - steps_cnt)) / (1 - self.rl_gamma)
    
    def get_termination_penalty(self, terminated: bool=False, steps_cnt: int=1):
        if terminated:
            if self.is_termination_reward_based_on_steps_left:
                reward = self.get_penalty_base_on_steps_left(steps_cnt)
            else:
                reward = self.termination_reward
        else:
            reward = 0.
        
        return reward
    
    @property
    def rl_gamma(self):
        return self.env_config["task"].get("gamma")

    @property
    def step_frequence(self):
        return self.env_config["task"].get("step_frequence")

    @property
    def max_simulate_time(self):
        return self.env_config["task"].get("max_simulate_time")

    @property
    def max_episode_steps(self):
        return self.max_simulate_time * self.step_frequence