#!/usr/bin/python
# -*- encoding: utf-8 -*-

# For general purposes
import numpy as np
import random
import time

from sklearn.model_selection import train_test_split
from sklearn.metrics import euclidean_distances

# For model definition/training/evaluation and loss definition
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.python.keras import Model
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.layers import Conv2D
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.layers import Flatten
from tensorflow.python.keras.layers import MaxPooling2D
from tensorflow.python.keras.layers import Dropout
from tensorflow.python.keras.layers import Lambda
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops

# Set seed for reproducing
seed = 11502
random.seed(seed)
np.random.seed(seed)


def normalize_vector(features: np.array) -> np.array:
    features = features.T
    features_max = np.amax(features, axis=0)
    features_min = np.amin(features, axis=0)
    return ((features - features_min) / (features_max - features_min)).T


class Biometrics(object):
    def __init__(
            self,
            features_1: np.ndarray,
            labels_1: np.ndarray,
            features_2: np.ndarray or None = None,
            labels_2: np.ndarray or None = None,
            recalls_k: np.ndarray or list or tuple = (1, 2, 4, 8),
            limit: int = 10000,
            resolution: int = 5000,
            normalize_features: bool = False):
        self._intra_session = False
        if features_2 is None:
            features_2 = features_1
            labels_2 = labels_1
            self._intra_session = True

        labels_1 = labels_1.reshape(-1, 1)
        labels_2 = labels_2.reshape(-1, 1)

        if normalize_features:
            features_1 = normalize_vector(features_1)
            features_2 = normalize_vector(features_2)

        min_features_1 = min([len(features_1), limit])
        min_features_2 = min([len(features_2), limit])
        distances = euclidean_distances(features_1[:min_features_1, :], features_2[:min_features_2, :])

        self._get_genuine_and_impostor_distances(distances, labels_1[:min_features_1], labels_2[:min_features_2])
        self._calculate_eer(resolution)
        self._calculate_recalls_k(distances, labels_1, recalls_k)

    def _get_genuine_and_impostor_distances(self, distances: np.ndarray, labels_1: np.ndarray, labels_2: np.ndarray) -> None:
        # Calculate the genuine and impostor distributions
        self._genuine, self._impostor = [], []
        self._g_size, self._i_size = 0, 0

        for i, yi in enumerate(labels_1):
            for j, yj in enumerate(labels_2):
                if self._intra_session and i == j:
                    continue
                if yi == yj:
                    self._genuine.append(distances[i, j])
                    self._g_size += 1
                else:
                    self._impostor.append(distances[i, j])
                    self._i_size += 1
        self._genuine = np.sort(np.asarray(self._genuine))
        self._impostor = np.sort(np.asarray(self._impostor))

    def _calculate_eer(self, resolution: int = 5000) -> None:
        # Set up environment
        self._fmr = np.zeros(resolution)
        self._fnmr = np.zeros(resolution)
        min_distance = np.amin(self._genuine)
        max_distance = np.amax(self._impostor)
        t = np.linspace(min_distance, max_distance, resolution)

        # Calculate False Match Rate and False NonMatch Rate for different thresholds
        i, g = 0, 0
        for t_val in range(resolution):
            for i in range(i, self._i_size):
                if self._impostor[i] > t[t_val]:
                    break
            for g in range(g, self._g_size):
                if self._genuine[g] > t[t_val]:
                    break
            self._fmr[t_val] = i / self._i_size
            self._fnmr[t_val] = (self._g_size - g) / self._g_size

        # Equal Error Rate (EER)
        abs_diffs = np.abs(self._fmr - self._fnmr)
        min_index = np.argmin(abs_diffs)
        self._eer = (self._fmr[min_index] + self._fnmr[min_index]) / 2
        self._threshold = t[min_index]

    def _compute_recall_at_k(self, distances, k, labels, dim):
        num_correct = 0
        for i in range(dim):
            this_gt_labels_idx = labels[i]
            this_row = distances[i, :]
            knn_indices = this_row.argsort()
            knn_labels_indices = labels[knn_indices][:k]
            if np.sum(np.in1d(knn_labels_indices, this_gt_labels_idx)) > 0:
                num_correct += 1
        return num_correct / dim

    def _calculate_recalls_k(self, distances: np.ndarray, labels: np.ndarray, recalls_k: np.ndarray or list or tuple) -> None:
        d = np.sqrt(np.abs(distances))
        np.fill_diagonal(d, np.inf)
        self._recalls = []
        for k in recalls_k:
            self._recalls.append(self._compute_recall_at_k(d, k, labels, d.shape[0]))

    def get_eer(self) -> float:
        return self._eer

    def get_recalls_at_k(self) -> list:
        return self._recalls


