import torch.nn as nn
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d


def _weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")


class PlainResnet18(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=2)
        conv2 = self.block(64, downscale=False)
        conv3 = self.block(64)
        conv4 = self.block(128)
        conv5 = self.block(256)
        pool = SelectAdaptivePool2d(
            pool_type="avg", flatten=nn.Flatten(start_dim=1, end_dim=-1)
        )
        fc = nn.Linear(512, 10)

        self._net = nn.Sequential(conv1, conv2, conv3, conv4, conv5, pool, fc)
        self._net.apply(_weights_init)

    def forward(self, x):
        return self._net(x)

    @staticmethod
    def block(din, downscale=True):
        if not downscale:
            dout = din
            scale = 1
        else:
            dout = 2 * din
            scale = 2
        pdropout = 0.4
        return nn.Sequential(
            nn.Conv2d(din, dout, 3, padding=1, stride=scale, bias=False),
            nn.Dropout(pdropout),
            nn.ReLU(),
            nn.Conv2d(dout, dout, 3, padding=1, stride=1, bias=False),
            nn.Dropout(pdropout),
            nn.ReLU(),
            nn.Conv2d(dout, dout, 3, padding=1, stride=1, bias=False),
            nn.Dropout(pdropout),
            nn.ReLU(),
            nn.Conv2d(dout, dout, 3, padding=1, stride=1, bias=False),
            nn.Dropout(pdropout),
            nn.ReLU(),
        )
