#coding:utf-8

import random
import os
import numpy as np
import math
import torch
from dppy.finite_dpps import FiniteDPP
#from scipy.sparse import linalg as spla
#import scipy.sparse as spr
from scipy.sparse.linalg import eigsh, svds#, splu
from scipy.sparse import csc_array, identity, csr_array, spdiags
from sksparse.cholmod import cholesky_AAt, cholesky, CholmodNotPositiveDefiniteError
#import torch

def seed_everything(seed):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.backends.cudnn.deterministic = True
	torch.use_deterministic_algorithms(True)
	random_state = np.random.default_rng(seed)
	return random_state

def chunks(N, n):
    """Yield successive n-sized chunks from range(N) while avoiding to store range(N) in memory."""
    i = 0
    while (i<N):
        yield range(i,min(i+n,N),1)
        i += n
    
## Determinantal Point Processes functions

## from https://github.com/laming-chen/fast-map-dpp/blob/master/dpp.py
## https://proceedings.neurips.cc/paper_files/paper/2018/file/dbbf603ff0e99629dda5d75b6f75f966-Paper.pdf
# adapted for L-gram factors
def greedy(LG, max_length, kernel, eta, epsilon=1E-10):
    """
    Our proposed fast implementation of the greedy algorithm
    :param kernel_matrix: 2-d array
    :param max_length: positive int
    :param epsilon: small positive scalar
    :return: list
    """
    #item_size = kernel_matrix.shape[0]
    item_size = LG.shape[0]
    cis = np.zeros((max_length, item_size))
    di2s = kernel.diag(LG, eta=eta).reshape((LG.shape[0],1))
    #di2s = np.copy(np.diag(kernel_matrix))
    selected_items = list()
    selected_item = np.argmax(di2s)
    selected_items.append(selected_item)
    while len(selected_items) < max_length:
        k = len(selected_items) - 1
        ci_optimal = cis[:k, selected_item]
        di_optimal = math.sqrt(di2s[selected_item])
        #elements = kernel_matrix[selected_item, :]
        if (LG.shape[0] == LG.shape[1]): ## LG is a matrix
            elements = LG[[selected_item],:]
        else: ## LG is a Gram factor
            elements = LG[[selected_item],:] @ LG.T
        eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal
        cis[k, :] = eis
        di2s -= np.square(eis).T
        #di2s -= np.square(eis)
        di2s[selected_item] = -np.inf
        selected_item = np.argmax(di2s)
        if di2s[selected_item] < epsilon:
            break
        selected_items.append(selected_item)
    return [int(i) for i in selected_items]

def SAMPLE(Q, M, B, K, seed, eta=1, n_components=100, maxiter=1000, tol=1e-3, rls_oversample_bless=4., rls_oversample_dppvfx=4., use_exact=10000):
    if (M.shape[0]<=use_exact): ## no need for approximations
        if (M.shape[0]==M.shape[1]):
            L = Q @ M @ Q
        else:
            L = (Q @ M) @ (Q @ M).T
        DPP = FiniteDPP('likelihood', **{"L": L.toarray()})
        DPP.sample_exact_k_dpp(size=B, random_state=seed)
        recs = DPP.list_of_samples[0]
    elif (M.shape[0]==M.shape[1]):
        #print("M is square (SAMPLE)")
        QM = Q @ M @ Q #+ eta*identity(Q.shape[0])
        k_value = QM.shape[0]-1 if (QM.shape[0]<=1000) else max(min(n_components, Q.shape[0]-1), B+1)
        e_vals_L, e_vecs = eigsh(QM, k=k_value, maxiter=maxiter, tol=tol, return_eigenvectors=True)
        e_vals_L, e_vecs = np.real(e_vals_L), np.real(e_vecs)
        ## Filter out nonpositive entries
        e_vecs = e_vecs[:,e_vals_L>0]
        e_vals_L = e_vals_L[e_vals_L>0]
        DPP = FiniteDPP(kernel_type='likelihood', projection=False, **{'L_eig_dec': (e_vals_L, e_vecs)})
        DPP.sample_exact_k_dpp(size=B, random_state=seed)
        recs = DPP.list_of_samples[0]
    else:
        LG = (Q @ M).toarray()
        def eval_L_linear(X, Y=None): 
            ## note that it works for any kernel, 
            ## thanks to the Nystroem approximation 
            X = np.atleast_2d(X)
            if Y is None:
                return X @ X.T
            else:
                Y = np.atleast_2d(Y)
                return X @ Y.T
        try:
            DPP = FiniteDPP(kernel_type='likelihood', projection=False, **{'L_eval_X_data': (eval_L_linear, LG)})
            DPP.sample_exact_k_dpp(size=B, random_state=seed, mode="alpha", 
                    params={
                        'rls_oversample_alphadpp': rls_oversample_dppvfx, # [2,10] 
                        'rls_oversample_bless' : rls_oversample_bless, # [2,10] 
                    }, verbose=False)
            recs = DPP.list_of_samples[0]
        except ValueError as e:
            assert "The expected sample size is smaller than the desired sample size or k (if sampling froma k-DPP)." in str(e)
            ## use quality scores instead without diversity
            recs = SAMPLE(Q, identity(Q.shape[0]), B, K, seed, eta=eta, n_components=n_components,maxiter=maxiter,tol=tol,)
    return recs