def pairwise_distance(feature, squared=False):
    """ Computes the pairwise distance matrix with numerical stability. """
    pairwise_distances_squared = math_ops.add(
        math_ops.reduce_sum(math_ops.square(feature), axis=[1], keepdims=True),
        math_ops.reduce_sum(
            math_ops.square(
                array_ops.transpose(feature)),
            axis=[0], keepdims=True)) - 2.0 * math_ops.matmul(feature, array_ops.transpose(feature))
    pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0)
    error_mask = math_ops.less_equal(pairwise_distances_squared, 0.0)
    if squared:
        pairwise_distances = pairwise_distances_squared
    else:
        pairwise_distances = math_ops.sqrt(pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16)
    pairwise_distances = math_ops.multiply(pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask)))
    num_data = array_ops.shape(feature)[0]
    mask_off_diagonals = array_ops.ones_like(pairwise_distances) - array_ops.diag(array_ops.ones([num_data]))
    pairwise_distances = math_ops.multiply(pairwise_distances, mask_off_diagonals)
    return pairwise_distances


# noinspection DuplicatedCode
def calculate_genuine_and_impostor_distances(_embeddings, _labels):
    pdist_matrix = pairwise_distance(_embeddings, squared=True)
    adjacency = math_ops.equal(_labels, array_ops.transpose(_labels))
    adjacency_not = math_ops.logical_not(adjacency)

    genuines = tf.reshape(pdist_matrix[adjacency], (-1, 1))
    impostors = tf.reshape(pdist_matrix[adjacency_not], (-1, 1))

    # Avoid the distances of the sample from itself
    # Avoid the number of _embeddings which is the number of samples
    genuines = tf.sort(genuines, direction='ASCENDING')
    genuines = genuines[_embeddings.shape[0]:]

    return genuines, impostors


# noinspection PyUnresolvedReferences,DuplicatedCode
def dloss(y_true, y_pred):
    genuines, impostors = calculate_genuine_and_impostor_distances(y_pred[:, 1:], tf.cast(y_true[:, :1], dtype='int32'))

    g_mean, g_std = tf.math.reduce_mean(genuines), tf.math.reduce_std(genuines)
    i_mean, i_std = tf.math.reduce_mean(impostors), tf.math.reduce_std(impostors)

    decidability = tf.math.divide(
        tf.math.abs(tf.math.subtract(i_mean, g_mean)),
        tf.math.sqrt(0.5 * ((g_std ** 2) + (i_std ** 2)))
    )

    return tf.clip_by_value(t=1/decidability, clip_value_min=0, clip_value_max=10000)


def create_model(embedding_size=256):
    _base_network = Sequential()
    _base_network.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu', input_shape=(28, 28, 1)))
    _base_network.add(MaxPooling2D(pool_size=2))
    _base_network.add(Dropout(0.3))
    _base_network.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
    _base_network.add(MaxPooling2D(pool_size=2))
    _base_network.add(Dropout(0.3))
    _base_network.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
    _base_network.add(MaxPooling2D(pool_size=2))
    _base_network.add(Dropout(0.3))
    _base_network.add(Flatten())
    _base_network.add(Dense(embedding_size, activation=None))
    _base_network.add(Lambda(lambda x: tf.math.l2_normalize(x, axis=1), name='embeddings'))
    return _base_network


def load_mnist():
    # Re-shape the images data
    (_x_train, _y_train), (_x_test, _y_test) = keras.datasets.mnist.load_data()
    _x_train = np.reshape(_x_train, (_x_train.shape[0], _x_train.shape[1], _x_train.shape[2], 1))
    _x_test = np.reshape(_x_test, (_x_test.shape[0], _x_train.shape[1], _x_train.shape[2], 1))
    _x_train, _x_val, _y_train, _y_val = train_test_split(_x_train, _y_train, test_size=0.2, stratify=_y_train)
    _x_train = _x_train.astype('float32')
    _x_test = _x_test.astype('float32')
    return (_x_train, _y_train), (_x_val, _y_val), (_x_test, _y_test)


