#!/usr/bin/env python3
"""
Created on Mar. 24th, 2023

@author: Anonymous
"""
import tensorflow as tf
import tensorflow.keras as K
from keras import regularizers
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir, os.pardir))

__all__ = [
    "LossLayer",
]

class LossLayer(K.layers.Layer):
    """
    `LossLayer` layer used to calculate contrastive loss.
    """

    def __init__(self, d_contra, data_mode, loss_mode, **kwargs):
        """
        Initialize `LossLayer` object.

        Args:
            d_contra: int - The dimension of contrastive space after projection layer.
            data_mode: str - The mode of data format.
            loss_mode: str - The mode of loss calculation.
            kwargs: The arguments related to initialize `tf.keras.layers.Layer`-style object.

        Returns:
            None
        """
        # First call super class init function to set up `K.layers.Layer`
        # style model and inherit it's functionality.
        super(LossLayer, self).__init__(**kwargs)

        # Initialize parameters.
        assert data_mode in ["point", "sequence"], (
            "ERROR: Unknown data mode {} in LossLayer."
        ).format(data_mode)
        assert loss_mode in ["clip", "clip_orig", "unicl"], (
            "ERROR: Unknown loss mode {} in LossLayer."
        ).format(loss_mode)
        self.d_contra = d_contra; self.data_mode = data_mode; self.loss_mode = loss_mode

    """
    network funcs
    """
    # def build func
    def build(self, input_shape):
        """
        Build the network on the first call of `call`.

        Args:
            input_shape: tuple - The shape of input data.

        Returns:
            None
        """
        # Initialize temperature variables according to `loss_mode`.
        if self.loss_mode == "clip":
            self.tau = tf.Variable(0.15, trainable=False, name="tau")
        elif self.loss_mode == "clip_orig":
            self.t = tf.Variable(0.0, trainable=False, name="t")
        elif self.loss_mode == "unicl":
            self.t = tf.Variable(0.0, trainable=False, name="t")
        # Initialize projection layers for `Z` & `Y`.
        # proj_z - (batch_size, n_features) -> (batch_size, d_contra)
        self.proj_z = K.models.Sequential(name="proj_z")
        self.proj_z.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            self.d_contra, activation="gelu", kernel_initializer="he_uniform",
            kernel_regularizer=K.regularizers.l2(l2=0.01),
            # Defaullt `Dense` layer parameters.
            use_bias=True, bias_initializer="zeros", bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        if self.data_mode == "point":
            pass
        elif self.data_mode == "sequence":
            self.proj_z.add(K.layers.Flatten(data_format="channels_last"))
        # proj_y - (batch_size, n_features) -> (batch_size, d_contra)
        self.proj_y = K.models.Sequential(name="proj_y")
        self.proj_y.add(K.layers.Dense(
            # Modified `Dense` layer parameters.
            self.d_contra, activation="gelu", kernel_initializer="he_uniform",
            kernel_regularizer=K.regularizers.l2(l2=0.01),
            # Defaullt `Dense` layer parameters.
            use_bias=True, bias_initializer="zeros", bias_regularizer=None,
            activity_regularizer=None, kernel_constraint=None, bias_constraint=None
        ))
        if self.data_mode == "point":
            pass
        elif self.data_mode == "sequence":
            self.proj_y.add(K.layers.Flatten(data_format="channels_last"))
        # Build super to set up `K.layers.Layer`-style model and inherit it's network.
        super(LossLayer, self).build(input_shape)

    # def call func
    def call(self, inputs):
        """
        Forward layers in `LossLayer` to get the final result.

        Args:
            inputs: (2[list],) - The input data, including [Z,Y].

        Returns:
            loss: tf.float32 - The corresponding contrastive loss.
            prob_matrix: (batch_size, batch_size) - The un-normalized probability matrix.
        """
        # Initialize `Z` & `Y` from `inputs`.
        # [Z,Y] - (batch_size, n_features), label - (batch_size, n_labels)
        X_f, y_true = inputs; Z, Y = X_f; label_z, label_y = y_true
        # Use `proj_*` layers to get the embeddings.
        # emb_[z,y] - (batch_size, d_contra)
        emb_z = tf.linalg.normalize(self.proj_z(Z), ord="euclidean", axis=1)[0]
        emb_y = tf.linalg.normalize(self.proj_y(Y), ord="euclidean", axis=1)[0]
        # Calculate `loss` and related matrices according to `loss_mode`.
        if self.loss_mode == "clip":
            # Calculate `loss_matrix` from `emb_z` and `emb_y`.
            # loss_matrix - (batch_size, batch_size)
            loss_matrix = tf.exp(tf.matmul(emb_z, tf.transpose(emb_y)) / self.tau)
            # Calculate `loss_z` & `loss_y` from `loss_matrix`, which is `z`x`y`.
            # loss_[z,y] - (batch_size,), loss - tf.float32
            labels = tf.eye(loss_matrix.shape[0], dtype=loss_matrix.dtype)
            loss_z = tf.squeeze(tf.subtract(tf.math.log(tf.reduce_sum(loss_matrix, axis=0, keepdims=True)),
                tf.math.log(tf.reduce_sum(tf.multiply(loss_matrix, labels), axis=0, keepdims=True))))
            loss_y = tf.squeeze(tf.subtract(tf.math.log(tf.reduce_sum(loss_matrix, axis=1, keepdims=True)),
                tf.math.log(tf.reduce_sum(tf.multiply(loss_matrix, labels), axis=1, keepdims=True))))
            loss = (tf.reduce_mean(loss_z) + tf.reduce_mean(loss_y)) / 2
        elif self.loss_mode == "clip_orig":
            # Calculate `loss_matrix` from `emb_z` and `emb_y`.
            # loss_matrix - (batch_size, batch_size)
            loss_matrix = tf.matmul(emb_z, tf.transpose(emb_y)) * tf.exp(self.t)
            # Calculate `loss_z` & `loss_y` from `loss_matrix`, which is `z`x`y`.
            # loss_[z,y] - (batch_size,), loss - tf.float32
            labels = tf.eye(loss_matrix.shape[0], dtype=loss_matrix.dtype)
            loss_z = tf.nn.softmax_cross_entropy_with_logits(logits=loss_matrix, labels=labels, axis=0)
            loss_y = tf.nn.softmax_cross_entropy_with_logits(logits=loss_matrix, labels=labels, axis=1)
            loss = (tf.reduce_mean(loss_z) + tf.reduce_mean(loss_y)) / 2
        elif self.loss_mode == "unicl":
            # Calculate `loss_matrix` from `emb_z` and `emb_y`.
            # loss_matrix - (batch_size, batch_size)
            loss_matrix = tf.matmul(emb_z, tf.transpose(emb_y)) * tf.exp(self.t)
            # Construct `labels` according to one-hot `labels`.
            # labels - (batch_size, batch_size)
            labels = tf.matmul(label_z, tf.transpose(label_y))
            # Calculate `loss_z` & `loss_y` from `loss_matrix`, which is `z`x`y`.
            # loss_[z,y] - (batch_size,), loss - tf.float32
            loss_z = tf.nn.softmax_cross_entropy_with_logits(logits=loss_matrix, labels=labels, axis=0)
            loss_y = tf.nn.softmax_cross_entropy_with_logits(logits=loss_matrix, labels=labels, axis=1)
            loss = (tf.reduce_mean(loss_z) + tf.reduce_mean(loss_y)) / 2
        # Return the final `loss` & `prob_matrix`.
        return loss, loss_matrix

if __name__ == "__main__":
    import numpy as np

    # Initialize macros.
    batch_size = 16; n_features_y = 1024; n_features_z = 1024; n_labels = 15
    d_contra = 256; loss_mode = "clip_orig"

    # Instantiate LossLayer.
    ll_inst = LossLayer(d_contra, loss_mode)
    # Initialize input data, including `Z` & `Y`.
    Z = tf.random.normal((batch_size, n_features_z), dtype=tf.float32)
    Y = tf.random.normal((batch_size, n_features_y), dtype=tf.float32)
    label_z = tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32)
    label_y = tf.cast(tf.one_hot(tf.cast(tf.range(batch_size), dtype=tf.int64), n_labels), dtype=tf.float32)
    # Forward layers in `ll_inst`.
    loss, prob_matrix = ll_inst(((Z, Y), (label_z, label_y)))

