import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import MessagePassing, global_add_pool
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder

class PatientRepresentModel(nn.Module):
    def __init__(self, diag_voc, pro_voc, med_voc, hidden_dim=256, device='cpu', num_heads=4, num_layers=2):
        super(PatientRepresentModel, self).__init__()
        self.device = device
        self.embeddings = nn.ModuleList([
            nn.Embedding(diag_voc, hidden_dim),
            nn.Embedding(pro_voc, hidden_dim),
            nn.Embedding(med_voc, hidden_dim),
        ])
        self.dropout = nn.Dropout(p=0.2)

        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
        self.encoders = nn.ModuleList([
            nn.TransformerEncoder(encoder_layer, num_layers=num_layers),
            nn.TransformerEncoder(encoder_layer, num_layers=num_layers),
            nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        ])

        self.init_weights()

    def init_weights(self):
        """Initialize weights."""
        initrange = 0.1
        for item in self.embeddings:
            item.weight.data.uniform_(-initrange, initrange)

    def forward(self, input):
        diag_seq = []
        pro_seq = []
        med_seq = []
        for idx, adm in enumerate(input):
            diag = adm[0]
            pro = adm[1]
            med = adm[2]
            diag = torch.LongTensor(diag).unsqueeze(dim=0).to(self.device)
            pro = torch.LongTensor(pro).unsqueeze(dim=0).to(self.device)
            med = torch.LongTensor(med).unsqueeze(dim=0).to(self.device)
            diag_embed = self.dropout(self.embeddings[0](diag))  # (batch, num_diag, dim)
            pro_embed = self.dropout(self.embeddings[1](pro))
            med_embed = self.dropout(self.embeddings[2](med))
            
            # pooling for every visit to get the representation of each visit
            diag_seq.append(diag_embed.sum(dim=1).unsqueeze(dim=0))  # (batch, seq, dim)
            pro_seq.append(pro_embed.sum(dim=1).unsqueeze(dim=0))
            med_seq.append(med_embed.sum(dim=1).unsqueeze(dim=0))

        diag_seq = torch.cat(diag_seq, dim=1)  # (batch, num_visit, dim)
        pro_seq = torch.cat(pro_seq, dim=1)
        med_seq = torch.cat(med_seq, dim=1)

        batch_size, num_visits, _ = med_seq.shape
        med_mask = torch.zeros(batch_size, num_visits, dtype=torch.bool, device=self.device)
        med_mask[:, -1] = True  

        # use transformer to encode the sequence
        diag_encoded = self.encoders[0](diag_seq)
        pro_encoded = self.encoders[1](pro_seq)
        med_encoded = self.encoders[2](med_seq, src_key_padding_mask=med_mask)

        # get the representation of the patient and the medication
        patient_representations = torch.cat([diag_encoded, pro_encoded], dim=-1)
        patient_status = patient_representations.squeeze(dim=0)  # (num_visit, 2 * dim)
        med_status = med_encoded.squeeze(dim=0)  # (num_visit, dim)

        return patient_status, med_status

