"""
edge_self_atten_encoder.py

edge-conditioned self-atten encoder for dense, weighted graphs.

inputs
  X : [B, N, F]  node features
  W : [B, N, N]  weighted adjacency (real-valued, symmetric; self-loops allowed)

outputs (default)
  out.z_g : [B, d_g]       graph embedding (deterministic)
  out.H   : [B, N, d_h]    node embeddings

optional variational
set GraphEncoder(..., variational=True) to also produce (mu_g, logvar_g) and
optionally sample z_g via reparameterization.

dependencies: torch
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


def signed_sqrt(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """element-wise signed square root (GRIT-style)"""
    return torch.sign(x) * torch.sqrt(torch.clamp(x.abs(), min=0.0) + eps)


@dataclass
class GraphEncoderOutput:
    """container for encoder outputs"""
    z_g: torch.Tensor
    H: torch.Tensor
    mu_g: Optional[torch.Tensor] = None
    logvar_g: Optional[torch.Tensor] = None


class GraphTransformerLayer(nn.Module):
    """
    multi-head self-attention with edge-conditioned logits and values.
    updates both node tokens h and edge tokens e.
    """

    def __init__(
        self,
        d_h: int,
        d_e: int,
        num_heads: int,
        d_ff: int,
        d_ff_edge: int,
        dropout_p: float,
    ) -> None:
        super().__init__()
        if d_h % num_heads != 0:
            raise ValueError("d_h must be divisible by num_heads.")
        self.d_h, self.d_e, self.H = int(d_h), int(d_e), int(num_heads)
        self.d_k = self.d_h // self.H
        self.drop_p = float(dropout_p)

        # node projections (per-head split happens in forward)
        self.W_Q = nn.Linear(self.d_h, self.d_h, bias=False)
        self.W_K = nn.Linear(self.d_h, self.d_h, bias=False)
        self.W_V = nn.Linear(self.d_h, self.d_h, bias=False)
        self.W_O = nn.Linear(self.d_h, self.d_h, bias=False)

        # edge projections (to d_h then reshaped into heads)
        self.W_Ew = nn.Linear(self.d_e, self.d_h, bias=False)  
        self.W_Eb = nn.Linear(self.d_e, self.d_h, bias=False)  

        # value injection operates per-head on d_k
        self.W_Ev = nn.Linear(self.d_k, self.d_k, bias=False)

        # edge head mixer: concat(H * d_k) -> d_e
        self.W_Eo = nn.Linear(self.H * self.d_k, self.d_e, bias=False)

        # per-head vector reducing d_k -> scalar attention logit
        self.w_A = nn.Parameter(torch.randn(self.H, self.d_k))

        # FFN + norms (node)
        self.ffn_node = nn.Sequential(
            nn.Linear(self.d_h, d_ff),
            nn.GELU(),
            nn.Dropout(self.drop_p),
            nn.Linear(d_ff, self.d_h),
        )
        self.ln1_node = nn.LayerNorm(self.d_h)
        self.ln2_node = nn.LayerNorm(self.d_h)

        # FFN + norms (edge)
        self.ffn_edge = nn.Sequential(
            nn.Linear(self.d_e, d_ff_edge),
            nn.GELU(),
            nn.Dropout(self.drop_p),
            nn.Linear(d_ff_edge, self.d_e),
        )
        self.ln1_edge = nn.LayerNorm(self.d_e)
        self.ln2_edge = nn.LayerNorm(self.d_e)

    def forward(self, h: torch.Tensor, e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        params
        h : [B, N, d_h]       node states
        e : [B, N, N, d_e]    edge/pair tokens (derived from weights)

        rtns
        h_new : [B, N, d_h]
        e_new : [B, N, N, d_e]
        """
        B, N, _ = h.shape

        # node Q,K,V: [B, H, N, d_k]
        Q = self.W_Q(h).view(B, N, self.H, self.d_k).transpose(1, 2)
        K = self.W_K(h).view(B, N, self.H, self.d_k).transpose(1, 2)
        V = self.W_V(h).view(B, N, self.H, self.d_k).transpose(1, 2)

        # edge projections to per-head: [B, H, N, N, d_k]
        E_w = self.W_Ew(e).view(B, N, N, self.H, self.d_k).permute(0, 3, 1, 2, 4).contiguous()
        E_b = self.W_Eb(e).view(B, N, N, self.H, self.d_k).permute(0, 3, 1, 2, 4).contiguous()

        # edge-conditioned logits
        Q_i = Q.unsqueeze(3)          # [B, H, N, 1, d_k]
        K_j = K.unsqueeze(2)          # [B, H, 1, N, d_k]
        qk = Q_i + K_j                # [B, H, N, N, d_k]
        gated = qk * E_w              # Hadamard
        e_hat = F.gelu(signed_sqrt(gated) + E_b)  # [B, H, N, N, d_k]

        # reduce to scalar logits per head
        logits = (e_hat * self.w_A.view(1, self.H, 1, 1, self.d_k)).sum(-1)  # [B,H,N,N]
        alpha = torch.softmax(logits / math.sqrt(self.d_k), dim=-1)          # [B,H,N,N]
        alpha = F.dropout(alpha, p=self.drop_p, training=self.training)

        # edge-conditioned value injection
        V_j = V.unsqueeze(2)  # [B,H,1,N,d_k] broadcast over i
        e_v = self.W_Ev(e_hat.reshape(-1, self.d_k)).view(B, self.H, N, N, self.d_k)
        value = V_j + e_v     # [B,H,N,N,d_k]

        # weighted aggregation -> [B,H,N,d_k]
        m = (alpha.unsqueeze(-1) * value).sum(dim=3)

        # node update
        m = m.transpose(1, 2).contiguous().view(B, N, self.d_h)
        h = self.ln1_node(h + self.W_O(m))
        h = self.ln2_node(h + self.ffn_node(h))

        # edge update (mix heads back to d_e)
        e_cat = e_hat.permute(0, 2, 3, 1, 4).contiguous().view(B, N, N, self.H * self.d_k)
        m_e = self.W_Eo(e_cat)
        e = self.ln1_edge(e + m_e)
        e = self.ln2_edge(e + self.ffn_edge(e))

        return h, e


