import sys
import numpy as np

"""
Computes the Bayes-optimal risk at finite sample complexity for the max-SLR at nu=+∞. Prints (m_v, m_k, risk).
alpha : float ; sample complexity.
nu : float ; signal strength.
L : int ; sequence length.
initI : bool ; if true starts from an informed initialization ; if false starts from an uninformed initialization.
"""

alpha = float(sys.argv[1])
nu = float(sys.argv[2])
L = int(sys.argv[3])
initI = int(sys.argv[4])

nMC = int(1e6)
tMax = 300

zetasV = np.random.normal(0, 1, (L, nMC))
zetasK = np.random.normal(0, 1, (L, nMC))
ys = np.random.normal(0, 1, nMC)

def _update(mv, mk):
    mvC = -nu*mk+np.sqrt(nu*mk)*(zetasK[1:,:]-zetasK[0,:])-(ys-np.sqrt(mv)*zetasV[1:,:])**2/2/(1-mv) \
            +(ys*np.sqrt(1-mv)-zetasV[0,:]*np.sqrt(mv))**2/2
    mvC = alpha*np.mean((np.sqrt(1-mv)*ys-np.sqrt(mv)*zetasV[0,:])**2/(1+np.sum(np.exp(mvC), axis=0)))/(1-mv)
    
    mkC = -nu*mk+np.sqrt(nu*mk)*(zetasK[1:,:]-zetasK[0,:])-(ys-np.sqrt(mv)*zetasV[1:,:])**2/2/(1-mv) \
            +(ys*np.sqrt(1-mv)-zetasV[0,:]*np.sqrt(mv))**2/2
    mkC = alpha*nu*np.mean(1/(1+np.sum(np.exp(mkC), axis=0)))

    return mvC/(1+mvC), mkC/(1+mkC)

def _err(mv, mk):
    mse = np.mean(1/(1+np.sum(np.exp(np.sqrt(nu*mk)*(zetasK[1:,:]-zetasK[0,:]-np.sqrt(nu*mk))), axis=0)))
    return 1-mv*mse


if initI:
    mv, mk = 1-1e-4, 1-1e-4
else:
    mv, mk = 0.1, 0.1

for t in range(tMax):
    mvN, mkN = _update(mv, mk)
    mvC, mkC = mv/(1-mv), mk/(1-mk)
        
    diff = max(abs(mv-mvN), abs(mk-mkN))
    mv, mk = mvN, mkN
    
    if t>5 and (diff<1e-8 or mv>1-1e-6):
        break

print("{}, {}, {}".format(mv, mk, _err(mk, mv)))