class GINConv(MessagePassing):
    def __init__(self, nn, edge_dim, aggr="add"):
        super(GINConv, self).__init__(aggr=aggr)
        self.nn = nn  # MLP
        self.eps = torch.nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(edge_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_emb = self.bond_encoder(edge_attr)
        out = self.nn((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_emb))
        return out

    def message(self, x_j, edge_attr):
        # add edge features to node features
        return F.relu(x_j + edge_attr)

    def update(self, inputs):
        return inputs


# GNN encoder for molecule graph
class MoleculeEncoder(nn.Module):
    def __init__(self, node_features=9, edge_features=3, mol_dim=256, device='cpu'):
        super(MoleculeEncoder, self).__init__()
        self.node_encoder = AtomEncoder(mol_dim)
        self.conv = GINConv(
            nn.Sequential(
                nn.Linear(mol_dim, 2 * mol_dim),
                nn.BatchNorm1d(2 * mol_dim),
                nn.ReLU(),
                nn.Linear(2 * mol_dim, mol_dim),
                nn.ReLU()
            ),
            edge_dim=mol_dim
        ).to(device)
        self.device = device

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x, edge_index, edge_attr, batch = x.to(self.device), edge_index.to(self.device), edge_attr.to(
            self.device), batch.to(self.device)

        x = self.conv(self.node_encoder(x), edge_index, edge_attr)
        x = F.relu(x)

        x = self.conv(x, edge_index, edge_attr)
        x = F.relu(x)

        x = self.conv(x, edge_index, edge_attr)
        x = F.relu(x)

        # golbal pooling
        x = global_add_pool(x, batch)
        return x


class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super(SelfAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.layer_norm = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.ReLU(),
            nn.Linear(dim * 2, dim)
        )
        self.layer_norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        attn_output, _ = self.multihead_attn(x, x, x)
        x = self.layer_norm(x + attn_output)
        ff_output = self.ff(x)
        x = self.layer_norm2(x + ff_output)
        return x


class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.layer_norm = nn.LayerNorm(dim)
        self.W_Q = nn.Linear(dim, dim)  
        self.W_K = nn.Linear(dim, dim)  
        self.W_V = nn.Linear(dim, dim)  

    def forward(self, q, k, v):
        Q = self.W_Q(q)  
        K = self.W_K(k)
        V = self.W_V(v)
        attn_output, attn_weights = self.multihead_attn(Q, K, V)  # (batch_size, num_elements, dim)
        return attn_output, attn_weights



class CVAE(nn.Module):
    def __init__(self, input_dim, cond_dim, latent_dim=32):
        super(CVAE, self).__init__()
        
        self.fc1 = nn.Linear(input_dim + cond_dim, 512)  
        self.fc2 = nn.Linear(512, 256)
        self.fc_mean = nn.Linear(256, latent_dim)  
        self.fc_logvar = nn.Linear(256, latent_dim)  

        self.fc3 = nn.Linear(latent_dim + cond_dim, 256)
        self.fc4 = nn.Linear(256, 512)
        self.fc5 = nn.Linear(512, input_dim)  

    def encode(self, x, c):
        h = torch.cat([x, c], dim=1)  
        h = torch.relu(self.fc1(h))
        h = torch.relu(self.fc2(h))
        mean = self.fc_mean(h) 
        logvar = self.fc_logvar(h)  
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5*logvar)  
        eps = torch.randn_like(std)  
        return mean + eps * std  

    def decode(self, z, c):
        h = torch.cat([z, c], dim=1) 
        h = torch.relu(self.fc3(h))
        h = torch.relu(self.fc4(h))
        return torch.sigmoid(self.fc5(h))  

    def forward(self, x, c):
        mean, logvar = self.encode(x, c)
        z = self.reparameterize(mean, logvar)
        return self.decode(z, c), mean, logvar


