from __future__ import print_function
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import utils
import time
from tqdm import tqdm


from torch.autograd import Variable
from scipy.spatial.distance import pdist, cdist, squareform


# lid of a batch of query points X
def mle_batch(data, batch, k):
    """
    commpute lid score using data & batch with k-neighbors
    return: a: computed LID score
    """
    data = np.asarray(data, dtype=np.float32)
    batch = np.asarray(batch, dtype=np.float32)

    k = min(k, len(data)-1)
    f = lambda v: - k / np.sum(np.log(v/v[-1]))
    a = cdist(batch, data)
    a = np.apply_along_axis(np.sort, axis=1, arr=a)[:,1:k+1]
    a = np.apply_along_axis(f, axis=1, arr=a)

    return a


# this function is from https://github.com/xingjunm/lid_adversarial_subspace_detection
def merge_and_generate_labels(X_pos, X_neg):
    """
    merge positve and nagative artifact and generate labels
    return: X: merged samples, 2D ndarray
            y: generated labels (0/1): 2D ndarray same size as X
    """
    X_pos = np.asarray(X_pos, dtype=np.float32)
    X_pos = X_pos.reshape((X_pos.shape[0], -1))

    X_neg = np.asarray(X_neg, dtype=np.float32)
    X_neg = X_neg.reshape((X_neg.shape[0], -1))

    X = np.concatenate((X_pos, X_neg))
    y = np.concatenate((np.ones(X_pos.shape[0]), np.zeros(X_neg.shape[0])))
    y = y.reshape((X.shape[0], 1))

    return X, y


def sample_estimator(model, num_classes, feature_list, train_loader):
    """
    compute sample mean and precision (inverse of covariance)
    return: sample_class_mean: list of class mean
            precision: list of precisions
    """
    import sklearn.covariance
    
    model.eval()
    group_lasso = sklearn.covariance.EmpiricalCovariance(assume_centered=False)
    correct, total = 0, 0
    num_output = len(feature_list)
    num_sample_per_class = np.empty(num_classes)
    num_sample_per_class.fill(0)
    list_features = []
    for i in range(num_output):
        temp_list = []
        for j in range(num_classes):
            temp_list.append(0)
        list_features.append(temp_list)

    ######### Added by me ######
    with torch.no_grad():
    ######### Added by me ######

        for data, target in train_loader:
            total += data.size(0)
            data = data.cuda()
            #data = Variable(data, volatile=True)
            data = Variable(data)
            output, out_features = model.feature_list(data)

            # get hidden features
            for i in range(num_output):
                out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1)
                out_features[i] = torch.mean(out_features[i].data, 2)

            # compute the accuracy
            pred = output.data.max(1)[1]
            equal_flag = pred.eq(target.cuda()).cpu()
            correct += equal_flag.sum()

            # construct the sample matrix
            for i in range(data.size(0)):
                label = target[i]
                if num_sample_per_class[label] == 0:
                    out_count = 0
                    for out in out_features:
                        list_features[out_count][label] = out[i].view(1, -1)
                        out_count += 1
                else:
                    out_count = 0
                    for out in out_features:
                        list_features[out_count][label] \
                        = torch.cat((list_features[out_count][label], out[i].view(1, -1)), 0)
                        out_count += 1
                num_sample_per_class[label] += 1
            
    sample_class_mean = []
    out_count = 0
    for num_feature in feature_list:
        temp_list = torch.Tensor(num_classes, int(num_feature)).cuda()
        for j in range(num_classes):
            temp_list[j] = torch.mean(list_features[out_count][j], 0)
        sample_class_mean.append(temp_list)
        out_count += 1
        
    precision = []
    for k in range(num_output):
        X = 0
        for i in range(num_classes):
            if i == 0:
                X = list_features[k][i] - sample_class_mean[k][i]
            else:
                X = torch.cat((X, list_features[k][i] - sample_class_mean[k][i]), 0)
                
        # find inverse            
        group_lasso.fit(X.cpu().numpy())
        temp_precision = group_lasso.precision_
        temp_precision = torch.from_numpy(temp_precision).float().cuda()
        precision.append(temp_precision)
        
    print('\n Training Accuracy:({:.2f}%)\n'.format(100. * float(correct) / total))

    return sample_class_mean, precision


