### Author: anonymized for review 

### All code and data released with this supplementary material uses 
### the following license:
### Creative Commons Attribution 4.0 International (CC BY 4.0)
### http://creativecommons.org/licenses/by/4.0

### This license permits use, sharing, adaptation, distribution and 
### reproduction in any medium or format, as long as you give 
### appropriate credit to the original authors and the paper with 
### title ''Learning to See Topological Properties in 4D'', provide 
### a link to the Creative Commons license, and indicate if changes 
### were made.


import os
import numpy as np
from torch.utils.data import Dataset
import torch
from torch import nn

from torch.autograd import Variable
import random
import time
import sys

from cnn4d import Network

###=== Copy to GPU if available

GPU_arg = sys.argv[2][1:len(sys.argv[2])-1].split(',')

GPU_ids = []
for i in GPU_arg:
    GPU_ids.append(int(i))

print("GPU_ids: ", GPU_ids)

multi_GPU = len(GPU_ids)>1

if torch.cuda.is_available():
    device = torch.device("cuda:"+str(GPU_ids[0]))
    print("GPU")
else:
    device = torch.device("cpu")
    print("CPU")


CWD = os.getcwd()

###=== Dataset

class ARCDataset(Dataset):
    def __init__(self, data_dir, files):
        self.data_dir = data_dir
        self.files = files
        
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        _path = os.path.join(self.data_dir, self.files[idx] )
                
        sample = np.load(_path, allow_pickle = True)
        label = sample['bettiNumbers']
        
        sample = sample['data']
    
    
        ###=== Random augmentation by rotation
        _axes = random.sample([0,1,2,3],2)
        _rots = random.randint(0,3)
        sample = torch.from_numpy(sample)
        sample = torch.rot90(sample, _rots, _axes)
        sample = sample.unsqueeze(0).float()
        
        
        ###=== Without augmentation
#         sample = torch.from_numpy(sample).unsqueeze(0).float()

        label = torch.from_numpy(np.asarray(label)).long()
    
        sample = sample.to(device)
        label = label.to(device)
        
        return sample, label


data_folder = CWD+'/../'+sys.argv[1]

_files = [d for d in os.listdir(data_folder) if d.endswith('.npz')]
_valid_files = []

for f in _files:
    try:
        sample = np.load(data_folder+'/'+f, allow_pickle=True)
        _valid_files.append(f)
    except:
        print(f, ' is corrupt')

random.shuffle(_valid_files)

noSamples = len(_valid_files)

print('noSamples: ', noSamples)

train_data = ARCDataset(data_folder, _valid_files[0: int(noSamples*0.9) ] )
val_data = ARCDataset(data_folder, _valid_files[int(noSamples*0.9): int(noSamples*0.95) ] )
test_data = ARCDataset(data_folder, _valid_files[int(noSamples*0.95): ] )

print(len(_valid_files[0: int(noSamples*0.9) ]))
print(len(_valid_files[int(noSamples*0.9): int(noSamples*0.95) ]))
print(len(_valid_files[int(noSamples*0.95): ]), '\n')


from torch.utils.data import DataLoader

batch_size = 32 * len(GPU_ids)
print('batch_size: ', batch_size)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)


###=== Instantiate network

model = Network()

if multi_GPU == True:
    model = nn.DataParallel(model,  device_ids = GPU_ids)
    print('Using DataParallel')

model.to(device)


###=== Training

train_losses = [[],[],[],[]]
val_losses = [[],[],[],[]]



criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)


###=== The scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[160,180], gamma=0.1)

use_amp = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


print('Training has started')
training_start_time = time.time()


degree = 4

for epoch in range(0,2):
    epoch_start_time = time.time()
    
    ###=== Train
    model.train()
    train_loss = np.zeros(degree)
    train_accuracy = np.zeros(degree)
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):    
        inputs, labels = Variable(inputs), Variable(labels)
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=use_amp):
            output = model(inputs)


            losses = []
            for i in range(0,degree):
                losses.append(criterion(output[i], labels[:,i]))


            loss = losses[0]
            for i in range(1,degree):
                loss = loss + losses[i]

        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        for i in range(0,degree):
            train_loss[i] += losses[i].item()*inputs.size(0)
            _,pred_train = torch.max(output[i],1)

            train_accuracy[i] += (pred_train.data == labels[:,i]).sum()/len(labels[:,i])
    
    for i in range(0,degree):
        train_accuracy[i] = train_accuracy[i]/len(train_loader)
        train_loss[i] = train_loss[i]/len(train_loader.dataset)
        train_losses[i].append(train_loss[i])
    
    scheduler.step()
    
    ###=== Validation
    model.eval() 
    val_loss = np.zeros(degree)
    val_accuracy = np.zeros(degree)
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        inputs, labels = Variable(inputs), Variable(labels)
        output = model(inputs)
                
        losses = []
        for i in range(0,degree):
            losses.append(criterion(output[i], labels[:,i]))

        loss = losses[0]
        for i in range(1,degree):
            loss = loss + losses[i]
        
        for i in range(0,degree):
            val_loss[i] += losses[i].item()*inputs.size(0)
            _, pred = torch.max(output[i], 1)
        
            val_accuracy[i] += (pred.data == labels[:,i]).sum()/len(labels[:,i])

    for i in range(0,degree):
        val_accuracy[i] = val_accuracy[i]/len(val_loader) 
        val_loss[i] = val_loss[i]/len(val_loader.dataset)
        val_losses[i].append(val_loss[i])
    
    
    print('Epoch: {}, LR: {:.1E} \nTraining Loss: {}, \ntrain_accuracy: {}, \nval_loss: {} \nval_accuracy: {}\nepoch_duration: {} minutes\n'.format(epoch+1, 
                                                                                                                                                        optimizer.state_dict()['param_groups'][0]['lr'], 
                                                                                                                                                        train_loss,
                                                                                                                                                        train_accuracy*100,
                                                                                                                                                        val_loss,
                                                                                                                                                        val_accuracy*100,
                                                                                                                                                        (time.time() - epoch_start_time)/60
                                                                                                                                                       ))
    


print('\nTraining duration: ', (time.time() - training_start_time)/3600, ' hours')



###=== Save final model in working directory

_path = 'demo_saved_model_' + str(sum(val_accuracy*100)/degree)

torch.save({
            'model_state_dict': model.state_dict(),
            }, _path)



###=== Inference on test set

criterion = nn.CrossEntropyLoss()

model.eval()
 
test_losses = [[],[],[],[]]
degree = 4
test_loss = np.zeros(degree)
test_accuracy = np.zeros(degree)

for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs, labels = Variable(inputs), Variable(labels)
    output = model(inputs)
    
    losses = []
    for i in range(0,degree):
        losses.append(criterion(output[i], labels[:,i]))

    loss = losses[0]
    for i in range(1,degree):
        loss = loss + losses[i]
    
      
    for i in range(0,degree):
        test_loss[i] += losses[i].item()*inputs.size(0)
        _, pred = torch.max(output[i], 1)

        test_accuracy[i] += (pred.data == labels[:,i]).sum()/len(labels[:,i])

    
for i in range(0,degree):
    test_accuracy[i] = test_accuracy[i]/len(test_loader) 
    test_loss[i] = test_loss[i]/len(test_loader.dataset)
    test_losses[i].append(test_loss[i])



print('\ntest_loss: {} \ntest_accuracy: {}'.format(test_loss,test_accuracy*100))


