import numpy as np
import torch
import torch.nn as nn
from torch import optim
from model.Mamba import Mamba
d = 4
dh = d * d * 5
n = 3000
n_ = 1000


def make_vector():
    v = np.random.normal(0, 1, size=d).reshape(1, d)
    return v


def generate_data(N):
    X = []
    X_ = []
    Q = []
    Q_ = []
    Y = []
    Y_ = []
    for i in range(n):
        w = make_vector().reshape(d, 1)
        x = make_vector()
        e = np.concatenate((x, x.dot(w)), 1)
        for j in range(N):
            x = make_vector()
            e_new = np.concatenate((x, x.dot(w)), 1)
            e = np.concatenate((e, e_new), axis=0)
        X.append(e)
        x = make_vector()
        Q.append(x)
        Y.append(x.dot(w))

    X = torch.tensor(X, dtype=torch.float32).cuda()
    Q = torch.tensor(Q, dtype=torch.float32).cuda()
    Y = torch.tensor(Y, dtype=torch.float32).cuda()

    for i in range(n_):
        w = make_vector().reshape(d, 1)
        x = make_vector()
        e = np.concatenate((x, x.dot(w)), 1)
        for j in range(N):
            x = make_vector()
            e_new = np.concatenate((x, x.dot(w)), 1)
            e = np.concatenate((e, e_new), axis=0)
        X_.append(e)
        x = make_vector()
        Q_.append(x)
        Y_.append(x.dot(w))

    X_ = torch.tensor(X_, dtype=torch.float32).cuda()
    Q_ = torch.tensor(Q_, dtype=torch.float32).cuda()
    Y_ = torch.tensor(Y_, dtype=torch.float32).cuda()
    return X, Q, Y, X_, Q_, Y_


def get_test_loss(N):
    model = Mamba(d, dh, N).cuda()
    X, Q, Y, X_, Q_, Y_ = generate_data(N)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss().cuda()
    MIN_EPOCHS = 75
    MAX_EPOCHS = 200

    last_loss = -1.

    for epoch in range(MAX_EPOCHS):
        model.train()
        optimizer.zero_grad()
        output = model(X, Q)
        training_loss = loss_fn(output, Y)
        training_loss.backward()
        optimizer.step()
        if torch.abs(last_loss - training_loss) < 0.0005 and epoch > MIN_EPOCHS:
            break
        last_loss = training_loss

    model.eval()
    output = model(X_, Q_)
    test_loss = loss_fn(output, Y_)
    return test_loss


Ns = np.arange(4, 82, 2)
means = []
stds = []
for i in Ns:
    print(i, '/', str(82))
    loss = []
    for j in range(10):
        loss.append(get_test_loss(i).cpu().detach().numpy() / 2)
    means.append(np.mean(loss))
    stds.append(np.std(loss))

means = np.array(means)
stds = np.array(stds)
np.save('tensor/means.npy', means)
np.save('tensor/stds.npy', stds)