import numpy as np
from GraphL_Utils import MapFineToCoarse, PlotCore, SBMDistProp, MMSBMDistProb, plotScatter
from GraphL_Utils import plot_single_regression, relPlot, df_empty, plot_line, vecNormalize
from Community_utils import MMSBMCore, measuringCore, CoarseningCommunityParams, communityRecovery
from Community_utils import error2Vecs, vecNorm
from ctrl_utils import ctrlEval, fine2coarseCtrl
import pandas as pd
from datetime import datetime
import itertools
import sys
# from numba import jit, cuda
# import graph_tool.all as gt
# from graph_tool import spectral
np.set_printoptions(precision=4)
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 10)



'''
 TODOOOOO!!!
     
 '''

windows = not True
n_array = [5000] # [2000, 3000, 4000, 5000, 6000] # [2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000] # 
# np.multiply([5, 6, 7, 8, 9, 10, 11, 12], 1e6) # [20, 50, 100, 200, 500, 1000] # 
K_array = [5] # [4] # n should be divisible by K
m_array = [20, 50, 75, 100, 150, 200, 400, 600] # [100] # 


xPureSize_array = [1]
regFig = False
errorFig = True
saveLogFile = True
errorType = 'L1_minus1_sumNorm' # 'L1_normSize' # 'L1_Zscore_normSize' # 'L1_relative_normSize' # 
comStructType_array = ['assortative_symmetric'] # ['assortative_symmetric', 'non_symmetric', 'disassortative_symmetric'] # 
comSizeMode_array = ['high'] # ['high','medium','low'] # ['low'] # 

dist_case = 'MMSBM' # 'MMSBM'  'SSBM'
fine_sample_size = 20
networkGenMode = 'synthetic' # 'com-youtube.ungraph' # 'email-Eu-core' # 
genModule = 'networkX' # 'networkX' # 'graph-tool' # 
scaling = 1 # 'row_normalize'
coarsening_flag = True
pruningDegree_array = [0.05]
# 'SPACL_MaoOriginal', 'SPACL_MaoOriginal' , 'CDMVSI_HuangOriginal'
# 'SPACL'
membership_recoveryMethod_array = ['SPACL_MaoOriginal']  # 'SPACL_MaoOriginal', 'SPACL' 

hyperparams_array = list(itertools.product(pruningDegree_array, membership_recoveryMethod_array))

fileName = 'Fig_m_coarseOut_n{}_K{}_m{}_Dirch04_Q{}_pureX{}'.format(n_array[0], K_array[0], m_array[0], \
                                                                      comStructType_array[0], xPureSize_array[0])

if(saveLogFile):
    sys.stdout = open(fileName+'_logFile'+'.log',"w")

