from abc import ABC, abstractmethod

import torch

from collector.replay_buffer.episode_replay import TrajectoryReplayBuffer


class BasePolicy(ABC, torch.nn.Module):

    @property
    @abstractmethod
    def ON_POLICY(self) -> bool:
        pass

    @property
    @abstractmethod
    def REPLAY_BUFFER_CAPACITY(self) -> int:
        pass

    @property
    @abstractmethod
    def LEARN_BATCH_SIZE(self) -> int:
        pass

    @property
    @abstractmethod
    def n_backprop_steps(self) -> int:
        pass

    @abstractmethod
    def explore(
            self,
            obs: torch.Tensor,
            h: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, tuple[torch.Tensor]]:
        pass

    @abstractmethod
    def greedy(
            self,
            obs: torch.Tensor,
            h: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, tuple[torch.Tensor]]:
        pass

    @abstractmethod
    def learn(self, memory: TrajectoryReplayBuffer) -> tuple[dict[str, float], dict[str, float]]:
        pass

    # @abstractmethod
    # def forward(
    #         self,
    #         obs: torch.Tensor,
    #         h: torch.Tensor
    # ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    #     pass
