import sys
import numpy as np

import _AMP_BP as AB

"""
Requires:
numpy, scipy

Use:
python AMP-BP.py alpha lambda rho

rho is the proportion of train nodes ; rho=0 for unsupervised

Output:
alpha, lambda, rho, S-overlap, error, W-overlap, error
"""

N = 10000
alpha = float(sys.argv[1])
M = int(N/alpha)

c = 5
_lambda = float(sys.argv[2])
rho = float(sys.argv[3])

maxIter = 500
varInit = N**-0.5
err = 1e-3

Nexp = 10



d = np.sqrt(c)*_lambda
ci, co = c+d, c-d

overlapS, overlapW = [], []

for n in range(Nexp):
    ws = AB._ws(M)
    F = AB._F(N, M)

    ss = 1*(np.dot(F, ws)>0)
    cs = AB._cs(ci, co)
    L, edgesIn, edgesInT, corrLN = AB.createGraph(ss, ci, co)
    pS = AB._pS(ss, rho)
    
    logChis, chisT, psis, marginals, a, v, goPrev = AB.initPrior(N, L, M, varInit, F)

    oSpr, oWpr = [0]*10, [0]*10

    for t in range(maxIter):        
        logChis, chisT, marginals = AB.stepBP(logChis, marginals, psis, corrLN, edgesIn, edgesInT, cs, pS)
        a, v, psis, goPrev = AB.stepAMP(a, v, chisT[:,0], goPrev, F)
        
        oS, oW = AB.overlapS(marginals, ss, rho), AB.overlapW(a, ws)
        oSpr.append(oS)
        oSpr.pop(0)
        oWpr.append(oW)
        oWpr.pop(0)
        if t>30 and np.std(oSpr)<err and np.std(oWpr)<err:
            break
        
    overlapS.append(np.mean(oSpr))
    overlapW.append(np.mean(oWpr))
    
print("{}, {}, {}, {:.4}, {:.4}, {:.4}, {:.4}".format(alpha, _lambda, rho, np.mean(overlapS), np.std(overlapS), np.mean(overlapW), np.std(overlapW)))
