# model/layers/causal_graph.py


from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


def _safe_row_normalize(A: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    # A: [B,N,N]  ->
    row_sum = A.sum(-1, keepdim=True).clamp_min(eps)  # [B,N,1]
    return A / row_sum


class CausalGraphModule(nn.Module):
    def __init__(self, config: Dict):
        super().__init__()
        self.cfg = config

        self.D = int(config.get("input_dim", 256))
        self.H = int(config.get("hidden_dim", max(64, self.D // 2)))


        # strict_causality=True
        self.strict = bool(config.get("strict_causality", True))
        self.allow_instant = bool(config.get("allow_instantaneous", False))

        self.win_eps = float(config.get("window_epsilon", 0.0))
        self.min_lag = float(config.get("min_lag", 0.0))
        self.max_lag = float(config.get("max_lag", 1e9))

        self.tau_time = float(config.get("tau_time", 4.0))


        c = config.get("cross_modal", {}) or {}
        self.cross_enabled = bool(c.get("enabled", True))
        self.modality_pairs = c.get("modality_pairs", [])
        self.mod_vocab = c.get("modality_vocab", {0: "ts", 1: "img", 2: "txt"})
        self.align_cfg = c.get("alignment", {"same_index_only": True, "window": 0})

        self.gamma_cross = float(config.get("gamma_cross", 0.8))


        self.node_encoder = nn.Sequential(
            nn.Linear(self.D, self.H), nn.GELU(), nn.LayerNorm(self.H),
            nn.Linear(self.H, self.D)
        )


        self.sim_proj = nn.Linear(self.D, self.D, bias=False)
        nn.init.xavier_uniform_(self.sim_proj.weight)


        self.log_tau_pos = nn.Parameter(torch.log(torch.tensor(self.tau_time)))

    # ----------
    def _build_relaxed_adjacency(
        self,
        X: torch.Tensor,                # [B,N,D]
        tpos: Optional[torch.Tensor],   # [B,N] or None
        mods: Optional[torch.Tensor],   # [B,N] or None (long ids)
    ) -> torch.Tensor:
        B, N, D = X.shape
        device = X.device
        eps = 1e-6

        # 1)
        if tpos is None:

            dt = torch.zeros(B, N, N, device=device)
        else:
            t = tpos.float()
            dt = t.unsqueeze(2) - t.unsqueeze(1)  # [B,N,N]

        # 2)
        tau = self.log_tau_pos.exp().clamp_min(1e-3)  # >0
        if self.strict:

            allow = (dt > 0) | ((dt == 0) & torch.tensor(self.allow_instant, device=device))
            # lag
            allow = allow & (dt >= self.min_lag) & (dt <= self.max_lag)

            win_mask = (dt.abs() <= self.win_eps) if self.win_eps > 0 else torch.zeros_like(allow)

            time_w = torch.exp(-(dt.clamp_min(0.0)) / tau)
            if self.win_eps > 0:
                time_w = torch.where(win_mask, torch.ones_like(time_w), time_w)
            time_w = time_w * allow.float()
        else:

            time_w = torch.exp(-(dt.abs()) / tau)
            if self.win_eps > 0:
                win_mask = (dt.abs() <= self.win_eps)
                time_w = torch.where(win_mask, torch.ones_like(time_w), time_w)

        # 3)

        Xp = self.sim_proj(X)                            # [B,N,D]
        Xp = F.normalize(Xp, dim=-1, eps=1e-6)

        sim = torch.einsum("bid,bjd->bij", Xp, Xp).clamp(-1.0, 1.0)   # [B,N,N]
        sim01 = 0.5 * (sim + 1.0)                                     # -> [0,1]

        # 4)
        if self.cross_enabled and mods is not None:
            mods = mods.long()
            same = (mods.unsqueeze(2) == mods.unsqueeze(1))           # [B,N,N]

            cross_gate = torch.where(same, torch.ones_like(sim01), torch.full_like(sim01, self.gamma_cross))

            if self.align_cfg.get("same_index_only", True):
                align_mask = (dt.abs() <= 0)
            else:
                w = int(self.align_cfg.get("window", 0))
                align_mask = (dt.abs() <= float(w))

            cross_gate = torch.where(align_mask, torch.ones_like(cross_gate), cross_gate)
        else:
            cross_gate = torch.ones_like(sim01)

        # 5)
        A_raw = (time_w * sim01 * cross_gate).clamp(0.0, 1.0)

        # 6)
        eye = torch.eye(N, device=device).unsqueeze(0)
        A_no_self = A_raw * (1.0 - eye)

        # 7)
        thr = float(self.cfg.get("edge_threshold", 0.0))
        if thr > 0:
            A_thr = torch.where(A_no_self >= thr, A_no_self, torch.zeros_like(A_no_self))
            # fallback：
            row_nonzero = (A_thr.sum(-1, keepdim=True) > 0).float()
            A = A_thr * row_nonzero + A_no_self * (1.0 - row_nonzero)
        else:
            A = A_no_self

        # 8)
        A = _safe_row_normalize(A, eps=1e-6)  # [B,N,N]
        return A

    def forward(
        self,
        features: torch.Tensor,                   # [B,N,D]
        temporal_positions: Optional[torch.Tensor] = None,  # [B,N] or None
        modality_labels: Optional[torch.Tensor] = None,     # [B,N] or None
        use_sparse_causal: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        B, N, D = features.shape

        feat = self.node_encoder(features)  # [B,N,D]
        A = self._build_relaxed_adjacency(feat, temporal_positions, modality_labels)  # [B,N,N]
        return feat, A
