import sys
import numpy as np
from scipy import optimize, special

"""
Computes the optimal population risk for the attention. Prints (m_vv, m_kk, R_vv, R_kk, risk).
nu : float or "inf" ; signal strength.
Lmax : int ; sequence length.
Lrandom : bool ; if true the sequence lengths are random uniform in [1, Lmax] ; if false the sequence lengths are Lmax.
model : in ["spiked", "max"] ; choice of the model between the spiked-SLR or the max-SLR.
sigmaT : in ["lin", "1+erfB", "softmax", "softplusn"] ; choice of the activation function between linear, erf with constant term and biais, softmax and normalized softplus.
"""

nu = np.inf if sys.argv[1].lower()=="inf" else float(sys.argv[1])
Lmax = int(sys.argv[2])
Lrandom = int(sys.argv[3])
model = sys.argv[4]
sigmaT = sys.argv[5]
assert model in ["spiked", "max"]
assert sigmaT in ["lin", "1+erfB", "softmax", "softplusn"]
learnBiais = sigmaT=="1+erfB"

nMC = int(1e5)
if model=="max" and nu==np.inf:
    nMC *= Lmax
chis = np.random.randn(Lmax, nMC)
if model=="spiked":
    chis[0,:] += np.sqrt(nu)
zs = np.random.randn(Lmax, nMC)
zetas = np.random.randn(Lmax, nMC)
xis = np.random.randn(Lmax, nMC)

def _g0(chis):
    if model=="spiked":
        g0 = 1
    elif model=="max":
        if nu==np.inf:
            g0 = 1*(np.argmax(chis, axis=0)==0)
        else:
            g0 = special.softmax(nu*chis, axis=0)[0,:]
        g0 *= chis.shape[0]
    return g0

eps = 2**-30
def _sigma(b, sigmaT, c):
    if sigmaT=="lin":
        sigma = b+1
    elif sigmaT=="1+erfB":
        sigma = 1+special.erf(b+c)
    elif sigmaT=="softmax":
        sigma = special.softmax(b, axis=0)
    elif sigmaT=="softplusn":
        sigma = special.softplus(b)
        sigma /= eps+np.sum(sigma, axis=0, keepdims=True)
    return sigma

def _err(x):
    mvv, mkk, Rvv, Rkk = x[:4]
    if learnBiais:
        c = x[-1]
    else:
        c = 0
    a = mvv*zs+Rvv*zetas
    b = mkk*chis+Rkk*xis
    if not Lrandom:
        yC = np.sum(a*_sigma(b, sigmaT, c), axis=0)
        err = np.mean(_g0(chis)*(zs[0,:]-yC)**2)
    else:
        err = 0
        for l in range(1,Lmax+1):
            yC = np.sum(a[:l,:]*_sigma(b[:l,:], sigmaT, c), axis=0)
            err += np.mean(_g0(chis[:l,:])*(zs[0,:]-yC)**2)/Lmax
    return err

if learnBiais:
    x = 0, 0, 1, 1, 0
else:
    x = 0, 0, 1, 1
    
optim = optimize.minimize(_err, x)
err = optim.fun
mvv, mkk, Rvv, Rkk = optim.x[:4]

print("{}, {}, {}, {}, {}".format(mvv, mkk, Rvv, Rkk, err))
