import sys
import numpy as np
from scipy import special, integrate

"""
Computes the Bayes-optimal risk at finite sample complexity for the max-SLR at nu=+∞. Prints (m_v, m_k, risk).
alpha : float ; sample complexity.
L : int ; sequence length.
initI : bool ; if true starts from an informed initialization ; if false starts from an uninformed initialization.
"""

alpha = float(sys.argv[1])
L = int(sys.argv[2])
initI = int(sys.argv[3])

nMC = int(1e4)
tol = 1e-5
tMax = 150

xis, zetas = np.random.randn(nMC,L), np.random.randn(nMC,L)
ys = np.random.randn(nMC)

def logErf(xs):
    """
    Computes log(1/2+erf(x)/2) in a stable manner for small x.
    """
    output = 0.*xs
    mask = xs>-4
    output[mask] = np.log(1/2+special.erf(xs[mask])/2)
    output[~mask] = -xs[~mask]**2-np.log(2)-np.log(np.pi)/2-np.log(np.abs(xs[~mask]))
    return output
    
def _h(e, gamma, R):
    gammaI = list(range(len(gamma)))
    gammaI.remove(e)
    i = integrate.quad(lambda chi: np.exp(-(chi-gamma[e])**2/2/R +\
                                           np.sum(logErf((chi-gamma[gammaI])/np.sqrt(2*R)))), \
                       gamma[e]-10*np.sqrt(R), gamma[e]+10*np.sqrt(R), epsabs=tol, epsrel=tol)[0]
    return i*L/np.sqrt(2*np.pi*R)

def _df_h(e, f, gamma, R):
    gammaI = list(range(len(gamma)))
    gammaI.remove(e)
    if e==f:
        i = integrate.quad(lambda chi: (chi-gamma[e])/R*np.exp(-(chi-gamma[e])**2/2/R +\
                                                               np.sum(logErf((chi-gamma[gammaI])/np.sqrt(2*R)))), \
                           gamma[e]-10*np.sqrt(R), gamma[e]+10*np.sqrt(R), epsabs=tol, epsrel=tol)[0]
        return i*L/np.sqrt(2*np.pi*R)
    else:
        gammaI.remove(f)
        i = integrate.quad(lambda chi: np.exp(-(chi-gamma[e])**2/2/R-(chi-gamma[f])**2/2/R +\
                                               np.sum(logErf((chi-gamma[gammaI])/np.sqrt(2*R)))), \
                           gamma[e]-10*np.sqrt(R), gamma[e]+10*np.sqrt(R), epsabs=tol, epsrel=tol)[0]
        return -i*L/(2*np.pi*R)

def _mC(mk, mv):
    R = 1-mk
    V = 1-mv
    integrMkC, integrMvC = np.zeros(nMC), np.zeros(nMC)
    
    for n in range(nMC):
        gamma = np.sqrt(mk)*xis[n,:]
        omega = np.sqrt(mv)*zetas[n,:]
        hs = np.array([_h(l, gamma, R) for l in range(L)])
        df_h0, df_h1 = _df_h(0, 0, gamma, R), _df_h(1, 0, gamma, R)
        
        n0 = df_h0**2*np.exp(-(np.sqrt(1-mv)*ys[n]-omega[0])**2/2)
        n0 += 2*(L-1)*df_h0*df_h1*np.exp(-(ys[n]-omega[1])**2/2/V)
        d0 = hs[0]*np.exp(-(np.sqrt(1-mv)*ys[n]-omega[0])**2/2)
        d0 += np.sum(hs[1:]*np.exp(-(ys[n]-omega[1:])**2/2/V))
        n1 = (L-1)*df_h1**2*np.exp(-(np.sqrt(1-mv)*ys[n]-omega[1])**2/2)
        if L>2:
            n1 += (L**2-3*L+2)*df_h1*_df_h(2, 0, gamma, R)*np.exp(-(ys[n]-omega[2])**2/2/V)
        d1 = hs[1]*np.exp(-(np.sqrt(1-mv)*ys[n]-omega[1])**2/2)
        d1 += hs[0]*np.exp(-(ys[n]-omega[0])**2/2/V)+np.sum(hs[2:]*np.exp(-(ys[n]-omega[2:])**2/2/V))
        integrMkC[n] = n0/d0+n1/d1
        
        n0 = _h(0, gamma, R)*np.exp(-(np.sqrt(1-mv)*ys[n]-omega[0])**2/2)*(np.sqrt(1-mv)*ys[n]-omega[0])**2/V
        d0 = hs[0]*np.exp(-(np.sqrt(1-mv)*ys[n]-omega[0])**2/2)
        d0 += np.sum(hs[1:]*np.exp(-(ys[n]-omega[1:])**2/2/V))
        integrMvC[n] = n0/d0
        
    return alpha*np.mean(integrMkC), alpha*np.mean(integrMvC)

def _err(mk, mv):
    R = 1-mk
    integr = np.zeros(nMC)
    
    for n in range(nMC):
        gamma = np.sqrt(mk)*xis[n,:]
        integr[n] = _h(0, gamma, R)**2

    return 1-mv*np.mean(integr)/L

if initI:
    mv, mk = 1-1e-2, 1-1e-2
else:
    mv, mk = 0.1, 0.1

for t in range(tMax):
    mkC, mvC = _mC(mk, mv)
    mkN, mvN = mkC/(1+mkC), mvC/(1+mvC)
    mkN = min(mkN, 1-2e-2)
    mvN = min(mvN, 1-1e-6)
        
    diff = max(abs(mv-mvN), abs(mk-mkN))
    mv, mk = mvN, mkN
    
    if t>5 and diff<tol:
        break
    
print("{}, {}, {}".format(mv, mk, _err(mk, mv)))
