# /usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn


class Base(nn.Module):
    def __init__(self, **kwargs):
        super(Base, self).__init__()
        self.kwargs = kwargs

    def new(self):
        return self.__class__(**self.kwargs)


class CNN(Base):
    def __init__(self, num_input_channels=3, hid_dim=128, num_classes=2):
        super(CNN, self).__init__()
        self.enc = nn.Sequential(self.conv_block(num_input_channels, 32), self.conv_block(32, 64),
                                 self.conv_block(64, 64), self.conv_block(64, hid_dim))
        self.linear = nn.Linear(hid_dim, num_classes)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    def forward(self, x, mode='label'):
        x = self.enc(x)
        x = torch.mean(x, dim=(2, 3))

        if mode == 'representation':
            return x
        if mode == 'label':
            return self.linear(x)
