import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn

d = 4
dh = d * d * 5
N = 50
n = 1000

B = np.load('tensor/B.npy')
C = np.load('tensor/C.npy')
bias_B = np.load('tensor/bias_B.npy')


class Mamba(nn.Module):
    def __init__(self, d, dh, N):
        super(Mamba, self).__init__()
        self.d = d
        self.dh = dh
        self.N = N
        self.WB = nn.Linear(d + 1, dh)
        self.WCT = nn.Linear(dh, d)

        self.WB.weight.data = torch.tensor(B)
        self.WB.bias.data = torch.tensor(bias_B)
        self.WCT.weight.data = torch.tensor(C.T)
        nn.init.zeros_(self.WCT.bias)

    def forward(self, X, W):
        B = X.shape[0]
        h = torch.zeros((B, self.dh, 1)).cuda()

        unit_matrices_list = []
        for _ in range(B):
            unit_matrix = torch.eye(self.dh)
            unit_matrices_list.append(unit_matrix)
        I = torch.stack(unit_matrices_list, dim=0).cuda()

        delta = torch.log(torch.tensor(2)) / torch.tensor(self.N)
        A_ = torch.exp(- delta) * I

        sims_mean = []
        sims_std = []
        for i in range(self.N):
            B_ = - torch.bmm((A_ - I), self.WB(X[:, i]).unsqueeze(2))
            h = torch.matmul(A_, h) + (X[:, i: i + 1, self.d: self.d + 1] * B_)
            Ch = self.WCT(h.squeeze(2))
            sim = torch.cosine_similarity(Ch, W.squeeze(2), dim=1)
            sim_mean = torch.mean(sim)
            sim_std = torch.std(sim)
            sims_mean.append(sim_mean.cpu().detach().numpy())
            sims_std.append(sim_std.cpu().detach().numpy())
        return np.array(sims_mean), np.array(sims_std)


mamba = Mamba(d, dh, N).cuda()

X = []
W = []
Q = []
Y = []


def make_vector():
    v = np.random.normal(0, 1, size=d).reshape(1, d)
    v = v / np.linalg.norm(v)
    return v


for i in range(n):
    w = make_vector().reshape(d, 1)
    W.append(w)
    x = make_vector()
    e = np.concatenate((x, x.dot(w)), 1)
    for j in range(N):
        x = make_vector()
        e_new = np.concatenate((x, x.dot(w)), 1)
        e = np.concatenate((e, e_new), axis=0)
    X.append(e)
    x = make_vector()
    Q.append(x)
    Y.append(x.dot(w))

X = torch.tensor(X, dtype=torch.float32).cuda()
W = torch.tensor(W, dtype=torch.float32).cuda()
Q = torch.tensor(Q, dtype=torch.float32).cuda()
Y = torch.tensor(Y, dtype=torch.float32).cuda()

x = np.linspace(0, N, N)
mean, std = mamba(X, W)
np.save('tensor/sim_mean.npy', mean)
np.save('tensor/sim_std.npy', std)

plt.plot(x, mean, label='cos')
plt.fill_between(x, mean - std, mean + std, color='blue', alpha=0.3)
plt.legend()
