import sys
import numpy as np
import torch
from torch import nn, optim

"""
Simulate the attention at finite N and D and compute its train and test risks. Prints (m_vv, m_vk, m_kv, m_kk, q_vv, q_vk, q_kk, V_vv, V_vk, V_kk, train risk, test risk).
D : int ; dimension of the embeddings.
alpha : float ; sample complexity.
nu : float or "inf" ; signal strength.
L : int ; sequence length.
Lrandom : bool ; if true the sequence lengths are random uniform in [1, Lmax] ; if false the sequence lengths are Lmax.
model : in ["spiked", "max"] ; choice of the model between the spiked-SLR or the max-SLR.
rv : float ; regularization on v.
rk : float ; regularization on k.
sigmaT : in ["lin", "softmax"] ; choice of the activation function between linear and softmax.
initI : bool ; if true starts from an informed initialization ; if false starts from an uninformed initialization.
"""

D = int(sys.argv[1])
alpha = float(sys.argv[2])
N = int(alpha*D)
nu = torch.inf if sys.argv[3].lower()=="inf" else float(sys.argv[3])
L = int(sys.argv[4])
Lrandom = int(sys.argv[5])
model = sys.argv[6]
rV = float(sys.argv[7])
rK = float(sys.argv[8])
sigmaT = sys.argv[9]
initI = int(sys.argv[10])
assert model in ["spiked", "max"]
assert sigmaT in ["lin", "softmax"]

computeV = False  #whether to compute V or not ; long if D>3000
tMax = 1000

if sigmaT=="lin":
    sigma = lambda chi: 1+chi
    neutre = -1
elif sigmaT=="softmax":
    sigma = lambda chi: torch.softmax(chi, dim=1)
    neutre = -torch.inf

def instance_spiked():
    v, k = torch.randn(D), torch.randn(D)
    X = torch.randn((N,L,D))
    X[:,0,:] += torch.unsqueeze(k, 0)*np.sqrt(nu/D)
    mask = torch.zeros(N,L,dtype=bool)
    X_test = torch.randn((N,L,D))
    X_test[:,0,:] += torch.unsqueeze(k, 0)*np.sqrt(nu/D)
    mask_test = torch.zeros(N,L,dtype=bool)
    if Lrandom:
        for n in range(N):
            mask[n,(1+torch.randint(L,(1,))):] = True
            mask_test[n,(1+torch.randint(L,(1,))):] = True
    y = X[:,0,:]@v/np.sqrt(D)
    y_test = X_test[:,0,:]@v/np.sqrt(D)
    return v, k, X, X_test, mask, mask_test, y, y_test

def instance_max():
    v, k = torch.randn(D), torch.randn(D)
    X = torch.randn((N,L,D))
    mask = torch.zeros(N,L,dtype=bool)
    X_test = torch.randn((N,L,D))
    mask_test = torch.zeros(N,L,dtype=bool)
    if Lrandom:
        for n in range(N):
            mask[n,(1+torch.randint(L,(1,))):] = True
            mask_test[n,(1+torch.randint(L,(1,))):] = True
    chi = X@k/np.sqrt(D)
    chi[mask] -= torch.inf
    chi_test = X_test@k/np.sqrt(D)
    chi_test[mask] -= torch.inf
    if nu==torch.inf:
        epsilon_test = torch.argmax(chi, dim=1)
        epsilon_test = torch.argmax(chi_test, dim=1)
    else:
        epsilon = torch.multinomial(torch.softmax(nu*chi, dim=1), 1)[:,0]
        epsilon_test = torch.multinomial(torch.softmax(nu*chi_test, dim=1), 1)[:,0]
    y = X[torch.arange(N),epsilon,:]@v/np.sqrt(D)
    y_test = X_test[torch.arange(N),epsilon_test,:]@v/np.sqrt(D)
    return v, k, X, X_test, mask, mask_test, y, y_test
    
class ERM(nn.Module):
    def __init__(self, rV, rK, v=None, k=None):
        super().__init__()
        self.v = nn.Parameter(v.clone() if initI else torch.randn(D))
        self.k = nn.Parameter(k.clone() if initI else torch.randn(D))
        self.rV = rV
        self.rK = rK

    def forward(self, X, mask):
        chi = X@self.k/np.sqrt(D)
        chi[mask] = neutre
        h = torch.sum(sigma(chi)*(X@self.v/np.sqrt(D)), dim=1)
        return h

    def loss(self, h, y):
        return torch.sum((y-h)**2/2)+self.rV*torch.sum(self.v**2)/2+self.rK*torch.sum(self.k**2)/2

def _err(h, y):
    with torch.no_grad():
        return torch.mean((y-h)**2).item()


if model=="spiked":
    v, k, X, X_test, mask, mask_test, y, y_test = instance_spiked()
elif model=="max":
    v, k, X, X_test, mask, mask_test, y, y_test = instance_max()
else:
    pass
    
model = ERM(rV, rK, v, k)
optimiseur = optim.LBFGS(model.parameters(), lr=1, max_iter=20, history_size=10)
losses = []

for n in range(tMax):
    def closure():
        optimiseur.zero_grad()
        h = model(X, mask)
        loss = model.loss(h, y)
        loss.backward()
        return loss
    optimiseur.step(closure)
    
    with torch.no_grad():
        h = model(X, mask)
        losses.append(model.loss(h, y).item())
    if n>10:
        diff = np.std(losses[-10:])/np.abs(np.mean(losses[-10:]))
        if diff<1e-8:
            break
        
with torch.no_grad():
    h = model(X, mask)
    loss = model.loss(h, y)
    err = _err(h, y)
    h_test= model(X_test, mask_test)
    err_test = _err(h_test, y_test)
    mVV = torch.mean(v*model.v)
    mVK = torch.mean(v*model.k)
    mKV = torch.mean(k*model.v)
    mKK = torch.mean(k*model.k)
    qVV = torch.mean(model.v*model.v)
    qVK = torch.mean(model.v*model.k)
    qKK = torch.mean(model.k*model.k)

if computeV:
    def _loss(vk):
        v, k = vk[:D], vk[D:]
        h = torch.sum(sigma(X@k/np.sqrt(D))*(X@v/np.sqrt(D)), dim=1)
        return torch.sum((y-h)**2/2)+rV*torch.sum(v**2)/2+rK*torch.sum(k**2)/2
    hess = torch.autograd.functional.hessian(_loss, inputs=torch.cat((model.v, model.k)))
    hessInv = torch.linalg.inv(hess)
    vVV = torch.trace(hessInv[:D,:D]).item()/D
    vVK = torch.trace(hessInv[D:,:D]).item()/D
    vKK = torch.trace(hessInv[D:,D:]).item()/D
else:
    vVV, vVK, vKK = -1, -1, -1
    
print("{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}".format(mVV, mVK, mKV, mKK, qVV, qVK, qKK,
                                                              vVV, vVK, vKK, err, err_test))
