from .layers import *
from torch import nn


class LayerBuilder:
    def __init__(self, config):
        self.config = config

    def conv(
        self,
        kernel_size,
        in_planes,
        out_planes,
        groups=1,
        stride=1,
        bn=False,
        zero_init_bn=False,
        act=False,
    ):
        conv = nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size=kernel_size,
            groups=groups,
            stride=stride,
            padding=int((kernel_size - 1) / 2),
            bias=False,
        ) 

        nn.init.kaiming_normal_(
            conv.weight, mode=self.config['conv_init'], nonlinearity="relu"
        )

        layers = [("conv", conv)]
        if bn:
            layers.append(("bn", self.batchnorm(out_planes, zero_init_bn)))
        if act:
            layers.append(("act", self.activation()))

        if bn or act:
            return nn.Sequential(OrderedDict(layers))
        else:
            return conv

    def convDepSep(
        self,
        kernel_size,
        in_planes,
        out_planes,
        stride=1,
        bn=False,
        act=False,
    ):
        """3x3 depthwise separable convolution with padding"""
        c = self.conv(
            kernel_size,
            in_planes,
            out_planes,
            groups=in_planes,
            stride=stride,
            bn=bn,
            act=act,
        )
        return c

    def batchnorm(self, planes, zero_init=False):
        bn_cfg = {}
        if self.config['bn_momentum'] is not None:
            bn_cfg["momentum"] = self.config['bn_momentum']
        if self.config['bn_epsilon'] is not None:
            bn_cfg["eps"] = self.config['bn_epsilon']

        bn = nn.BatchNorm2d(planes, **bn_cfg)
        gamma_init_val = 0 if zero_init else 1
        nn.init.constant_(bn.weight, gamma_init_val)
        nn.init.constant_(bn.bias, 0)

        return bn

    def activation(self):
        return {
            "silu": lambda: nn.SiLU(inplace=True),
            "relu": lambda: nn.ReLU(inplace=True),
        }[self.config['activation']]()


class EFATLayerBuilder(LayerBuilder):
    def __init__(self, config):
        self.config = config

    def conv(
        self,
        kernel_size,
        in_planes,
        out_planes,
        groups=1,
        stride=1,
        bn=False,
        zero_init_bn=False,
        act=False,
        EFAT_kwargs={},
    ):
        conv = EFATConv(
                        in_planes,
                        out_planes,
                        nn.Conv2d,
                        {
                            'kernel_size':kernel_size,
                            'groups':groups,
                            'stride':stride,
                            'padding':int((kernel_size - 1) / 2),
                            'bias':False,
                        },
                        **EFAT_kwargs,
                    )

        for fn in [conv.fun_l, conv.fun_r]:
            nn.init.kaiming_normal_(
                fn.weight, mode=self.config['conv_init'], nonlinearity="relu"
            )

        layers = [("conv", conv)]
        if bn:
            layers.append(("bn", self.batchnorm(out_planes, zero_init_bn)))
        if act:
            layers.append(("act", self.activation()))

        if bn or act:
            return nn.Sequential(OrderedDict(layers))
        else:
            return conv

