import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from abc import ABC, abstractmethod

class AbstractAutoencoder(nn.Module, ABC):
    def __init__(self):
        super(AbstractAutoencoder, self).__init__()

    @abstractmethod
    def forward(self, x):
        pass # NOTE returns logits

    @abstractmethod
    def semantic_BHWC_to_onehot_with_background(self, x):
        """
        x: (batchsize, H, W, encoding_dim)
        return onehots: (batchsize, H, W, num_classes)
        """
        pass

    @abstractmethod
    def groupids_to_semantics(self, groupids):
        """
        groupids: (N,) dtype=torch.int32
        return semantics: (N, 3)
        """
        pass

    @abstractmethod
    def get_bg4semantics(self, H, W):
        """
        H, W: int
        return bg: (H, W, 3)
        """
        pass