def MAX(Q, M, B, K, seed, eta=1):
    if (M.shape[0]==M.shape[1]):
        #print("M is square (MAX)")
        S = greedy(Q @ M @ Q, B, K, eta=eta) ## full matrix
    else:
        S = greedy(Q @ M, B, K, eta=eta) ## only the Gram factor
    if (len(S)<B):
        print(f"Number of current recommendations {len(S)} is too small (<{B}). Using quality scores to make recommendations")
        S = greedy(Q, B, K, eta=eta)
    return S

## https://gist.github.com/sumartoyo/edba2eee645457a98fdf046e1b4297e4
def vars(a, axis=None):
    """ Variance of sparse matrix a
    var = mean(a**2) - mean(a)**2
    """
    a_squared = a.copy()
    a_squared.data **= 2
    return a_squared.mean(axis) - np.square(a.mean(axis))

def stds(a, axis=None):
    """ Standard deviation of sparse matrix a
    std = sqrt(var(a))
    """
    return np.sqrt(vars(a, axis))

def volume(P, K, eta=0):
    ## Compute vol(Kernel(P))
    KP, _ = K(P, force_fit=True) 
    try:
        Factor = cholesky_AAt(csc_array(KP), beta=eta, mode="auto", ordering_method="default", use_long=None)
        if (Factor.D() <= 0).any():
            raise CholmodNotPositiveDefiniteError(0)
        ldet = Factor.logdet()
        if (np.isnan(ldet)):
            raise CholmodNotPositiveDefiniteError(0)
        val = math.sqrt(math.exp(ldet))
    except CholmodNotPositiveDefiniteError:
        val = 0
    return val
    
def set_score(Q, M=None, eta=0):
    try:
        if (M is not None):
            if (M.shape[0]==M.shape[1]):
                #print("M is square (MAX)")
                L = Q @ M @ Q ## full matrix
                Factor = cholesky(csc_array(L), beta=eta, mode="auto", ordering_method="default", use_long=None)
            else:
                L = Q @ M ## only the Gram factor
                Factor = cholesky_AAt(csc_array(L), beta=eta, mode="auto", ordering_method="default", use_long=None)
        else:
            if (Q.shape[0]==Q.shape[1]):
                Factor = cholesky(csc_array(Q), beta=eta, mode="auto", ordering_method="default", use_long=None) ## full matrix
            else:
                Factor = cholesky_AAt(csc_array(Q), beta=eta, mode="auto", ordering_method="default", use_long=None) ## only Gram factor
        ldet = Factor.logdet()
        if (np.isnan(ldet)):
            raise CholmodNotPositiveDefiniteError(0)
        val = math.exp(ldet)
    except CholmodNotPositiveDefiniteError:
        val = 0
    return val

## matrix power by finding the SVD of the matrix M=UDV (knowing that V^TV=VV^T=I and U^TU=UU^T~I for real matrix M) 
## and leveraging the equality M^p~(VDV)^p=VD^pV
def power(a, p, n_components=100, maxiter=1000, tol=1e-3):
	if (p==1):
		return a
	if (a.shape[0]==a.shape[1]):
		aa = a.copy()
	else:
		aa = a @ a.T
	k_value = aa.shape[0]-1 if (aa.shape[0]<=1000) else max(min(n_components, aa.shape[0]-1), 10+1)
	_, D, V = svds(np.real(aa), k=k_value, tol=tol, maxiter=maxiter)
	D_d, V_sp = spdiags(D, np.array([0]), (k_value, k_value)), csr_array(V)
	if (a.shape[0]==a.shape[1]):
		return V_sp.T @ D_d @ V_sp
	else:
		return V_sp.T @ D_d.sqrt() ## for diagonal matrix D, D.sqrt() = sqrtm(D) because D.sqrt() @ D.sqrt() = D
    
## Performs the Cholesky decomposition of M.M^T+beta.Id
## then compute the matrix X such that (M.M^T+beta.Id)X = Id
def Cholesky(M, eta=0):
    Factor = cholesky(csc_array(M), beta=eta, mode="auto", ordering_method="default", use_long=None)
    return Factor.L()
    
## Performs the Cholesky decomposition of M.M^T+beta.Id
## then compute the matrix X such that (M.M^T+beta.Id)X = Id
def fast_inverse(M, eta=0):
    Factor = cholesky_AAt(csc_array(M), beta=eta, mode="auto", ordering_method="default", use_long=None)
    M_inv = Factor(csc_array(identity(M.shape[0])))
    return M_inv
    
## Performs the Cholesky decomposition of M+beta.Id
## then compute the matrix X such that (M+beta.Id)X = Id
def fast_inverse_fullmat(M, eta=0):
    Factor = cholesky(csc_array(M), beta=eta, mode="auto", ordering_method="default", use_long=None)
    M_inv = Factor(csc_array(identity(M.shape[0])))
    return M_inv
        
if __name__ == "__main__":
    from time import time
    N = 100000000
    find_ids = list(sorted([5,1000,25436,23,536,78,22543]))
    t=time()
    lst = list(range(N))
    res = [lst[i] for i in find_ids]
    print(f"Enumeration with traditional range {time()-t} seconds")
    t=time()
    res1 = []
    nchunk = 1000
    for i, lst in enumerate(chunks(N,nchunk)):
        for idx in find_ids:
            if (i==idx//nchunk):
                res1.append(list(lst)[idx%nchunk])
    print(f"Enumeration with chunk approach {time()-t} seconds")
    assert all([res1[i]==ii for i, ii in enumerate(res)])
