import tensorflow as tf
import numpy as np

EPS = 1e-4
    
class TrainHelper:
    def __init__(self, model, regressor, lr=1e-3, inv_iterations=50):
        self.model = model
        # the regressor maps inducting feature points
        # to their ``labels``
        self.regressor = regressor
        self.n_points = model.inputs[1].shape[-1]
        self.inv_iterations = inv_iterations
        
        # self.dummy is a workaround 
        # for using inducting points
        self.dummy = tf.constant(np.eye(self.n_points, dtype=np.float32))
        
        self.opt = tf.keras.optimizers.Adam(lr=lr)
        
    def iterative_inverse(self, A, niter=4, eps=0.02):
        # Newton-Schulz method
        n = tf.shape(A)[0]
        I = tf.eye(n)
        V = tf.eye(n) / (tf.linalg.norm(A) + eps)
        for i in tf.range(niter):
            V = tf.matmul(V, 2*I - tf.matmul(A, V))
        return V

    def train_step(self, x, y):
        if len(y.shape) == 1:
            y = y[:, None]
        loss = self._train_step(x, y)
        loss = loss.numpy()
        return loss

    def predict(self, x):
        x = x.astype(np.float32)
        y_pred, K_pred = self._predict(x)
        y_pred = y_pred.numpy()
        K_pred = K_pred.numpy()
        return y_pred, K_pred
    
    @tf.function
    def _predict(self, x):
        i_points, K_ii, K_ix, K_xx = self.model((x, self.dummy), training=False)
        # compute real inverse at prediction time
        n_i = tf.cast(self.n_points,tf.float32)
        K_ii_inv = tf.linalg.inv(K_ii + tf.eye(n_i)*EPS)
        K_xi     = tf.transpose(K_ix)
        P        = tf.matmul(K_xi, K_ii_inv)
        y_pred   = tf.matmul(P, self.regressor(i_points, training=False))
        K_pred   = K_xx - tf.matmul(P, K_ix)
        return y_pred, K_pred

    @tf.function
    def _train_step(self, x, y):
        with tf.GradientTape() as tape:
            # i: induction points
            # x: Training points
            # y: Training labels 
            i_points, K_ii, K_ix, K_xx = self.model((x, self.dummy))
            
            y_i_points = self.regressor(i_points)
            
            bs  = tf.shape(x)[0]
            n_i = tf.cast(self.n_points,tf.float32)
            K_ii_inv    = tf.linalg.inv(K_ii + tf.eye(n_i)*EPS)
            #K_ii_inv    = self.iterative_inverse(K_ii, niter=self.inv_iterations)
            K_xi        = tf.transpose(K_ix)
            P           = tf.matmul(K_xi, K_ii_inv)

            y_pred      = tf.matmul(P, y_i_points)
            K_x_given_i = K_xx - tf.matmul(P, K_ix) + tf.eye(bs)*EPS
            
            # t0 is the loss used to encourage
            # inducing point to 
            # t1 + t2 is the negative log-likelihood
            # defined on the inducing point
            
            
            #t0 = 0.5*tf.reduce_mean((y-y_pred)**2) #+ 0.01*tf.reduce_mean(tf.square(K_x_given_i))
            #t1 = tf.matmul(tf.matmul( y_i_points, K_ii_inv, transpose_a=True), y_i_points)
            #t2 = tf.math.log(tf.linalg.det(K_ii) + EPS)
            #objective = t0 + 1./(2*n_i) * ( t1 + t2 + n_i*tf.math.log(2.*np.pi))
           
            t0 = 1/(2.*tf.cast(bs, tf.float32)) * tf.math.log(tf.linalg.det(K_x_given_i) + EPS)
            K_x_given_i_inv = tf.linalg.inv(K_x_given_i) 

            #K_x_given_i_inv = self.iterative_inverse(K_x_given_i, \
            #                                         niter=self.inv_iterations)
            diff = y-y_pred
            print(diff.shape, K_x_given_i_inv.shape)
            t1 = tf.matmul(tf.matmul(diff, K_x_given_i_inv, transpose_a=True), diff)
            t1 =  1/(2.*tf.cast(bs, tf.float32))*tf.reduce_mean(t1)
            objective = t0 + t1 + 0.5*tf.math.log(2.*np.pi)
                 
        tw = self.model.trainable_weights + self.regressor.trainable_weights
        grads = tape.gradient(objective, tw)
        self.opt.apply_gradients(zip(grads, tw))
        return objective
