import numpy as np
from pdb import set_trace
from statsmodels.multivariate.cancorr import CanCorr
from math import log, pow
from scipy.stats import chi2
from scipy.linalg import eigh
import time




"""
_summary_
The data are ndarray with zero mean and standard derivation = 1. The shape is n x features.
# Centre the data
# data = data - np.mean(data, axis=0)
# data = data/data.std()
# data = np.array(data) 
"""   
         
         
# when the rank equals to r, return true
def test_rank_equal_expected_phase1(X, Y, r=1, significance_level=0.5, rescale_rank_test=1):
    if r==0:
        reject_null_r, p_value = _test(X, Y, r, significance_level, rescale_rank_test)
        if reject_null_r:
            return False, [p_value]
        else:
            return True, [p_value]
    elif r>=1:
        reject_null_r, p_value_1 = _test(X, Y, r, significance_level, rescale_rank_test)
        reject_null_r_1, p_value_2 = _test(X, Y, r-1, significance_level, rescale_rank_test)
        # print(reject_null_r)
        # print(reject_null_r_1)
        
        if reject_null_r == False and reject_null_r_1 == True:
        # if reject_null_r == False :
            return True, [p_value_1, p_value_2]
        else:
            return False, [p_value_1, p_value_2]
  
        
# when the rank equals to r, return true
def test_rank_equal_expected_phase2(X, Y, r=1, significance_level=0.05, rescale_rank_test=1):
    if r==0:
        reject_null_r, p_value = _test(X, Y, r, significance_level, rescale_rank_test)
        if reject_null_r:
            return False, [p_value]
        else:
            return True, [p_value]
    elif r>=1:
        reject_null_r, p_value_1 = _test(X, Y, r, significance_level, rescale_rank_test)
        reject_null_r_1, p_value_2 = _test(X, Y, r-1, significance_level, rescale_rank_test)
        # print(reject_null_r)
        # print(reject_null_r_1)
        
        # # if (reject_null_r == False and reject_null_r_1 == True) or (p_value_1-p_value_2>0.2 and p_value_2<0.5):
        # # if (reject_null_r == False) or (p_value_1-p_value_2>0.2 and p_value_2<0.5):
        # if reject_null_r == False :
        if (reject_null_r == False and reject_null_r_1 == True):
        
            return True, [p_value_1, p_value_2]
        else:
            return False, [p_value_1, p_value_2]
        






# Test null hypothesis that rank is less than or equal to r
# Return True if reject null
#X and Y are two cloumn slices of ndarray data
def _test(X, Y, r=1, significance_level=0.05, rescale_rank_test=1):
    X = np.array(X)
    Y = np.array(Y)
    p = X.shape[1]
    q = Y.shape[1]
    num = X.shape[0]
    
    cancorr = CanCorr(X, Y, tolerance=1e-8).cancorr
    l = cancorr[r:]
    
    testStat = 0
    eps = 1e-15
    for li in l:
        modified_li = min(li, 1-1e-15)
        testStat += log(1 - pow(modified_li, 2))
    #testStat = testStat * -(self.n - 0.5*(p+q+3))
    testStat = testStat * -(num - 0.5*(p+q+1))

    dfreedom = (p-r) * (q-r)
    criticalValue = chi2.ppf(1-significance_level, dfreedom)

    p_value = 1 - chi2.cdf(testStat, dfreedom)

    #print(f"testStat: {testStat}, crit: {criticalValue}")

    #result = testStat > criticalValue

    result = p_value<significance_level

    return result, p_value
   