def main():
    df = df_empty(['n', 'K', 'm', 'r', 'scaling', 'densityScaling'], \
                  ['int', 'int', 'int', 'int', 'int', 'float'])
    for n in n_array:
        for K in K_array:  
            for comSizeMode in comSizeMode_array:
                DirichletParam_fine = None
                if(comSizeMode=='high'):
                    comSizes = list(np.array(np.array([0.1, 0.15, 0.2, 0.25, 0.3])*n, dtype=int)) if K==5 \
                                    else list(np.array(np.array([0.1, 0.2, 0.3, 0.4])*n, dtype=int)) if K==4 \
                                        else list(np.array(np.array([0.2, 0.3, 0.5])*n, dtype=int)) if K==3 \
                                            else []
                elif(comSizeMode=='medium'):
                    comSizes = list(np.array(np.array([0.15, 0.2, 0.2, 0.2, 0.25])*n, dtype=int)) if K==5 \
                                    else list(np.array(np.array([0.2, 0.2, 0.3, 0.3])*n, dtype=int)) if K==4 \
                                        else list(np.array(np.array([0.3, 0.3, 0.4])*n, dtype=int)) if K==3 \
                                            else []
                elif(comSizeMode=='low'):
                    DirichletParam_fine = [0.01]*K
                    comSizes = None
                if(True):
                    p_array = [0.05] # [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9] # 
                    q_array = [0.01]
                else:
                    p_array = np.multiply([0.1] , (np.log(n)/n) * (n_array[0]/np.log(n_array[0]))) 
                    q_array = np.multiply([0.01] , (np.log(n)/n) * (n_array[0]/np.log(n_array[0])))
                for p in p_array:
                    for q in q_array:  
                        for comStructType in comStructType_array:
                            if(comStructType=='assortative_symmetric'):
                                nonScaledQ_array = [q*np.ones((K, K)) + (p-q)*np.eye(K)]  
                            elif(comStructType=='disassortative_symmetric'):
                                nonScaledQ_array = [p*np.ones((K, K)) + (q-p)*np.eye(K)]
                            elif(comStructType=='non_symmetric'):
                                nonScaledQ_array = [np.array([[p, q, (p+q)/2, (p-q)/2, p/3], 
                                                             [q, p, (p+q)/2, (p-q)/2, p/4], 
                                                             [(p+q)/2, (p+q)/2, p, q, p/5], 
                                                             [(p-q)/2, (p-q)/2, q, p, p/6],
                                                             [p/3, p/4, p/5, p/6, p]])] if K==5 \
                                                             else [np.array([[p, q, (p+q)/2, (p-q)/2], 
                                                             [q, p, (p+q)/2, (p-q)/2], 
                                                             [(p+q)/2, (p+q)/2, p, q], 
                                                             [(p-q)/2, (p-q)/2, q, p]])] if K==4 \
                                                             else [] # np.random.uniform(low=q, high=p, size=(K,K))
                            
    
                            for densityScaling in [1]: # [1, 2, 4, 6, 8, 10, 15, 20, 30, 50]: # 
                                for nonScaledQ in nonScaledQ_array: 
                                    Q = nonScaledQ * densityScaling
                                    df_in = {'n':n, 'K':K, 'scaling':scaling, 'densityScaling':densityScaling, \
                                              'DirichletParam_fine': DirichletParam_fine, 'comSizeMode': comSizeMode, \
                                               'Q':Q, 'p':p, 'q':q, 'A-normalization':'Adim', \
                                               'coarse-normalization':'mean', 'Qmode': comStructType}  
                                    # 'A-normalization':'Adim' or 'comSizes' 
                                    # 'coarse-normalization':'mean' or 'sqrt-mean'
                                    
                                    graphGenMode = MMSBMDistProb(MMSBMCore(df_in)) # SBMDistProp(size=n, df=df_in)  # ,  
                                    fine_sample_counter = 0     
                                    while(fine_sample_counter<fine_sample_size):    
                                        try:
                                            for xPureSize in xPureSize_array:
                                                df_in['xPureSize'] = xPureSize 
                                                
                                                fine_graph, df_in = graphGenMode.sample(size=n, df=df_in, comSizes=comSizes) 
                                                
                                                for hyperparams in hyperparams_array:
                                                    pruningDegree, membership_recoveryMethod = hyperparams
                                                    df_in['membership_recoveryMethod'] = membership_recoveryMethod
                                                    df_in['pruningDegree'] = pruningDegree
                                                    
                                                    for m in m_array:
                                                        df_in['m'] = m
                                                        coverage_array = [4] # [4, 6, 8, 10, 12, 16] # [10] #  [np.int(4*n/n_array[0])] # 
                                                        # [np.int(n_array[0]/K_array[0]), 50, 5] # np.arange(4, 20, 4, dtype=int) # 
                                                        for r in coverage_array:
                                                            df_in['r'] = r
                                    #                         DirichletParam_Coarse_array = [[0.01]*K, [0.1]*K] # [[0.01, 0.05, 0.1, 0.4]] # [np.arange(0.01, 0.4, (0.4-0.02)/(K-1)), [0.4]*K, [0.1]*K, [0.01]*K] # [[0.2]*K] # [[0.4]*K] # [[1, 0.1, 0.01]] # 
                                    #                         for DirichletParam_Coarse in DirichletParam_Coarse_array:
                                                            overlapExtent_array = [0.1]
                                                            # [0.01, 0.1, 0.3, 0.5, 0.8, 1, 5, 10, 15, 20, 30] 
                                                            # [0.01, 0.1, 1, 10] 
                                                            # [0.01, 0.1, 0.2, 0.4, 0.6, 0.8] # 
                                                            # [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4] # 
                                                            # np.arange(0.01, 0.4, 0.1) # [0.01, 0.1, 0.5]  
                                                            for overlapExtent in overlapExtent_array:
                                                                DirichletParam_Coarse = [overlapExtent]*K
                                                                df_in['overlap'] = overlapExtent
                                                                df_in['DirichletParam_coarse'] = DirichletParam_Coarse
                                                                try:
                                                                    for ctrl_mode in ['avg']: # , 'modal'
                                                                        df_in[ctrl_mode+'_fine'] = ctrlEval(fine_graph, df_in, ctrl_mode, alpha_scale=True)
                                                                        
                                                                        measuring_core = measuringCore(df_in, fine_graph)
                                                                        coarse_graph = MapFineToCoarse(fine_graph, measuring_core, df_in)
                                                                        df_in['coarseMembershipMatrix'] = list(np.max(coarse_graph.nodeMembershipMatrix, 1))
                                                                        tmp = fine2coarseCtrl(df_in[ctrl_mode+'_fine'], measuring_core, df_in['coarse-normalization']) 
                                                                        if(np.any(tmp<0)):
                                                                            raise Exception('Negative Ctrlability!')
                                                                        df_in[ctrl_mode+'_fineProjected'] = list(tmp) 
                                                                        
                                                                        df_in['avgDegree'] = df_in['densityScaling']*df_in['n']
                                                                        
                                                                        for in_ctrl_mode in [ctrl_mode, ctrl_mode+'_estComtyBased', ctrl_mode+'_trueComtyBased']:
                                                                            tmp = ctrlEval(coarse_graph, df_in, in_ctrl_mode)
                                                                            if(np.any(tmp<0)):
                                                                                raise Exception('Negative Ctrlability!')
                                                                            df_in[in_ctrl_mode+'_coarse'] = list(tmp) 
                                                                            
                                                                            df_in[in_ctrl_mode+'_fineProj '+errorType+' error'] = \
                                                                                error2Vecs(df_in[in_ctrl_mode+'_coarse'], \
                                                                                           df_in[ctrl_mode+'_fineProjected'] , errorType)
                                                                        df_in['random '+errorType+' error'] = \
                                                                            error2Vecs(np.random.rand(m), \
                                                                                       df_in[ctrl_mode+'_fineProjected'] , errorType)
                        #                                                     df_in['coarse-fineProjected Spearman error'] = \
                        #                                                         error2Vecs(df_in['coarse_'+in_ctrl_mode], \
                    #                                                                          df_in[ctrl_mode+'_fineProjected'], 'spearman')
                                                                except Exception as exc:
                                                                    print(exc)
                                                                    continue 
                                                                df = df.append(df_in, ignore_index=True)
                                                                
                                                                print(df)
                                                                if(True):
                                                                    df.to_csv(fileName+'.csv', index=False)
                                        
                                        except Exception as exc:
                                            print(exc)
                                            continue 
                                        fine_sample_counter += 1    
                    
    
    if(True):
        df.to_csv(fileName+'.csv', index=False)
    return df


