import numpy as np
import matlab
import matlab.engine
from Community_utils import communityRecovery, maxKeigenDecomposition
import scipy.sparse.linalg as sla
import scipy.linalg as la
np.set_printoptions(precision=4)


if(True):
    matlab_engine = matlab.engine.start_matlab('-desktop')
    matlab_engine.matlabAddPath(nargout=0)

def avg_ctrl(A):
    if(True):
        out = matlab_engine.ave_control(matlab.double(A.tolist()), nargout=1)
    else:
        u, s, vt = sla.svds(A, k=1)
        print('largest singular value = ', s[0])
        A = A/(1+s[0])     # Matrix normalization 
        T, U = la.schur(A,output='real') # Schur stability
        midMat = (U**2).T
        v = np.diag(T)
        print('np.matmul(v,v.T)', np.diag(1 - np.matmul(v,v.T)))
        P = np.tile(np.diag(1 - np.matmul(v,v.T)), (1, A.shape[0]))
        print('tiled', P)
        print('midMat.shape, P.shape = ', midMat.shape, P.shape)
        out = sum(midMat/P).T
#     print('out matlab avg control', np.array(out).flatten())
    return np.array(out).flatten()


def modal_ctrl(A):
    if(True):
        out = matlab_engine.modal_control(matlab.double(A.tolist()), nargout=1)
    else:
        u, s, vt = sla.svds(A, k=1)
        A = A/(1+s[0])       # Matrix normalization 
        T, U = la.schur(A,output='real')   # Schur stability
        eigVals = np.diag(T)
        N = A.shape[0]
        phi = np.zeros((N,))
        for i in np.arange(N):
            phi[i] = np.matmul(U[i,:]**2, 1 - eigVals**2)
        out = phi
#     print('out matlab modal control', np.array(out).flatten())
    return np.array(out).flatten()

def inCtrlEval(A, ctrl_measure):
    if('avg' in ctrl_measure):
        ctrlout = avg_ctrl(A)
    elif('modal' in ctrl_measure):
        ctrlout = modal_ctrl(A)
    ctrlout = ctrlout.flatten()
    return ctrlout
        

def ctrlEval(graph, df, ctrl_mode, alpha_scale=False):
    """
    input:
        graph
        df
        ctrl_mode: 'avg', 'modal', 'avg_estComtyBased', 'modal_estComtyBased', 'avg_trueComtyBased', 'modal_trueComtyBased'
    """
    A = graph.adjacencyMatrix  
     
    if('ComtyBased' in ctrl_mode):
        if('true' in ctrl_mode):
            membershipMat = graph.nodeMembershipMatrix
#             membershipMat = graph.comNorm_nodeMembershipMatrix
            if(df['A-normalization']=='comSizes'):
                membershipMat = graph.nodeMembershipMatrix/((df['fine community sizes'])[np.newaxis,:])
#             colNorm_membershipMat = true_membershipMat/np.sum(true_membershipMat,0)[np.newaxis,:]
            print('true cross com = ', df['Q'])
            ctrlCom = inCtrlEval(df['Q'], ctrl_mode) # *df['n']/df['K']
            
            print('true community controllability = ', ctrlCom)
            print('true membership matrix = ', membershipMat)
            max_crossCom = np.max(df['Q'])
            crossCom = df['Q']/max_crossCom
        else:
            df, estimated_membershipMat, estimated_crossCom, estimated_pureIdx = communityRecovery(A, df) 
        # ideal
            # max_Est = np.max(df['Q'])
        # estimated from Atilde
            if(False):
                max_Est = np.min([1, np.max(A[estimated_pureIdx[:, None], estimated_pureIdx])])
                print('max A = ', max_Est)
            else:
                max_Est = 1
            max_crossCom = np.max(estimated_crossCom)
            estimated_crossCom = estimated_crossCom * max_Est / max_crossCom
            
            if(df['A-normalization']=='comSizes'):
    #             estimated_membershipMat = estimated_membershipMat/np.sqrt(np.sum(estimated_membershipMat,0))[np.newaxis,:]
            #----------------------------------------------
            # ideal
                # estimated_membershipMat = estimated_membershipMat/((df['fine community sizes'])[np.newaxis,:]) 
            # estimated from Atilde 
                estimated_membershipMat = estimated_membershipMat * (np.diag(A)[estimated_pureIdx])[np.newaxis,:]
                # estimated_membershipMat = estimated_membershipMat * np.diag(A)[:,np.newaxis]
                ctrlCom = inCtrlEval(estimated_crossCom, ctrl_mode) # *df['n']/df['K']
    #             estimated_membershipMat = estimated_membershipMat*np.min(np.where(estimated_membershipMat>0,estimated_membershipMat, np.inf),0)[np.newaxis,:]
                print('estimated community controllability = ', ctrlCom)
            #----------------------------------------------
