import jax.numpy as jnp
import timeit
import jax
from typing import Callable
from jax.scipy.linalg import cholesky, solve_triangular,solve

def sqeuclidean_distance(x: jnp.array, y: jnp.array) -> float:
    return jnp.sum((x - y) ** 2)


def distmat(func: Callable, x: jnp.ndarray, y: jnp.ndarray)-> float :
    """distance matrix"""
    return jax.vmap(lambda x1: jax.vmap(lambda y1: func(x1, y1))(y))(x)
  


def pdist_squareform(x: jnp.ndarray, y: jnp.ndarray)-> float :
    """squared euclidean distance matrix

    Notes
    -----
    This is equivalent to the scipy commands

    >>> from scipy.spatial.distance import pdist, squareform
    >>> dists = squareform(pdist(X, metric='sqeuclidean')
    """
    return distmat(sqeuclidean_distance, x, y)
  

def rbf(x: jnp.ndarray, y: jnp.ndarray)-> float :

  return jnp.exp(- 0.5 * pdist_squareform(x,y))

def is_equal(x: jnp.ndarray, y: jnp.ndarray)-> float :

  f = lambda  x,y : jnp.allclose(x,y)* 1.

  return distmat(f,x,y)


def constrainer(a,b):

  return lambda x:  a + (b-a)* jax.lax.logistic(x)


def unconstrainer(a,b):

  return lambda y : jax.lax.log((y-a)/(b-y))  

def update_cholesky(K_XX,L_XX,K_XX_inv,K_XZ,K_ZZ):
    
    S11 = L_XX
    #S12 = L_XX @ (K_XX_inv @ K_XZ)
    S12 = L_XX @ solve(K_XX,K_XZ,assume_a="pos")
    S21 = jnp.zeros_like(S12).T
    S22 = jnp.linalg.cholesky(K_ZZ-S12.T@S12).T

    
    ###K_XZ_XZ = L.T@L
    L = jnp.vstack ([
                jnp.hstack([S11,S12]),
                jnp.hstack([S21,S22])
                ])
    
    
    
    
    return L

  
class CodeTimer:
    def __init__(self, name=None):
        self.name = " '"  + name + "'" if name else ''

    def __enter__(self):
        self.start = timeit.default_timer()

    def __exit__(self, exc_type, exc_value, traceback):
        self.took = (timeit.default_timer() - self.start)
        print('Code block' + self.name + ' took: ' + str(self.took) + ' s')

