from abc import ABC, abstractmethod

import gym
import torch


class BaseEncoder(ABC, torch.nn.Module):
    """
    pixel observation -> latent
    """

    @abstractmethod
    def __init__(self, observation_space: gym.spaces.Box, latent_size: int):
        super().__init__()
        self.latent_size = latent_size
        self.observation_space = observation_space

    @abstractmethod
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        pass