def get_model_trained_for_verification(
        identification_model, layer_name=None):
    # Creating empty model
    if layer_name is None:
        try:
            _verification_model = Sequential()
            # Grabbing the weights from the trained network
            for layer_with_weights in identification_model.layers[2].layers:
                _verification_model.add(layer_with_weights)
        except Exception:
            _verification_model = Model(inputs=identification_model.layers[2].layers[0].input,
                                        outputs=identification_model.layers[2].layers[-1].output)
    else:
        _verification_model = Model(inputs=identification_model.input, outputs=identification_model.get_layer(layer_name).output)
    return _verification_model


def convert_model_to_verification_mode(
        model,
        input_shape):
    # Getting the input and the outputs
    _input_images = Input(shape=input_shape, name='input_image')  # input layer for images
    _input_labels = Input(shape=(1,), name='input_label')  # input layer for labels
    _embeddings = model([_input_images])  # output of network -> embeddings
    _labels_plus_embeddings = concatenate([_input_labels, _embeddings])  # concatenating the labels + embeddings

    # Defining a model with inputs (images, labels) and outputs (labels_plus_embeddings)
    return Model(inputs=[_input_images, _input_labels], outputs=_labels_plus_embeddings)


# noinspection DuplicatedCode
def protocol(_embedding_size,
             _embedding_layer,
             _batch_size=150,
             _epochs=100,
             _number_of_classes=10,
             _shuffle_in_training=True,
             _lr=1e-3,
             _train_portion=0.9,
             _patience=2):

    _input_image_shape = (28, 28, 1)

    # The data, split between train and test sets
    (x_train, y_train), (x_val, y_val), (x_test, y_test) = load_mnist()

    x_train = x_train.astype('float32')
    x_val = x_val.astype('float32')
    x_test = x_test.astype('float32')

    ##############################################################################################################################
    # -- Getting model for verification and preparing ir for embedding training
    ##############################################################################################################################
    model = create_model()

    _base_model = Model(inputs=model.input, outputs=model.get_layer(_embedding_layer).output)

    model = convert_model_to_verification_mode(_base_model, _input_image_shape)

    ##############################################################################################################################
    # -- Preparing training session
    ##############################################################################################################################
    _opt = Adam(lr=_lr)  # choose optimiser. RMS is good too!
    model.compile(loss=dloss, optimizer=_opt)

    ##############################################################################################################################
    # -- Preparing embeddings
    ##############################################################################################################################
    # Uses 'dummy' embeddings + dummy gt labels. Will be removed as soon as loaded, to free memory
    _dummy_gt_train = np.zeros((len(x_train), _embedding_size + 1))
    _dummy_gt_train[:, 0] = y_train
    _dummy_gt_val = np.zeros((len(x_val), _embedding_size + 1))
    _dummy_gt_val[:, 0] = y_val

    ####################################################################################################################
    # -- Starting training
    ####################################################################################################################
    start = time.time()
    model.fit(
        x=[x_train, y_train],
        y=_dummy_gt_train,
        batch_size=_batch_size,
        shuffle=_shuffle_in_training,
        epochs=_epochs,
        validation_data=([x_val, y_val], _dummy_gt_val))
    stop = time.time()
    print('Time spent to train: {}s'.format(stop - start))

    ##############################################################################################################################
    # -- Evaluating model trained
    ##############################################################################################################################
    model = get_model_trained_for_verification(model)

    print('Evaluating model.......')
    recalls = (1, 2, 4, 8)
    embeddings = model.predict(x_test)
    biometric_metrics = Biometrics(features_1=embeddings, labels_1=y_test, recalls_k=recalls)
    recalls_after = biometric_metrics.get_recalls_at_k()
    print('         EER: {:.2f}%'.format(biometric_metrics.get_eer() * 100))
    for k, recall_after in zip(recalls, recalls_after):
        print(' Recall@K{:3d}: {:.2f}'.format(k, recall_after))


if __name__ == "__main__":
    protocol(_embedding_size=256, _embedding_layer='embeddings')
