"""
memory_xatten_decoder.py

memory-based cross-attention decoder for dense graphs

default
decode only edges from a graph embedding z_g (connectome recon)

optional node reconstruction
Set `reconstruct_nodes=True`

inputs
  z_g : [B, d_g]

outputs
  A_hat : [B, N, N]                    
  X_hat : [B, N, F] if reconstruct_nodes=True, else None

dependencies: torch
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class GraphDecoderOutput:
    A_hat: torch.Tensor
    X_hat: Optional[torch.Tensor] = None


class CrossAttnLayer(nn.Module):
    """
    per-node, multi-head cross-attention over a fixed memory table E.

      Queries  : Q_i = W_q^z z_g + W_q^h h_i
      Keys/Val : from fixed table E (do not depend on z_g)

    h   : [B, N, d_h]  current node states
    z_g : [B, d_g]     graph embedding
    K,V : [H, N, d_k]  per-head memory keys/values (fixed across batch)
    """

    def __init__(self, d_h: int, d_g: int, num_heads: int, d_ff: int, dropout_p: float) -> None:
        super().__init__()
        if d_h % num_heads != 0:
            raise ValueError("d_h must be divisible by num_heads")
        self.H = int(num_heads)
        self.d_k = int(d_h) // self.H
        self.drop_p = float(dropout_p)

        # queries: separate projections for z_g and node state
        self.W_q_z = nn.Linear(int(d_g), int(d_h), bias=False)
        self.W_q_h = nn.Linear(int(d_h), int(d_h), bias=False)

        # output + FFN + norms
        self.W_o = nn.Linear(int(d_h), int(d_h), bias=False)
        self.ffn = nn.Sequential(
            nn.Linear(int(d_h), int(d_ff)),
            nn.GELU(),
            nn.Dropout(self.drop_p),
            nn.Linear(int(d_ff), int(d_h)),
        )
        self.ln1 = nn.LayerNorm(int(d_h))
        self.ln2 = nn.LayerNorm(int(d_h))
        self.dropout = nn.Dropout(self.drop_p)

    def forward(self, h: torch.Tensor, z_g: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
        B, N, d_h = h.shape

        # queries
        q_z = self.W_q_z(z_g).view(B, self.H, 1, self.d_k)                      # [B,H,1,d_k]
        q_h = self.W_q_h(h).view(B, N, self.H, self.d_k).transpose(1, 2)        # [B,H,N,d_k]
        Q = q_z + q_h                                                           # [B,H,N,d_k]

        # attention
        scores = torch.einsum("b h n d, h m d -> b h n m", Q, K)                # [B,H,N,N]
        alpha = F.softmax(scores / math.sqrt(self.d_k), dim=-1)
        alpha = self.dropout(alpha)

        m = torch.einsum("b h n m, h m d -> b h n d", alpha, V)                 # [B,H,N,d_k]
        m = m.transpose(1, 2).contiguous().view(B, N, d_h)                      # concat heads

        # residual + FFN (pre-LN)
        h = self.ln1(h + self.W_o(m))
        h = self.ln2(h + self.ffn(h))
        return h


class GraphDecoder(nn.Module):
    """
    memory-based cross-attention decoder.

    args
    num_nodes : N
    node_feat_dim : F (only used if reconstruct_nodes=True)
    d_g : graph embedding size
    d_m : memory embedding size per node (learned)
    d_h : hidden size in cross-attn (must be divisible by num_heads)
    d_r : deterministic node code size used for decoding edges (and optionally nodes)
    """

    def __init__(
        self,
        num_nodes: int,
        node_feat_dim: int,
        d_g: int,
        *,
        d_m: int = 64,
        d_h: int = 128,
        d_r: int = 64,
        num_heads: int = 4,
        num_layers: int = 2,
        d_ff: int = 256,
        dropout_p: float = 0.1,
        d_x_hidden: int = 128,
        d_e_hidden: int = 128,
        reconstruct_nodes: bool = False,
    ) -> None:
        super().__init__()
        if d_h % num_heads != 0:
            raise ValueError("d_h must be divisible by num_heads")

        self.N = int(num_nodes)
        self.F = int(node_feat_dim)
        self.d_g = int(d_g)
        self.d_r = int(d_r)
        self.H = int(num_heads)
        self.d_k = int(d_h) // self.H
        self.reconstruct_nodes = bool(reconstruct_nodes)

        # memory table -> one learned vector per node
        self.E = nn.Parameter(torch.randn(self.N, int(d_m)))

        # memory keys/values (fixed across batch per forward)
        self.W_K = nn.Linear(int(d_m), int(d_h), bias=False)
        self.W_V = nn.Linear(int(d_m), int(d_h), bias=False)

        # initial node states
        self.in_proj = nn.Linear(int(d_m), int(d_h))

        # cross-attention stack
        self.layers: List[nn.Module] = nn.ModuleList(
            [
                CrossAttnLayer(
                    d_h=int(d_h),
                    d_g=int(d_g),
                    num_heads=int(num_heads),
                    d_ff=int(d_ff),
                    dropout_p=float(dropout_p),
                )
                for _ in range(int(num_layers))
            ]
        )

        # node code network r_i
        d_r_hidden = max(self.d_r * 2, 128)
        self.r_net = nn.Sequential(
            nn.Linear(int(d_h), int(d_r_hidden)),
            nn.GELU(),
            nn.Linear(int(d_r_hidden), self.d_r),
        )

        # optional node decoder
        if self.reconstruct_nodes:
            self.x_net = nn.Sequential(
                nn.Linear(self.d_r + self.d_g, int(d_x_hidden)),
                nn.GELU(),
                nn.Linear(int(d_x_hidden), self.F),
            )
        else:
            self.x_net = None

        # edge decoder
        self.edge_net = nn.Sequential(
            nn.Linear(4 * self.d_r + self.d_g, int(d_e_hidden)),
            nn.GELU(),
            nn.Linear(int(d_e_hidden), 1),
        )

    def forward(self, z_g: torch.Tensor) -> GraphDecoderOutput:
        """
        z_g : [B, d_g]

        returns:
          A_hat : [B, N, N] (symm)
          X_hat : [B, N, F] if reconstruct_nodes=True else None
        """
        if z_g.ndim != 2 or z_g.shape[1] != self.d_g:
            raise ValueError(f"z_g must be [B,{self.d_g}], got {tuple(z_g.shape)}")

        B = z_g.size(0)

        # fixed keys/values (no batch dim): [H,N,d_k]
        K = self.W_K(self.E).view(self.N, self.H, self.d_k).permute(1, 0, 2).contiguous()
        V = self.W_V(self.E).view(self.N, self.H, self.d_k).permute(1, 0, 2).contiguous()

        # initial node states from memory E: [B,N,d_h]
        h = self.in_proj(self.E).unsqueeze(0).expand(B, -1, -1)

        # cross-attention layers
        for layer in self.layers:
            h = layer(h, z_g, K, V)

        # deterministic node codes r: [B,N,d_r]
        r = self.r_net(h)

        # optional node reconstruction
        X_hat: Optional[torch.Tensor]
        if self.reconstruct_nodes:
            assert self.x_net is not None
            z_exp = z_g.unsqueeze(1).expand(-1, self.N, -1)  # [B,N,d_g]
            X_hat = self.x_net(torch.cat([r, z_exp], dim=-1))
        else:
            X_hat = None

        # edge reconstruction
        z_exp = z_g.unsqueeze(1).expand(-1, self.N, -1)      # [B,N,d_g]
        r_i = r.unsqueeze(2)                                 # [B,N,1,d_r]
        r_j = r.unsqueeze(1)                                 # [B,1,N,d_r]
        diff = torch.abs(r_i - r_j)
        prod = r_i * r_j
        z_broad = z_exp.unsqueeze(2)                         # [B,N,1,d_g]

        edge_feat = torch.cat(
            [
                r_i.expand(-1, -1, self.N, -1),
                r_j.expand(-1, self.N, -1, -1),
                prod,
                diff,
                z_broad.expand(-1, self.N, self.N, -1),
            ],
            dim=-1,
        )                                                    # [B,N,N,4d_r+d_g]

        A_hat = self.edge_net(edge_feat).squeeze(-1)         # [B,N,N]
        A_hat = 0.5 * (A_hat + A_hat.transpose(1, 2))        # enforce symmetry

        return GraphDecoderOutput(A_hat=A_hat, X_hat=X_hat)

    # Losses

    @staticmethod
    def edge_mse(A: torch.Tensor, A_hat: torch.Tensor) -> torch.Tensor:
        """MSE over all edges (including diagonal), per graph: returns [B]."""
        return ((A - A_hat) ** 2).mean(dim=(-2, -1))

    @staticmethod
    def node_mse(X: torch.Tensor, X_hat: torch.Tensor) -> torch.Tensor:
        """MSE over all nodes & features, per graph: returns [B]."""
        return ((X - X_hat) ** 2).mean(dim=(-2, -1))

    @staticmethod
    def total_mse(
        A: torch.Tensor,
        A_hat: torch.Tensor,
        X: Optional[torch.Tensor] = None,
        X_hat: Optional[torch.Tensor] = None,
        *,
        beta_node: float = 1.0,
    ) -> torch.Tensor:
        """
        total reconstruction loss.

        - edge MSE
        - includes node MSE only if X and X_hat are provided
        """
        loss = GraphDecoder.edge_mse(A, A_hat)
        if X is not None and X_hat is not None:
            loss = loss + float(beta_node) * GraphDecoder.node_mse(X, X_hat)
        return loss
