from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import backend as K
from tensorflow.keras import layers
import os

BATCH_SIZE_STAT     = 250       # batch size for statistics calculations
BATCH_SIZE_TRAIN    = 64        # batch size for model training
BATCH_SIZE_TEST     = 1         # batch size for model evaluation


def load_test_data(batch_size=BATCH_SIZE_TEST, class_mode=None, model='Inceptionv3'):

    if model == 'ViT-B':
        from vit_keras.vit import preprocess_inputs as preprocess_input
        if K.image_data_format() == 'channels_first':
            INPUT_SHAPE = (3, 224, 224)
            IMAGE_SHAPE = INPUT_SHAPE[1:]
        else:
            INPUT_SHAPE = (224, 224, 3)
            IMAGE_SHAPE = INPUT_SHAPE[:-1]
        preprocessing_fn = preprocess_input

    elif model == 'ResNet50':
        from tensorflow.keras.applications.resnet50 import preprocess_input
        if K.image_data_format() == 'channels_first':
            INPUT_SHAPE = (3, 224, 224)
            IMAGE_SHAPE = INPUT_SHAPE[1:]
        else:
            INPUT_SHAPE = (224, 224, 3)
            IMAGE_SHAPE = INPUT_SHAPE[:-1]
        preprocessing_fn = preprocess_input

    elif model == 'Inceptionv3':
        from tensorflow.keras.applications.inception_v3 import preprocess_input
        if K.image_data_format() == 'channels_first':
            INPUT_SHAPE = (3, 299, 299)
            IMAGE_SHAPE = INPUT_SHAPE[1:]
        else:
            INPUT_SHAPE = (299, 299, 3)
            IMAGE_SHAPE = INPUT_SHAPE[:-1]

    elif model == 'WideResNet50':
        def normalize_WRN(x):
            x/=255.
            mean=[0.485, 0.456, 0.406]
            variance=[0.229, 0.224, 0.225]
            return (x-mean)/variance
        INPUT_SHAPE = (224, 224, 3)
        IMAGE_SHAPE = INPUT_SHAPE[:-1]
        preprocessing_fn = normalize_WRN

    else:
        raise ValueError('Unknown model name: {}'.format(model))

    test_data_preprocessor = ImageDataGenerator(
        preprocessing_function=preprocessing_fn,
        samplewise_center=False,
    )
    test_generator = test_data_preprocessor.flow_from_directory(
        os.path.join(os.path.split(__file__)[0], 'val_data'),
        target_size=IMAGE_SHAPE,
        batch_size=batch_size,
        class_mode=class_mode,
        color_mode='rgb',
        shuffle=False,
    )
    return test_generator