def get_Mahalanobis_score(model, test_loader, num_classes, outf, out_flag, dataset, sample_mean, precision, layer_index, magnitude):
    """
    Compute the proposed Mahalanobis confidence score on input dataset
    return: Mahalanobis score from layer_index
    """
    model.eval()
    Mahalanobis = []
    
    if out_flag == True:
        temp_file_name = '%s/confidence_Ga%s_In.txt'%(outf, str(layer_index))
    else:
        temp_file_name = '%s/confidence_Ga%s_Out.txt'%(outf, str(layer_index))
        
    g = open(temp_file_name, 'w')
    
    for data, target in test_loader:
        
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, requires_grad = True), Variable(target)
        
        out_features = model.intermediate_forward(data, layer_index)
        out_features = out_features.view(out_features.size(0), out_features.size(1), -1)
        out_features = torch.mean(out_features, 2)
        
        # compute Mahalanobis score
        gaussian_score = 0
        for i in range(num_classes):
            batch_sample_mean = sample_mean[layer_index][i]
            zero_f = out_features.data - batch_sample_mean
            term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag()
            if i == 0:
                gaussian_score = term_gau.view(-1,1)
            else:
                gaussian_score = torch.cat((gaussian_score, term_gau.view(-1,1)), 1)
        
        # Input_processing
        sample_pred = gaussian_score.max(1)[1]
        batch_sample_mean = sample_mean[layer_index].index_select(0, sample_pred)
        zero_f = out_features - Variable(batch_sample_mean)
        pure_gau = -0.5*torch.mm(torch.mm(zero_f, Variable(precision[layer_index])), zero_f.t()).diag()
        loss = torch.mean(-pure_gau)
        loss.backward()
         
        gradient = torch.ge(data.grad.data, 0)
        gradient = (gradient.float() - 0.5) * 2

        ############ NEW CODE ################
        if dataset == 'cifar10':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.247)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.243)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.261)
            # transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
        elif dataset == 'cifar100':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.267)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.256)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.276)
            # transforms.Normalize((0.507, 0.486, 0.440), (0.267, 0.256, 0.276))])
        elif dataset == 'svhn':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.198)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.201)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.197)
            # transforms.Normalize((0.437, 0.443, 0.472), (0.198, 0.201, 0.197))])
        ############ NEW CODE ################

        """
        if net_type == 'densenet':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (63.0/255.0))
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (62.1/255.0))
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (66.7/255.0))
        elif net_type == 'resnet':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (0.2023))
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (0.1994))
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (0.2010))
        """

        tempInputs = torch.add(data.data, -magnitude, gradient)
 
        noise_out_features = model.intermediate_forward(Variable(tempInputs, volatile=True), layer_index)
        noise_out_features = noise_out_features.view(noise_out_features.size(0), noise_out_features.size(1), -1)
        noise_out_features = torch.mean(noise_out_features, 2)
        noise_gaussian_score = 0
        for i in range(num_classes):
            batch_sample_mean = sample_mean[layer_index][i]
            zero_f = noise_out_features.data - batch_sample_mean
            term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag()
            if i == 0:
                noise_gaussian_score = term_gau.view(-1, 1)
            else:
                noise_gaussian_score = torch.cat((noise_gaussian_score, term_gau.view(-1, 1)), 1)

        noise_gaussian_score, _ = torch.max(noise_gaussian_score, dim=1)
        Mahalanobis.extend(noise_gaussian_score.cpu().numpy())
        
        for i in range(data.size(0)):
            g.write("{}\n".format(noise_gaussian_score[i]))
    g.close()

    return Mahalanobis


