"""
spatial_temporal_encoder.py

encoder utils for spatial-temporal modeling

components
- DynamicStateExtractor:
    X_time [B, N, T] -> S_dyn [B, N, d_s]
    GRU over time where each step is an N-dim whole-brain vector, then dense map

- GraphEncoderFusion (fuse at node)
    fuse structural node features S and dynamic node features S_dyn at node-token level
    then run edge-conditioned graph transformer and pool to z_g

- GraphEncoderDual (dual branch)
    spatial branch: transformer on (W, S) -> pooled s_space
    temporal branch: DynamicStateExtractor on X_time -> pooled s_time
    fusion MLP([s_space || s_time]) -> z_g

deterministic by default. set `variational=True` to produce (mu_g, logvar_g)

dependencies: torch
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# Utils

def signed_sqrt(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    """element-wise signed square root (GRIT-style stabilizer)"""
    return torch.sign(x) * torch.sqrt(torch.clamp(x.abs(), min=0.0) + eps)


@dataclass
class SpatialTemporalEncoderOutput:
    """container for encoder outputs"""
    z_g: torch.Tensor
    aux: Optional[object] = None 
    mu_g: Optional[torch.Tensor] = None
    logvar_g: Optional[torch.Tensor] = None


# temporal feature extractor

class DynamicStateExtractor(nn.Module):
    """
    global temporal encoder over per-node time series.

    Input
      X_time : [B, N, T]
        assumed already z-scored per node over time. set do_norm=True if not

    Steps
      1) (optional) per-node z-score over time
      2) treat time as sequence; each step is an N-dim whole-brain vector
         GRU over sequence -> global summary h_g : [B, H_out] (mean or last)
      3) dense mapping h_g -> [B, N, d_s] and LayerNorm over d_s

    output
      S_dyn : [B, N, d_s]
    """

    def __init__(
        self,
        N: int,
        T: int,
        H_g: int,
        d_s: int,
        *,
        use_last: bool = False,
        bidirectional: bool = False,
        do_norm: bool = False,
        dropout_p: float = 0.0,
    ) -> None:
        super().__init__()
        self.N, self.T = int(N), int(T)
        self.H_g, self.d_s = int(H_g), int(d_s)
        self.use_last = bool(use_last)
        self.do_norm = bool(do_norm)
        self.bi = bool(bidirectional)
        self.H_out = 2 * self.H_g if self.bi else self.H_g

        self.encoder = nn.GRU(
            input_size=self.N,
            hidden_size=self.H_g,
            batch_first=True,
            bidirectional=self.bi,
        )
        self.dropout = nn.Dropout(float(dropout_p))

        # Dense map: H_out -> N * d_s
        self.out_linear = nn.Linear(self.H_out, self.N * self.d_s)
        self.out_norm = nn.LayerNorm(self.d_s)

    def forward(self, X_time: torch.Tensor) -> torch.Tensor:
        """
        X_time: [B, N, T]  ->  S_dyn: [B, N, d_s]
        """
        if X_time.ndim != 3:
            raise ValueError(f"X_time must be [B,N,T]; got {tuple(X_time.shape)}")
        B, N, T = X_time.shape
        if N != self.N or T != self.T:
            raise ValueError(f"Mismatched (N,T): expected ({self.N},{self.T}), got ({N},{T})")

        x = X_time
        if self.do_norm:
            mu = x.mean(dim=-1, keepdim=True)
            sd = x.std(dim=-1, keepdim=True).clamp_min(1e-6)
            x = (x - mu) / sd

        # Sequence view: [B, T, N]
        seq = x.transpose(1, 2).contiguous()

        out, h_n = self.encoder(seq)  # out: [B,T,H_out], h_n: [dir,B,H_g]
        if self.use_last:
            if self.bi:
                h = torch.cat([h_n[0], h_n[1]], dim=-1)  # [B, 2H_g]
            else:
                h = h_n[0]                                # [B, H_g]
        else:
            h = out.mean(dim=1)                           # [B, H_out]

        h = self.dropout(h)
        dense = self.out_linear(h)                        # [B, N*d_s]
        S_dyn = dense.view(B, self.N, self.d_s)           # [B, N, d_s]
        S_dyn = self.out_norm(S_dyn)
        return S_dyn


# Graph Transformer block

class GraphTransformerLayer(nn.Module):
    """
    multi-head self-attention over nodes with edge-conditioned logits & values.
    updates both node tokens h and edge/pair tokens e.
    """

    def __init__(
        self,
        d_h: int,
        d_e: int,
        num_heads: int,
        d_ff: int,
        d_ff_edge: int,
        dropout_p: float,
    ) -> None:
        super().__init__()
        if d_h % num_heads != 0:
            raise ValueError("d_h must be divisible by num_heads")
        self.d_h, self.d_e, self.H = int(d_h), int(d_e), int(num_heads)
        self.d_k = self.d_h // self.H
        self.drop_p = float(dropout_p)

        # node projections
        self.W_Q = nn.Linear(self.d_h, self.d_h, bias=False)
        self.W_K = nn.Linear(self.d_h, self.d_h, bias=False)
        self.W_V = nn.Linear(self.d_h, self.d_h, bias=False)
        self.W_O = nn.Linear(self.d_h, self.d_h, bias=False)

        # edge projections
        self.W_Ew = nn.Linear(self.d_e, self.d_h, bias=False)
        self.W_Eb = nn.Linear(self.d_e, self.d_h, bias=False)
        self.W_Ev = nn.Linear(self.d_k, self.d_k, bias=False)

        # mix heads back for edges: concat(H*d_k) -> d_e
        self.W_Eo = nn.Linear(self.H * self.d_k, self.d_e, bias=False)

        # per-head vector to reduce d_k -> scalar
        self.w_A = nn.Parameter(torch.randn(self.H, self.d_k))

        # node FFN + norms
        self.ffn_node = nn.Sequential(
            nn.Linear(self.d_h, int(d_ff)),
            nn.GELU(),
            nn.Dropout(self.drop_p),
            nn.Linear(int(d_ff), self.d_h),
        )
        self.ln1_node = nn.LayerNorm(self.d_h)
        self.ln2_node = nn.LayerNorm(self.d_h)

        # edge FFN + norms
        self.ffn_edge = nn.Sequential(
            nn.Linear(self.d_e, int(d_ff_edge)),
            nn.GELU(),
            nn.Dropout(self.drop_p),
            nn.Linear(int(d_ff_edge), self.d_e),
        )
        self.ln1_edge = nn.LayerNorm(self.d_e)
        self.ln2_edge = nn.LayerNorm(self.d_e)

    def forward(self, h: torch.Tensor, e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        h: [B, N, d_h]
        e: [B, N, N, d_e]
        """
        B, N, _ = h.shape

        # node project -> [B, H, N, d_k]
        Q = self.W_Q(h).view(B, N, self.H, self.d_k).transpose(1, 2)
        K = self.W_K(h).view(B, N, self.H, self.d_k).transpose(1, 2)
        V = self.W_V(h).view(B, N, self.H, self.d_k).transpose(1, 2)

        # edge project -> [B, H, N, N, d_k]
        E_w = self.W_Ew(e).view(B, N, N, self.H, self.d_k).permute(0, 3, 1, 2, 4).contiguous()
        E_b = self.W_Eb(e).view(B, N, N, self.H, self.d_k).permute(0, 3, 1, 2, 4).contiguous()

        # edge-conditioned logits
        qk = Q.unsqueeze(3) + K.unsqueeze(2)               # [B,H,N,N,d_k]
        gated = qk * E_w
        e_hat = F.gelu(signed_sqrt(gated) + E_b)           # [B,H,N,N,d_k]

        logits = (e_hat * self.w_A.view(1, self.H, 1, 1, self.d_k)).sum(-1)  # [B,H,N,N]
        alpha = torch.softmax(logits / math.sqrt(self.d_k), dim=-1)
        alpha = F.dropout(alpha, p=self.drop_p, training=self.training)

        # edge-conditioned value
        V_j = V.unsqueeze(2)                                # [B,H,1,N,d_k]
        e_v = self.W_Ev(e_hat.reshape(-1, self.d_k)).view(B, self.H, N, N, self.d_k)
        value = V_j + e_v

        # aggregation -> [B,H,N,d_k]
        m = (alpha.unsqueeze(-1) * value).sum(dim=3)

        # update node tokens
        m = m.transpose(1, 2).contiguous().view(B, N, self.d_h)
        h = self.ln1_node(h + self.W_O(m))
        h = self.ln2_node(h + self.ffn_node(h))

        # update edge tokens
        e_cat = e_hat.permute(0, 2, 3, 1, 4).contiguous().view(B, N, N, self.H * self.d_k)
        m_e = self.W_Eo(e_cat)
        e = self.ln1_edge(e + m_e)
        e = self.ln2_edge(e + self.ffn_edge(e))

        return h, e


