# hpc_sgt.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence

class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class GraphTransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        attn_output, _ = self.multihead_attn(x, x, x, attn_mask=attn_mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output)
        x = self.norm2(x)
        return x

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            GELU(),
            nn.Linear(hidden_dim, latent_dim * 2) 
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            GELU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu, logvar = torch.chunk(mu_logvar, 2, dim=-1)
        
        z = self.reparameterize(mu, logvar)
        
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

class HPC_SGT(nn.Module):
    def __init__(self, n, m, num_line_nodes, embed_size=32, num_heads=4, num_layers=2, 
                 num_prototypes_per_class=4, dropout=0.1):
        super().__init__()
        self.n = n
        self.m = m
        self.embed_size = embed_size
        self.num_prototypes_per_class = num_prototypes_per_class

        self.features_a = nn.Parameter(torch.randn((n, embed_size)))
        self.features_b = nn.Parameter(torch.randn((m, embed_size)))
        
        self.link_embed_layer = nn.Linear(embed_size * 2, embed_size)


        self.tme_weights = nn.Parameter(torch.randn(3))
        self.transformer_layers = nn.ModuleList(
            [GraphTransformerLayer(embed_size, num_heads, dropout) for _ in range(num_layers)]
        )

        self.lvo_vae = VAE(input_dim=embed_size, latent_dim=embed_size//2)
        self.prototypes = nn.Parameter(torch.randn(num_prototypes_per_class * 2, embed_size))
        
        self.classifier = nn.Linear(embed_size, 1)

    def forward(self, idx_to_edge, rse_bias, tme_paths):
        edge_indices = list(idx_to_edge.keys())
        edges = [idx_to_edge[i] for i in edge_indices]
        u_indices, v_indices = zip(*edges)
        
        h_u = self.features_a[torch.tensor(u_indices).long()]
        h_v = self.features_b[torch.tensor(v_indices).long()]
        
        h_b = self.link_embed_layer(torch.cat([h_u, h_v], dim=1))
        
        P_p = torch.einsum("ijk,k->ij", tme_paths, self.tme_weights) 
        attn_bias = rse_bias + P_p

        h_l = h_b.unsqueeze(0)
        for layer in self.transformer_layers:
            h_l = layer(h_l, attn_mask=attn_bias)
        h_l = h_l.squeeze(0)
        
        return h_b, h_l

    def loss(self, h_b, h_l, labels, minority_mask, gamma1, gamma2, beta_vae, beta1, beta2, tau):
        minority_h = h_l[minority_mask]
        recon_h, mu, logvar = self.lvo_vae(minority_h)
        recon_loss = F.mse_loss(recon_h, minority_h)
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        lvo_loss = recon_loss + beta_vae * kld_loss

        with torch.no_grad():
            z_syn = torch.randn(minority_h.size(0), self.lvo_vae.decoder[0].in_features, device=h_l.device)
            h_syn = self.lvo_vae.decoder(z_syn)
        
        h_combined = torch.cat([h_l, h_syn], dim=0)
        labels_syn = torch.ones(h_syn.size(0), device=labels.device) * labels[minority_mask][0]
        labels_combined = torch.cat([labels, labels_syn], dim=0)

        class_labels = ((labels_combined + 1) / 2).long()

        distances = torch.cdist(h_combined, self.prototypes)
        
        compact_loss = 0
        for i in range(h_combined.size(0)):
            c_label = class_labels[i]
            in_class_proto_indices = torch.arange(c_label * self.num_prototypes_per_class, (c_label + 1) * self.num_prototypes_per_class)
            dists_in_class = distances[i, in_class_proto_indices]
            closest_dist = torch.min(dists_in_class)
            compact_loss += -torch.log(torch.exp(-closest_dist) / torch.sum(torch.exp(-dists_in_class)))
        compact_loss /= h_combined.size(0)

        pos_protos = self.prototypes[:self.num_prototypes_per_class]
        neg_protos = self.prototypes[self.num_prototypes_per_class:]
        sep_loss = torch.sum(torch.exp(-torch.cdist(pos_protos, neg_protos)))

        proto_loss = 0
        for i in range(h_combined.size(0)):
            c_label = class_labels[i]
            in_class_proto_indices = torch.arange(c_label * self.num_prototypes_per_class, (c_label + 1) * self.num_prototypes_per_class)
            closest_in_class_proto_idx = in_class_proto_indices[torch.argmin(distances[i, in_class_proto_indices])]
            sim_in = F.cosine_similarity(h_combined[i], self.prototypes[closest_in_class_proto_idx], dim=0) / tau
            
            out_class_proto_indices = torch.cat([
                torch.arange(0, c_label * self.num_prototypes_per_class),
                torch.arange((c_label + 1) * self.num_prototypes_per_class, self.prototypes.size(0))
            ])
            sim_out_sum = torch.sum(torch.exp(F.cosine_similarity(h_combined[i].unsqueeze(0), self.prototypes[out_class_proto_indices], dim=1) / tau))
            
            proto_loss += -torch.log(torch.exp(sim_in) / (torch.exp(sim_in) + sim_out_sum))
        proto_loss /= h_combined.size(0)

        mpcl_loss = compact_loss + beta1 * sep_loss + beta2 * proto_loss # [cite: 184]
        
        dist_b = torch.cdist(h_b, self.prototypes)
        dist_l = torch.cdist(h_l, self.prototypes)
        prob_b = F.softmax(-dist_b, dim=1)
        prob_l = F.softmax(-dist_l, dim=1)
    
        kl_b_l = kl_divergence(Normal(prob_b.log(), torch.ones_like(prob_b)), Normal(prob_l.log(), torch.ones_like(prob_l))).sum(dim=1)
        kl_l_b = kl_divergence(Normal(prob_l.log(), torch.ones_like(prob_l)), Normal(prob_b.log(), torch.ones_like(prob_b))).sum(dim=1)
        co_train_loss = (kl_b_l + kl_l_b).mean()

        total_loss = mpcl_loss + gamma1 * lvo_loss + gamma2 * co_train_loss

        logits = self.classifier(h_l).squeeze()
        
        return total_loss, logits