import torch
from torch import nn


class Mamba(nn.Module):
    def __init__(self, d, dh, N):
        super(Mamba, self).__init__()
        self.d = d
        self.dh = dh
        self.N = N
        self.WB = nn.Linear(d + 1, dh)
        self.WC = nn.Linear(d, dh)
        nn.init.normal_(self.WB.weight, mean=0.0, std=1)
        nn.init.zeros_(self.WB.bias)
        nn.init.normal_(self.WC.weight, mean=0.0, std=1)
        nn.init.zeros_(self.WC.bias)

    def forward(self, X, Q):
        B = X.shape[0]
        h = torch.zeros((B, self.dh, 1)).cuda()

        unit_matrices_list = []
        for _ in range(B):
            unit_matrix = torch.eye(self.dh)
            unit_matrices_list.append(unit_matrix)
        I = torch.stack(unit_matrices_list, dim=0).cuda()

        delta = torch.log(torch.tensor(2)) / torch.tensor(self.N)
        A_ = torch.exp(- delta) * I

        for i in range(self.N):

            B_ = - torch.bmm((A_ - I), self.WB(X[:, i]).unsqueeze(2))
            h = torch.matmul(A_, h) + (X[:, i: i+1, self.d: self.d+1] * B_)
        h = torch.matmul(A_, h)
        C = self.WC(Q)
        return torch.matmul(C, h)