from abc import abstractmethod, ABC
from typing import (
    Callable,
    Dict,
    Generic,
    Literal,
    Tuple,
    TypeVar,
    List,
    Union,
    Any,
    Union,
    Optional,
)
from utils.common import Action, Info, Reward, ActionInfo, State
from utils.transition import TransitionTuple
from utils.reporter import Reporter, ReportTrait

Mode = Union[Literal["train"], Literal["eval"]]
ReportInfo = Info

Config = Dict[str, Any]
Params = Dict[str, Any]


class Algorithm(ABC, ReportTrait):
    trained_steps: int = 0

    def __init__(self,
                 name: str,
                 config: Config,
                 params: Optional[Params] = None):
        ABC.__init__(self)
        ReportTrait.__init__(self, )
        assert not hasattr(self, "name")
        self.name = name

        self.p = {**config, **(params or {})}

    @abstractmethod
    def take_action(self, mode: Mode, state: State,
                    env) -> Union[ActionInfo, Action]:
        raise NotImplementedError()

    @abstractmethod
    def eval(self, info: Info):
        raise NotImplementedError()
    
    def pretrain(self, info: Info):
        ...

    def valid(self, info: Info):
        ...

    @abstractmethod
    def manual_train(self, info: Info):
        raise NotImplementedError()

    def after_step(self, mode: Mode, env, transitions: List[TransitionTuple]):
        ...
    
    def save(self, times: int):
        ...

    def on_episode_termination(self, mode: Mode, env, index: int):
        ...

    def set_reporter(self, reporter: Reporter):
        self.add_reporter(reporter)
