import tensorflow as tf
import numpy as np
import scipy.stats
# import tensorflow_probability as tfp
# from tensorflow_probability import distributions as tfd

EPS = 1e-4
MODE = 'probit'

def N(x):
    return 1/(tf.sqrt(np.pi * 2)) * tf.exp(-0.5 * x**2)
    
# def probit(x):
#     return tf.sigmoid(x * tf.sqrt(8/np.pi))


def probit_(x):
    return scipy.stats.norm.cdf(x, loc=0, scale=1).astype(np.float32)

@tf.function(input_signature=[tf.TensorSpec([None, 1], tf.float32)])
def probit(x):
    y = tf.numpy_function(probit_, [tf.squeeze(x)], tf.float32)
    return tf.reshape(y, tf.shape(x))


def p(f, y, mode="probit"):
    if mode=='probit':
        return probit(f*y)
    else:
        return tf.sigmoid(f*y)

def grad_log_p(f, y,mode="probit"):
    if mode=='probit':
        return y * N(f) / p(f, y,mode)
    else:
        return (y+1.)/2. - p(f,1.,mode)

def grad_grad_log_p(f, y,mode="probit"):
    if mode=='probit':
        return -N(f)**2 / p(f, y,mode) **2 - y * f * N(f) / p(f, y,mode)
    else:
        return -p(f,1.,mode) * ( 1-p(f,1.,mode))
    
class TrainHelper:
    def __init__(self, model, regressor, lr=1e-3, inv_iterations=50):
        self.model = model
        # the classifier maps inducting feature points
        # to their ``labels``
        self.regressor = regressor
        self.n_points = model.inputs[1].shape[-1]
        self.inv_iterations = inv_iterations
        
        # self.pseudo_labels = np.ones(self.n_points, dtype=np.float32)
        # self.pseudo_labels[self.n_points//2:] = -1.
        # self.pseudo_labels = tf.constant(self.pseudo_labels)
        
        # self.dummy is a workaround 
        # for using inducting points
        self.dummy = tf.constant(np.eye(self.n_points, dtype=np.float32))
        
        self.opt = tf.keras.optimizers.Adam(lr=lr)
        
    def iterative_inverse(self, A, niter=4, eps=0.02):
        # Newton-Schulz method
        n = tf.shape(A)[0]
        I = tf.eye(n)
        V = tf.eye(n) / (tf.linalg.norm(A) + eps)
        for i in tf.range(niter):
            V = tf.matmul(V, 2*I - tf.matmul(A, V))
        return V
    
    # def iterative_f(self, K_ii, y_i, max_iter=20):
    #     #  y_i are in {-1, 1}
    #     f = tf.zeros_like(y_i)
    #     W = tf.zeros_like(y_i)
    #     I = self.dummy
    #     W_= tf.zeros_like(I)
    #     for i in tf.range(max_iter):
    #         # compute hessian
    #         W  = -grad_grad_log_p(f, y_i, mode=MODE)
    #         W_ = I * W
    #         T1 = K_ii @ tf.linalg.inv(I + (W_ @ K_ii) + I*EPS)
    #         T2 = W_ @ f + grad_log_p(f, y_i, mode=MODE)
    #         f  = T1 @ T2
    #     return W_, f

    def iterative_f(self, K, y, max_iter=20):
        f = tf.zeros_like(y)
        W = tf.zeros_like(y)
        I = tf.eye(tf.shape(y)[0])
        W_ = tf.zeros_like(I)
        for i in tf.range(max_iter):
            # compute hessian
            W = -grad_grad_log_p(f, y)
            W_ = I * W
            
            # better to invert the SVD here an clip to EPS
            T1 = K @ tf.linalg.inv(I + (W_ @ K) + I*EPS)
            T2 = W_ @ f + grad_log_p(f, y)
            f  = T1 @ T2
        return W_, f

    def train_step(self, x, y):
        if len(y.shape) == 1:
            y = y[:, None]
        loss = self._train_step(x, y)
        loss = loss.numpy()
        return loss

    def predict(self, x):
        x = x.astype(np.float32)
        y_pred, K_pred = self._predict(x)
        K_pred = tf.linalg.diag_part(K_pred)
        mu = y_pred
        sigma2 = K_pred

        mu = tf.reshape(mu, [-1, 1])
        sigma2 = tf.reshape(sigma2, [-1, 1])

        z = mu / tf.sqrt(1 + sigma2)
        mu_shift = sigma2 * N(z) / (probit(z) * tf.sqrt(1 + sigma2))
        sigma2_shift = -sigma2**2 * N(z)/ ( (1+sigma2)*probit(z) ) * (z + N(z)/probit(z))
        new_sigma2 = sigma2 + sigma2_shift
        new_mu = mu + mu_shift
        
        y_pred = new_mu
        K_pred = new_sigma2

        y_pred = y_pred.numpy().ravel()
        K_pred = K_pred.numpy().ravel()

        return y_pred, K_pred

    @tf.function
    def _predict(self, x):
        i_points, K_ii, K_ix, K_xx, x = self.model((x, self.dummy), training=False)
        # compute real inverse at prediction time
        n_i = tf.cast(self.n_points, tf.float32)
        K_ii_inv = tf.linalg.inv(K_ii + tf.eye(n_i) * EPS)
        K_xi = tf.transpose(K_ix)
        P = tf.matmul(K_xi, K_ii_inv)
        y_pred = tf.squeeze(tf.matmul(P, self.regressor(i_points, training=False)))
        K_pred = K_xx - tf.matmul(P, K_ix)

        # p_y = probit( (y_pred/tf.sqrt(1.0 + f_pred_diag))[...,None])

        return y_pred, K_pred

    @tf.function
    def _train_step(self, x, y):
        with tf.GradientTape() as tape:
            # bs: batch size
            # n_i: # inducing points
            I = self.dummy
            bs = tf.shape(x)[0]
            den = tf.cast(bs, tf.float32)
            n_i = tf.cast(self.n_points, tf.float32)
            # i: induction points
            # x: Training points
            # y: Training labels
            i_points, K_ii, K_ix, K_xx, x = self.model((x, I))
            K_xi = tf.transpose(K_ix)

            r_i_points = self.regressor(i_points)

            K_ii_inv = tf.linalg.inv(K_ii + EPS*tf.eye(K_ii.shape[0]))

            part_K = K_xi @ K_ii_inv

            K = K_xx - part_K @ K_ix
            K_inv = tf.linalg.inv(K + EPS*tf.eye(K.shape[0]))

            W, f_hat = self.iterative_f(K, y, max_iter=self.inv_iterations)
            W = tf.stop_gradient(W)
            f_hat = tf.stop_gradient(f_hat)
            a = part_K @ r_i_points
            f_diff_a = f_hat - a

            logB_part1 = -0.5/den*tf.math.log(tf.linalg.det(W + K_inv + EPS*tf.eye(K.shape[0])) + EPS)
            logB_part2 = -0.5/den*tf.math.log(tf.linalg.det(K + EPS*tf.eye(K.shape[0])) + EPS)
            ll_first = logB_part1 + logB_part2
            ll_second = -0.5/den * tf.transpose(f_diff_a) @ K_inv @ f_diff_a 
            objective = -(ll_first + ll_second) #+ tf.reduce_mean(K_inv**2)
        tw = self.model.trainable_weights + self.regressor.trainable_weights
        grads = tape.gradient(objective, tw)
        self.opt.apply_gradients(zip(grads, tw))
        return objective
