import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import ipdb as pdb
import numpy as np
from mamba_ssm import Mamba





class MambaBlock(nn.Module):
	"Encoder is made up of retention and feed forward (defined below)"

	def __init__(self, d_model, dropout=0.0):
		super(MambaBlock, self).__init__()
		self.d_model = d_model
		
		self.ln_1 = nn.LayerNorm(self.d_model)
		

		self.mamba = Mamba(
					# This module uses roughly 3 * expand * d_model^2 parameters
					d_model=d_model, # Model dimension d_model
					d_state=16,  # SSM state expansion factor
					d_conv=4,    # Local convolution width
					expand=2,    # Block expansion factor
				)

		
		self.drop = nn.Dropout(p=dropout)

		
	
	def forward(self, x):
		'''
			x: [B L H]
			decay: [N L L]
		'''
		
		ln_x = self.ln_1(x)                     # Layernorm

		y = self.mamba(ln_x)                      # Mamba [B L H] -> [B L H] 

		x = x + y                               # Residual connection
		# ln_x2 = self.ln_2(x)                    # Layernorm 
	
		# out = self.ffn(ln_x2)                   # FFN [B L H] -> [B L H]
		# x = x + out                             # Residual connection
		
		return x
		
	




class MambaModel(nn.Module):
	"Stack of N Model Blocks"

	def __init__(self, d_model, n_layers):
		super(MambaModel, self).__init__()
		
		self.n_layers = n_layers
		self.d_model = d_model
		self.layernorm = nn.LayerNorm(d_model)

		self.blocks = nn.ModuleList([MambaBlock(d_model) for _ in range(n_layers)])
		


	def forward(self, x):
		'''
			Pass the input through each layer in turn.
			x: [B L H]
		'''
		for blocks in self.blocks:
			x = blocks(x)
		return self.layernorm(x)