#             else: 
#                 max_Est = np.max(A) # A[estimated_pureIdx[:, None], estimated_pureIdx]  
#                 estimated_crossCom = estimated_crossCom * max_Est /np.max(estimated_crossCom)
                
            print('estimated cross com = ', estimated_crossCom)
            print('estimated membership matrix = ', estimated_membershipMat)
            crossCom = estimated_crossCom
            membershipMat = estimated_membershipMat
#         ctrlout = np.matmul(membershipMat, ctrlCom) 
        if(df['A-normalization']=='comSizes'):
            ctrlout = np.matmul(membershipMat, ctrlCom - np.ones((df['K'],))) + np.ones((df['m'],))
        elif(df['A-normalization']=='Adim'):
            relComSizeHat = np.sum(membershipMat, 0)[:]
            relComSizeHat = np.diag(relComSizeHat/relComSizeHat.sum())
            print('estimated relComSizeHat = ', relComSizeHat)
            
        # ideal
            # alphaHat = 1/df['n']
        # estimate 
            # alphaHat = 1 #/df['m']
#             vals, _ = sla.eigs(A,k=1)
#             sr = np.real(vals[0])
            if('true' in ctrl_mode):
                sr = 1/df['n']
            else:
                sr = 1 # 1/df['n'] # 1/np.trace(A)
            crossCom = crossCom/np.trace(np.matmul(relComSizeHat,crossCom))
            ctrlout =  np.matmul(membershipMat, \
                        np.diag(np.matmul(np.matmul(crossCom, np.matmul(relComSizeHat,crossCom)), \
                                           np.linalg.inv(np.eye(df['K'])-\
                                                np.linalg.matrix_power(np.matmul(relComSizeHat,crossCom), 2)  ) )))*sr \
                                                + np.ones((df['m'],)) #* df['n'] /df['r']                   /(df['m']**2)
    else:
#         if(alpha_scale):
# #             est_E, _ = maxKeigenDecomposition(A, 1)
# #             A = A/(1+est_E[0])
#             A = A/df['n']
        if(df['A-normalization']=='Adim'):
            if(True):
#                 vals, _ = sla.eigs(A,k=1)
#                 sr = np.real(vals[0])
                sr = np.trace(A)
                print('trace adj = ', sr)
                A = A/(sr) # - np.eye(A.shape[0])
            else:
                if(A.shape[0]==df['n']):
                    crossCom = df['Q']
                else:
                    _, _, crossCom, _ = communityRecovery(A, df)
                A = A/(A.shape[0]*np.max(crossCom))
                
        ctrlout = inCtrlEval(A, ctrl_mode) #*A.shape[0]
#         print('ctrlout.shape=', ctrlout.shape)
#         ctrlout = ctrlout/np.sum(ctrlout)
#         print('ctrlout=', ctrlout)
#     outs.append(ctrlout)
    return ctrlout # outs[0], outs[1]




def fine2coarseCtrl(v, measuring_core, coarse_normalization):
    if(coarse_normalization=='sqrt-mean' or coarse_normalization=='mean'):
        return np.array([v[np.array(sense)].mean() for sense in iter(measuring_core.senseIdx)]) # - (0 if r is None else (r-1))
#     elif(coarse_normalization=='mean'):
#         return np.array([v[np.array(sense)].sum()/(len(sense)**2) for sense in iter(measuring_core.senseIdx)])
    else:
        return []
