import numpy as np
from common.util import get_output_bs_layer, load_status

def load_adv(adv_path, dataset_name, attack, transformers=True):
    adv_file_path = '{}{}/transformers/{}_transformers{}{}.npy'.format(adv_path, dataset_name, dataset_name, transformers, attack)
    print(adv_file_path)
    return np.load(adv_file_path)


def get_prediction_train_test_test_adv(layer, X_train, X_test, X_test_adv, batch_size=500, train_adv=False):
    X_train_1 = get_output_bs_layer(layer, X_train, batch_size=batch_size, desc='train' if not train_adv else 'train adv')
    X_test_1 = get_output_bs_layer(layer, X_test, batch_size=batch_size, desc='test')
    X_test_adv_1 = get_output_bs_layer(layer, X_test_adv, batch_size=batch_size, desc='attack')

    return X_train_1, X_test_1, X_test_adv_1


def flatten_datasets(train, test, adv):
    a = train.shape[1]
    b = train.shape[2]
    c = train.shape[3]
    X_train_o = train.reshape((train.shape[0], a * b * c))
    X_test_o = test.reshape((test.shape[0], a * b * c))
    X_test_adv_o = adv.reshape((adv.shape[0], a * b * c))

    return X_train_o, X_test_o, X_test_adv_o
