import numpy as np
import tensorflow as tf
from gpflow.base import Parameter
from gpflow.kernels import Stationary, IsotropicStationary
from gpflow.utilities import positive, to_default_float
import tensorflow_probability as tfp

class SPKernel(IsotropicStationary):
    def __init__(self,
                 variance: float = 1,
                 alpha: float = 0,
                 beta: float = 0,
                 gamma: float = 0,
                 lengthscales: float = 1,
                 kernel_type: str = "SP",
                 trainable_variance: bool = True,
                 trainable_alpha: bool = True,
                 trainable_beta: bool = True,
                 trainable_gamma: bool = False,
                 trainable_lengthscales: bool = False,
                 **kwargs):
        """
        Args:
            variance (float): kernel variance for exponential kernels (ESP)
            alpha, beta (float): kernel trainable parameters
            lengthscales (float): default to 1
            kernel_type (str): kernel type "SP"
            trainable_variance (boolean): set to True for exponential kernels
        """
        for kwarg in kwargs:
            if kwarg not in {"name", "active_dims"}:
                raise TypeError(f"Unknown keyword argument: {kwarg}")

        super().__init__(**kwargs)
        self.variance = Parameter(variance, transform=tfp.bijectors.SoftClip(
                    to_default_float(0.01),
                    to_default_float(100),
                )) if trainable_variance else variance
        self.lengthscales = Parameter(lengthscales, transform=positive()) if trainable_lengthscales else lengthscales
        self.alpha = Parameter(alpha+1, transform=tfp.bijectors.SoftClip(
                    to_default_float(0.01),
                    to_default_float(100),
                )) if trainable_alpha else alpha
        self.beta = Parameter(beta+1, transform=tfp.bijectors.SoftClip(
                    to_default_float(0.01),
                    to_default_float(100),
                )) if trainable_beta else beta
        self.gamma = Parameter(gamma+1, transform=tfp.bijectors.SoftClip(
                    to_default_float(0.01),
                    to_default_float(100),
                )) if trainable_gamma else gamma
        self.kernel_type = kernel_type
        self.exp_option = trainable_variance
        self._validate_ard_active_dims(self.lengthscales)

        super(Stationary).__init__(**kwargs)

    def _extract_graph_feature(self, G):

        return np.array(G.S)

    def _get_V_Matrix(self, G_data):
        '''calculate feature matrix for given graph dataset
        Args:
            G_data: graph dataset, size n
        Returns:
            V_matrix: shape (n x Ln)
        '''
        V_matrix = np.zeros(shape=(len(G_data), 1))
        return tf.convert_to_tensor(V_matrix, dtype=tf.float64)

    def calculate_path_covariance(self, G1, G2):
        Kp = 0.
        if self.kernel_type == "SSP":
            for l in range(min(G1.N, G2.N)):
                Kp += G1.D[l] * G2.D[l]
        elif self.kernel_type == "SP":
            for l in range(min(G1.N, G2.N)):
                for l1 in range(G1.Ln):
                    for l2 in range(G2.Ln):
                        Kp += G1.P[(l, l1, l2)] * G2.P[(l, l1, l2)]
        else:
            NotImplementedError("Graph kernel not supported.")

        Kp /= (G1.N * (G1.N - 1) / 2) * (G2.N * (G2.N - 1) / 2)

        return Kp

    def calculate_node_feature_covariance(self, G1, G2):
        Kn = 0.
        for i in range(G1.Ln):
            Kn += G1.S[i] * G2.S[i]
        Kn /= (G1.N * G2.N)
        return Kn

    def calculate_edge_feature_covariance(self, G1, G2):
        Ke = 0.
        for u in range(G1.N):
            for v in range(G2.N):
                if G1.edge_attr[u, v] is not None and G2.edge_attr[u, v] is not None and G1.edge_attr[u, v]==G2.edge_attr[u, v]:
                    Ke += 1
        return Ke / (G1.N * (G1.N - 1) / 2)


    def K(self, Kp, Kn, Ke):
        K =  tf.cast(self.alpha, dtype=tf.float64) * Kp
        if self.beta > 0:
            K += tf.cast(self.beta, dtype=tf.float64) * Kn
        if self.gamma > 0:
            K += tf.cast(self.gamma, dtype=tf.float64) * Ke
        if self.exp_option:
            K = self.variance * tf.exp(K)
        return K

    def calculate_shift(self, G_data):
        Kp, Kn, Ke = self.extract_kernel_matrix(G_data, full_cov=False)
        Kxx = tf.cast(self.alpha, dtype=tf.float64) * Kp
        if self.beta > 0:
            Kxx += tf.cast(self.beta, dtype=tf.float64) * Kn
        if self.gamma > 0:
            Kxx += tf.cast(self.gamma, dtype=tf.float64) * Ke
        Kxx = np.diag(Kxx)
        shift = (np.min(Kxx) + np.max(Kxx)) / 2
        return shift

    def extract_kernel_matrix(self, G_data1, G_data2=None, full_cov=True):
        G_data2 = G_data1 if G_data2 is None else G_data2
        m, n = len(G_data1), len(G_data2)
        Kp = np.zeros(shape=(m, n))
        Kn = np.zeros(shape=(m, n))
        Ke = np.zeros(shape=(m, n))
        for idx1, G1 in enumerate(G_data1):
            for idx2, G2 in enumerate(G_data2):
                if not full_cov and idx1 != idx2:
                    continue
                Kp[idx1, idx2] = self.calculate_path_covariance(G1, G2)
                if self.beta > 0:
                    Kn[idx1, idx2] = self.calculate_node_feature_covariance(G1, G2)
                if self.gamma > 0:
                    Ke[idx1, idx2] = self.calculate_edge_feature_covariance(G1, G2)

        return tf.convert_to_tensor(Kp, dtype=tf.float64), tf.convert_to_tensor(Kn, dtype=tf.float64), tf.convert_to_tensor(Ke, dtype=tf.float64)
