import torch
from torch import nn
from torch.nn import functional as F

from abc import abstractmethod
from typing import List, Callable, Union, Any, TypeVar, Tuple


Tensor = TypeVar('torch.tensor')

class Basenet(nn.Module):

    def __init__(self) -> None:
        super(Basenet, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor:
        raise RuntimeWarning()

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass