import torch
import torch.nn as nn
import torch.nn.functional as F


class Expression(nn.Module):
    def __init__(self, func):
        super(Expression, self).__init__()
        self.func = func

    def forward(self, input):
        return self.func(input)


class Model(nn.Module):
    def __init__(self, i_c=1, n_c=10, act='ReLU'):
        super(Model, self).__init__()

        # self.conv1 = nn.Conv2d(i_c, 32, 5, stride=1, padding=2, bias=True)
        # self.pool1 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)
        #
        # self.conv2 = nn.Conv2d(32, 64, 5, stride=1, padding=2, bias=True)
        # self.pool2 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)
        #
        # self.flatten = Expression(lambda tensor: tensor.view(tensor.shape[0], -1))
        # self.fc1 = nn.Linear(7 * 7 * 64, 1024, bias=True)
        # self.fc2 = nn.Linear(1024, n_c)

        self.conv1 = nn.Conv2d(i_c, 32, 5, stride=1, padding=(2, 2), bias=True)
        self.pool1 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)

        self.conv2 = nn.Conv2d(32, 64, 5, stride=1, padding=2, bias=True)
        self.pool2 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)

        self.flatten = Expression(lambda tensor: tensor.view(tensor.shape[0], -1))
        self.fc1 = nn.Linear(14 * 7 * 64, 1024, bias=True)  # 14 * 7 * 64
        self.fc2 = nn.Linear(1024, n_c)
        self.relu_act = True if act == 'ReLU' else False

    def forward(self, x, _eval=False):
        if _eval:
            # switch to eval mode
            self.eval()
        else:
            self.train()

        d_ratio = 0.6
        s_r = 0.5  # 0.5  # 10
        x = self.conv1(x)
        x = torch.relu(x) if self.relu_act else torch.nn.Softplus(beta=s_r)(x)
        # x = F.dropout(x, p=d_ratio, training=(not _eval))
        x = self.pool1(x)

        x = self.conv2(x)
        x = torch.relu(x) if self.relu_act else torch.nn.Softplus(beta=s_r)(x)
        # x = F.dropout(x, p=d_ratio, training=(not _eval))
        x = self.pool2(x)

        x = self.flatten(x)
        x = self.fc1(x)
        x = torch.relu(x) if self.relu_act else torch.nn.Softplus(beta=s_r)(x)
        # x = F.dropout(x, p=d_ratio, training=(not _eval))

        self.train()
        return self.fc2(x)


if __name__ == '__main__':
    i = torch.FloatTensor(4, 1, 28, 28)

    n = Model()

    print(n(i).size())
