import torch
from torch import nn
import torch.nn.functional as F
from mamba_main.mamba_ssm.modules.mamba_simple import MambaHealth

class MambaHealthModel(nn.Module):
    def __init__(self, vocab_size, ddi_adj, ehr_adj, emb_dim=512, d_state=256, d_conv=4, expand=2, num_layers=1, dropout_prob=0.5, num_heads=8, device=torch.device('cpu:0')):
        super(MambaHealthModel, self).__init__()

        self.device = device

        self.embeddings = nn.ModuleList(
            [nn.Embedding(vocab_size[i], emb_dim) for i in range(2)]
        )
        self.dropout = nn.Dropout(dropout_prob)

        
        self.patient_mamba_layers = nn.ModuleList([
            MambaLayer(
                d_model=emb_dim,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                dropout=dropout_prob,
                num_layers=num_layers,
                ddi_matrix=torch.FloatTensor(ddi_adj).to(device),
                ehr_matrix=torch.FloatTensor(ehr_adj).to(device),
            ) for _ in range(3)
        ])
        
        self.query = nn.Sequential(
            nn.ReLU(),
            nn.Linear(2 * emb_dim, emb_dim)
        )

        self.attention_layer = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, dropout=dropout_prob)

        self.output_layer = nn.Linear(emb_dim, vocab_size[2])

        self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device)
        self.tensor_ehr_adj = torch.FloatTensor(ehr_adj).to(device)  
        
        self.init_weights()

    def init_weights(self):
        """Initialize weights."""
        initrange = 0.1
        for item in self.embeddings:
            item.weight.data.uniform_(-initrange, initrange)

    def forward(self, input):
       
        i1_seq = []
        i2_seq = []
        
        def sum_embedding(embedding):
            return embedding.sum(dim=1).unsqueeze(dim=0)  # (1,1,dim)
        
        for adm in input:
            i1 = sum_embedding(self.dropout(self.embeddings[0](torch.LongTensor(adm[0]).unsqueeze(dim=0).to(self.device))))
            i2 = sum_embedding(self.dropout(self.embeddings[1](torch.LongTensor(adm[1]).unsqueeze(dim=0).to(self.device))))
            i1_seq.append(i1)
            i2_seq.append(i2)
            
        i1_seq = torch.cat(i1_seq, dim=1) 
        i2_seq = torch.cat(i2_seq, dim=1) 

        for mamba_layer in self.patient_mamba_layers:
            i1_seq = mamba_layer(i1_seq)
            i2_seq = mamba_layer(i2_seq)
        
        patient_representations = torch.cat([i1_seq, i2_seq], dim=-1).squeeze(dim=0) 
        query = self.query(patient_representations)[-1:, :] 

        query = query.unsqueeze(0) 

        query, _ = self.attention_layer(query, query, query)

        query_flattened = query.view(query.size(0) * query.size(1), query.size(2)) 

        result = self.output_layer(query_flattened)

        neg_pred_prob = F.sigmoid(result)
        neg_pred_prob = neg_pred_prob.t() * neg_pred_prob  

        batch_neg_ddi = 0.0005 * neg_pred_prob.mul(self.tensor_ddi_adj).sum()
        batch_neg_ehr = 0.0005 * neg_pred_prob.mul(self.tensor_ehr_adj).sum() 

        return result, batch_neg_ddi, batch_neg_ehr


class MambaLayer(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand, dropout, num_layers, ddi_matrix, ehr_matrix):
        super().__init__()
        self.num_layers = num_layers
        self.mamba = MambaHealth(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            ddi_matrix=ddi_matrix,  
            ehr_matrix=ehr_matrix,  
        )
        self.dropout = nn.Dropout(dropout)
        self.LayerNorm = nn.LayerNorm(d_model, eps=1e-12)
        self.ffn = FeedForward(d_model=d_model, inner_size=d_model * 4, dropout=dropout)

    def forward(self, input_tensor):
        hidden_states = self.mamba(input_tensor)
        if self.num_layers == 1:  
            hidden_states = self.LayerNorm(self.dropout(hidden_states))
        else:  
            hidden_states = self.LayerNorm(self.dropout(hidden_states) + input_tensor)
        hidden_states = self.ffn(hidden_states)
        return hidden_states

class FeedForward(nn.Module):
    def __init__(self, d_model, inner_size, dropout=0.2):
        super().__init__()
        self.w_1 = nn.Linear(d_model, inner_size)  
        self.w_2 = nn.Linear(inner_size, d_model)  
        self.activation = nn.GELU()               
        self.dropout = nn.Dropout(dropout)        
        self.LayerNorm = nn.LayerNorm(d_model, eps=1e-12)  

    def forward(self, input_tensor):
        hidden_states = self.w_1(input_tensor)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.w_2(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