def get_posterior(model, dataset, test_loader, magnitude, temperature, outf, out_flag, loss_type, score_type=None):
    """
    Compute the maximum value of (processed) posterior distribution - ODIN
    return: null
    """

    criterion = nn.CrossEntropyLoss()
    model.eval()

    total = 0
    if out_flag == True:
        temp_file_name_val = '%s/confidence_PoV_In.txt'%(outf)
        temp_file_name_test = '%s/confidence_PoT_In.txt'%(outf)
    else:
        temp_file_name_val = '%s/confidence_PoV_Out.txt'%(outf)
        temp_file_name_test = '%s/confidence_PoT_Out.txt'%(outf)
        
    g = open(temp_file_name_val, 'w')
    f = open(temp_file_name_test, 'w')

    for data, _ in test_loader:
    #for data, target in test_loader:
        total += data.size(0)
        data = data.cuda()
        #data = Variable(data, requires_grad = True)
        data = Variable(data, requires_grad = True)

        #"""
        batch_output = model(data)
        # temperature scaling
        outputs = batch_output / temperature
        labels = outputs.data.max(1)[1]
        labels = Variable(labels)
        loss = criterion(outputs, labels)
        loss.backward()
        # Normalizing the gradient to binary in {0, 1}
        gradient = torch.ge(data.grad.data, 0)
        gradient = (gradient.float() - 0.5) * 2

        ############ NEW CODE ################
        if dataset == 'cifar10':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.247)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.243)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.261)
            # transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
        elif dataset == 'cifar100':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.267)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.256)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.276)
            # transforms.Normalize((0.507, 0.486, 0.440), (0.267, 0.256, 0.276))])
        elif dataset == 'svhn':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.198)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.201)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.197)
            # transforms.Normalize((0.437, 0.443, 0.472), (0.198, 0.201, 0.197))])
        ############ NEW CODE ################
        #"""
        
        ######### Added by me ######
        with torch.no_grad():
        ######### Added by me ######

            #tempInputs = data.data
            tempInputs = torch.add(data.data, -magnitude, gradient)

            #outputs = model(Variable(tempInputs, volatile=True))
            outputs = model(Variable(tempInputs))

            # now outputs are probabilities: NOT ANYMORE!!!
            outputs = outputs / temperature

            logits = outputs
            if loss_type.startswith("eml"):
                outputs = nn.Softmax(dim=2)(outputs).mean(dim=1)
                #outputs = nn.Softmax(dim=1)(outputs.mean(dim=1))
            else:
                outputs = nn.Softmax(dim=1)(outputs)

            if score_type == "MPS":
                soft_out = outputs.max(dim=1)[0]
            elif score_type == "ES":
                soft_out = - utils.entropies_from_probabilities(outputs)
            elif score_type == "MIDS":
                soft_out = logits.max(dim=1)[0]

        for i in range(data.size(0)):
            """
            #if total <= 1000:
            if total < 0:
                g.write("{}\n".format(soft_out[i]))
            else:
                f.write("{}\n".format(soft_out[i]))
            """
            #if total <= 1000:
            ##if total < 0:
            #    g.write("{}\n".format(soft_out[i]))
            ##else:
            f.write("{}\n".format(soft_out[i]))
                
    f.close()
    g.close()