class GraphEncoder(nn.Module):
    """
    graph encoder producing a graph-level embedding z_g and node embeddings H

    eeterministic by default. Set `variational=True` to output (mu_g, logvar_g) and
    optionally sample z_g via reparam trick
    """

    def __init__(
        self,
        in_dim: int,
        d_h: int,
        d_e: int,
        num_heads: int,
        d_ff: int,
        d_ff_edge: int,
        num_layers: int,
        d_g: int,
        dropout_p: float,
        *,
        variational: bool = False,
    ) -> None:
        super().__init__()
        self.variational = bool(variational)

        self.input_proj = nn.Linear(int(in_dim), int(d_h))
        self.edge_init = nn.Linear(1, int(d_e))  # scalar w_ij -> d_e

        self.layers: List[nn.Module] = nn.ModuleList(
            [
                GraphTransformerLayer(
                    d_h=int(d_h),
                    d_e=int(d_e),
                    num_heads=int(num_heads),
                    d_ff=int(d_ff),
                    d_ff_edge=int(d_ff_edge),
                    dropout_p=float(dropout_p),
                )
                for _ in range(int(num_layers))
            ]
        )

        # attention pooling
        self.pool_proj = nn.Linear(int(d_h), int(d_h))
        self.pool_context = nn.Parameter(torch.randn(int(d_h)))

        # graph embedding heads
        self.mu_proj = nn.Linear(int(d_h), int(d_g))
        self.logvar_proj = nn.Linear(int(d_h), int(d_g)) if self.variational else None

    def forward(
        self,
        W: torch.Tensor,
        X: torch.Tensor,
        *,
        sample: bool = False,
    ) -> GraphEncoderOutput:
        """
        params
        W : [B, N, N]
        X : [B, N, F]
        sample:
            Only used when `variational=True`. If True, sample z_g; else return mean.

        rtns
        GraphEncoderOutput with:
          z_g : [B, d_g]
          H   : [B, N, d_h]
          mu_g/logvar_g present only if variational=True
        """
        if W.ndim != 3:
            raise ValueError(f"W must be [B,N,N]; got {tuple(W.shape)}")
        if X.ndim != 3:
            raise ValueError(f"X must be [B,N,F]; got {tuple(X.shape)}")
        if W.shape[0] != X.shape[0] or W.shape[1] != X.shape[1] or W.shape[2] != X.shape[1]:
            raise ValueError(f"Shape mismatch: W={tuple(W.shape)}, X={tuple(X.shape)}")

        B, N, _ = X.shape

        # init tokens
        h = self.input_proj(X)                 # [B,N,d_h]
        e = self.edge_init(W.unsqueeze(-1))    # [B,N,N,d_e] (preserve sign)

        # transformer stack
        for layer in self.layers:
            h, e = layer(h, e)

        H = h  # deterministic node embeddings

        # attention pooling: s = sum_i alpha_i * H_i
        scores = torch.tanh(self.pool_proj(H)) @ self.pool_context   # [B,N]
        alpha = torch.softmax(scores, dim=-1)                        # [B,N]
        s = (alpha.unsqueeze(-1) * H).sum(dim=1)                     # [B,d_h]

        # graph embedding
        mu_g = self.mu_proj(s)  # [B,d_g]

        if self.variational:
            assert self.logvar_proj is not None
            logvar_g = self.logvar_proj(s)  # [B,d_g]
            if sample:
                sigma = torch.exp(0.5 * logvar_g)
                eps = torch.randn_like(sigma)
                z_g = mu_g + sigma * eps
            else:
                z_g = mu_g
            return GraphEncoderOutput(z_g=z_g, H=H, mu_g=mu_g, logvar_g=logvar_g)

        # deterministic
        return GraphEncoderOutput(z_g=mu_g, H=H)
