###############
# Code for the circles experiment of the MIDL 2020 submission "Poolability and Transferability in CNN. A Thrifty Approach"
###############

##############
#Requirements:
#-All packages in the import statement
#-A CUDA enabled graphics card

##############
#Runtime:
# about 10 minutes on a Nvidia GTX 1070 GPU

##############
#Usage:
#run the script

##############
#Troubleshooting:
#Should the skript get stuck at "generating batches" change the seed of the random number generator in line 649

import torch  #pytorch
import numpy as np
import matplotlib.pyplot as plt

class PoolingBlockIndividual(torch.nn.Module):
    def __init__(self,imageSize):
        super(PoolingBlockIndividual, self).__init__()
        
        self.poolingOperations = list()
        
        self.windowSizesList = list()
        
        lengthList = [1,3, 5,9, 17,33,65,129,257,513,1025,2049]

        for i in lengthList:
            if i>2*imageSize:
                break
            
            if i%2 == 0:
                i+=1
            
            windowSize = [i,i]
            paddings = [windowSize[0]//2,windowSize[1]//2]
            
            self.windowSizesList.append(windowSize)
            
            self.poolingOperations.append(torch.nn.MaxPool2d(kernel_size=windowSize,stride=1,padding=paddings, return_indices=False))
            self.add_module("pool"+str(i),self.poolingOperations[-1])
                
        self.windowSizes = torch.tensor(self.windowSizesList, dtype=torch.int32)
                
        self.outputMultiplier = len(self.windowSizesList)#*5
        
    def forward(self, x, multiplicity):
        
        #using built in max pooling operations
        
        #pooling operations
        poolingOut = list()
        
        for j in range(len(self.poolingOperations)):
            poolingOut.append(self.poolingOperations[j](x[:,j*multiplicity:j*multiplicity+multiplicity]))
            
        pooling1 = torch.cat(poolingOut,1)
            
        return pooling1
    
class TransferBlock(torch.nn.Module):
    def __init__(self,inputChannels,multiplicity,convSize,imageSize,numberOfClasses):
        super(TransferBlock, self).__init__()
        
        numberOfParameters = 0
        
        self.poolingBlock = PoolingBlockIndividual(imageSize)
        
        startChannels = multiplicity*self.poolingBlock.outputMultiplier
        
        
        self.initialConv = torch.nn.Conv2d(in_channels=inputChannels , out_channels=startChannels, kernel_size=convSize, stride=1, padding=convSize//2, dilation=1, groups=1, bias=True)
        numberOfParameters += self.initialConv.weight.numel()+self.initialConv.bias.numel()
        
        afterPoolTargetChannels = startChannels//multiplicity
        afterConvSize = convSize
        self.afterPoolConv = torch.nn.Conv2d(in_channels=startChannels, out_channels=afterPoolTargetChannels , kernel_size=afterConvSize, stride=1, padding=afterConvSize//2, dilation=1, groups=1, bias=True)
        numberOfParameters += self.afterPoolConv.weight.numel()+self.afterPoolConv.bias.numel()
        
        poolingOutChannels = afterPoolTargetChannels
        
        
        self.finalConvolution = torch.nn.Conv2d(in_channels=poolingOutChannels , out_channels=numberOfClasses, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True)
        numberOfParameters += self.finalConvolution.weight.numel()+self.finalConvolution.bias.numel()
        
        self.multiplicity = multiplicity

    def forward(self, x):
        numberOfSlices = x.size()[0]
        
        outputList = list()
        for i in range(numberOfSlices):
            #initial convolution
            x1 = torch.nn.functional.elu(self.initialConv(x[i]))

                
            pooling = self.poolingBlock(x1,self.multiplicity)
            pooling3 = torch.nn.functional.elu(self.afterPoolConv(pooling))

            outputList.append(self.finalConvolution(pooling3))
            
        output = torch.stack(outputList,dim=0)
            
        return output

class TransferBlockDouble(torch.nn.Module):
    def __init__(self,inputChannels,multiplicity,convSize,imageSize,numberOfClasses):
        super(TransferBlockDouble, self).__init__()
        
        numberOfParameters = 0
        
        self.poolingBlock = PoolingBlockIndividual(imageSize)
        
        startChannels = multiplicity*self.poolingBlock.outputMultiplier
        
        
        self.initialConv = torch.nn.Conv2d(in_channels=inputChannels , out_channels=startChannels, kernel_size=convSize, stride=1, padding=convSize//2, dilation=1, groups=1, bias=True)
        numberOfParameters += self.initialConv.weight.numel()+self.initialConv.bias.numel()

        
        afterPoolTargetChannels = startChannels//multiplicity
        afterConvSize = convSize
        self.afterPoolConv = torch.nn.Conv2d(in_channels=startChannels, out_channels=afterPoolTargetChannels , kernel_size=afterConvSize, stride=1, padding=afterConvSize//2, dilation=1, groups=1, bias=True)
        numberOfParameters += self.afterPoolConv.weight.numel()+self.afterPoolConv.bias.numel()
        
        self.poolingBlock2 = PoolingBlockIndividual(imageSize)
        self.middleConv = torch.nn.Conv2d(in_channels=startChannels, out_channels=startChannels , kernel_size=afterConvSize, stride=1, padding=afterConvSize//2, dilation=1, groups=1, bias=True)
        numberOfParameters += self.middleConv.weight.numel()+self.middleConv.bias.numel()
        
        poolingOutChannels = afterPoolTargetChannels
        
        
        self.finalConvolution = torch.nn.Conv2d(in_channels=poolingOutChannels , out_channels=numberOfClasses, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True)
        numberOfParameters += self.finalConvolution.weight.numel()+self.finalConvolution.bias.numel()
        
        self.multiplicity = multiplicity
        

    def forward(self, x):
        numberOfSlices = x.size()[0]
        
        outputList = list()
        for i in range(numberOfSlices):
            #initial convolution
            #x1 = torch.nn.functional.dropout(torch.nn.functional.elu(self.initialConv(x[i])),p=0.5)
            x1 = torch.nn.functional.elu(self.initialConv(x[i]))
            #skip = torch.nn.functional.elu(self.skipConv(x[i]))
            
                
            pooling = self.poolingBlock(x1,self.multiplicity)
            
            middle = torch.nn.functional.elu(self.middleConv(pooling))
            
            middle2 = self.poolingBlock2(middle,self.multiplicity)
            
            pooling3 = torch.nn.functional.elu(self.afterPoolConv(middle2))

            outputList.append(self.finalConvolution(pooling3))
            
        output = torch.stack(outputList,dim=0)
            
        return output
    
class ASPPArchitectureN(torch.nn.Module):
    #Atrous Spatial Pyramid pooling 
    def __init__(self,inputChannels,startChannels,convSize,imageSize,numberOfClasses):
        super(ASPPArchitectureN, self).__init__()
        
        #calculate dilation rates for the dilated convolutions
        
        dilationRates = list()
        
        for i in range(6,2*imageSize,6):
            dilationRates.append(i)
            
        convNumber = len(dilationRates)
        self.convNumber = convNumber
        
        self.atrousConvolutions = list()
        
        for dilation in dilationRates:
             self.atrousConvolutions.append(torch.nn.Conv2d(in_channels=inputChannels , out_channels=startChannels, kernel_size=3, stride=1, padding=(3+2*dilation)//2-1, dilation=dilation, groups=1, bias=True))
             self.add_module("atrous conv dilated"+str(dilation),self.atrousConvolutions[-1])
            
        self.oneXone = torch.nn.Conv2d(in_channels=convNumber*startChannels , out_channels=convNumber*startChannels, kernel_size=1, stride=1, padding=0, dilation=1, groups=convNumber, bias=True)
        
        self.final = torch.nn.Conv2d(in_channels=convNumber*startChannels , out_channels=convNumber*numberOfClasses, kernel_size=1, stride=1, padding=0, dilation=1, groups=convNumber, bias=True)
        
        

    def forward(self, x):
        numberOfSlices = x.size()[0]
        
        outputList = list()
        for i in range(numberOfSlices):
                     
           convOut = list()
            
           for atrousConv in self.atrousConvolutions:
                convOut.append(torch.nn.functional.elu(atrousConv(x[i])))
                
           x1 = torch.cat(convOut,dim=1)
            
           x2 = torch.nn.functional.elu(self.oneXone(x1))
           
           x3 = self.final(x2)
           
           n,c,h,w = x3.size()
           k = self.convNumber
           
           outputList.append( x3.reshape(n, c//k, k, h, w).sum(1))
            
        output = torch.stack(outputList,dim=0)
            
        return output

class ReferenceBlock(torch.nn.Module):
    def __init__(self,inputChannels,multiplicity,convSize,imageSize,numberOfClasses):
        super(ReferenceBlock, self).__init__()
        
        numberOfParameters = 0
        
        self.poolingBlock = PoolingBlockIndividual(imageSize)
        
        startChannels = multiplicity*self.poolingBlock.outputMultiplier
        
        
        self.initialConv = torch.nn.Conv2d(in_channels=inputChannels , out_channels=startChannels, kernel_size=convSize, stride=1, padding=convSize//2, dilation=1, groups=1, bias=True)
        numberOfParameters += self.initialConv.weight.numel()+self.initialConv.bias.numel()
        
        afterPoolTargetChannels = startChannels//multiplicity
        afterConvSize = convSize
        self.afterPoolConv = torch.nn.Conv2d(in_channels=startChannels, out_channels=afterPoolTargetChannels , kernel_size=afterConvSize, stride=1, padding=afterConvSize//2, dilation=1, groups=1, bias=True)
        numberOfParameters += self.afterPoolConv.weight.numel()+self.afterPoolConv.bias.numel()
        
        
        
        poolingOutChannels = afterPoolTargetChannels
        
        self.finalConvolution = torch.nn.Conv2d(in_channels=poolingOutChannels , out_channels=numberOfClasses, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True)
        numberOfParameters += self.finalConvolution.weight.numel()+self.finalConvolution.bias.numel()
        
        self.multiplicity = multiplicity
        

    def forward(self, x):
        numberOfSlices = x.size()[0]
        
        outputList = list()
        for i in range(numberOfSlices):
            #initial convolution
            #x1 = torch.nn.functional.dropout(torch.nn.functional.elu(self.initialConv(x[i])),p=0.5)
            x1 = torch.nn.functional.elu(self.initialConv(x[i]))
            #skip = torch.nn.functional.elu(self.skipConv(x[i]))
            
            pooling3 = torch.nn.functional.elu(self.afterPoolConv(x1))
            #pooling2 = torch.cat((pooling,skip,pooling3),dim = 1)
            #pooling2 = torch.cat((x[i],skip,pooling3),dim = 1)
            
#            oneone = torch.nn.functional.elu(self.oneoneLayers[0](pooling2))
#            
#            for i in range(1,self.oneoneLayerNumber):
#                oneone = torch.nn.functional.elu(self.oneoneLayers[i](oneone))
            
            #outputList.append(torch.nn.functional.elu(self.finalConvolution(oneone)))
            outputList.append(self.finalConvolution(pooling3))
            
        output = torch.stack(outputList,dim=0)
            
        return output

class Unet_Architecture(torch.nn.Module):
    def __init__(self,startChannels,inputChannels,layers,numberOfClasses):
        #inputChannels: channels of the input data
        #start channels: number of feature channels after the first convolution layer
        
        super(Unet_Architecture, self).__init__()
        
        self.startChannels = startChannels
        
        self.layers = layers
        
        self.channelIncreaseLayers = list()
        self.convLayers1 = list()
        self.downsamplingLayers = list()
        self.upsamplingLayers = list()
        self.channelReductionLayers = list()
        self.convLayers2 = list()
        
        for i in range(layers):
            #special channel increase layer if it is the first one
            if i == 0:
                self.channelIncreaseLayers.append(torch.nn.Conv2d(in_channels=inputChannels , out_channels=startChannels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True))
                self.add_module("ci"+str(i),self.channelIncreaseLayers[-1])
                currentChannelNumber = startChannels
            else:
                #else double the channel number
                self.channelIncreaseLayers.append(torch.nn.Conv2d(in_channels=currentChannelNumber , out_channels=2*currentChannelNumber, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True))
                self.add_module("ci"+str(i),self.channelIncreaseLayers[-1])
                currentChannelNumber = 2*currentChannelNumber
                
            #first channel conserving conv layer
            self.convLayers1.append(torch.nn.Conv2d(in_channels=currentChannelNumber , out_channels=currentChannelNumber, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True))
            self.add_module("c1"+str(i),self.convLayers1[-1])
            
            #the following operations are not needed for the last layer
            if i == layers-1:
                break
            #maxpool downsampling
            self.downsamplingLayers.append(torch.nn.MaxPool2d(kernel_size=2))
            self.add_module("max"+str(i),self.downsamplingLayers[-1])
            
            #transpose 2d conv for upsampling to the current layer
            self.upsamplingLayers.append(torch.nn.ConvTranspose2d(in_channels = 2*currentChannelNumber, out_channels = currentChannelNumber, kernel_size = 2, stride= 2, padding=0, output_padding=0, groups=1, bias=True, dilation=1))
            self.add_module("ups"+str(i),self.upsamplingLayers[-1])
            
            #reduce the upsampled channel number with this layer
            self.channelReductionLayers.append(torch.nn.Conv2d(in_channels=2*currentChannelNumber , out_channels=currentChannelNumber, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True))
            self.add_module("cr"+str(i),self.channelReductionLayers[-1])
            
            #second channel conserving conv layer
            self.convLayers2.append(torch.nn.Conv2d(in_channels=currentChannelNumber , out_channels=currentChannelNumber, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True))
            self.add_module("c2"+str(i),self.convLayers2[-1])
            
        self.finalConvolution = torch.nn.Conv2d(in_channels=startChannels , out_channels=numberOfClasses, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True)
            
        
    def forward(self, x):
        numberOfSlices = x.size()[0]
        
        outputList = list()
        for i in range(numberOfSlices):
            inTensor = x[i]
            skipConnections = list()
            #downsampling branch
            for j in range(self.layers):
                #increase channels
                down1 = torch.nn.functional.relu(self.channelIncreaseLayers[j](inTensor))
                
                #first convolution
                #skipConnections.append(torch.nn.functional.dropout(torch.nn.functional.relu(self.convLayers1[j](down1)),p=0.3))
                skipConnections.append(torch.nn.functional.relu(self.convLayers1[j](down1)))
                
                #downsampling
                if j < self.layers-1:
                    inTensor = self.downsamplingLayers[j](skipConnections[-1])
                    
                
            #upsampling branch
            upsampleTarget = skipConnections[self.layers-1]
            for j in range(self.layers-2,-1,-1):
                #upsample previous layer
                upsampled = torch.nn.functional.relu(self.upsamplingLayers[j](upsampleTarget))
                
                #concatenate with skipped output
                up1 = torch.cat((upsampled,skipConnections[j]),dim = 1)
                
                #decrease the channel number
                up2 = torch.nn.functional.relu(self.channelReductionLayers[j](up1))
                
                #convolution 2
                upsampleTarget = torch.nn.functional.relu(self.convLayers2[j](up2))
                
            outputList.append(self.finalConvolution(upsampleTarget))
            
        output = torch.stack(outputList,dim=0)
            
        return output
    
def init_TransferBlock(m):
    if type(m) == torch.nn.Conv2d:
        torch.nn.init.normal_(m.weight, mean=0,std=0.001)
        torch.nn.init.normal_(m.bias, mean=0.01,std=0.001)

class Pyramid_Net:
    def __init__(self,learningRate,startChannels,inputChannels,layers,batchSize,overlap,multiplicity,convSize,imageSize,numberOfClasses):
        self.netList = list()
        
        self.netList.append(Unet_Architecture(startChannels,inputChannels,layers,numberOfClasses))
        self.netList[-1].cuda()
        
        self.netList.append(Unet_Architecture(startChannels,inputChannels,layers-1,numberOfClasses))
        self.netList[-1].cuda()
        
        self.netList.append(TransferBlockDouble(inputChannels,multiplicity,convSize,imageSize,numberOfClasses))
        self.netList[-1].apply(init_TransferBlock)
        self.netList[-1].cuda()
       
        self.netList.append(TransferBlock(inputChannels,multiplicity,convSize,imageSize,numberOfClasses))
        self.netList[-1].apply(init_TransferBlock)
        self.netList[-1].cuda()
        
        self.netList.append(ReferenceBlock(inputChannels,multiplicity,convSize,imageSize,numberOfClasses))
        self.netList[-1].apply(init_TransferBlock)
        self.netList[-1].cuda()     
        
        self.netList.append(ASPPArchitectureN(inputChannels,startChannels,convSize,imageSize,numberOfClasses))
        self.netList[-1].cuda()
        
        self.NumberOfNetworks = len(self.netList)
        
        self.numberOfClasses = numberOfClasses
        
        self.overlap = overlap
        
        self.lossFunction = torch.nn.BCEWithLogitsLoss()
        
        
        self.optimizerList = list()
            
        if len(learningRate)>1:
            for i in range(self.NumberOfNetworks):            
                self.optimizerList.append(torch.optim.Adam(self.netList[i].parameters(), lr=learningRate[i]))
        else:
            for i in range(self.NumberOfNetworks):            
                self.optimizerList.append(torch.optim.Adam(self.netList[i].parameters(), lr=learningRate[0]))
        
        self.batchSize = batchSize
        
        
        self.lossListList = list()
        self.accListList = list()
        self.veriLossListList = list()
        self.veriAccListList = list()
        
        for i in range(self.NumberOfNetworks):
            self.lossListList.append(list())
            self.accListList.append(list())
            self.veriLossListList.append(list())
            self.veriAccListList.append(list())
        
    def trainStep(self,batch,label):
        
        labelCropped = label[self.overlap[2]:label.size()[0]-self.overlap[2],:,self.overlap[0]:label.size()[2]-self.overlap[0],self.overlap[1]:label.size()[3]-self.overlap[1]]
        labelPermuted = labelCropped.permute(1,2,3,0)
        
        losses = list()
        accuracies = list()
        grads = list()
        
        for i in range(self.NumberOfNetworks):
            self.optimizerList[i].zero_grad()
            
            outputs = self.netList[i](batch)
            
            outputCropped = outputs[self.overlap[2]:outputs.size()[0]-self.overlap[2],:,:,self.overlap[0]:outputs.size()[3]-self.overlap[0],self.overlap[1]:outputs.size()[4]-self.overlap[1]]
            outputPermuted = outputCropped.permute(1,2,3,4,0)
            
            lossList = list()
            for j in range(self.numberOfClasses):
                lossList.append(self.lossFunction(outputPermuted[:,j], (labelPermuted==j).float()))
                
            loss = torch.mean(torch.stack(lossList))
            
            acc = self.dice(torch.sigmoid(outputCropped),labelCropped)
        
            loss.backward()
            
            parameters = list(self.netList[i].parameters())
            maxGrads = list()
            for p in parameters:
                if not p.grad is None:
                    #p.grad = torch.clamp(p.grad,-0.1,0.1)
                    maxGrads.append(torch.max(torch.abs(p.grad)).cpu().detach().numpy())
                
            maxGrad = np.max(maxGrads)
            
            self.optimizerList[i].step()
            
            losses.append(loss.item())
            accuracies.append(acc.item())
            grads.append(maxGrad.item())
        
        
        return losses, accuracies, grads
    
    def evaluationStep(self,batch,label):
        labelCropped = label[self.overlap[2]:label.size()[0]-self.overlap[2],:,self.overlap[0]:label.size()[2]-self.overlap[0],self.overlap[1]:label.size()[3]-self.overlap[1]]
        labelPermuted = labelCropped.permute(1,2,3,0)
        
        losses = list()
        accuracies = list()
        
        for i in range(self.NumberOfNetworks):
            self.optimizerList[i].zero_grad()
            
            outputs = self.netList[i](batch)
            
            outputCropped = outputs[self.overlap[2]:outputs.size()[0]-self.overlap[2],:,:,self.overlap[0]:outputs.size()[3]-self.overlap[0],self.overlap[1]:outputs.size()[4]-self.overlap[1]]
            outputPermuted = outputCropped.permute(1,2,3,4,0)
            
            lossList = list()
            for i in range(self.numberOfClasses):
                lossList.append(self.lossFunction(outputPermuted[:,i], (labelPermuted==i).float()))
                
            loss = torch.mean(torch.stack(lossList))
            
            losses.append(loss.item())
            accuracies.append(self.dice(torch.sigmoid(outputCropped),labelCropped).item())
        
        
        return losses, accuracies
        
    def dice(self,A,B):
        #find out the number of labels
        nLabels = torch.max(B)
        A = torch.argmax(A,dim=2)
        dice = 0.0
        for i in range(1,nLabels):
            a = (A==i).float()
            b = (B==i).float()
            
            overlap = int(torch.sum(a*b,dtype=torch.float32))
            
            cardinalities =  torch.sum(a,dtype=torch.float32)+torch.sum(b,dtype=torch.float32)
            
            if cardinalities==0:
                cardinalities = 1
                
            dice += 2*overlap/cardinalities
            
            
        return dice/(nLabels.float()-1.0)
        
        
    def countTrainableParameters(self):
        
        nList = list()
        
        for j in range(self.NumberOfNetworks):
            parameters = self.netList[j].parameters()
            nParameters = 0
            for p in parameters:
                if p.requires_grad:
                    nParameters += p.numel()
                    
            nList.append(nParameters)
        
                
        print("Number of trainable Parameters of U-Net 5, U-Net 4, DT-Net, T-Net, R-Net, ASPP")
        print(nList)
                
        return nList
    
    def plotMovingAverage(self,kernelLength):
        
        #smoothedLists
        slossListList = list()
        saccListList = list()
        sveriAccListList = list()
        
        kernel = np.ones(kernelLength)
        kernel = kernel/np.sum(kernel)
        
        for i in range(self.NumberOfNetworks):
            a = np.convolve(np.asarray(self.lossListList[i]),kernel,mode='same')
            b = np.convolve(np.asarray(self.accListList[i]),kernel,mode='same')
            c = np.convolve(np.asarray(self.veriAccListList[i]),kernel,mode='same')
            
            slossListList.append(a)
            saccListList.append(b)
            sveriAccListList.append(c)
                
        legend = ['U-Net 5', 'U-Net 4','DT-Net','T-Net','R-Net','ASPP']
        fig, axes = plt.subplots(1,3)
        for i in range(self.NumberOfNetworks):
            axes[0].plot(slossListList[i],label=legend[i])
            axes[1].plot(saccListList[i])
            axes[2].plot(sveriAccListList[i])
            
            axes[0].set_title('Train Loss')
            axes[1].set_title('Train Dice')
            axes[2].set_title('Validation Dice')
            
            axes[0].set_ylim(0,1)
            
#                    plt.subplot(1,3,1)
#                    plt.plot(slossListList[i],label=legend[i])
#                    plt.rc('legend', fontsize=6)
#                    plt.title('Train Loss')
#                    plt.ylim(0,1)
#                    
#                    plt.subplot(1,3,2)
#                    plt.plot(saccListList[i])
#                    plt.title('Train Dice')
#                    
#                    plt.subplot(1,3,3)
#                    plt.plot(sveriAccListList[i])
#                    plt.title('Validation Loss')
        axes[0].legend()
        
    def trainLog(self,provider,iterations):
        gc = torch.device("cuda") #graphics card
        
        
        verificationBatch,verificationBatchLabel = provider.getVerificationBatch(self.batchSize)
        
        verificationBatch = np.transpose(verificationBatch,[0,1,4,2,3])
        verificationBatchLabel = verificationBatchLabel
        
        vbatchGPU = torch.tensor(verificationBatch, device=gc, dtype=torch.float32, requires_grad=False)
        vlabelGPU = torch.tensor(verificationBatchLabel, device=gc, dtype=torch.int64, requires_grad=False)
        
        for i in range(0,iterations):
            
            a = provider(self.batchSize)
            
            batch = np.transpose(a[0],[0,1,4,2,3])
            batchLabels = a[1]
            
            batchGPU = torch.tensor(batch, device=gc, dtype=torch.float32, requires_grad=False)
            labelGPU = torch.tensor(batchLabels*1, device=gc, dtype=torch.int64, requires_grad=False)
            
            losses, accuracies, grads = self.trainStep(batchGPU,labelGPU)
            
            for j in range(self.NumberOfNetworks):
                self.lossListList[j].append(losses[j])
                self.accListList[j].append(accuracies[j])
            
            losses, accuracies = self.evaluationStep(vbatchGPU,vlabelGPU)
            
            for j in range(self.NumberOfNetworks):
                self.veriLossListList[j].append(losses[j])
                self.veriAccListList[j].append(accuracies[j])
       
        
            if i%100 == 0:
                print(i)
                self.plotMovingAverage(i//100+1)
                
                plt.draw()
                
                
                plt.pause(0.000001)
                     
########################
########################
class TestProviderCircles:
    #provides circles with radius 3,4 and 5. Radius 4 is to be segmented
    def __init__(self,batchSize,patchSize,exampleNumber,verificationNumber):
        np.random.seed(seed=1)
        
        self.patchSize = patchSize
        sequenceLength = patchSize[2]
        inputSize = patchSize[0:2]
        
        self.numberOfCircles = 18#int(np.round(patchSize[2]*1.5))
        
        numberOfCircles = self.numberOfCircles
        
        self.batchSize = batchSize
        self.verificationNumber = verificationNumber
        
        self.vcentersList = list()
        self.vcenters = list()
        for batchNumber in range(batchSize+verificationNumber):
            #first define circle center positions
            #3 circles for each radius
            #assume 5 radius for each
            centers = list()
            currentCenter = [np.random.randint(5,inputSize[0]-5),np.random.randint(5,inputSize[1]-5),np.random.randint(0,sequenceLength)]
            centers.append(currentCenter)
            
            control = list()
            for i in range(-10,11):
                for j in range(-10,11):
                    if np.sqrt((i)**2+(j)**2)<11:
                        control.append([i+currentCenter[0],j+currentCenter[1],currentCenter[2]])
                    
            
            for k in range(1,numberOfCircles):
                currentCenter = [np.random.randint(5,inputSize[0]-5),np.random.randint(5,inputSize[1]-5),np.random.randint(0,sequenceLength)]
                while currentCenter in control:
                    currentCenter = [np.random.randint(5,inputSize[0]-5),np.random.randint(5,inputSize[1]-5),np.random.randint(0,sequenceLength)]
                
                centers.append(currentCenter)
                for i in range(-10,11):
                    for j in range(-10,11):
                        if np.sqrt((i)**2+(j)**2)<11:
                            control.append([i+currentCenter[0],j+currentCenter[1],currentCenter[2]])
                        
            self.vcentersList.append(centers)
            self.vcenters = self.vcenters+centers
            
        self.batchList = list()
        for i in range(exampleNumber):
            print("generating batch number "+str(i))
            self.batchList.append(self.calculateBatch())
            
        self.currentBatch = 0
            
    def calculateBatch(self):
        xlist=[]
        ylist=[]
        
        batchSize = self.batchSize
        
        sequenceLength = self.patchSize[2]
        inputSize = self.patchSize[0:2]
        
        numberOfCircles = self.numberOfCircles
        
        for batchNumber in range(batchSize):
            x = np.zeros([sequenceLength,1,*inputSize])
        
            #first define circle center positions
            #3 circles for each radius
            #assume 5 radius for each
            centers = list()
            currentCenter = [np.random.randint(5,inputSize[0]-5),np.random.randint(5,inputSize[1]-5),np.random.randint(0,sequenceLength)]
            centers.append(currentCenter)
            
            control = list()
            for i in range(-10,11):
                for j in range(-10,11):
                    if np.sqrt((i)**2+(j)**2)<11:
                        control.append([i+currentCenter[0],j+currentCenter[1],currentCenter[2]])
            
            for k in range(1,numberOfCircles):
                currentCenter = [np.random.randint(5,inputSize[0]-5),np.random.randint(5,inputSize[1]-5),np.random.randint(0,sequenceLength)]
                while currentCenter in control or currentCenter in self.vcenters:
                    currentCenter = [np.random.randint(5,inputSize[0]-5),np.random.randint(5,inputSize[1]-5),np.random.randint(0,sequenceLength)]
                
                centers.append(currentCenter)
                for i in range(-10,11):
                    for j in range(-10,11):
                        if np.sqrt((i)**2+(j)**2)<11:
                            control.append([i+currentCenter[0],j+currentCenter[1],currentCenter[2]])
                        
            #circles of radius 4
            for k in range(0,int(numberOfCircles/3)):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<5:
                            x[centers[k][2],0,i,j] = 1
                            
            y = np.copy(x)
            
            #circles of radius 3
            for k in range(int(numberOfCircles/3),int(numberOfCircles/3)*2):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<4:
                            x[centers[k][2],0,i,j] = 1
                            y[centers[k][2],0,i,j] = 2
                            
            #circles of radius 5
            for k in range(int(numberOfCircles/3)*2,numberOfCircles):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<6:
                            x[centers[k][2],0,i,j] = 1
                            y[centers[k][2],0,i,j] = 3
                            
            
            
            xlist.append(x)
            ylist.append(y)
            
        x = np.concatenate(xlist,axis=1)[:,:,:,:,None]
        y = np.concatenate(ylist,axis=1)
            
        return x,y
            
    def getVerificationBatch(self,batchSize):
        xlist=[]
        ylist=[]
        
        sequenceLength = self.patchSize[2]
        inputSize = self.patchSize[0:2]
        
        numberOfCircles = self.numberOfCircles
        
        for batchNumber in range(batchSize):
            x = np.zeros([sequenceLength,1,*inputSize])
        
            #first define circle center positions
            #3 circles for each radius
            #assume 5 radius for each
            centers = self.vcentersList[batchNumber]
                        
            #circles of radius 4
            for k in range(0,int(numberOfCircles/3)):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<5:
                            x[centers[k][2],0,i,j] = 1
                            
            y = np.copy(x)
            
            #circles of radius 3
            for k in range(int(numberOfCircles/3),int(numberOfCircles/3)*2):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<4:
                            x[centers[k][2],0,i,j] = 1
                            y[centers[k][2],0,i,j] = 2
                            
                            
            #circles of radius 5
            for k in range(int(numberOfCircles/3)*2,numberOfCircles):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<6:
                            x[centers[k][2],0,i,j] = 1
                            y[centers[k][2],0,i,j] = 3
            
            
            xlist.append(x)
            ylist.append(y)
            
        x = np.concatenate(xlist,axis=1)[:,:,:,:,None]
        y = np.concatenate(ylist,axis=1)
        
        
        return x,y
    
    def getValidationBatches(self):
        xlist=[]
        ylist=[]
        
        batchSize = self.batchSize
        
        sequenceLength = self.patchSize[2]
        inputSize = self.patchSize[0:2]
        
        numberOfCircles = self.numberOfCircles
        
        for batchNumber in range(batchSize,batchSize+self.verificationNumber):
            x = np.zeros([sequenceLength,1,*inputSize])
        
            #first define circle center positions
            #3 circles for each radius
            #assume 5 radius for each
            centers = self.vcentersList[batchNumber]
                        
            #circles of radius 4
            for k in range(0,int(numberOfCircles/3)):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<5:
                            x[centers[k][2],0,i,j] = 1
                            
            y = np.copy(x)
            
            #circles of radius 3
            for k in range(int(numberOfCircles/3),int(numberOfCircles/3)*2):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<4:
                            x[centers[k][2],0,i,j] = 1
                            y[centers[k][2],0,i,j] = 2
                            
            #circles of radius 5
            for k in range(int(numberOfCircles/3)*2,numberOfCircles):
                for i in range(inputSize[0]):
                    for j in range(inputSize[1]):
                        if np.sqrt((i-centers[k][0])**2+(j-centers[k][1])**2)<6:
                            x[centers[k][2],0,i,j] = 1
                            y[centers[k][2],0,i,j] = 3
            
            
            xlist.append(x)
            ylist.append(y)
            
        batches = list()
        labels = list()
        
        numberOfBatches = int(self.verificationNumber/self.batchSize)
        
        for i in range(numberOfBatches):
            x = np.concatenate(xlist[i*self.batchSize:i*self.batchSize+self.batchSize],axis=1)[:,:,:,:,None]
            y = np.concatenate(ylist[i*self.batchSize:i*self.batchSize+self.batchSize],axis=1)
            
            batches.append(x)
            labels.append(y)
            
        return batches,labels
            
    def __call__(self,batchSize):
        self.currentBatch = (self.currentBatch+1)%len(self.batchList)
        return self.batchList[self.currentBatch]
    
    
patchSize = [64,64,1]

inputSize = [patchSize[0],patchSize[1]]
batchSize = 3
numberOfInputChannels = 1
sequenceLength = patchSize[2]

trainingSetSize = 20
verificationSetSize = 10

provider = TestProviderCircles(batchSize,patchSize,trainingSetSize,verificationSetSize)
numberOfClasses = 4

#network configurations

#U-Net
startFeatures = 64
resolutionSteps = 5

#Transfer Block
logicMultiplicity = 10
logicConvSize = 5

#Learning rates
#[U-Net 5, U-Net 4, DT-Net, T-Net, R-Net, ASPP]
learningRate = [0.001,0.002,0.0005,0.002,0.005,0.005]

overlap = [0,0,0]
netz = Pyramid_Net(learningRate,startFeatures,numberOfInputChannels,resolutionSteps,batchSize,overlap,logicMultiplicity,logicConvSize,patchSize[0],numberOfClasses)
netz.countTrainableParameters()

netz.trainLog(provider,iterations = 3001)