### Author: anonymized for review 

### All code and data released with this supplementary material uses 
### the following license:
### Creative Commons Attribution 4.0 International (CC BY 4.0)
### http://creativecommons.org/licenses/by/4.0

### This license permits use, sharing, adaptation, distribution and 
### reproduction in any medium or format, as long as you give 
### appropriate credit to the original authors and the paper with 
### title ''Learning to See Topological Properties in 4D'', provide 
### a link to the Creative Commons license, and indicate if changes 
### were made.


import numpy as np

import torch
from torch import nn
from torch.nn import functional as F

from layers4d import Conv4de, MaxPool4de


###=== Network parameters

distribution_of_classes = np.array([1,8,8,8])
num_classes = distribution_of_classes + 1
num_classes = sum(num_classes)

out_chan = 8
ker_size = 5
padding = 1
pool_ker_size = 2
reshape_size = 32*(2**4)


###=== The network

class Network(torch.nn.Module):
  def __init__(self):
    super().__init__()
    
    
    self.conv1 = Conv4de(1,out_chan, kernel_size=ker_size)

    self.pool1 = MaxPool4de( pool_ker_size )

    self.conv2 = Conv4de(out_chan,out_chan*2,kernel_size=ker_size)
    self.pool2 = MaxPool4de( pool_ker_size )
    
    self.conv3 = Conv4de(out_chan*2,out_chan*4,kernel_size=ker_size)
    self.pool3 = MaxPool4de( pool_ker_size )

    self.fc1 = nn.Linear( reshape_size, int(reshape_size/2) )
    self.fc2 = nn.Linear( int(reshape_size/2), num_classes )


  def forward(self,x):
    x = F.pad(x, (padding,padding,padding,padding,padding,padding,padding,padding), 'constant', value = 1)
    x = self.conv1(x)
    x = F.relu(x)
    x = self.pool1(x)

    x = F.pad(x, (padding,padding,padding,padding,padding,padding,padding,padding), 'constant')
    x = self.conv2(x)
    x = F.relu(x)
    x = self.pool2(x)

    x = F.pad(x, (padding,padding,padding,padding,padding,padding,padding,padding), 'constant')
    x = self.conv3(x)
    x = F.relu(x)
    x = self.pool3(x)
  
    x = x.reshape(x.shape[0],-1)

    x = self.fc1(x)
    x = F.relu(x)
    x = self.fc2(x)    
    
    
    return x[:, 0:distribution_of_classes[0]+1], x[:, distribution_of_classes[0]+1 : distribution_of_classes[0]+1 + distribution_of_classes[1]+1], x[:, distribution_of_classes[0]+1 + distribution_of_classes[1]+1 : distribution_of_classes[0]+1 + distribution_of_classes[1]+1 + distribution_of_classes[2]+1], x[:, distribution_of_classes[0]+1 + distribution_of_classes[1]+1 + distribution_of_classes[2]+1 : distribution_of_classes[0]+1 + distribution_of_classes[1]+1 + distribution_of_classes[2]+1 + distribution_of_classes[3]+1  ]