def _test_old(X, Y, r=1, significance_level=0.05, rescale_rank_test=1):

    X = np.array(X)
    Y = np.array(Y)
    p = X.shape[1]
    q = Y.shape[1]
    num = X.shape[0]
    
    kernel = [d.T for d in [X,Y]]
    crosscovs_scaled = [np.dot(ki, kj.T) for ki in kernel for kj in kernel]
    # crosscovs = [self.crosscovs[pcols,:][:,pcols], self.crosscovs[pcols,:][:,qcols], self.crosscovs[qcols,:][:,pcols], self.crosscovs[qcols,:][:,qcols]]


    try:
        comps = kcca_modified([X, Y], reg=0.,
            numCC=None, kernelcca=False, ktype='linear',
            gausigma=1.0, degree=2, crosscovs = crosscovs_scaled)

        cancorr, _, _ = recon([X,Y], comps, kernelcca=False)
        cancorr = cancorr[:,0,1]
    except:
        print(f"calculating cancorr error, using another implementation instead")
        cancorr = CanCorr(X, Y, tolerance=1e-8).cancorr


    l = cancorr[r:]
    # print(len(cancorr))
    testStat = 0
    for li in l:
        li = min(li, 1-1e-15)
        #lambda_i = li/(pow(1-li*li,0.5))
        #testStat += log(1+pow(lambda_i,2))
        testStat += log(1)-log(1-li*li)

    ratio = 0
    for i in range(r):
        li = cancorr[i]
        #li = min(li, 1-1e-15)
        #lambda_i = li/(pow(1-li*li,0.5))
        #ratio += pow(lambda_i,-2)
        ratio += 1/(li*li)-1
    ratio += num*rescale_rank_test - r - 0.5*(p+q+1)

    testStat = testStat * ratio
    

    dfreedom = (p-r) * (q-r)
    criticalValue = chi2.ppf(1-significance_level, dfreedom)

    ##############
    p_value = 1 - chi2.cdf(testStat, dfreedom)
    #result = p_value<self.alpha_dict[r] 
    result = testStat > criticalValue

    #if testStat <= criticalValue:
    #    print(pcols, qcols)

    ##############
    # p_cov = self.crosscovs[pcols,:][:,pcols]/self.n
    # #if (abs(p_cov)<self.cov_thres).sum()!=0:
    # #    result = True # not rank deficient


    return result, p_value









#utils used for testing rank hypothesis
def kcca_modified(
        data, reg=0.0, numCC=None, kernelcca=False, ktype="linear", gausigma=1.0, degree=2, crosscovs=None
):
    """Set up and solve the kernel CCA eigenproblem"""
    if kernelcca:
        raise NotImplementedError
        #kernel = [
        #    _make_kernel(d, ktype=ktype, gausigma=gausigma, degree=degree) for d in data
        #]
    else:
        kernel = [d.T for d in data]

    nDs = len(kernel)
    nFs = [k.shape[0] for k in kernel]
    numCC = min([k.shape[0] for k in kernel]) if numCC is None else numCC

    # Get the auto- and cross-covariance matrices
    if crosscovs is None:
        crosscovs = [np.dot(ki, kj.T) for ki in kernel for kj in kernel]

    # Allocate left-hand side (LH) and right-hand side (RH):
    n = sum(nFs)
    LH = np.zeros((n, n))
    RH = np.zeros((n, n))

    # Fill the left and right sides of the eigenvalue problem
    for i in range(nDs):
        RH[
        sum(nFs[:i]): sum(nFs[: i + 1]), sum(nFs[:i]): sum(nFs[: i + 1])
        ] = crosscovs[i * (nDs + 1)] + reg * np.eye(nFs[i])

        for j in range(nDs):
            if i != j:
                LH[
                sum(nFs[:j]): sum(nFs[: j + 1]), sum(nFs[:i]): sum(nFs[: i + 1])
                ] = crosscovs[nDs * j + i]

    LH = (LH + LH.T) / 2.0
    RH = (RH + RH.T) / 2.0

    maxCC = LH.shape[0]
    r, Vs = eigh(LH, RH, eigvals=(maxCC - numCC, maxCC - 1))
    r[np.isnan(r)] = 0
    rindex = np.argsort(r)[::-1]
    comp = []
    Vs = Vs[:, rindex]
    for i in range(nDs):
        comp.append(Vs[sum(nFs[:i]): sum(nFs[: i + 1]), :numCC])
    return comp

def _listdot(d1, d2):
    return [np.dot(x[0].T, x[1]) for x in zip(d1, d2)]

def _listcorr(a):
    """Returns pairwise row correlations for all items in array as a list of matrices"""
    corrs = np.zeros((a[0].shape[1], len(a), len(a)))
    for i in range(len(a)):
        for j in range(len(a)):
            if j > i:
                corrs[:, i, j] = [
                    np.nan_to_num(np.corrcoef(ai, aj)[0, 1])
                    for (ai, aj) in zip(a[i].T, a[j].T)
                ]
    return corrs

def recon(data, comp, corronly=False, kernelcca=False):
    # Get canonical variates and CCs
    if kernelcca:
        ws = _listdot(data, comp)
    else:
        ws = comp
    ccomp = _listdot([d.T for d in data], ws)
    corrs = _listcorr(ccomp)
    if corronly:
        return corrs
    else:
        return corrs, ws, ccomp