from torch import nn

from ..utils import DistriConfig


class BaseModule(nn.Module):
    def __init__(
        self,
        module: nn.Module,
        distri_config: DistriConfig,
    ):
        super(BaseModule, self).__init__()
        self.module = module
        self.distri_config = distri_config
        self.comm_manager = None

        self.counter = 0

        self.buffer_list = None
        self.idx = None

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def set_counter(self, counter: int = 0):
        self.counter = counter

    def set_comm_manager(self, comm_manager):
        self.comm_manager = comm_manager
