### Author: anonymized for review 

### All code and data released with this supplementary material uses 
### the following license:
### Creative Commons Attribution 4.0 International (CC BY 4.0)
### http://creativecommons.org/licenses/by/4.0

### This license permits use, sharing, adaptation, distribution and 
### reproduction in any medium or format, as long as you give 
### appropriate credit to the original authors and the paper with 
### title ''Learning to See Topological Properties in 4D'', provide 
### a link to the Creative Commons license, and indicate if changes 
### were made.


import torch
from torch import nn
import math


###=== 4D convolution layer (Extension)

class Conv4de(torch.nn.Module):
    def __init__(self, C_in, C_out, kernel_size, stride = 1, padding = 0, bias = True):
        super().__init__()

        self.C_in = C_in
        self.C_out = C_out
        self.kernel_size = kernel_size
        
        
        ###=== Shortcut for kernel dimensions
        if isinstance(kernel_size, int):
            self.kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
        elif isinstance(kernel_size, tuple) and len(kernel_size)==4:
            self.kernel_size = (kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
        
        if isinstance(stride, int) and stride <= 0:
            self.stride = self.kernel_size
        elif isinstance(stride, int):
            self.stride = (stride, stride, stride, stride)
        elif isinstance(stride, tuple) and len(stride)==4:
            self.stride = (stride[0], stride[1], stride[2], stride[3])
            
        self.padding = padding
        
        stdv = 1. / math.sqrt(self.kernel_size[0])
        
        self.is_bias = False
        
        if bias is True:
            self.is_bias = True
            self.bias = torch.nn.Parameter( torch.zeros( (C_out) ) )
            self.bias.data.uniform_(-stdv, stdv)  
        
        ###=== Use a ModuleList to store layers to make the Conv4d layer trainable
        self.conv3d_layers = torch.nn.ModuleList()

        ###=== Construct 3D convolutional layers
        
        for i in range(self.kernel_size[0]):

            ###=== Initialize a Conv3D layer
            conv3d_layer = torch.nn.Conv3d(self.C_in,
                                           self.C_out,
                                           kernel_size=self.kernel_size[1:],
                                           padding=0,
                                           stride=self.stride[1:],
                                           bias = False
                                          )
            

            ###=== Store the layer
            self.conv3d_layers.append(conv3d_layer)

    def forward(self, x):
        
        if isinstance(self.padding, int):
            _pad = (self.padding, self.padding, self.padding, self.padding, self.padding, self.padding, self.padding, self.padding)
        
        _input = torch.nn.functional.pad(x, _pad) ###=== (N, C, T, D, H, W)
                
        ###=== Output tensors for each 3D frame (padding is included in _input.shape)
        y = torch.zeros(_input.shape[0], 
                        self.C_out, 
                        int((_input.shape[2] - self.kernel_size[0])/self.stride[0] + 1), 
                        int((_input.shape[3] - self.kernel_size[1])/self.stride[1] + 1),
                        int((_input.shape[4] - self.kernel_size[2])/self.stride[2] + 1),
                        int((_input.shape[5] - self.kernel_size[3])/self.stride[3] + 1)).to(x.device)
                
        ###=== Convolve each kernel frame i with each _input frame j
        for i in range(self.kernel_size[0]):
            for j in range(i,_input.shape[2] - self.kernel_size[0] + i + 1, self.stride[0]):               
                y[:,:,int((j-i)/self.stride[0]),:,:,:] +=                     self.conv3d_layers[i](_input[:, :, j, :,:,:].view(_input.shape[0], 
                                                                      _input.shape[1], 
                                                                      _input.shape[3], 
                                                                      _input.shape[4], 
                                                                      _input.shape[5])).view(_input.shape[0], 
                                                                                             self.C_out, 
                                                                                             y.shape[3], 
                                                                                             y.shape[4], 
                                                                                             y.shape[5] )
                
        if self.is_bias:
            for i in range(y.shape[1]):
                y[:,i,:,:,:,:] += self.bias[i]
        
        return y
    


###=== 4D max pool layer (Extension)

class MaxPool4de(torch.nn.Module):
    def __init__(self, kernel_size, stride = 0, padding = 0):
        super().__init__()
        
        if isinstance(kernel_size, int):
            self.kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
        elif isinstance(kernel_size, tuple) and len(kernel_size)==4:
            self.kernel_size = (kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
        
        ###=== according to PyTorch default
        if isinstance(stride, int) and stride <= 0:
            self.stride = self.kernel_size
        elif isinstance(stride, int):
            self.stride = (stride, stride, stride, stride)
        elif isinstance(stride, tuple) and len(stride)==4:
            self.stride = (stride[0], stride[1], stride[2], stride[3])
            
        self.padding = padding

    def forward(self, x):
        
        if isinstance(self.padding, int):
            _pad = (self.padding, self.padding, self.padding, self.padding, self.padding, self.padding, self.padding, self.padding)
        
        _input = torch.nn.functional.pad(x, _pad) ###=== (N, C, T, D, H, W)
                
        ###=== Output tensors for each 3D frame
        y = torch.zeros(_input.shape[0], 
                        _input.shape[1],
                        int((_input.shape[2] - self.kernel_size[0])/self.stride[0] + 1), 
                        int((_input.shape[3] - self.kernel_size[1])/self.stride[1] + 1),
                        int((_input.shape[4] - self.kernel_size[2])/self.stride[2] + 1),
                        int((_input.shape[5] - self.kernel_size[3])/self.stride[3] + 1)).to(x.device)
        
        ###=== Convolve each kernel frame i with each _input frame j
        for i in range(y.shape[2]):
            for j in range(i*self.stride[0], i*self.stride[0] + self.kernel_size[0]):
                y[:,:,i,:,:,:] =                 torch.max( y.clone()[:,:,i,:,:,:], 
                          torch.nn.MaxPool3d(kernel_size=self.kernel_size[1:], stride=self.stride[1:])(_input[:, :, j, :,:,:].view(_input.shape[0], 
                                                                                                           _input.shape[1], 
                                                                                                           _input.shape[3], 
                                                                                                           _input.shape[4], 
                                                                                                           _input.shape[5])).view(y.shape[0], 
                                                                                                                                  y.shape[1], 
                                                                                                                                  y.shape[3], 
                                                                                                                                  y.shape[4], 
                                                                                                                                  y.shape[5] )
                         )
                
        return y


###=== 4D average pool layer (Extension)

class AvgPool4de(torch.nn.Module):
    def __init__(self, kernel_size, stride = 0, padding = 0):
        super().__init__()
        
        if isinstance(kernel_size, int):
            self.kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
        elif isinstance(kernel_size, tuple) and len(kernel_size)==4:
            self.kernel_size = (kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
        
        ###=== according to PyTorch default
        if isinstance(stride, int) and stride <= 0:
            self.stride = self.kernel_size
        elif isinstance(stride, int):
            self.stride = (stride, stride, stride, stride)
        elif isinstance(stride, tuple) and len(stride)==4:
            self.stride = (stride[0], stride[1], stride[2], stride[3])
        
        self.padding = padding
        
    def forward(self, x):
        
        if isinstance(self.padding, int):
            _pad = (self.padding, self.padding, self.padding, self.padding, self.padding, self.padding, self.padding, self.padding)
        
        _input = torch.nn.functional.pad(x, _pad) ###=== (N, C, T, D, H, W)
                
        ###=== Output tensors for each 3D frame
        y = torch.zeros(_input.shape[0], 
                        _input.shape[1],
                        int((_input.shape[2] - self.kernel_size[0])/self.stride[0] + 1),
                        int((_input.shape[3] - self.kernel_size[1])/self.stride[1] + 1), 
                        int((_input.shape[4] - self.kernel_size[2])/self.stride[2] + 1),
                        int((_input.shape[5] - self.kernel_size[3])/self.stride[3] + 1)).to(x.device)
        
        ###=== take the average and divide by kernel_size[0]
        for i in range(y.shape[2]):
            for j in range(i*self.stride[0], i*self.stride[0] + self.kernel_size[0]):
                y[:,:,i,:,:,:] +=                           torch.nn.AvgPool3d(kernel_size=self.kernel_size[1:], stride=self.stride[1:])(_input[:, :, j, :, :, :].view(_input.shape[0], 
                                                                                                             _input.shape[1], 
                                                                                                             _input.shape[3],
                                                                                                             _input.shape[4],
                                                                                                             _input.shape[5])).view(y.shape[0], 
                                                                                                                                    y.shape[1], 
                                                                                                                                    y.shape[3],
                                                                                                                                    y.shape[4], 
                                                                                                                                    y.shape[5]) / self.kernel_size[0] 
                         
                
        return y


###=== 4D convolution layer (Reformulation)

class Conv4dr(torch.nn.Module):
    def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, bias=None):
        super().__init__()
        
        self.C_in = C_in
        self.C_out = C_out
        
        if isinstance(kernel_size, int):
            self.kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
        elif isinstance(kernel_size, tuple) and len(kernel_size)==4:
            self.kernel_size = (kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
        
        if isinstance(stride, int) and stride <= 0:
            self.stride = self.kernel_size
        elif isinstance(stride, int):
            self.stride = (stride, stride, stride, stride)
        elif isinstance(stride, tuple) and len(stride)==4:
            self.stride = (stride[0], stride[1], stride[2], stride[3])
        
        self.padding = padding
            
        self.weight = torch.nn.Parameter( torch.zeros( (C_out, C_in, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], self.kernel_size[3]) ) )
            
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        
        self.is_bias = False
        
        if bias is True:
            self.is_bias = True
            self.bias = torch.nn.Parameter( torch.zeros( (C_out) ) )
            self.bias.data.uniform_(-stdv, stdv)
    
    def conv4d_subcubes(self, x, kernel_size, padding=0, stride=1):
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
        if isinstance(padding, int):
            padding = (padding, padding, padding, padding, padding, padding, padding, padding)
        if isinstance(stride, int):
            stride = (stride, stride, stride, stride)

        channels = x.shape[1]

        x = torch.nn.functional.pad(x, padding) ###=== (N, C, T, D, H, W)
        x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2]).unfold(5, kernel_size[3], stride[3]) 

        x = x.contiguous().view(x.shape[0], channels, x.shape[2]*x.shape[3]*x.shape[4]*x.shape[5], kernel_size[0]*kernel_size[1]*kernel_size[2]*kernel_size[3] )

        x = torch.transpose(x,1,2)
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2]*x.shape[3] )

        return x

    def forward(self, x):        
        _patches = self.conv4d_subcubes(x, kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)
                
        _weight = self.weight.view(self.C_out, self.C_in * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] * self.kernel_size[3] )
        
        _conv = torch.matmul(_patches, _weight.T)
        
        if self.is_bias:
            _conv = _conv + self.bias
        
        _conv = torch.transpose( _conv, 1, 2).view(x.shape[0], self.C_out, 
                                                   int((x.shape[2] + 2*self.padding - self.kernel_size[0])/self.stride[0] + 1), 
                                                   int((x.shape[3] + 2*self.padding - self.kernel_size[1])/self.stride[1] + 1), 
                                                   int((x.shape[4] + 2*self.padding - self.kernel_size[2])/self.stride[2] + 1), 
                                                   int((x.shape[5] + 2*self.padding - self.kernel_size[3])/self.stride[3] + 1) )
            
        return _conv
    

