# -*- coding: utf-8 -*-


from torch import nn
import torch.nn.functional as F


class SCN(nn.Module):
    def __init__(self, input_num, class_num):
        super(SCN, self).__init__()

        self.conv1 = nn.Conv2d(input_num, 6, 5)  
        self.conv2 = nn.Conv2d(6, 16, 5)  
        self.conv3 = nn.Conv2d(16, 32, 5)
        self.conv4 = nn.Conv2d(32, 64, 5)
        self.AAPool = nn.AdaptiveAvgPool2d(1)

        self.fc3 = nn.Linear(64, 2)
        self.fc4 = nn.Linear(2, class_num)


    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.AAPool(x)
        # x = F.max_pool2d(x, 2)
        x = self.conv4(x)
        x = F.relu(x)

        x = self.AAPool(x)
        embedding = x.detach()
        x = x.squeeze().squeeze()
        x = self.fc3(x)
        x = self.fc4(x)
        return x,embedding