# Encoders
class GraphEncoderFusion(nn.Module):
    """
    node-token level fusion encoder:
      - fuse structural node features S and dynamic node features S_dyn at input tokens
      - run edge-conditioned graph transformer stack
      - attention pool -> graph summary -> z_g

    deterministic by default
    """

    def __init__(
        self,
        in_dim: int,
        d_s: int,
        d_h: int,
        d_e: int,
        num_heads: int,
        d_ff: int,
        d_ff_edge: int,
        num_layers: int,
        d_g: int,
        dropout_p: float,
        *,
        gate_init: float = 0.1,
        variational: bool = False,
    ) -> None:
        super().__init__()
        self.variational = bool(variational)

        self.input_proj = nn.Linear(int(in_dim), int(d_h))
        self.dyn_proj = nn.Linear(int(d_s), int(d_h))
        self.edge_init = nn.Linear(1, int(d_e))

        # scalar gate for dynamic contribution
        self.lambda0 = nn.Parameter(torch.tensor(float(gate_init), dtype=torch.float32))
        self.ln_fuse = nn.LayerNorm(int(d_h))

        self.layers: List[nn.Module] = nn.ModuleList(
            [
                GraphTransformerLayer(int(d_h), int(d_e), int(num_heads), int(d_ff), int(d_ff_edge), float(dropout_p))
                for _ in range(int(num_layers))
            ]
        )

        # attention pooling
        self.pool_proj = nn.Linear(int(d_h), int(d_h))
        self.pool_context = nn.Parameter(torch.randn(int(d_h)))

        # graph embedding heads
        self.mu_proj = nn.Linear(int(d_h), int(d_g))
        self.logvar_proj = nn.Linear(int(d_h), int(d_g)) if self.variational else None

    def forward(
        self,
        W: torch.Tensor,
        S: torch.Tensor,
        S_dyn: torch.Tensor,
        *,
        sample: bool = False,
    ) -> SpatialTemporalEncoderOutput:
        if W.ndim != 3:
            raise ValueError(f"W must be [B,N,N]; got {tuple(W.shape)}")
        if S.ndim != 3 or S_dyn.ndim != 3:
            raise ValueError(f"S and S_dyn must be [B,N,*]; got S={tuple(S.shape)}, S_dyn={tuple(S_dyn.shape)}")
        if W.shape[0] != S.shape[0] or W.shape[1] != S.shape[1] or W.shape[2] != S.shape[1]:
            raise ValueError(f"Shape mismatch: W={tuple(W.shape)}, S={tuple(S.shape)}")
        if S.shape[0] != S_dyn.shape[0] or S.shape[1] != S_dyn.shape[1]:
            raise ValueError(f"Shape mismatch: S={tuple(S.shape)}, S_dyn={tuple(S_dyn.shape)}")

        # init tokens
        h_struct = self.input_proj(S)
        h_dyn = self.dyn_proj(S_dyn)
        h = self.ln_fuse(h_struct + F.softplus(self.lambda0) * h_dyn)
        e = self.edge_init(W.unsqueeze(-1))

        for layer in self.layers:
            h, e = layer(h, e)

        H = h

        # pool
        scores = torch.tanh(self.pool_proj(H)) @ self.pool_context
        alpha = torch.softmax(scores, dim=-1)
        s = (alpha.unsqueeze(-1) * H).sum(dim=1)

        mu_g = self.mu_proj(s)

        if self.variational:
            assert self.logvar_proj is not None
            logvar_g = self.logvar_proj(s)
            if sample:
                sigma = torch.exp(0.5 * logvar_g)
                eps = torch.randn_like(sigma)
                z_g = mu_g + sigma * eps
            else:
                z_g = mu_g
            return SpatialTemporalEncoderOutput(z_g=z_g, mu_g=mu_g, logvar_g=logvar_g, aux=H)

        return SpatialTemporalEncoderOutput(z_g=mu_g, aux=H)