def get_Mahalanobis_score_adv(model, test_data, test_label, num_classes, outf, dataset, sample_mean, precision, layer_index, magnitude):
    """
    Compute the proposed Mahalanobis confidence score on adversarial samples
    return: Mahalanobis score from layer_index
    """
    model.eval()
    Mahalanobis = []
    batch_size = 100
    total = 0
    
    for data_index in range(int(np.floor(test_data.size(0)/batch_size))):
        target = test_label[total : total + batch_size].cuda()
        data = test_data[total : total + batch_size].cuda()
        total += batch_size
        data, target = Variable(data, requires_grad = True), Variable(target)
        
        out_features = model.intermediate_forward(data, layer_index)
        out_features = out_features.view(out_features.size(0), out_features.size(1), -1)
        out_features = torch.mean(out_features, 2)
        
        gaussian_score = 0
        for i in range(num_classes):
            batch_sample_mean = sample_mean[layer_index][i]
            zero_f = out_features.data - batch_sample_mean
            term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag()
            if i == 0:
                gaussian_score = term_gau.view(-1,1)
            else:
                gaussian_score = torch.cat((gaussian_score, term_gau.view(-1,1)), 1)
        
        # Input_processing
        sample_pred = gaussian_score.max(1)[1]
        batch_sample_mean = sample_mean[layer_index].index_select(0, sample_pred)
        zero_f = out_features - Variable(batch_sample_mean)
        pure_gau = -0.5*torch.mm(torch.mm(zero_f, Variable(precision[layer_index])), zero_f.t()).diag()
        loss = torch.mean(-pure_gau)
        loss.backward()
         
        gradient =  torch.ge(data.grad.data, 0)
        gradient = (gradient.float() - 0.5) * 2

        ############ NEW CODE ################
        if dataset == 'cifar10':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.247)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.243)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.261)
            # transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
        elif dataset == 'cifar100':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.267)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.256)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.276)
            # transforms.Normalize((0.507, 0.486, 0.440), (0.267, 0.256, 0.276))])
        elif dataset == 'svhn':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / 0.198)
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / 0.201)
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / 0.197)
            # transforms.Normalize((0.437, 0.443, 0.472), (0.198, 0.201, 0.197))])
            ############ NEW CODE ################

        """
        if net_type == 'densenet':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (63.0/255.0))
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (62.1/255.0))
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (66.7/255.0))
        elif net_type == 'resnet':
            gradient.index_copy_(1, torch.LongTensor([0]).cuda(), gradient.index_select(1, torch.LongTensor([0]).cuda()) / (0.2023))
            gradient.index_copy_(1, torch.LongTensor([1]).cuda(), gradient.index_select(1, torch.LongTensor([1]).cuda()) / (0.1994))
            gradient.index_copy_(1, torch.LongTensor([2]).cuda(), gradient.index_select(1, torch.LongTensor([2]).cuda()) / (0.2010))
        """

        tempInputs = torch.add(data.data, -magnitude, gradient)

        ######### Added by me ######
        with torch.no_grad():
        ######### Added by me ######

            #noise_out_features = model.intermediate_forward(Variable(tempInputs, volatile=True), layer_index)
            noise_out_features = model.intermediate_forward(Variable(tempInputs), layer_index)
            noise_out_features = noise_out_features.view(noise_out_features.size(0), noise_out_features.size(1), -1)
            noise_out_features = torch.mean(noise_out_features, 2)
            noise_gaussian_score = 0
            for i in range(num_classes):
                batch_sample_mean = sample_mean[layer_index][i]
                zero_f = noise_out_features.data - batch_sample_mean
                term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag()
                if i == 0:
                    noise_gaussian_score = term_gau.view(-1,1)
                else:
                    noise_gaussian_score = torch.cat((noise_gaussian_score, term_gau.view(-1,1)), 1)

            noise_gaussian_score, _ = torch.max(noise_gaussian_score, dim=1)
            Mahalanobis.extend(noise_gaussian_score.cpu().numpy())
        
    return Mahalanobis


def get_Entropy_score_adv(model, test_data, test_label, num_classes, outf, dataset, sample_mean, precision, layer_index, magnitude):
    """
    Compute the proposed Entropy confidence score on adversarial samples
    return: Entropy score from layer_index
    """
    model.eval()
    Entropy = []
    batch_size = 100
    total = 0

    for data_index in range(int(np.floor(test_data.size(0) / batch_size))):
        target = test_label[total: total + batch_size].cuda()
        data = test_data[total: total + batch_size].cuda()
        total += batch_size

        ########################
        with torch.no_grad():
            outputs = model(data)
            soft_out = F.softmax(outputs, dim=1)
            soft_out, _ = torch.max(soft_out.data, dim=1)
            #soft_out = - utils.entropies_from_probabilities(soft_out)
            Entropy.extend(soft_out.cpu().numpy())
            #print("!!!!!!")
        ########################

    return Entropy


