import torch
import torch.nn as nn


class ToyConvBinary(nn.Module):
    def __init__(self, input_channel):
        super(ToyConvBinary, self).__init__()
        features = [
            nn.Conv2d(input_channel, 8, 5),
            nn.AvgPool2d(4, 4),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, 8),
            nn.AvgPool2d(2, 2),
            nn.Flatten()
        ]
        self.features = nn.Sequential(*features)
        endo_map = [
            nn.Linear(16 * 4 * 4, 16 * 4 * 4)
        ]
        self.endo_map = nn.Sequential(*endo_map)
        dense = [
            nn.Linear(16 * 4 * 4, 1)
        ]
        self.dense = nn.Sequential(*dense)

    def forward(self, x):
        x = x.view((-1, 1, 64, 64))
        feats = self.features(x)
        feat_maps = self.endo_map(feats)
        # feats = feats.view(-1, 64*4*4)
        out = torch.sigmoid(self.dense(feat_maps))
        return out


class BirdSimpleConv(nn.Module):
    def __init__(self, endo=True):
        super(BirdSimpleConv, self).__init__()
        self.use_endo_layer = endo
        features = [
            nn.Conv2d(3, 8, 25),
            nn.AvgPool2d(4, 4),
            nn.Conv2d(8, 16, 19),
            nn.AvgPool2d(4, 4),
            nn.Flatten()
        ]
        self.features = nn.Sequential(*features)
        endo_map = [
            nn.Linear(16 * 8 * 8, 16 * 8 * 8)
        ]
        self.endo_map = nn.Sequential(*endo_map)
        dense = [
            nn.Linear(16 * 8 * 8, 1)
        ]
        self.dense = nn.Sequential(*dense)

    def forward(self, x):
        x = x.view((-1, 3, 224, 224))
        feats = self.features(x)
        if self.use_endo_layer:
            feat_maps = self.endo_map(feats)
        else:
            feat_maps = feats
        # feats = feats.view(-1, 64*4*4)
        out = torch.sigmoid(self.dense(feat_maps))
        return out
