import abc
from abc import abstractmethod
from typing import Any, Dict, List, Tuple, Union

import torch as th  # type:ignore

from .loggers import WandbLogger


class RLAlgorithm(abc.ABC):

    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def act(
        self, state: th.Tensor, accrued_reward: th.Tensor
    ) -> Tuple[th.Tensor, Any]:
        pass

    @abstractmethod
    def train(
        self,
        n_episodes: int,
        eval_env: Any,
        n_evals: int = 16,
        eval_freq: int = 100,
        log: bool = False,
        logger: WandbLogger = None,
    ) -> None:
        pass

    @abstractmethod
    def eval(
        self, eval_env: Any, n_evals: int
    ) -> Tuple[th.Tensor, th.Tensor, Dict[str, Union[float, int]]]:
        pass

    @abstractmethod
    def update(self) -> None:
        pass

    @abstractmethod
    def play_best_policy(self, env: Any) -> None:
        pass

    def save_best_policy(self, path) -> None:
        pass


class DecentralizedRLAlgorithm(RLAlgorithm):

    def __init__(self, n_agents: int) -> None:
        super().__init__()
        self.n_agents = n_agents
        self.agents: List[RLAlgorithm] = []
