import torch
from torch import nn


class NeuralNetworkImage(nn.Module):

    def __init__(self):
        super(NeuralNetworkImage, self).__init__()
        self.image_features_ = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Dropout(),
            nn.Conv2d(16, 64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=1, stride=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=1),
            nn.Conv2d(64, 16, kernel_size=5, padding=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(),
        )

        self.numeric_features_ = nn.Sequential(
            nn.Linear(5, 16),
            nn.ReLU(inplace=True),
            nn. Dropout(),
            nn.Linear(16, 8),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(8, 8),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(8, 16 * 16),
            nn.ReLU(inplace=True),
            nn.Dropout(),
        )
        self.combined_features_ = nn.Sequential(
            nn.Linear(16*16*2, 8),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(8, 8),
            nn.ReLU(inplace=True),
            nn.Linear(8, 16),
            nn.Linear(16, 1),
        )

    def forward(self, x, y):
        x = self.image_features_(x)
        x = x.view(-1, 16*16)
        y = self.numeric_features_(y)

        z = torch.cat((x, y), 1)
        z = self.combined_features_(z)
        return z
