from abc import ABC, abstractmethod

import torch


class StochasticLayer(torch.nn.Module, ABC):

    def __init__(self, network: torch.nn.Module, output_dim: int):

        super().__init__()

        self._network = network
        self._output_dim = output_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._network(x)

    def initialize(self) -> None:
        self._network.initialize()

    @property
    def network(self) -> torch.nn.Module:
        return self._network

    @abstractmethod
    def build_distribution(self, x: torch.Tensor,
                           y_mean: torch.Tensor | None = None) -> torch.distributions.Distribution:
        pass
