""" REFERENCES
The following three references are enough. 
Weights will be downloaded automatically when instantiating a model. 
They will be stored at ~/.keras/models/.
- Available models & How to use Keras pre-trained models: https://keras.io/api/applications/#usage-examples-for-image-classification-models
- Available models detailed: https://www.tensorflow.org/api_docs/python/tf/keras/applications (EfficientNet: https://www.tensorflow.org/api_docs/python/tf/keras/applications/efficientnet)
- How to fine-tune models: https://keras.io/guides/transfer_learning/#do-a-round-of-finetuning-of-the-entire-model

=====================================================================================================

# https://keras.io/api/applications/#usage-examples-for-image-classification-models
Model               Size Top-1 Accuracy Top-5 Accuracy  Parameters Depth
-------------------------------------------------------------------------
Xception           88 MB          0.790          0.945  22,910,480 126
VGG16             528 MB          0.713          0.901 138,357,544 23
VGG19             549 MB          0.713          0.900 143,667,240 26
ResNet50           98 MB          0.749          0.921  25,636,712 -
ResNet101         171 MB          0.764          0.928  44,707,176 -
ResNet152         232 MB          0.766          0.931  60,419,944 -
ResNet50V2         98 MB          0.760          0.930  25,613,800 -    impl wip
ResNet101V2       171 MB          0.772          0.938  44,675,560 -    impl wip
ResNet152V2       232 MB          0.780          0.942  60,380,648 -    impl wip
InceptionV3        92 MB          0.779          0.937  23,851,784 159
InceptionResNetV2 215 MB          0.803          0.953  55,873,736 572
MobileNet          16 MB          0.704          0.895   4,253,864 88
MobileNetV2        14 MB          0.713          0.901   3,538,984 88   impl wip
DenseNet121        33 MB          0.750          0.923   8,062,504 121
DenseNet169        57 MB          0.762          0.932  14,307,880 169
DenseNet201        80 MB          0.773          0.936  20,242,984 201
NASNetMobile       23 MB          0.744          0.919   5,326,716 -
NASNetLarge       343 MB          0.825          0.960  88,949,818 -
EfficientNetB0     29 MB              -              -   5,330,571 -    impl wip
EfficientNetB1     31 MB              -              -   7,856,239 -    impl wip
EfficientNetB2     36 MB              -              -   9,177,569 -    impl wip
EfficientNetB3     48 MB              -              -  12,320,535 -    impl wip
EfficientNetB4     75 MB              -              -  19,466,823 -    impl wip
EfficientNetB5    118 MB              -              -  30,562,527 -    impl wip
EfficientNetB6    166 MB              -              -  43,265,143 -    impl wip
EfficientNetB7    256 MB              -              -  66,658,687 -    impl wip

# https://www.tensorflow.org/api_docs/python/tf/keras/applications/mobilenet_v2
          Version|Input size|MACs (M)|Parameters(M)|Top 1 Accuracy|Top 5 Accuracy|
mobilenet_v2_1.4 |       224|     582|        6.06 |         75.0 |         92.5 | 
mobilenet_v2_1.3 |       224|     509|        5.34 |         74.4 |         92.1 | 
mobilenet_v2_1.0 |       224|     300|        3.47 |         71.8 |         91.0 | 
mobilenet_v2_1.0 |       192|     221|        3.47 |         70.7 |         90.1 | 
mobilenet_v2_1.0 |       160|     154|        3.47 |         68.8 |         89.0 | 
mobilenet_v2_1.0 |       128|      99|        3.47 |         65.3 |         86.9 | 
mobilenet_v2_1.0 |        96|      56|        3.47 |         60.3 |         83.2 | 
mobilenet_v2_0.75|       224|     209|        2.61 |         69.8 |         89.6 | 
mobilenet_v2_0.75|       192|     153|        2.61 |         68.7 |         88.9 | 
mobilenet_v2_0.75|       160|     107|        2.61 |         66.4 |         87.3 | 
mobilenet_v2_0.75|       128|      69|        2.61 |         63.2 |         85.3 | 
mobilenet_v2_0.75|        96|      39|        2.61 |         58.8 |         81.6 | 
mobilenet_v2_0.5 |       224|      97|        1.95 |         65.4 |         86.4 | 
mobilenet_v2_0.5 |       192|      71|        1.95 |         63.9 |         85.4 | 
mobilenet_v2_0.5 |       160|      50|        1.95 |         61.0 |         83.2 | 
mobilenet_v2_0.5 |       128|      32|        1.95 |         57.7 |         80.8 | 
mobilenet_v2_0.5 |        96|      18|        1.95 |         51.2 |         75.8 | 
mobilenet_v2_0.35|       224|      59|        1.66 |         60.3 |         82.9 |
mobilenet_v2_0.35|       192|      43|        1.66 |         58.2 |         81.2 | 
mobilenet_v2_0.35|       160|      30|        1.66 |         55.7 |         79.1 | 
mobilenet_v2_0.35|       128|      20|        1.66 |         50.8 |         75.0 | 
mobilenet_v2_0.35|        96|      11|        1.66 |         45.5 |         70.4 |
MACs stands for Multiply Adds Classification Checkpoint
"""

