import numpy as np
import matplotlib.pyplot as plt
import incomplete_cholesky
#import incomplete_cholesky_cython

def generate_samples_rewards(n, eps, gamma, r):
    # X[0] corresponds to x_1 and will be used for V_1
    # X[n] is only used as next iterate for V_n ~(X[n-1] + gamma X[n])
    # reward of the last iterate is never used
    X = np.zeros(n+1)
    X[0] = np.random.rand() #start on the invariant distribution
    for i in range(1, n+1):
        s = np.random.rand()
        if s < eps:
            X[i] = np.random.rand()
        else:
            X[i] = X[i-1]
    rwds = np.zeros(n)
    for k in range(n):
        rwds[k] = r(X[k])
    return X, rwds


def generate_skipped_samples_rewards(n, tau, eps, gamma, r):
    assert type(tau)==int and tau>=1, "tau must be an integer greater than one"
    X, rwds = generate_samples_rewards(n*tau+1, eps, gamma, r)
    return X[::tau], rwds[::tau][:-1]

    
def kernel_TD(lbda, rho, X, rwds, K, gamma, icd_nmax=0, plot=False):
    # Retrieve the number of iterations
    n = X.shape[0]-1
    # generate the kernel matrix
    Kmat = np.zeros((n+1, n+1))
    if icd_nmax>0 and icd_nmax<n:
        tol = 1e-3 #   /* approximation parameter */ # 1e-2 default
        nmax = icd_nmax #       /* maximal rank */
        print('doing icd with max rank:', nmax)
        if K.__name__ == 'Kb1':
            G, P, m, res = incomplete_cholesky.icd_bern_s1(X, tol, nmax)
        elif K.__name__ == 'Kb2':
            G, P, m, res = incomplete_cholesky.icd_bern_s2(X, tol, nmax)
        Pi = np.argsort(P)
        G = G[Pi,:]
        Kmat = G @ G.T
    else:
        #print('doing no icd')
        # set icd_nmax to 0 or larger than n to avoid icd
        for i in range(n+1):
            for j in range(n+1):
                Kmat[i, j] = K(X[i], X[j])
    if plot:
        plt.plot(np.sort(np.log(abs(np.linalg.eigvals(Kmat))+1e-12)))
    # make the TD iterations
    alpha = np.zeros((n, n))
    for k in range(n):
        if k==0:
            alpha[k, k] = rho * rwds[k] #rho[k] if decreasing step size
        else:
            alpha[k, :k] = (1. - rho*lbda) * alpha[k-1, :k]
            beta_loc = gamma * Kmat[:k, k+1] - Kmat[:k, k]
            alpha[k, k] = rho * ( rwds[k] + np.dot(alpha[k-1, :k], beta_loc) )
    return alpha

def V(x, k, X, alpha, K):
    kmat_loc = np.array([K(x, X[idx]) for idx in range(k+1)])
    return np.dot(alpha[k, :k+1], kmat_loc) 

def exp_avg(n, alpha, rho, lbda):
    beta = np.array([(1.-rho*lbda)**(n-j)  for j in range(n)])
    beta = beta/np.sum(beta)
    alpha_avg = np.zeros(n)
    for j in range(n):
        alpha_avg[j] = np.dot(beta[j+1:n], alpha[j:n-1, j])
    alpha_avg_mat = np.zeros((n, n))
    alpha_avg_mat[n-1, :] = alpha_avg
    return alpha_avg_mat