if(True):
    df = main()
    
else:
    df = pd.read_csv(fileName+'.csv') # 'out_wrt_measurementSize_for_p.csv' , 'out_wrt_coverage_for_nu.csv'
    
# print(df[['n', 'm', 'r', 'K']].join(df.loc[:, df.columns.str.contains('Recovery Error')]))

if(regFig):
    
    for ctrl_mode in ['avg']: # , 'modal'
#         for in_ctrl_mode in [ctrl_mode, ctrl_mode+'_estComtyBased', ctrl_mode+'_trueComtyBased']:
        value_cols = [ctrl_mode+'_coarse', 'coarse_'+ctrl_mode+'_estComtyBased', 'coarse_'+ctrl_mode+'_trueComtyBased', ctrl_mode+'_fineProjected'] #, ctrl_mode+'_fineProjected'+'_comtyBased']
        rowID = 0 if ctrl_mode=='avg' else 1
        df_fig = pd.DataFrame(data={'nodeID':np.arange(df['m'][rowID])+1,\
                                    'coarse_'+ctrl_mode:vecNormalize(df['coarse_'+ctrl_mode].iloc[rowID]),\
                                     'coarse_'+ctrl_mode + '_estComtyBased':vecNormalize(df['coarse_'+ctrl_mode + '_estComtyBased'].iloc[rowID]),\
                                     'coarse_'+ctrl_mode + '_trueComtyBased':vecNormalize(df['coarse_'+ctrl_mode + '_trueComtyBased'].iloc[rowID]),\
                                      ctrl_mode+'_fineProjected':vecNormalize(df[ctrl_mode+'_fineProjected'].iloc[rowID]) \
#                                        ,ctrl_mode+'_fineProjected'+'_comtyBased':vecNormalize(df[ctrl_mode+'_fineProjected'+'_comtyBased'].iloc[rowID])\
                                      })
        df_melt = df_fig.melt(id_vars=list(set(df_fig.columns).difference(set(value_cols))), value_vars=value_cols, \
                                                            value_name='normalized controllability', var_name='controllability type')
#         print(df_melt)
        
        x_col = 'nodeID' 
        y_col = 'normalized controllability'
        hue_col = 'controllability type'
        title = ctrl_mode + ' controllability'
        figName = fileName + ' ' +  title
        if(False):
            plot_line(df_melt, x_col=x_col, target_cols=[y_col], hue_col=hue_col, plot_core=PlotCore(title='', figName=figName,\
                                       saveFlag=not windows, showFlag=windows, log_scale=False, minX_axis=0, maxX_axis=1))
        else:
            plotScatter(df_melt, x_col=x_col, y_col=y_col, hue_col=hue_col, plot_core=PlotCore(title='', figName=figName,\
                                       saveFlag=not windows, showFlag=windows, log_scale=False, minX_axis=0, maxX_axis=1))






if(saveLogFile):
    sys.stdout.close()