class myModel(nn.Module):
    def __init__(self, diag_voc_size, pro_voc_size, med_voc_size, med_voc, graph_data_dict, ddi_graph, node_features=9,
                 edge_features=3, patient_dim=256,
                 mol_dim=256, latent_dim=32, device='cpu'):
        """
        :param diag_voc_size: size of diagnosis vocabulary
        :param pro_voc_size: size of procedure vocabulary
        :param med_voc_size: size of medication vocabulary
        :param med_voc: medication vocabulary
        :param graph_data_dict: dict { atc4: molecule data in torch_geometric.data.Data format }
        :param ddi_graph: ddi matrix
        :param node_features: molecule graph node feature dim
        :param edge_features: molecule graph edge feature dim
        :param patient_dim: patient embedding dim
        :param mol_dim: molecule embedding dim
        """
        super(myModel, self).__init__()
        # patient model to encode patient status and medication status
        self.patient_model = PatientRepresentModel(diag_voc=diag_voc_size, pro_voc=pro_voc_size, med_voc=med_voc_size,
                                                   hidden_dim=patient_dim, num_heads=8, num_layers=4,
                                                   device=device).to(device)

        # cross attention
        self.CA = CrossAttention(dim=mol_dim, num_heads=4).to(device)

        self.molecule_encoder = MoleculeEncoder(node_features, edge_features, mol_dim, device=device).to(device)
        self.CVAE = CVAE(input_dim=mol_dim, cond_dim=mol_dim, latent_dim=32).to(device)
        self.atc4_embedding_layer = nn.Embedding(med_voc_size, mol_dim)

        # projector to map patient representation to molecule representation space
        self.patient_projector = nn.Sequential(nn.Linear(2 * patient_dim, patient_dim), nn.ReLU(), nn.Linear(patient_dim, mol_dim))

        # projector to map molecule representation to medication space
        self.med_projector = nn.Sequential(nn.Linear(patient_dim, med_voc_size))

        # LayerNorm
        self.patient_layernorm = nn.LayerNorm(mol_dim)
        self.atc4_layernorm = nn.LayerNorm(mol_dim)
        self.pred_layernorm = nn.LayerNorm(med_voc_size)

        # attr
        self.med_voc = med_voc
        self.graph_data_dict = graph_data_dict
        self.mol_dim = mol_dim
        self.latent_dim = latent_dim
        self.device = device
        self.atc_emb = self._build_atc4_embedding(self.graph_data_dict)
        self.ddi_graph = torch.tensor(ddi_graph, dtype=torch.float, device=self.device)

    def _build_atc4_embedding(self, graph_data_dict):
        """
        use ATC4 and SMILES mapping to generate embedding, one atc4 code may map to several atc5 embeddings
        """
        atc4_emb = {}
        for atc4, graph_data in graph_data_dict.items():
            with torch.no_grad():  # if you want to train GNN encoder then remove this line
                graph_emb = self.molecule_encoder(graph_data)  # (num_atc5, emb)
            atc4_emb[atc4] = graph_emb
        return atc4_emb

    def _build_atc4_matrix(self):
        #  Build a matrix of ATC4 embeddings aligned with med_voc indices.
        vocab_size = len(self.med_voc.word2idx)
        atc4_matrix = torch.zeros((vocab_size, self.mol_dim)).to(self.device)
        tensors = []
        indices = []
        build_loss = 0
        #atc_emb =self._build_atc4_embedding(self.graph_data_dict) 
        atc_emb = self.atc_emb
        for atc4, idx in self.med_voc.word2idx.items():
            if atc4 in atc_emb:
                tensors.append(atc_emb[atc4])
                indices.append(idx)

        if tensors:
            processed_embs, loss, z, y = self.generate_atc4_embeddings(tensors, indices) # (num_atc4, mol_dim)
            build_loss = loss
            atc4_matrix[indices] = processed_embs
        return atc4_matrix.to(self.device), build_loss, z, y

    def generate_atc4_embeddings(self, stacked_tensors, indices):
        total_loss = 0
        
        # Stack tensors for batch processing
        y = []
        x = []
        all_c = []

        for idx, label in enumerate(indices):
            num_classes = len(self.med_voc.word2idx)
            label_tensor = torch.tensor([label]).to(self.device)
            atc4_embedding = self.atc4_embedding_layer(label_tensor)
            num_samples_for_label = stacked_tensors[idx].shape[0]
            atc4_embedding_batch = atc4_embedding.repeat(num_samples_for_label, 1)
            all_c.append(atc4_embedding.squeeze(0))
            
            y.append(atc4_embedding_batch)
            x.append(stacked_tensors[idx])
        
        # Stack all tensors and labels for batch processing
        y = torch.cat(y, dim=0)  # (total_num_atc5, num_atc4)
        x = torch.cat(x, dim=0)  # (total_num_atc5, mol_dim)

        # Now process all tensors at once
        mean, logvar = self.CVAE.encode(x, y)  # (total_num_atc5, num_latent)
        
        z = self.CVAE.reparameterize(mean, logvar)  # (total_num_atc5, num_latent)
        
        reconstructed_atc4_embeddings = self.CVAE.decode(z, y)  # (total_num_atc5, mol_dim)

        # Calculate losses
        recon_loss = nn.MSELoss()(reconstructed_atc4_embeddings, x)  # Reconstruction loss
        KL_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())  # KL divergence
        total_loss += recon_loss + KL_loss

        sampled_z = torch.randn(len(self.med_voc.word2idx), self.latent_dim).to(self.device)
        # c = torch.nn.functional.one_hot(torch.tensor([label]).to(self.device), num_classes=num_classes).float()
        all_c = torch.stack(all_c, dim=0)
        atc4_samples = self.CVAE.decode(sampled_z, all_c) 

        return atc4_samples, total_loss, z, y

    def forward(self, input):
        patient_status, med_status = self.patient_model(input)
        patient_status.to(self.device)
        med_status.to(self.device)
        e_c = patient_status[-1:, :]

        atc4_emb_matrix, bulid_loss, z, y = self._build_atc4_matrix()
        atc4_emb = self.atc4_layernorm(atc4_emb_matrix)  # (med_voc_size, mol_dim)

        previous_pred_logits = None
        if len(input) > 1:
            med_status_previous = med_status[-2: -1, :] # (batch, patient_dim)
            previous_pred_logits = self.pred_layernorm(self.med_projector(med_status_previous))

        # get prediction based on current status

        e_c = self.patient_projector(e_c).to(self.device)
        e_c = self.patient_layernorm(e_c)  # (seq_size, mol_dim)
        e_r, r = self.CA(e_c, atc4_emb_matrix, atc4_emb_matrix)
        
        logits = self.pred_layernorm(torch.matmul(e_r, atc4_emb.t()))
        if previous_pred_logits is not None:
            logits += previous_pred_logits

        return logits, atc4_emb_matrix, bulid_loss, z, y, r
