### Author: anonymized for review 

### All code and data released with this supplementary material uses 
### the following license:
### Creative Commons Attribution 4.0 International (CC BY 4.0)
### http://creativecommons.org/licenses/by/4.0

### This license permits use, sharing, adaptation, distribution and 
### reproduction in any medium or format, as long as you give 
### appropriate credit to the original authors and the paper with 
### title ''Learning to See Topological Properties in 4D'', provide 
### a link to the Creative Commons license, and indicate if changes 
### were made.


import os
import numpy as np
from torch.utils.data import Dataset
import torch
from torch import nn

from torch.autograd import Variable

import time
import sys

from cnn4d import Network


###=== Copy to GPU if available

GPU_arg = sys.argv[2][1:len(sys.argv[2])-1].split(',')

GPU_ids = []
for i in GPU_arg:
    GPU_ids.append(int(i))

print("GPU_ids: ", GPU_ids)

multi_GPU = len(GPU_ids)>1

if torch.cuda.is_available():
    device = torch.device("cuda:"+str(GPU_ids[0]))
    print("GPU")
else:
    device = torch.device("cpu")
    print("CPU")


CWD = os.getcwd()

###=== Dataset

class ARCDataset(Dataset):
    def __init__(self, data_dir, files):
        self.data_dir = data_dir
        self.files = files
        
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        _path = os.path.join(self.data_dir, self.files[idx] )
                
        sample = np.load(_path, allow_pickle = True)
        label = sample['bettiNumbers']
        
        sample = sample['data']
          
        
        ###=== Without augmentation
       	sample = torch.from_numpy(sample).unsqueeze(0).float()

        label = torch.from_numpy(np.asarray(label)).long()
    
        sample = sample.to(device)
        label = label.to(device)
        
        return sample, label


data_folder = CWD+'/../'+sys.argv[1]

_files = [d for d in os.listdir(data_folder) if d.endswith('.npz')]
_valid_files = []

for f in _files:
    try:
        sample = np.load(data_folder+'/'+f, allow_pickle=True)
        _valid_files.append(f)
    except:
        print(f, ' is corrupt')


noSamples = len(_valid_files)

print('noSamples: ', noSamples)

test_data = ARCDataset(data_folder, _valid_files )

print(len(_valid_files), '\n')


from torch.utils.data import DataLoader

batch_size = 32 * len(GPU_ids)
print('batch_size: ', batch_size)

test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)


###=== Instantiate network

model = Network()

if multi_GPU == True:
    model = nn.DataParallel(model,  device_ids = GPU_ids)
    print('Using DataParallel')

model.to(device)



###=== Load trained network

_path = CWD+'/../models/'+sys.argv[3]

model = Network()
model.to(device)


checkpoint = torch.load(_path, map_location=device )

###=== Load parameters of model trained in multi-GPU environment

from collections import OrderedDict
new_state_dict = OrderedDict()

for k, v in checkpoint['model_state_dict'].items():
    ###=== delete 'module.' prefix
    name = k[7:]
    new_state_dict[name] = v


model.load_state_dict(new_state_dict)

###=== Load parameters of model trained in standard environment

# model.load_state_dict(checkpoint['model_state_dict'])


###=== Inference on test set


criterion = nn.CrossEntropyLoss()

model.eval()
 
test_losses = [[],[],[],[]]
degree = 4
test_loss = np.zeros(degree)
test_accuracy = np.zeros(degree)

for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs, labels = Variable(inputs), Variable(labels)
    output = model(inputs)
    
    losses = []
    for i in range(0,degree):
        losses.append(criterion(output[i], labels[:,i]))

    loss = losses[0]
    for i in range(1,degree):
        loss = loss + losses[i]
    
      
    for i in range(0,degree):
        test_loss[i] += losses[i].item()*inputs.size(0)
        _, pred = torch.max(output[i], 1)

        test_accuracy[i] += (pred.data == labels[:,i]).sum()/len(labels[:,i])

    
for i in range(0,degree):
    test_accuracy[i] = test_accuracy[i]/len(test_loader) 
    test_loss[i] = test_loss[i]/len(test_loader.dataset)
    test_losses[i].append(test_loss[i])



print('\ntest_loss: {} \ntest_accuracy: {}'.format(test_loss,test_accuracy*100))

