import numpy as np
import torch
import torch.nn as nn
from torch import optim
from model.Mamba import Mamba

d = 4
dh = d * d * 5
N = 50
n = 3000
n_ = 100


X = []
X_ = []
Q = []
Q_ = []
Y = []
Y_ = []


def make_vector():
    v = np.random.normal(0, 1, size=d).reshape(1, d)
    return v


for i in range(n):
    w = make_vector().reshape(d, 1)
    x = make_vector()
    e = np.concatenate((x, x.dot(w)), 1)
    for j in range(N):
        x = make_vector()
        e_new = np.concatenate((x, x.dot(w)), 1)
        e = np.concatenate((e, e_new), axis=0)
    X.append(e)
    x = make_vector()
    Q.append(x)
    Y.append(x.dot(w))

X = torch.tensor(X, dtype=torch.float32).cuda()
Q = torch.tensor(Q, dtype=torch.float32).cuda()
Y = torch.tensor(Y, dtype=torch.float32).cuda()

for i in range(n_):
    w = make_vector().reshape(d, 1)
    x = make_vector()
    e = np.concatenate((x, x.dot(w)), 1)
    for j in range(N):
        x = make_vector()
        e_new = np.concatenate((x, x.dot(w)), 1)
        e = np.concatenate((e, e_new), axis=0)
    X_.append(e)
    x = make_vector()
    Q_.append(x)
    Y_.append(x.dot(w))

X_ = torch.tensor(X_, dtype=torch.float32).cuda()
Q_ = torch.tensor(Q_, dtype=torch.float32).cuda()
Y_ = torch.tensor(Y_, dtype=torch.float32).cuda()


model = Mamba(d, dh, N).cuda()

optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss().cuda()
EPOCHS = 100

loss = []
loss_ = []

B0 = model.WB.weight.data.cpu().detach().numpy()
C0 = model.WC.weight.data.cpu().detach().numpy()
np.save('tensor/B0.npy', B0)
np.save('tensor/C0.npy', C0)

for epoch in range(1, EPOCHS + 1):
    if epoch % 10 == 0:
        print(epoch)
    model.train()
    optimizer.zero_grad()
    output = model(X, Q)
    training_loss = loss_fn(output, Y)
    training_loss.backward()
    loss.append(training_loss.item())
    optimizer.step()

    model.eval()
    output = model(X_, Q_)
    test_loss = loss_fn(output, Y_)
    loss_.append(test_loss.item())


torch.save(model.state_dict(), 'model/mamba.pth')

B = model.WB.weight.data.cpu().detach().numpy()
C = model.WC.weight.data.cpu().detach().numpy()
bias_B = model.WB.bias.data.cpu().detach().numpy()
np.save('tensor/B.npy', B)
np.save('tensor/C.npy', C)
np.save('tensor/bias_B.npy', bias_B)