class GraphEncoderDual(nn.Module):
    """
    dual-branch encoder:
      - spatial branch: transformer on (W, S) -> pooled s_space
      - temporal branch: DynamicStateExtractor(X_time) -> pooled s_time
      - fusion: MLP([s_space || s_time]) -> z_g

    Deterministic by default

    forward takes:
      W      : [B, N, N]
      S      : [B, N, in_dim]
      X_time : [B, N, T]
    """

    def __init__(
        self,
        # spatial
        in_dim: int,
        d_h: int,
        d_e: int,
        num_heads: int,
        d_ff: int,
        d_ff_edge: int,
        num_layers: int,
        # temporal
        N: int,
        T: int,
        H_g: int,
        d_s: int,
        # posterior
        d_g: int,
        dropout_p: float,
        *,
        bidirectional: bool = False,
        use_last: bool = False,
        do_norm_time: bool = False,
        variational: bool = False,
    ) -> None:
        super().__init__()
        self.variational = bool(variational)

        # spatial pathway
        self.s_input = nn.Linear(int(in_dim), int(d_h))
        self.edge_init = nn.Linear(1, int(d_e))
        self.s_layers: List[nn.Module] = nn.ModuleList(
            [
                GraphTransformerLayer(int(d_h), int(d_e), int(num_heads), int(d_ff), int(d_ff_edge), float(dropout_p))
                for _ in range(int(num_layers))
            ]
        )
        self.s_pool_proj = nn.Linear(int(d_h), int(d_h))
        self.s_pool_context = nn.Parameter(torch.randn(int(d_h)))

        # temporal pathway
        self.temporal = DynamicStateExtractor(
            N=int(N),
            T=int(T),
            H_g=int(H_g),
            d_s=int(d_s),
            use_last=use_last,
            bidirectional=bidirectional,
            do_norm=do_norm_time,
            dropout_p=float(dropout_p),
        )
        self.t_proj = nn.Linear(int(d_s), int(d_h))
        self.t_pool_proj = nn.Linear(int(d_h), int(d_h))
        self.t_pool_context = nn.Parameter(torch.randn(int(d_h)))

        # fusion + graph heads
        self.fuse = nn.Sequential(
            nn.Linear(2 * int(d_h), int(d_h)),
            nn.GELU(),
            nn.Linear(int(d_h), int(d_h)),
        )
        self.mu_proj = nn.Linear(int(d_h), int(d_g))
        self.logvar_proj = nn.Linear(int(d_h), int(d_g)) if self.variational else None

    def forward(
        self,
        W: torch.Tensor,
        S: torch.Tensor,
        X_time: torch.Tensor,
        *,
        sample: bool = False,
    ) -> SpatialTemporalEncoderOutput:
        if W.ndim != 3 or S.ndim != 3 or X_time.ndim != 3:
            raise ValueError(f"Expected W [B,N,N], S [B,N,F], X_time [B,N,T]; got W={tuple(W.shape)}, S={tuple(S.shape)}, X_time={tuple(X_time.shape)}")
        if W.shape[0] != S.shape[0] or W.shape[1] != S.shape[1] or W.shape[2] != S.shape[1]:
            raise ValueError(f"Shape mismatch: W={tuple(W.shape)}, S={tuple(S.shape)}")
        if X_time.shape[0] != S.shape[0] or X_time.shape[1] != S.shape[1]:
            raise ValueError(f"Shape mismatch: X_time={tuple(X_time.shape)}, S={tuple(S.shape)}")

        # spatial branch
        h = self.s_input(S)
        e = self.edge_init(W.unsqueeze(-1))
        for layer in self.s_layers:
            h, e = layer(h, e)
        H_spatial = h

        s_scores = torch.tanh(self.s_pool_proj(H_spatial)) @ self.s_pool_context
        s_alpha = torch.softmax(s_scores, dim=-1).unsqueeze(-1)
        s_space = (s_alpha * H_spatial).sum(dim=1)  # [B, d_h]

        # temporal branch
        S_dyn = self.temporal(X_time)               # [B, N, d_s]
        h_t = self.t_proj(S_dyn)                    # [B, N, d_h]
        t_scores = torch.tanh(self.t_pool_proj(h_t)) @ self.t_pool_context
        t_alpha = torch.softmax(t_scores, dim=-1).unsqueeze(-1)
        s_time = (t_alpha * h_t).sum(dim=1)         # [B, d_h]

        # fuse
        s_fuse = self.fuse(torch.cat([s_space, s_time], dim=-1))
        mu_g = self.mu_proj(s_fuse)

        if self.variational:
            assert self.logvar_proj is not None
            logvar_g = self.logvar_proj(s_fuse)
            if sample:
                sigma = torch.exp(0.5 * logvar_g)
                eps = torch.randn_like(sigma)
                z_g = mu_g + sigma * eps
            else:
                z_g = mu_g
            return SpatialTemporalEncoderOutput(z_g=z_g, mu_g=mu_g, logvar_g=logvar_g, aux=(H_spatial, S_dyn))

        return SpatialTemporalEncoderOutput(z_g=mu_g, aux=(H_spatial, S_dyn))
