#coding:utf-8

from sklearn.gaussian_process.kernels import Kernel
from sklearn.kernel_approximation import Nystroem
from scipy.sparse import csr_array, csc_array, identity, lil_array
from scipy.sparse.linalg import norm
from sksparse.cholmod import cholesky_AAt
from tqdm import tqdm

from bdivrec.fabaphe.utils import power, chunks

class DenudedKernel(Kernel):
	def __init__(self, kernel="linear", n_components=100, nchunks=5000, seed=1234, beta=None):
		self.n_components = n_components
		self.seed = seed
		self.kernel_name = kernel
		self.n_components = n_components
		self.kernel_approx = Nystroem(kernel=kernel, random_state=seed, n_components=n_components)
		self.beta = beta
		self.Factor = None
		self.nchunks = nchunks
		self.is_fit = False
		self.name = f"{kernel}Kernel"

	def __call__(self, X, Y=None, force_fit=False):
		if (Y is None):
			KX = self.auxcall(X, force_fit=force_fit)
			return KX, None
		KX, KY = self.auxcall(X, force_fit=force_fit), self.auxcall(Y, force_fit=force_fit)
		return KX, KY
		
	def auxcall(self, X, force_fit=False):
		if (not self.is_fit):
			self.kernel_approx.fit(X)
			self.is_fit = True
			f_mapX = self.kernel_approx.transform(X)
		elif (force_fit): ## to compute volumes meaningfully
			kernel_approx = Nystroem(kernel=self.kernel_name, random_state=self.seed, n_components=min(X.shape[0], self.n_components))
			kernel_approx.fit(X)
			f_mapX = kernel_approx.transform(X)
		else:
			f_mapX = self.kernel_approx.transform(X)
		if (self.beta is not None):
			f_mapX *= (f_mapX>self.beta).astype(int)
		KX = csr_array(f_mapX)
		return KX
		
	def diag(self, KX, eta=0, mode=["exact", "approx"][0]):
		N = KX.shape[0]
		## note that it works for any kernel, 
		## thanks to the Nystroem approximation 
		if (KX.shape[0] == KX.shape[1]): ## diagonal is already computed
			return KX.diagonal()#k=0)
		if (mode == "approx"):
			if (self.Factor is None):
				self.Factor = cholesky_AAt(csc_array(KX), beta=eta, mode="auto", ordering_method="default", use_long=None)
			L = self.Factor.L()
			D = csr_array(norm(L.T @ identity(N), axis=0).reshape((N,1)))
			D = power(D, 2)
		else:
			D = lil_array((N,1),dtype=float)
			for i, ilst in enumerate(pbar := tqdm(
                chunks(N, self.nchunks),
                position=4,
                leave=False
            )):
				pbar.set_description(f"Computing diagonal of kernel {min((i+1)*self.nchunks, N)}/{N}") 
				dt = (KX[ilst,:] @ KX[ilst,:].T).diagonal()#k=0)
				D[ilst] = dt.data
		return D.toarray()

	def is_stationary(self):
		return True