###=== 4D convolution layer (Naive)

class Conv4dn(torch.nn.Module):
    def __init__(self, C_in, C_out, kernel_size, stride=1, padding=0, bias=None):
        super().__init__()


        self.C_in = C_in
        self.C_out = C_out
        
        self.kernel_size = kernel_size
        if isinstance(kernel_size, int):
            self.kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
            
        if isinstance(stride, int) and stride <= 0:
            self.stride = self.kernel_size
        elif isinstance(stride, int):
            self.stride = (stride, stride, stride, stride)
        elif isinstance(stride, tuple) and len(stride)==4:
            self.stride = (stride[0], stride[1], stride[2], stride[3])

        if isinstance(padding, int):
            self.padding = (padding, padding, padding, padding, padding, padding, padding, padding)

        self.weight = torch.nn.Parameter( torch.zeros( (self.C_out, self.C_in, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], self.kernel_size[3]) ) )
            
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        
        self.is_bias = False
        
        if bias is True:
            self.is_bias = True
            self.bias = torch.nn.Parameter( torch.zeros( (C_out) ) )
            self.bias.data.uniform_(-stdv, stdv)
            
    def forward(self, x):        
        result = torch.zeros(x.shape[0], self.C_out, 
                         int((x.shape[2] + 2*self.padding[0] - self.kernel_size[0])/self.stride[0] + 1), 
                         int((x.shape[3] + 2*self.padding[1] - self.kernel_size[1])/self.stride[1] + 1), 
                         int((x.shape[4] + 2*self.padding[2] - self.kernel_size[2])/self.stride[2] + 1), 
                         int((x.shape[5] + 2*self.padding[3] - self.kernel_size[3])/self.stride[3] + 1)).to(x.device)

        input_data_pad = torch.nn.functional.pad(x, self.padding) ###=== (N, C, T, D, H, W)
        
        for n in range(input_data_pad.shape[0]):
            for c in range(self.C_out):
                for i in range(0, int((input_data_pad.shape[2] - self.kernel_size[0])/self.stride[0] + 1)):
                    for j in range(0, int((input_data_pad.shape[3] - self.kernel_size[1])/self.stride[1] + 1)):
                        for k in range(0, int((input_data_pad.shape[4] - self.kernel_size[2])/self.stride[2] + 1)):
                            for l in range(0, int((input_data_pad.shape[5] - self.kernel_size[3])/self.stride[3] + 1)):
                                
                                result[n,c,i,j,k,l] = torch.sum( input_data_pad[n, :, 
                                                                                i*self.stride[0]:i*self.stride[0] +self.kernel_size[0], 
                                                                                j*self.stride[1]:j*self.stride[1] +self.kernel_size[1], 
                                                                                k*self.stride[2]:k*self.stride[2] +self.kernel_size[2], 
                                                                                l*self.stride[3]:l*self.stride[3] +self.kernel_size[3] ] \
                                                                * self.weight[c,:,:,:,:,:] )
                                if self.is_bias:
                                    result[n,c,i,j,k,l] = result[n,c,i,j,k,l] + self.bias[c]

        return result