def get_LID(model, test_clean_data, test_adv_data, test_noisy_data, test_label, num_output):
    """
    Compute LID score on adversarial samples
    return: LID score
    """
    model.eval()  
    total = 0
    batch_size = 100
    
    LID, LID_adv, LID_noisy = [], [], []    
    overlap_list = [10, 20, 30, 40, 50, 60, 70, 80, 90]
    for i in overlap_list:
        LID.append([])
        LID_adv.append([])
        LID_noisy.append([])

    ######### Added by me ######
    with torch.no_grad():
    ######### Added by me ######

        for data_index in range(int(np.floor(test_clean_data.size(0)/batch_size))):
            data = test_clean_data[total : total + batch_size].cuda()
            adv_data = test_adv_data[total : total + batch_size].cuda()
            noisy_data = test_noisy_data[total : total + batch_size].cuda()
            target = test_label[total : total + batch_size].cuda()

            total += batch_size
            #data, target = Variable(data, volatile=True), Variable(target)
            data, target = Variable(data), Variable(target)

            output, out_features = model.feature_list(data)
            X_act = []
            for i in range(num_output):
                out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1)
                out_features[i] = torch.mean(out_features[i].data, 2)
                #X_act.append(np.asarray(out_features[i], dtype=np.float32).reshape((out_features[i].size(0), -1)))
                X_act.append(np.asarray(out_features[i].cpu(), dtype=np.float32).reshape((out_features[i].size(0), -1)))


            #output, out_features = model.feature_list(Variable(adv_data, volatile=True))
            output, out_features = model.feature_list(Variable(adv_data))
            X_act_adv = []
            for i in range(num_output):
                out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1)
                out_features[i] = torch.mean(out_features[i].data, 2)
                #X_act_adv.append(np.asarray(out_features[i], dtype=np.float32).reshape((out_features[i].size(0), -1)))
                X_act_adv.append(np.asarray(out_features[i].cpu(), dtype=np.float32).reshape((out_features[i].size(0), -1)))

            #output, out_features = model.feature_list(Variable(noisy_data, volatile=True))
            output, out_features = model.feature_list(Variable(noisy_data))
            X_act_noisy = []
            for i in range(num_output):
                out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1)
                out_features[i] = torch.mean(out_features[i].data, 2)
                #X_act_noisy.append(np.asarray(out_features[i], dtype=np.float32).reshape((out_features[i].size(0), -1)))
                X_act_noisy.append(np.asarray(out_features[i].cpu(), dtype=np.float32).reshape((out_features[i].size(0), -1)))

            # LID
            list_counter = 0
            for overlap in overlap_list:
                LID_list = []
                LID_adv_list = []
                LID_noisy_list = []

                for j in range(num_output):
                    lid_score = mle_batch(X_act[j], X_act[j], k = overlap)
                    lid_score = lid_score.reshape((lid_score.shape[0], -1))
                    lid_adv_score = mle_batch(X_act[j], X_act_adv[j], k = overlap)
                    lid_adv_score = lid_adv_score.reshape((lid_adv_score.shape[0], -1))
                    lid_noisy_score = mle_batch(X_act[j], X_act_noisy[j], k = overlap)
                    lid_noisy_score = lid_noisy_score.reshape((lid_noisy_score.shape[0], -1))

                    LID_list.append(lid_score)
                    LID_adv_list.append(lid_adv_score)
                    LID_noisy_list.append(lid_noisy_score)

                LID_concat = LID_list[0]
                LID_adv_concat = LID_adv_list[0]
                LID_noisy_concat = LID_noisy_list[0]

                for i in range(1, num_output):
                    LID_concat = np.concatenate((LID_concat, LID_list[i]), axis=1)
                    LID_adv_concat = np.concatenate((LID_adv_concat, LID_adv_list[i]), axis=1)
                    LID_noisy_concat = np.concatenate((LID_noisy_concat, LID_noisy_list[i]), axis=1)

                LID[list_counter].extend(LID_concat)
                LID_adv[list_counter].extend(LID_adv_concat)
                LID_noisy[list_counter].extend(LID_noisy_concat)
                list_counter += 1
            
    return LID, LID_adv, LID_noisy
