import sys
import numpy as np
from scipy import optimize

"""
Computes the risk of the attention at finite sample complexity. Prints (m_v, m_k, q_v, q_k, V_v, V_k, risk).
alpha : float ; sample complexity.
nu : float or "inf" ; signal strength.
L : int ; sequence length.
model : in ["spiked", "max"] ; choice of the model between the spiked-SLR or the max-SLR.
rv : float ; regularization on v.
rk : float ; regularization on k.
sigmaT : in ["lin", "softmax"] ; choice of the activation function between linear and softmax.
initI : bool ; if true starts from an informed initialization ; if false starts from an uninformed initialization.
"""

alpha = float(sys.argv[1])
nu = np.inf if sys.argv[2].lower()=="inf" else float(sys.argv[2])
L = int(sys.argv[3])
model = sys.argv[4]
rv, rk = float(sys.argv[5]), float(sys.argv[6])
sigmaT = sys.argv[7]
initI = int(sys.argv[8])
assert sigmaT in ["lin", "softmax"]
assert model in ["spiked", "max"]

nMC = int(1e5)
if model=="max" and nu==np.inf:
    nMC *= L
tMax = 150
tol = 1e-4

def _lin(chi):
    return 1+chi
def _dlin(chi):
    return np.identity(L)
def _ddlin(chi):
    return np.zeros((L,L,L))

def _softmax(chi):
    m = np.max(chi)
    e = np.exp(chi-m)
    return e/np.sum(e)
def _dsoftmax(chi):
    softmax = _softmax(chi)
    return softmax[np.newaxis,:]*(np.identity(L)-softmax[:,np.newaxis])
def _ddsoftmax(chi):
    softmax = _softmax(chi)
    ddsoftmax = softmax[np.newaxis,np.newaxis,:] \
                *(np.identity(L)[:,np.newaxis,:]-softmax[:,np.newaxis,np.newaxis]) \
                *(np.identity(L)[np.newaxis,:,:]-softmax[np.newaxis,:,np.newaxis])
    ddsoftmax -= softmax[np.newaxis,np.newaxis,:]*softmax[np.newaxis,:,np.newaxis] \
                *(np.identity(L)[:,:,np.newaxis]-softmax[:,np.newaxis,np.newaxis])
    return ddsoftmax

if sigmaT=="lin":
    sigma, dsigma, ddsigma = _lin, _dlin, _ddlin
elif sigmaT=="softmax":
    sigma, dsigma, ddsigma = _softmax, _dsoftmax, _ddsoftmax

def _updateVK(mvC, mkC, qvC, qkC, VvC, VkC):
    Vv, Vk = 1/(rv+VvC), 1/(rk+VkC)
    mv, mk = mvC*Vv, mkC*Vk
    qv, qk = (mvC**2+qvC)*Vv**2, (mkC**2+qkC)*Vk**2
    return mv, mk, qv, qk, Vv, Vk

def _psi(zchi, omega, gamma, Vv, Vk, y):
    z = zchi[::2]
    chi = zchi[1::2]
    
    psi = -(y-z@sigma(chi))**2/2
    psi += -np.sum((z-omega)**2)/Vv/2
    psi += -np.sum((chi-gamma)**2)/Vk/2
    return -psi

def _gradPsi(zchi, omega, gamma, Vv, Vk, y):
    z = zchi[::2]
    chi = zchi[1::2]
    preAct = y-z@sigma(chi)
    
    B = np.zeros(2*L)
    B[::2] += sigma(chi)*preAct-(z-omega)/Vv
    B[1::2] += dsigma(chi)@z*preAct-(chi-gamma)/Vk
    return -B

def _hessPsi(zchi, omega, gamma, Vv, Vk, y):
    z = zchi[::2]
    chi = zchi[1::2]
    sigmachi = sigma(chi)
    dsigmachi = dsigma(chi)
    preAct = y-z@sigmachi
    
    A = np.zeros((2*L, 2*L))
    A[::2,::2] += -np.outer(sigmachi, sigmachi)
    A[::2,1::2] += -np.outer(sigmachi, dsigmachi@z)+dsigmachi*preAct
    A[1::2,::2] += -np.outer(dsigmachi@z, sigmachi)+dsigmachi.T*preAct
    A[1::2,1::2] += -np.outer(dsigmachi@z, dsigmachi@z)+ddsigma(chi)@z*preAct
    for l in range(L):
        A[2*l, 2*l] += -1/Vv
        A[2*l+1, 2*l+1] += -1/Vk       
    return -A

def _omegasGammas(mv, mk, qv, qk):
    omegas = np.sqrt(qv)*zetas
    omegas[0,:] = ys*mv+np.sqrt(qv-mv**2)*zetas[0,:]
    if model=="spiked":
        gammas = np.sqrt(qk)*xis
        gammas[0,:] += np.sqrt(nu)*mk
    elif model=="max":
        gammas = mk*chis_+np.sqrt(qk-mk**2)*xis
    return omegas, gammas

