import gym
import torch

from policy.base import BasePolicy


class RandomPolicy(BasePolicy):

    def __init__(self, action_space: gym.spaces.Discrete):
        super().__init__()
        self.action_space = action_space

        self.explore_mode = False  # eval

    @property
    def replay_buffer_capacity(self) -> int:
        raise NotImplementedError

    @property
    def learn_batch_size(self) -> int:
        raise NotImplementedError

    @property
    def n_backprop_steps(self) -> int:
        raise NotImplementedError

    def explore(self, obs, h) -> tuple[torch.Tensor, torch.Tensor, None, None]:
        act = torch.tensor(self.action_space.sample()).unsqueeze(0).unsqueeze(0)
        return act, h, None, None

    def greedy(self, obs, h) -> tuple[torch.Tensor, torch.Tensor, None, None]:
        raise NotImplementedError

    def learn(self, memory) -> list[float]:
        raise NotImplementedError

    def forward(self, obs, h):
        raise NotImplementedError
