import tensorflow as tf
from tensorflow import keras
layers = keras.layers

from utils.utils import get_time, exec_time
from losses.losses import different_weights_loss
from metrics.metrics import weights_convergence_metric

from functools import partial

def get_model(args, nb_classes, nb_channels, norm_layer):
    input_shape = (args.image_size, args.image_size, nb_channels)
    
    model_args = {
        "input_shape": input_shape,
        "nb_classes": nb_classes,
        "nb_features": args.nb_features,
        "normalization": norm_layer,
        "fc_end": args.fc_end,
        "add_layer_loss": args.add_layer_loss
    }
    if args.model == "standard_vgg32":
        model = create_VGG32(**model_args)
    elif args.model == "vgg16":
        model = create_VGG16(**model_args)
    elif args.model == "resnet50":
        model = create_ResNet50(**model_args)
    elif args.model == "efb0":
        model = create_EfficientNetB0(**model_args)
    else:
        raise NotImplementedError(f"Model {args.model} not implemented")
        
    return model

def wrap_model(model, input_shape, nb_classes, nb_features, 
               normalization=None, fc_end=False, add_layer_loss=True):
    
    inputs = keras.Input(shape=input_shape)
    
    # Rescale here (augmentation need data between 0 and 255)
    # Like that, test data and new data will be rescaled and normalized as part of the model
    out = keras.layers.Rescaling(1./255)(inputs)

    if normalization:
        out = normalization(out)
    
    out = model(out)
    
    if not fc_end:
        features_layer = layers.Conv2D(filters=nb_features*nb_classes, 
                            kernel_size=1, 
                            use_bias=False,
                            padding="same")
        out = features_layer(out)
    
    outputs = layers.GlobalAveragePooling2D(name="features_layer")(out)

    if fc_end:
        outputs = layers.Dense(nb_classes*nb_features, use_bias=False)(outputs)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    if not fc_end:
        if nb_features > 1 and add_layer_loss:
        # no need to add loss to maximize std if nb_features = 1 since it's always 0
            features_layer.add_loss(different_weights_loss(features_layer.trainable_variables[0], 
                                                        nb_classes,
                                                        nb_features))
    
    return model

def conv_block(inputs, 
               filters, 
               activation="leaky_relu",
               padding="same",
               bias=False,
               initializer="he_normal", # he_uniform
               dropout=0.2):
    
    out = inputs
    
    if dropout:
        out = layers.Dropout(dropout)(out)

    initializer_conv = partial(tf.keras.initializers.RandomNormal, mean=0.0, stddev=0.05)
        
    for i, f in enumerate(filters):
        out = layers.Conv2D(f, 3, 
                            padding=padding, 
                            strides=1 if i != len(filters)-1 else 2,
                            kernel_initializer=initializer_conv(), 
                            use_bias=bias)(out)
        out = layers.BatchNormalization()(out)
        
        if activation == "leaky_relu":
            out = layers.LeakyReLU(0.2)(out)
        else:
            out = layers.Activation(activation)(out)

    return out

def create_VGG32(input_shape, nb_classes, nb_features, 
                 normalization=None, fc_end=False, add_layer_loss=False):
    architecture = [
        [64,64,128],
        [128,128,128],
        [128,128,128]
    ]
    dropout = [0.2, 0.2, 0.2]

    inputs = keras.Input(shape=input_shape)
    
    # Rescale here (augmentation need data between 0 and 255)
    # Like that, test data and new data will be rescaled and normalized as part of the model
    out = keras.layers.Rescaling(1./255)(inputs)

    if normalization:
        out = normalization(out)

    for filters, d in zip(architecture, dropout):
        out = conv_block(out, filters=filters, dropout=d)
        
    if not fc_end:
        features_layer = layers.Conv2D(filters=nb_features*nb_classes, 
                            kernel_size=3, 
                            use_bias=False,
                            padding="same",
                            name="last_conv")
        # features_layer.add_loss(different_weights_loss(features_layer.get_weights()[0], nb_classes))
        out = features_layer(out)
    
    outputs = layers.GlobalAveragePooling2D(name="features_layer")(out)

    if fc_end:
        outputs = layers.Dense(nb_classes, use_bias=False)(outputs)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    if not fc_end:
        if nb_features > 1 and add_layer_loss:
        # no need to add loss to maximize std if nb_features = 1 since it's always 0
            features_layer.add_loss(different_weights_loss(features_layer.trainable_variables[0], 
                                                        nb_classes,
                                                        nb_features))
    
    return model

def create_VGG16(input_shape, nb_classes, nb_features,
                 normalization=None, fc_end=False, add_layer_loss=True):
    
    vgg16 = tf.keras.applications.vgg16.VGG16(
        include_top=False,
        weights=None,
        input_tensor=None,
        input_shape=input_shape,
        pooling=None
    )
    
    return wrap_model(vgg16, input_shape, nb_classes, nb_features,
                      normalization=normalization, 
                      fc_end=fc_end, 
                      add_layer_loss=add_layer_loss)

def create_ResNet50(input_shape, nb_classes, nb_features,
                    normalization=None, fc_end=False, add_layer_loss=True):
    
    resnet50 = tf.keras.applications.resnet50.ResNet50(
        include_top=False,
        weights=None,
        input_tensor=None,
        input_shape=input_shape,
        pooling=None
    )
    
    return wrap_model(resnet50, input_shape, nb_classes, nb_features,
                      normalization=normalization,
                      fc_end=fc_end,
                      add_layer_loss=add_layer_loss)
    
def create_EfficientNetB0(input_shape, nb_classes, nb_features,
                            normalization=None, fc_end=False, add_layer_loss=True):
    
    efb0 = tf.keras.applications.efficientnet.EfficientNetB0(
        include_top=False,
        weights=None,
        input_tensor=None,
        input_shape=input_shape,
        pooling=None
    )
    
    return wrap_model(efb0, input_shape, nb_classes, nb_features,
                      normalization=normalization,
                      fc_end=fc_end,
                      add_layer_loss=add_layer_loss)


def predict_distance(model, ds, anchors):
    """Compute the distance between the predictions from the model on the dataset
    and the anchors. 
    """
    
    preds = model.predict(ds)
    
    dist = tf.norm(tf.expand_dims(preds, 1) - anchors, axis=2)
    
    return dist