def _proxi(omegas, gammas, Vv, Vk):
    zs = np.zeros((L, nMC))
    chis = np.zeros((L, nMC))
    covZs = np.zeros((L, nMC))
    covChis = np.zeros((L, nMC))

    for n in range(nMC):
        if gs[n]==0:
            zchi, Ainv = np.zeros(2*L), np.zeros((2*L, 2*L))
        else:
            zchi = optimize.minimize(_psi, zchisPr[:,n], args=(omegas[:,n], gammas[:,n], Vv, Vk, ys[n]), \
                                           method='Newton-CG', jac=_gradPsi, hess=_hessPsi, tol=tol**2).x
            Ainv = np.linalg.inv(_hessPsi(zchi, omegas[:,n], gammas[:,n], Vv, Vk, ys[n]))
        zchisPr[:,n] = zchi
        for l in range(L):
            zs[l,n] = zchi[2*l]
            chis[l,n] = zchi[2*l+1]
            covZs[l,n] = Ainv[2*l, 2*l]
            covChis[l,n] = Ainv[2*l+1, 2*l+1]
            
    return zs, chis, covZs, covChis
    
def _update(x):
    mv, mk, qv, qk, Vv, Vk = reparam(x)
    omegas, gammas = _omegasGammas(mv, mk, qv, qk)
    zs, chis, covZs, covChis = _proxi(omegas, gammas, Vv, Vk)
    
    mvC = alpha/Vv*np.mean(gs*(ys*zs[0,:]-covZs[0,:]*mv/Vv))
    if model=="spiked":
        mkC = alpha/Vk*(np.sqrt(nu)*np.mean(chis[0,:])-nu*mk)
    elif model=="max":
        mkC = alpha/Vk*np.mean(gs*np.sum(chis_*chis-covChis*mk/Vk, axis=0))
    else:
        pass
    qvC = alpha/Vv**2*np.mean(gs*np.sum((zs-omegas)**2, axis=0))
    qkC = alpha/Vk**2*np.mean(gs*np.sum((chis-gammas)**2, axis=0))
    VvC = alpha*L/Vv-alpha/Vv**2*np.mean(gs*np.sum(covZs, axis=0))
    VkC = alpha*L/Vk-alpha/Vk**2*np.mean(gs*np.sum(covChis, axis=0))

    update_mv, update_mk, update_qv, update_qk, update_Vv, update_Vk = _updateVK(mvC, mkC, qvC, qkC, VvC, VkC)
    x = np.abs(update_mv), np.abs(update_mk), np.abs(update_qv-update_mv**2), np.abs(update_qk-update_mk**2), np.abs(update_Vv), np.abs(update_Vk)
    return np.array(x)
    
def _err(mv, mk, qv, qk, Vv, Vk):
    zs, chis = omegas, gammas
    err = np.mean([gs[n]*(ys[n]-zs[:,n]@sigma(chis[:,n]))**2 for n in range(nMC)])
    return err

def reparam(x):
    mv = np.abs(x[0])
    mk = np.abs(x[1])
    qv = np.abs(x[2])+mv**2
    qk = np.abs(x[3])+mk**2
    Vv = np.abs(x[4])
    Vk = np.abs(x[5])
    return mv, mk, qv, qk, Vv, Vk

def fixed_point(f, x0, xtol, maxiter):
    p0 = x0
    for t in range(maxiter):
        p1 = f(p0)
        p2 = f(p1)
        d = p2-2*p1+p0
        if np.any(d==0):
            p = p2
        else:
            p = p0-(p1-p0)**2/d
        err = np.abs(1-p0/p)
        if np.all(err<xtol):
            break
        p0 = p
    return p, err, t
    
zetas = np.random.normal(0, 1, (L, nMC))
xis = np.random.normal(0, 1, (L, nMC))
ys = np.random.normal(0, 1, nMC)
zchisPr = np.zeros((2*L, nMC))
if model=="spiked":
    gs = np.ones(nMC)
elif model=="max":
    chis_ = np.random.normal(0, 1, (L, nMC))
    if nu==np.inf:
        gs = L*(np.argmax(chis_, axis=0)==0)
    else:
        m = np.max(chis_, axis=0, keepdims=True)
        gs = L*np.exp(nu*(chis_[0,:]-m[0,:]))/np.sum(np.exp(nu*(chis_-m)), axis=0)

if initI:
    eps = 2**-3
    x = np.array([1., 1, 0, 0, eps, eps])
else:
    x = np.array([0., 0, 1, 1, 1, 1])

x, diff, t = fixed_point(_update, x, tol, tMax)
mv, mk, qv, qk, Vv, Vk = reparam(x)
err = _err(mv, mk, qv, qk, Vv, Vk)

print("{}, {}, {}, {}, {}, {}, {}".format(mv, mk, qv, qk, Vv, Vk, err))
