
from torch import nn

# from .batch_norm import FrozenBatchNorm2d


class CNNBlockBase(nn.Module):
    """
    A CNN block is assumed to have input channels, output channels and a stride.
    The input and output of `forward()` method must be NCHW tensors.
    The method can perform arbitrary computation but must match the given
    channels and stride specification.

    Attribute:
        in_channels (int):
        out_channels (int):
        stride (int):
    """

    def __init__(self, in_channels, out_channels, stride):
        """
        The `__init__` method of any subclass should also contain these arguments.

        Args:
            in_channels (int):
            out_channels (int):
            stride (int):
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride

    # def freeze(self):
    #     """
    #     Make this block not trainable.
    #     This method sets all parameters to `requires_grad=False`,
    #     and convert all BatchNorm layers to FrozenBatchNorm

    #     Returns:
    #         the block itself
    #     """
    #     for p in self.parameters():
    #         p.requires_grad = False
    #     FrozenBatchNorm2d.convert_frozen_batchnorm(self)
    #     return self
