import tensorflow as tf
import numpy as np
import torch


class ModelAdapter():
    def __init__(self, model, num_classes=10):
        """
        Please note that model should be tf.keras model without activation function 'softmax'
        """
        self.num_classes = num_classes
        self.tf_model = model
        self.data_format = self.__check_channel_ordering()

    def __tf_to_pt(self, tf_tensor):
        """ Private function
        Convert tf tensor to pt format

        Args:
            tf_tensor: (tf_tensor) TF tensor

        Retruns:
            pt_tensor: (pt_tensor) Pytorch tensor
        """

        cpu_tensor = tf_tensor.numpy()
        pt_tensor = torch.from_numpy(cpu_tensor).cuda()

        return pt_tensor

    def set_data_format(self, data_format):
        """
        Set data_format manually

        Args:
            data_format: A string, whose value should be either 'channels_last' or 'channels_first'
        """

        if data_format != 'channels_last' or data_format != 'channels_first':
            raise ValueError("data_format should be either 'channels_last' or 'channels_first'")

        self.data_format = data_format

    def __check_channel_ordering(self):
        """ Private function
        Determinate TF model's channel ordering based on model's information.
        Default ordering is 'channels_last' in TF.
        However, 'channels_first' is used in Pytorch.

        Returns:
            data_format: A string, whose value should be either 'channels_last' or 'channels_first'
        """

        data_format = None

        # Get the ordering of the dimensions in data from TF model
        for L in self.tf_model.layers:
            if isinstance(L, tf.keras.layers.Conv2D):
                print("[INFO] set data_format = '{:s}'".format(L.data_format))
                data_format = L.data_format
                break

        # Guess the ordering of the dimensions in data by input dimensions which sould be 4-D tensor
        if data_format is None:
            print("[WARNING] Can not find Conv2D layer")
            input_shape = self.tf_model.input_shape

            # Assume that input is *colorful image* whose dimensions should be [batch_size, img_w, img_h, 3]
            if input_shape[3] == 3:
                print("[INFO] Because detecting input_shape[3] == 3, set data_format = 'channels_last'")
                data_format = 'channels_last'

            # Assume that input is *gray image* whose dimensions should be [batch_size, img_w, img_h, 1]
            elif input_shape[3] == 1:
                print("[INFO] Because detecting input_shape[3] == 1, set data_format = 'channels_last'")
                data_format = 'channels_last'

            # Assume that input is *colorful image* whose dimensions should be [batch_size, 3, img_w, img_h]
            elif input_shape[1] == 3:
                print("[INFO] Because detecting input_shape[1] == 3, set data_format = 'channels_first'")
                data_format = 'channels_first'

            # Assume that input is *gray image* whose dimensions should be [batch_size, 1, img_w, img_h]
            elif input_shape[1] == 1:
                print("[INFO] Because detecting input_shape[1] == 1, set data_format = 'channels_first'")
                data_format = 'channels_first'

            else:
                print("[ERROR] Unknow case")

        return data_format

    # Common function which may be called in tf.function #
    def __get_logits(self, x_input):
        """ Private function
        Get model's pre-softmax output in inference mode

        Args:
            x_input: (tf_tensor) Input data

        Returns:
            logits: (tf_tensor) Logits
        """

        return self.tf_model(x_input, training=False)

    def __get_xent(self, logits, y_input):
        """ Private function
        Get cross entropy loss

        Args:
            logits: (tf_tensor) Logits.
            y_input: (tf_tensor) Label.

        Returns:
            xent: (tf_tensor) Cross entropy
        """

        return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_input)

    def __get_dlr(self, logit, y_input):
        """ Private function
        Get DLR loss

        Args:
            logit: (tf_tensor) Logits
            y_input: (tf_tensor) Input label

        Returns:
            loss: (tf_tensor) DLR loss
        """

        # logit
        logit_sort = tf.sort(logit, axis=1)

        # onthot_y
        y_onehot = tf.one_hot(y_input, self.num_classes, dtype=tf.float32)
        logit_y = tf.reduce_sum(y_onehot * logit, axis=1)

        # z_i
        logit_pred = tf.reduce_max(logit, axis=1)
        cond = (logit_pred == logit_y)
        z_i = tf.where(cond, logit_sort[:, -2], logit_sort[:, -1])

        # loss
        z_y = logit_y
        z_p1 = logit_sort[:, -1]
        z_p3 = logit_sort[:, -3]

        loss = - (z_y - z_i) / (z_p1 - z_p3 + 1e-12)
        return loss

    def __get_dlr_target(self, logits, y_input, y_target):
        """ Private function
        Get targeted version of DLR loss

        Args:
            logit: (tf_tensor) Logits
            y_input: (tf_tensor) Input label
            y_target: (tf_tensor) Input targeted label

        Returns:
            loss: (tf_tensor) Targeted DLR loss
        """

        x = logits
        x_sort = tf.sort(x, axis=1)
        y_onehot = tf.one_hot(y_input, self.num_classes)
        y_target_onehot = tf.one_hot(y_target, self.num_classes)
        loss = -(tf.reduce_sum(x * y_onehot, axis=1) - tf.reduce_sum(x * y_target_onehot, axis=1)) / (
                    x_sort[:, -1] - .5 * x_sort[:, -3] - .5 * x_sort[:, -4] + 1e-12)

        return loss

    # function called by public API directly #
    @tf.function
    @tf.autograph.experimental.do_not_convert
    def __get_jacobian(self, x_input):
        """ Private function
        Get Jacoian

        Args:
            x_input: (tf_tensor) Input data

        Returns:
            jaconbian: (tf_tensor) Jacobian
        """

        with tf.GradientTape(watch_accessed_variables=False) as g:
            g.watch(x_input)
            logits = self.__get_logits(x_input)

        jacobian = g.batch_jacobian(logits, x_input)

        return logits, jacobian

    @tf.function
    @tf.autograph.experimental.do_not_convert
    def __get_grad_xent(self, x_input, y_input):
        """ Private function
        Get gradient of cross entropy

        Args:
            x_input: (tf_tensor) Input data
            y_input: (tf_tensor) Input label

        Returns:
            logits: (tf_tensor) Logits
            xent: (tf_tensor) Cross entropy
            grad_xent: (tf_tensor) Gradient of cross entropy
        """

        with tf.GradientTape(watch_accessed_variables=False) as g:
            g.watch(x_input)
            logits = self.__get_logits(x_input)
            xent = self.__get_xent(logits, y_input)

        grad_xent = g.gradient(xent, x_input)

        return logits, xent, grad_xent

    @tf.function
    @tf.autograph.experimental.do_not_convert
    def __get_grad_diff_logits_target(self, x, la, la_target):
        """ Private function
        Get difference of logits and corrospopnding gradient

        Args:
            x_input: (tf_tensor) Input data
            la: (tf_tensor) Input label
            la_target: (tf_tensor) Input targeted label

        Returns:
            difflogits: (tf_tensor) Difference of logits
            grad_diff: (tf_tensor) Gradient of difference of logits
        """

        la_mask = tf.one_hot(la, self.num_classes)
        la_target_mask = tf.one_hot(la_target, self.num_classes)

        with tf.GradientTape(watch_accessed_variables=False) as g:
            g.watch(x)
            logits = self.__get_logits(x)
            difflogits = tf.reduce_sum((la_target_mask - la_mask) * logits, axis=1)

        grad_diff = g.gradient(difflogits, x)

        return difflogits, grad_diff

    @tf.function
    @tf.autograph.experimental.do_not_convert
    def __get_grad_dlr(self, x_input, y_input):
        """ Private function
        Get gradient of DLR loss

        Args:
            x_input: (tf_tensor) Input data
            y_input: (tf_tensor) Input label

        Returns:
            logits: (tf_tensor) Logits
            val_dlr: (tf_tensor) DLR loss
            grad_dlr: (tf_tensor) Gradient of DLR loss
        """

        with tf.GradientTape(watch_accessed_variables=False) as g:
            g.watch(x_input)
            logits = self.__get_logits(x_input)
            val_dlr = self.__get_dlr(logits, y_input)

        grad_dlr = g.gradient(val_dlr, x_input)

        return logits, val_dlr, grad_dlr

    @tf.function
    @tf.autograph.experimental.do_not_convert
    def __get_grad_dlr_target(self, x_input, y_input, y_target):
        """ Private function
        Get gradient of targeted DLR loss

        Args:
            x_input: (tf_tensor) Input data
            y_input: (tf_tensor) Input label
            y_target: (tf_tensor) Input targeted label

        Returns:
            logits: (tf_tensor) Logits
            val_dlr: (tf_tensor) Targeted DLR loss
            grad_dlr: (tf_tensor) Gradient of targeted DLR loss
        """

        with tf.GradientTape(watch_accessed_variables=False) as g:
            g.watch(x_input)
            logits = self.__get_logits(x_input)
            dlr_target = self.__get_dlr_target(logits, y_input, y_target)

        grad_target = g.gradient(dlr_target, x_input)

        return logits, dlr_target, grad_target

    # Public API #
    def predict(self, x):
        """
        Get model's pre-softmax output in inference mode

        Args:
            x_input: (pytorch_tensor) Input data

        Returns:
            y: (pytorch_tensor) Pre-softmax output
        """

        # Convert pt_tensor to tf format
        x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
        if self.data_format == 'channels_last':
            x2 = tf.transpose(x2, perm=[0, 2, 3, 1])

        # Get result
        y = self.__get_logits(x2)

        # Convert result to pt format
        y = self.__tf_to_pt(y)

        return y

    def grad_logits(self, x):
        """
        Get logits and gradient of logits

        Args:
            x: (pytorch_tensor) Input data

        Returns:
            logits: (pytorch_tensor) Logits
            g2: (pytorch_tensor) Jacobian
        """

        # Convert pt_tensor to tf format
        x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
        if self.data_format == 'channels_last':
            x2 = tf.transpose(x2, perm=[0, 2, 3, 1])

        # Get result
        logits, g2 = self.__get_jacobian(x2)

        # Convert result to pt format
        if self.data_format == 'channels_last':
            g2 = tf.transpose(g2, perm=[0, 1, 4, 2, 3])
        logits = self.__tf_to_pt(logits)
        g2 = self.__tf_to_pt(g2)

        return logits, g2

    def get_logits_loss_grad_xent(self, x, y):
        """
        Get gradient of cross entropy

        Args:
            x: (pytorch_tensor) Input data
            y: (pytorch_tensor) Input label

        Returns:
            logits_val: (pytorch_tensor) Logits
            loss_indiv_val: (pytorch_tensor) Cross entropy
            grad_val: (pytorch_tensor) Gradient of cross entropy
        """

        # Convert pt_tensor to tf format
        x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
        y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
        if self.data_format == 'channels_last':
            x2 = tf.transpose(x2, perm=[0, 2, 3, 1])

        # Get result
        logits_val, loss_indiv_val, grad_val = self.__get_grad_xent(x2, y2)

        # Convert result to pt format
        if self.data_format == 'channels_last':
            grad_val = tf.transpose(grad_val, perm=[0, 3, 1, 2])
        logits_val = self.__tf_to_pt(logits_val)
        loss_indiv_val = self.__tf_to_pt(loss_indiv_val)
        grad_val = self.__tf_to_pt(grad_val)

        return logits_val, loss_indiv_val, grad_val

    def set_target_class(self, y, y_target):
        pass

    def get_grad_diff_logits_target(self, x, y, y_target):
        """
        Get difference of logits and corrospopnding gradient

        Args:
            x: (pytorch_tensor) Input data
            y: (pytorch_tensor) Input label
            y_target: (pytorch_tensor) Input targeted label

        Returns:
            difflogits: (pytorch_tensor) Difference of logits
            g2: (pytorch_tensor) Gradient of difference of logits
        """

        # Convert pt_tensor to tf format
        la = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
        la_target = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32)
        x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
        if self.data_format == 'channels_last':
            x2 = tf.transpose(x2, perm=[0, 2, 3, 1])

        # Get result
        difflogits, g2 = self.__get_grad_diff_logits_target(x2, la, la_target)

        # Convert result to pt format
        if self.data_format == 'channels_last':
            g2 = tf.transpose(g2, perm=[0, 3, 1, 2])
        difflogits = self.__tf_to_pt(difflogits)
        g2 = self.__tf_to_pt(g2)

        return difflogits, g2

    def get_logits_loss_grad_dlr(self, x, y):
        """
        Get gradient of DLR loss

        Args:
            x: (pytorch_tensor) Input data
            y: (pytorch_tensor) Input label

        Returns:
            logits_val: (pytorch_tensor) Logits
            loss_indiv_val: (pytorch_tensor) DLR loss
            grad_val: (pytorch_tensor) Gradient of DLR loss
        """

        # Convert pt_tensor to tf format
        x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
        y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
        if self.data_format == 'channels_last':
            x2 = tf.transpose(x2, perm=[0, 2, 3, 1])

        # Get result
        logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr(x2, y2)

        # Convert result to pt format
        if self.data_format == 'channels_last':
            grad_val = tf.transpose(grad_val, perm=[0, 3, 1, 2])
        logits_val = self.__tf_to_pt(logits_val)
        loss_indiv_val = self.__tf_to_pt(loss_indiv_val)
        grad_val = self.__tf_to_pt(grad_val)

        return logits_val, loss_indiv_val, grad_val

    def get_logits_loss_grad_target(self, x, y, y_target):
        """
        Get gradient of targeted DLR loss

        Args:
            x: (pytorch_tensor) Input data
            y: (pytorch_tensor) Input label
            y_target: (pytorch_tensor) Input targeted label

        Returns:
            logits_val: (pytorch_tensor) Logits
            loss_indiv_val: (pytorch_tensor) Targeted DLR loss
            grad_val: (pytorch_tensor) Gradient of targeted DLR loss
        """

        # Convert pt_tensor to tf format
        x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32)
        y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32)
        y_targ = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32)
        if self.data_format == 'channels_last':
            x2 = tf.transpose(x2, perm=[0, 2, 3, 1])

        # Get result
        logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr_target(x2, y2, y_targ)

        # Convert result to pt format
        if self.data_format == 'channels_last':
            grad_val = tf.transpose(grad_val, perm=[0, 3, 1, 2])
        logits_val = self.__tf_to_pt(logits_val)
        loss_indiv_val = self.__tf_to_pt(loss_indiv_val)
        grad_val = self.__tf_to_pt(grad_val)

        return logits_val, loss_indiv_val, grad_val