import torch.nn as nn
import torch

# for imagenet


def first_block(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(),
      nn.MaxPool2d(2)
  )


def conv_block(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, 3, padding=1),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(),
      nn.MaxPool2d(2)
  )


class ConvNet(nn.Module):
  def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
    super().__init__()
    self.block1 = first_block(x_dim, hid_dim)
    self.block2 = conv_block(hid_dim, hid_dim)
    self.block3 = conv_block(hid_dim, hid_dim)
    self.block4 = conv_block(hid_dim, z_dim)
    self.avgpool = nn.AvgPool2d(7)
    self.out_dim = 512

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)

    x = self.avgpool(x)
    features = x.view(x.shape[0], -1)

    return {
        "features": features
    }


class GeneralizedConvNet(nn.Module):
  def __init__(self, x_dim=3, hid_dim=128, z_dim=512):
    super().__init__()
    self.block1 = first_block(x_dim, hid_dim)
    self.block2 = conv_block(hid_dim, hid_dim)
    self.block3 = conv_block(hid_dim, hid_dim)

  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    return x


class SpecializedConvNet(nn.Module):
  def __init__(self, hid_dim=128, z_dim=512):
    super().__init__()
    self.block4 = conv_block(hid_dim, z_dim)
    self.avgpool = nn.AvgPool2d(7)
    self.feature_dim = 512

  def forward(self, x):
    x = self.block4(x)
    x = self.avgpool(x)
    features = x.view(x.shape[0], -1)
    return features


def conv4():
  model = ConvNet()
  return model


def conv_a2fc_imagenet():
  _base = GeneralizedConvNet()
  _adaptive_net = SpecializedConvNet()
  return _base, _adaptive_net