import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        B, Nq, D = query.size()
        Nk = key.size(1)
        Q = self.q_proj(query).view(B, Nq, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(B, Nk, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, Nk, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Nq, D)
        return self.out_proj(output), attn


class HypergraphAttention(nn.Module):
    """
    This module applies self-attention within each hyperedge to update the incidence
    matrix H_p and compute a hyperedge representation X_E.

    For each hyperedge e, we compute:
      alpha_e(i, i) = softmax_{j in N(e)} ( Q_i · K_j / sqrt(d) )[i]
    and update:
      H_p(i, j) = H(i, j) * alpha_e(i, i)
    where H is the original binary incidence matrix.

    Inputs:
      X: [batch_size, num_nodes, embedding_dim] representing the current disease state S(t)
      T: [batch_size, num_nodes] representing time information for each node
      H: [num_nodes, num_hyperedges] the original hypergraph association matrix (binary 0/1)
      node_mask (optional): [batch_size, num_nodes] where False indicates inactive nodes

    Outputs:
      H_p: [batch_size, num_nodes, num_hyperedges] updated incidence matrix
      X_E: [batch_size, num_hyperedges, embedding_dim] aggregated hyperedge representations
    """
    def __init__(self, embedding_dim):
        super(HypergraphAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.W_Q = nn.Parameter(torch.randn(embedding_dim, embedding_dim) * 0.01)
        self.W_K = nn.Parameter(torch.randn(embedding_dim, embedding_dim) * 0.01)
        self.W_V = nn.Parameter(torch.randn(embedding_dim, embedding_dim) * 0.01)

    def forward(self, X, T, H, node_mask=None):
        batch_size, num_nodes, d = X.shape
        # Ensure H is 2D: [num_nodes, num_hyperedges]
        if H.dim() > 2:
            H = H.squeeze(0)
        num_hyperedges = H.shape[1]
        time_enc = torch.sin(T.unsqueeze(-1).expand(batch_size, num_nodes, d))
        Q = torch.matmul(X + time_enc, self.W_Q)  # [batch_size, num_nodes, d]
        K = torch.matmul(X + time_enc, self.W_K)  # [batch_size, num_nodes, d]
        V = torch.matmul(X + time_enc, self.W_V)  # [batch_size, num_nodes, d]
        scores = torch.bmm(Q, K.transpose(1,2)) / (d ** 0.5)  # [batch_size, num_nodes, num_nodes]

        H_p_list = []
        hyperedge_repr_list = []

        for e in range(num_hyperedges):
            valid_mask = (H[:, e] > 0)  # [num_nodes]
            valid_mask = valid_mask.unsqueeze(0).expand(batch_size, num_nodes)
            if node_mask is not None:
                valid_mask = valid_mask & node_mask.bool()
            valid_mask_exp = valid_mask.unsqueeze(1)  # [batch_size, 1, num_nodes]

            has_any = valid_mask.any(dim=1)  # [B]
            if not has_any.all():
                H_p_e = torch.zeros(batch_size, num_nodes, device=X.device)
                hyperedge_repr = torch.zeros(batch_size, 1, d, device=X.device)

                idx = has_any.nonzero(as_tuple=True)[0]
                if len(idx) > 0:
                    scores_sub = scores[idx]  # [b',N,N]
                    vm_sub = valid_mask[idx]  # [b',N]
                    scores_sub = scores_sub.masked_fill(~vm_sub.unsqueeze(1), -1e9)
                    attn_sub = torch.softmax(scores_sub, dim=2)
                    attn_sub = torch.nan_to_num(attn_sub, nan=0.0)
                    diag_alpha_sub = torch.diagonal(attn_sub, dim1=1, dim2=2)

                    H_col_sub = H[:, e].unsqueeze(0).expand(len(idx), num_nodes).float()
                    H_p_e[idx] = H_col_sub * diag_alpha_sub

                    w = H_p_e[idx].unsqueeze(-1)
                    hyperedge_repr[idx] = (w * V[idx]).sum(dim=1, keepdim=True) / (w.sum(dim=1, keepdim=True) + 1e-6)

                H_p_list.append(H_p_e.unsqueeze(-1))
                hyperedge_repr_list.append(hyperedge_repr)
                continue

            scores_e = scores.masked_fill(~valid_mask_exp, -1e9)
            attn_e = torch.softmax(scores_e, dim=2)
            attn_e = torch.nan_to_num(attn_e, nan=0.0)
            diag_alpha = torch.diagonal(attn_e, dim1=1, dim2=2)

            H_col = H[:, e].unsqueeze(0).expand(batch_size, num_nodes).float()  # [batch_size, num_nodes]
            H_p_e = H_col * diag_alpha  # [batch_size, num_nodes]
            H_p_list.append(H_p_e.unsqueeze(-1))  # [batch_size, num_nodes, 1]

            H_p_e_unsq = H_p_e.unsqueeze(-1)  # [batch_size, num_nodes, 1]
            sum_weights = torch.sum(H_p_e_unsq, dim=1, keepdim=True)  # [batch_size, 1, 1]
            weighted_sum = torch.sum(H_p_e_unsq * V, dim=1, keepdim=True)  # [batch_size, 1, d]
            hyperedge_repr = weighted_sum / (sum_weights + 1e-6)  # [batch_size, 1, d]
            hyperedge_repr_list.append(hyperedge_repr)

        H_p = torch.cat(H_p_list, dim=-1)  # [batch_size, num_nodes, num_hyperedges]
        X_E = torch.cat(hyperedge_repr_list, dim=1)  # [batch_size, num_hyperedges, d]
        return H_p, X_E

class HypergraphCrossAttention(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.front_attn1 = MultiHeadSelfAttention(embedding_dim, num_heads=8)
        self.front_attn2 = MultiHeadSelfAttention(embedding_dim, num_heads=8)
        self.back_attn1 = MultiHeadSelfAttention(embedding_dim, num_heads=8)
        self.back_attn2 = MultiHeadSelfAttention(embedding_dim, num_heads=8)

    def forward(self, X, T, H, node_mask=None):
        B, N, D = X.shape
        E = H.shape[1]

        time_enc_front = torch.sin(T.unsqueeze(-1).expand(B, N, D))
        index_enc = torch.arange(N, device=X.device).float()
        time_enc_back = torch.sin(index_enc.unsqueeze(-1).expand(N, D)).unsqueeze(0).expand(B, N, D)

        H_p_list = []
        X_E_list = []

        for e in range(E):
            valid = (H[:, e] > 0).unsqueeze(0).expand(B, N)
            if node_mask is not None:
                valid = valid & node_mask.bool()
            indices = torch.nonzero(valid[0], as_tuple=True)[0]
            if len(indices) == 0:
                continue
            A = indices[-1]

            front_mask = valid.clone(); front_mask[:, A:] = False
            back_mask = valid.clone(); back_mask[:, :A] = False

            Q_A = X[:, A, :].unsqueeze(1)

            front_input = X + time_enc_front
            front_query = Q_A + time_enc_front[:, A:A+1, :]
            front_context1, attn_front1 = self.front_attn1(front_query, front_input, front_input)
            front_context2, attn_front2 = self.front_attn2(front_context1, front_input, front_input)

            back_input = X + time_enc_back
            back_query = Q_A + time_enc_back[:, A:A+1, :]
            back_context1, attn_back1 = self.back_attn1(back_query, back_input, back_input)
            back_context2, attn_back2 = self.back_attn2(back_context1, back_input, back_input)

            combined = torch.cat([front_context2, back_context2], dim=1)  # [B, 2, D]
            hyperedge_repr = combined.mean(dim=1, keepdim=True)  # [B, 1, D]
            X_E_list.append(hyperedge_repr)
            avg_attn_front = attn_front2.mean(dim=1)  # [B, 1, N]
            avg_attn_back = attn_back2.mean(dim=1)    # [B, 1, N]
            attn_sum = avg_attn_front + avg_attn_back  # [B, 1, N]
            H_col = H[:, e].unsqueeze(0).expand(B, N).float()
            H_p_e = H_col * attn_sum.squeeze(1)  # [B, N]
            H_p_list.append(H_p_e.unsqueeze(-1))

        H_p = torch.cat(H_p_list, dim=-1) if H_p_list else torch.zeros(B, N, E, device=X.device)
        X_E = torch.cat(X_E_list, dim=1) if X_E_list else torch.zeros(B, E, D, device=X.device)
        return H_p, X_E


class NeuralODEFunc(nn.Module):
    def __init__(self, embedding_dim):
        super(NeuralODEFunc, self).__init__()
        self.theta = nn.Parameter(torch.randn(embedding_dim, embedding_dim) * 0.01)
        self.L = None

    def set_laplacian(self, L):
        self.L = L

    def forward(self, t, S):
        if self.L is None:
            raise ValueError("Laplacian matrix L has not been set.")
        S_theta = torch.matmul(S, self.theta)
        dSdt = -torch.bmm(self.L, S_theta)
        return dSdt

def rk4_step(func, s, t, dt):
    k1 = func(t, s)
    k2 = func(t + dt/2, s + dt/2 * k1)
    k3 = func(t + dt/2, s + dt/2 * k2)
    k4 = func(t + dt, s + dt * k3)
    return s + dt/6 * (k1 + 2*k2 + 2*k3 + k4)

class NeuralODE(nn.Module):
    def __init__(self, embedding_dim, n_steps=10):
        super(NeuralODE, self).__init__()
        self.ode_func = NeuralODEFunc(embedding_dim)
        self.n_steps = n_steps

    def forward(self, S, t_span, L):
        self.ode_func.set_laplacian(L)
        t0 = t_span[0] if t_span[0] > 0.0 else 1e-5
        t1 = t_span[1]
        dt = (t1 - t0) / self.n_steps
        s = S
        t = t0
        for _ in range(self.n_steps):
            s = rk4_step(self.ode_func, s, t, dt)
            t += dt
        return torch.cat([S.unsqueeze(1), s.unsqueeze(1)], dim=1)


class TDHNODE(nn.Module):
    def __init__(self, num_biomarkers=21, num_risk_factors=34, hyper_nodes=None, hyper_edges=None, hidden_dim=128, use_node_mask=False):
        super(TDHNODE, self).__init__()
        self.num_biomarkers = num_biomarkers
        self.num_risk_factors = num_risk_factors
        self.hyper_nodes = hyper_nodes if hyper_nodes is not None else num_biomarkers
        self.hyper_edges = hyper_edges if hyper_edges is not None else 13
        self.hidden_dim = hidden_dim
        self.use_node_mask = use_node_mask

        self.S0 = nn.Parameter(torch.randn(self.hyper_nodes, hidden_dim) * 0.01)
        self.W_C = nn.Linear(num_risk_factors, hidden_dim)
        self.hyper_attn = HypergraphCrossAttention(hidden_dim)
        self.H_proj = nn.Linear(self.hyper_edges, self.hyper_nodes)
        self.neural_ode = NeuralODE(hidden_dim, n_steps=10)
        self.classifier = nn.Linear(hidden_dim, 1)

    def forward(self, times, train_biomarkers, risk_factor, H, mask):
        batch_size = train_biomarkers.size(0)
        seq_len = train_biomarkers.size(1)
        device = train_biomarkers.device

        presence = train_biomarkers[:, 0, :]
        S0_expanded = self.S0.unsqueeze(0).expand(batch_size, -1, -1)
        S_t = presence.unsqueeze(-1) * S0_expanded
        risk0 = self.W_C(risk_factor[:, 0, :])
        S_t = S_t + risk0.unsqueeze(1)

        predictions = []

        for t in range(seq_len):
            t_i = times[:, t]
            t_ip1 = times[:, t+1]
            delta_t = torch.mean(t_ip1 - t_i).item()
            delta_t = float(np.clip(delta_t, 1e-3, 0.5))
            time_span = [0.0, delta_t]

            T_node = t_i.unsqueeze(1).expand(batch_size, self.hyper_nodes)
            node_mask = (train_biomarkers[:, t, :] > 0.5) if self.use_node_mask else None

            H_p, X_E = self.hyper_attn(S_t, T_node, H, node_mask=node_mask)
            if H_p.dim() == 2:
                H_p = H_p.unsqueeze(0)
            if X_E.dim() == 2:
                X_E = X_E.unsqueeze(0)

            G = F.normalize(X_E, dim=-1)  # [B,E,d]
            W_p = torch.bmm(G, G.transpose(1, 2))  # [-1,1]

            eps = 1e-3
            D_e = H_p.sum(dim=1) + eps
            D_e_inv = torch.clamp(1.0 / D_e, max=1e3)
            D_e_inv = torch.diag_embed(D_e_inv)

            D_v = H_p.sum(dim=2) + eps
            D_v_inv_sqrt = torch.clamp(1.0 / torch.sqrt(D_v), max=1e3)
            D_v_inv_sqrt = torch.diag_embed(D_v_inv_sqrt)


            Hp_Wp = torch.bmm(H_p, torch.bmm(W_p, D_e_inv))
            middle = torch.bmm(Hp_Wp, H_p.transpose(1, 2))
            D_norm = torch.bmm(D_v_inv_sqrt, torch.bmm(middle, D_v_inv_sqrt))

            I = torch.eye(self.hyper_nodes, device=device).unsqueeze(0).expand(batch_size, -1, -1)
            L_new = I - D_norm

            L_new = 0.5 * (L_new + L_new.transpose(1, 2))
            L_new = torch.nan_to_num(L_new, nan=0.0, posinf=1e3, neginf=-1e3)
            L_new = torch.clamp(L_new, -1e3, 1e3)

            ode_out = self.neural_ode(S_t, time_span, L_new)
            S_next = ode_out[:, -1, :, :]

            logits = self.classifier(S_next[:, :self.num_biomarkers, :]).squeeze(-1)
            predictions.append(logits)
            risk_update = self.W_C(risk_factor[:, t+1, :])
            S_t = S_next + risk_update.unsqueeze(1)

        predictions = torch.stack(predictions, dim=1)
        return predictions
