import sys
import numpy as np
import _AMP_BP as AB

N = int(sys.argv[1])
alpha = float(sys.argv[2])
P = int(N/alpha)

d = 5

mu2 = float(sys.argv[3])
_lambda = float(sys.argv[4])
rho = float(sys.argv[5])

maxIter = 200
varInit = 10**-3
err = 1e-3
doParameterEstimation = False

Nexp = 10

ci, co = AB.mapParameters(d, _lambda)


def main(u, v, B, edgesIn, edgesInT, corrLN, logpU, isTest):
    """
    Performs one experiment. Takes the observed quantities, runs AMP-BP and return the achieved overlaps and if the algorithm converged well.
    
    u: (N,) array, the 0-1 communities
    v: (P,) array, the centroid
    B: (P,N) array, the observed features
    edgesIn: (L,x) array, as returned by AB.createGraph
    edgesInT: (L,y) array, as returned by AB.createGraph
    corrLN: (L,) array, as returned by AB.createGraph
    logpU: (N,2) array, log-prior on the groups, as returned by AB._logpU
    isTest: ((1-rho)*N,) array, test node indices, as returned by AB._logpU
    
    qU: float, overlap with the true groups, between 0 and 1
    qV: float, overlap with the true centroid, between 0 and 1
    converge: int, 1 if converged, 0 otherwise. At snr<1 the algorithm can fluctuate slightly around qU=0 and the function will output 0 though it is fine
    """
    L = len(edgesIn)
    chi, uC, vC = AB.initPrior(N, P, L, varInit)  # for a different initialization change here
    ovsU = []
    ovsV = []
    converge = 0
    
    for t in range(maxIter):
        chi, uC, vC = AB.step(chi, uC, vC, B, corrLN, edgesIn, edgesInT, mu2**0.5, ci, co, logpU)
        
        ovsU.append(AB.overlapU(uC, u, isTest))
        ovsV.append(AB.overlapV(vC, v))
        #print(ovsU[-1], ovsV[-1])
        
        if t>20 and np.std(ovsU[-10:])<err and np.std(ovsV[-10:])<err:
            converge = 1
            break
    
    return ovsU[-1], ovsV[-1], converge


def main_parameterEstimation(u, v, B, edgesIn, edgesInT, corrLN, logpU, isTest):
    """
    Performs one experiment where the parameters of the cSBM are learned. Takes the observed quantities, runs AMP-BP in alternance with parameter estimation and return the achieved overlaps and estimated parameters.
    
    u: (N,) array, the 0-1 communities
    v: (P,) array, the centroid
    B: (P,N) array, the observed features
    edgesIn: (L,x) array, as returned by AB.createGraph
    edgesInT: (L,y) array, as returned by AB.createGraph
    corrLN: (L,) array, as returned by AB.createGraph
    logpU: (N,2) array, log-prior on the groups, as returned by AB._logpU
    isTest: ((1-rho)*N,) array, test node indices, as returned by AB._logpU
    
    qU: float, overlap with the true groups, between 0 and 1
    qV: float, overlap with the true centroid, between 0 and 1
    ci, co, mu: floats, estimated parameters
    """
    L = len(edgesIn)
    ciEst, coEst, muEst = 2, 1, 1  # another initialization may work better
    cis, cos, mus = [], [], []
    converge = 0
    
    for n in range(maxIter):
        chi, uC, vC = AB.initPrior(N, P, L, varInit)
        ovsU = []
        ovsV = []

        for t in range(maxIter):
            chi, uC, vC = AB.step(chi, uC, vC, B, corrLN, edgesIn, edgesInT, muEst, ciEst, coEst, logpU)

            ovsU.append(AB.overlapU(uC, u, isTest))
            ovsV.append(AB.overlapV(vC, v))

            if t>20 and np.std(ovsU[-10:])<err and np.std(ovsV[-10:])<err:
                break
                
        sigmaV = 1/(1+np.mean(uC**2)*muEst)
        ciEst, coEst, muEst = AB.stepParameterEstimation(chi, uC, vC, sigmaV, B, ciEst, coEst, muEst)
        #print(ovsU[-1], ovsV[-1], ciEst, coEst, muEst)
        
        cis.append(ciEst)
        cos.append(coEst)
        mus.append(muEst)
        
        if n>5 and np.std(cis[-5:])<err and np.std(cos[-5:])<err and np.std(mus[-5:])<err:
            converge = 1
            break
    
    return ovsU[-1], ovsV[-1], ciEst, coEst, muEst, converge



overlapsU = []
overlapsV = []
convergences = []
cisEst, cosEst, musEst = [], [], []

for n in range(Nexp):
    v = AB._v(P)
    u = AB._u(N)
    B = AB._B(v, u, mu2**0.5)
    logpU, isTest = AB._logpU(u, rho)
    L, edgesIn, edgesInT, corrLN = AB.createGraph(u, ci, co)
    
    if not doParameterEstimation:
        ovU, ovV, converge = main(u, v, B, edgesIn, edgesInT, corrLN, logpU, isTest)
        overlapsU.append(ovU)
        overlapsV.append(ovV)
        convergences.append(converge)
    else:
        ovU, ovV, ciEst, coEst, muEst, converge = main_parameterEstimation(u, v, B, edgesIn, edgesInT, corrLN, logpU, isTest)
        overlapsU.append(ovU)
        overlapsV.append(ovV)
        cisEst.append(ciEst)
        cosEst.append(coEst)
        musEst.append(muEst)
        convergences.append(converge)
        
if not doParameterEstimation:
    print("{}, {}, {}, {}, {}, {:.4}, {:.4}, {:.4}, {:.4}, {}".format(N, alpha, mu2, _lambda, rho, np.mean(overlapsU), np.std(overlapsU), np.mean(overlapsV), np.std(overlapsV), np.mean(convergences)))
else:
    print("{}, {}, {}, {}, {}, {:.4}, {:.4}, {:.4}, {:.4}, {:.4}, {:.4}, {:.4}, {:.4}, {:.4}, {:.4}, {}".format(N, alpha, mu2**0.5, _lambda, rho, np.mean(overlapsU), np.std(overlapsU), np.mean(overlapsV), np.std(overlapsV), np.mean(cisEst), np.std(cisEst), np.mean(cosEst), np.std(cosEst), np.mean(musEst), np.std(musEst), np.mean(convergences)))