""" HOW TO IMPLEMENT YET ANOTHER MODEL
1. Add a line to dc_available.
2. Add a function to ModelManagerFE. You can modify, e.g., ResNet50V2Model. 
"""

import numpy as np
import tensorflow as tf
import tensorflow.keras.applications as tka

class ModelManagerFE():
    """
    # Current support
    (base model, last Dense layer) and (trainable or not, trainable or not) = 
    (imagenet pretr, imagenet pretr) and (False, True), 
    (imagenet pretr, imagenet pretr) and (True,  True), 
    (imagenet pretr, random init)    and (False, True),
    (imagenet pretr, random init)    and (True,  True),
    (random init,    random init)    and (True,  True).
    More last layers (e.g., 3 Dense layers after GAP layer of ResNet) can be 
    equipped similarly.
    """
    def __init__(self, name_model, weights_base, weights_top, trainable_base, **kwargs):
        """
        # Args
        name_model: A str. One of the keywords in dc_available.
        weights_base: One of None (random initialization), 
            or 'imagenet' (pre-training on ImageNet).
        weights_top: One of None (random initialization), 
            or 'imagenet' (pre-training on ImageNet).
        trainable_base: If False, the base model's weights will be fixed.
            If True, the base model's weights will be trainable.
        # Returns
        model: A Keras Model.
        """
        # Add new models at the last
        dc_available = {
            "ResNet50V2": self.ResNet50V2Model,
            "ResNet101V2": self.ResNet101V2Model,
            "ResNet152V2": self.ResNet152V2Model,
            "EfficientNetB0": self.EfficientNetB0Model,
            "EfficientNetB1": self.EfficientNetB1Model,
            "EfficientNetB2": self.EfficientNetB2Model,
            "EfficientNetB3": self.EfficientNetB3Model,
            "EfficientNetB4": self.EfficientNetB4Model,
            "EfficientNetB5": self.EfficientNetB5Model,
            "EfficientNetB6": self.EfficientNetB6Model,
            "EfficientNetB7": self.EfficientNetB7Model,
        }

        # Assertion
        assert weights_top == None or weights_base == "imagenet"
        if not np.any(name_model == np.array(list(dc_available.keys()))):
            raise ValueError("Wrong name_model. It should be in {}.".\
                format(list(dc_available.keys())))

        # Model definition
        if weights_base == "imagenet" or weights_top == "imagenet":
            weights = 'imagenet'
        else:
            weights = None

        if weights_top == "imagenet":
            include_top = True
        else:
            include_top = False

        model = dc_available[name_model](include_top=include_top, 
            weights=weights, trainable_base=trainable_base, **kwargs)

        return model

    def add_last_dense(self, base_model, classes, classifier_activation): # sub-function
        x = base_model.output
        preds = tf.keras.layers.Dense(classes, activation=classifier_activation, name="fc_logit")(x)
        model = tf.keras.models.Model(inputs=base_model.input, outputs=preds)

        return model

    def ResNet50V2Model(self, include_top=True, weights="imagenet",
        trainable_base=True,
        input_tensor=None, input_shape=None, pooling=None, classes=1000,
        classifier_activation='softmax'):
        """
        # Remark
        kwargs in __init__ are
            input_tensor, 
            input_shape, 
            pooling, 
            classes,
            classifier_activation.
        # Args
        include_top: Whether to include the fully-connected layer at the top of
            the network.
        weights: One of None (random initialization), 
            'imagenet' (pre-training on ImageNet), 
            or the path to the weights file to be loaded.
            weights will be downloaded in ~/.keras/.
        input_tensor: Optional Keras tensor (i.e. output of layers.Input())
            to use as image input for the model.
        # Args if include_top = False
        input_shape: Optional shape tuple, only to be specified if include_top
            is False (otherwise the input shape has to be (224, 224, 3) 
            (with 'channels_last' data format) or (3, 224, 224) 
            (with 'channels_first' data format). 
            It should have exactly 3 inputs channels, and width and height 
            should be no smaller than 32. E.g. (200, 200, 3) would be 
            one valid value.
        pooling: Optional pooling mode for feature extraction when 
            include_top is False.
            - None means that the output of the model will be the 4D tensor 
            output of the last convolutional block.
            - "avg" means that global average pooling will be applied to the 
            output of the last convolutional block, and thus the output 
            of the model will be a 2D tensor.
            - "max" means that global max pooling will be applied.
        # Args if include_top = True
        classes: Optional number of classes to classify images into, 
            only to be specified if include_top is True, and if no weights 
            argument is specified (= random initialization).
        classifier_activation: A str, None, or callable. 
            The activation function to use on the "top" (last) layer. 
            Ignored unless include_top=True. 
            Set classifier_activation=None to return the logits of the 
            "top" layer. When loading pretrained weights, classifier_activation
            can only be None or "softmax".
        # Returns
        model: A Keras Model.
        preproc: A preprocessing function with arguments x and data_format.
            # Args
            x: A floating point numpy.array or a tf.Tensor, 3D or 4D with 
                3 color channels, with values in the range [0, 255]. 
                The preprocessed data are written over the input data 
                if the data types are compatible. To avoid this behaviour,
                numpy.copy(x) can be used.
            data_format: Optional data format of the image tensor/array. 
                Defaults to None, in which case the global setting 
                tf.keras.backend.image_data_format() is used 
                (unless you changed it, it defaults to "channels_last").
            # Ref
            https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/preprocess_input
        # Ref
        tf.keras.applications.resnet_v2.ResNet50V2: 
            https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/ResNet50V2
        """
        base_model = tka.resnet_v2.ResNet50V2(include_top, weights,
            input_tensor, input_shape, pooling, classes,
            classifier_activation)
        base_model.trainable = trainable_base

        if not include_top:
            model = self.add_last_dense(base_model, classes, classifier_activation)
        else:
            model = base_model

        preproc = tka.resnet_v2.preprocess_input

        return model, preproc

    def ResNet101V2Model(self, include_top=True, weights="imagenet",
        trainable_base=True,
        input_tensor=None, input_shape=None, pooling=None, classes=1000,
        classifier_activation='softmax'):
        """
        # Remark
        kwargs in __init__ are
            input_tensor, 
            input_shape, 
            pooling, 
            classes,
            classifier_activation.
        # Args
        include_top: Whether to include the fully-connected layer at the top of
            the network.
        weights: One of None (random initialization), 
            'imagenet' (pre-training on ImageNet), 
            or the path to the weights file to be loaded.
            weights will be downloaded in ~/.keras/.
        input_tensor: Optional Keras tensor (i.e. output of layers.Input())
            to use as image input for the model.
        # Args if include_top = False
        input_shape: Optional shape tuple, only to be specified if include_top
            is False (otherwise the input shape has to be (224, 224, 3) 
            (with 'channels_last' data format) or (3, 224, 224) 
            (with 'channels_first' data format). 
            It should have exactly 3 inputs channels, and width and height 
            should be no smaller than 32. E.g. (200, 200, 3) would be 
            one valid value.
        pooling: Optional pooling mode for feature extraction when 
            include_top is False.
            - None means that the output of the model will be the 4D tensor 
            output of the last convolutional block.
            - "avg" means that global average pooling will be applied to the 
            output of the last convolutional block, and thus the output 
            of the model will be a 2D tensor.
            - "max" means that global max pooling will be applied.
        # Args if include_top = True
        classes: Optional number of classes to classify images into, 
            only to be specified if include_top is True, and if no weights 
            argument is specified (= random initialization).
        classifier_activation: A str, None, or callable. 
            The activation function to use on the "top" (last) layer. 
            Ignored unless include_top=True. 
            Set classifier_activation=None to return the logits of the 
            "top" layer. When loading pretrained weights, classifier_activation
            can only be None or "softmax".
        # Returns
        model: A Keras Model.
        preproc: A preprocessing function with arguments x and data_format.
            # Args
            x: A floating point numpy.array or a tf.Tensor, 3D or 4D with 
                3 color channels, with values in the range [0, 255]. 
                The preprocessed data are written over the input data 
                if the data types are compatible. To avoid this behaviour,
                numpy.copy(x) can be used.
            data_format: Optional data format of the image tensor/array. 
                Defaults to None, in which case the global setting 
                tf.keras.backend.image_data_format() is used 
                (unless you changed it, it defaults to "channels_last").
            # Ref
            https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/preprocess_input
        # Ref
        tf.keras.applications.resnet_v2.ResNet101V2: 
            https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/ResNet101V2
        """
        base_model = tka.resnet_v2.ResNet101V2(include_top, weights,
            input_tensor, input_shape, pooling, classes,
            classifier_activation)
        base_model.trainable = trainable_base

        if not include_top:
            model = self.add_last_dense(base_model, classes, classifier_activation)
        else:
            model = base_model

        preproc = tka.resnet_v2.preprocess_input

        return model, preproc

    def ResNet152V2Model(self, include_top=True, weights="imagenet",
        trainable_base=True,
        input_tensor=None, input_shape=None, pooling=None, classes=1000,
        classifier_activation='softmax'):
        """
        # Remark
        kwargs in __init__ are
            input_tensor, 
            input_shape, 
            pooling, 
            classes,
            classifier_activation.
        # Args
        include_top: Whether to include the fully-connected layer at the top of
            the network.
        weights: One of None (random initialization), 
            'imagenet' (pre-training on ImageNet), 
            or the path to the weights file to be loaded.
            weights will be downloaded in ~/.keras/.
        input_tensor: Optional Keras tensor (i.e. output of layers.Input())
            to use as image input for the model.
        # Args if include_top = False
        input_shape: Optional shape tuple, only to be specified if include_top
            is False (otherwise the input shape has to be (224, 224, 3) 
            (with 'channels_last' data format) or (3, 224, 224) 
            (with 'channels_first' data format). 
            It should have exactly 3 inputs channels, and width and height 
            should be no smaller than 32. E.g. (200, 200, 3) would be 
            one valid value.
        pooling: Optional pooling mode for feature extraction when 
            include_top is False.
            - None means that the output of the model will be the 4D tensor 
            output of the last convolutional block.
            - "avg" means that global average pooling will be applied to the 
            output of the last convolutional block, and thus the output 
            of the model will be a 2D tensor.
            - "max" means that global max pooling will be applied.
        # Args if include_top = True
        classes: Optional number of classes to classify images into, 
            only to be specified if include_top is True, and if no weights 
            argument is specified (= random initialization).
        classifier_activation: A str, None, or callable. 
            The activation function to use on the "top" (last) layer. 
            Ignored unless include_top=True. 
            Set classifier_activation=None to return the logits of the 
            "top" layer. When loading pretrained weights, classifier_activation
            can only be None or "softmax".
        # Returns
        model: A Keras Model.
        preproc: A preprocessing function with arguments x and data_format.
            # Args
            x: A floating point numpy.array or a tf.Tensor, 3D or 4D with 
                3 color channels, with values in the range [0, 255]. 
                The preprocessed data are written over the input data 
                if the data types are compatible. To avoid this behaviour,
                numpy.copy(x) can be used.
            data_format: Optional data format of the image tensor/array. 
                Defaults to None, in which case the global setting 
                tf.keras.backend.image_data_format() is used 
                (unless you changed it, it defaults to "channels_last").
            # Ref
            https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/preprocess_input
        # Ref
        tf.keras.applications.resnet_v2.ResNet152V2: 
            https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/ResNet152V2
        """
        base_model = tka.resnet_v2.ResNet152V2(include_top, weights,
            input_tensor, input_shape, pooling, classes,
            classifier_activation)
        base_model.trainable = trainable_base

        if not include_top:
            model = self.add_last_dense(base_model, classes, classifier_activation)
        else:
            model = base_model

        preproc = tka.resnet_v2.preprocess_input

        return model, preproc

    def EfficientNetB0Model(self, include_top=True, weights='imagenet',
        input_tensor=None, input_shape=None, pooling=None, classes=1000,
        classifier_activation='softmax'):
        """
        https://www.tensorflow.org/api_docs/python/tf/keras/applications/efficientnet/EfficientNetB0
        https://www.tensorflow.org/api_docs/python/tf/keras/applications/efficientnet/preprocess_input
        """
        model = tka.efficientnet.EfficientNetB0(include_top, weights,
            input_tensor, input_shape, pooling, classes,
            classifier_activation)
        preproc = tka.efficientnet.preprocess_input
        return model, preproc

