# -*- coding: utf-8 -*-
from typing import Dict
import torch
import torch.nn as nn

class CausalGraphModuleLite(nn.Module):
    def __init__(self, shared_dim: int, config: Dict):
        super().__init__()
        self.config = config
        self.shared_dim = shared_dim
        self.epsilon = int(float(self.config.get("causal_window_epsilon", 3.0)))
        self.gnn_layers = nn.ModuleList([nn.Linear(shared_dim, shared_dim)
                                         for _ in range(self.config.get("gnn_layers", 2))])
        self.gnn_dropout = nn.Dropout(self.config.get("gnn_dropout", 0.1))
        self.gnn_act = nn.GELU()
        self.last_info: Dict = {}

    def forward(self, seg_seq: torch.Tensor, img_tok=None, txt_tok=None):
        # seg_seq: [B,S,L,D]
        if seg_seq.dim() == 4:
            nodes = seg_seq.mean(dim=2)  # [B,S,D]
        else:
            nodes = seg_seq              # [B,S,D]
        B, S, D = nodes.shape
        device = nodes.device

        # ---
        eps = int(self.epsilon)
        causal_mask = torch.zeros(B, S, S, device=device)
        for i in range(S):
            j0 = max(0, i - eps)
            if i > 0:
                causal_mask[:, i, j0:i] = 1.0

        # ---
        deg = causal_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
        adj_norm = causal_mask / deg  # [B,S,S]

        # ---
        x = torch.nn.functional.normalize(nodes, dim=-1)
        sim = torch.einsum('bid,bjd->bij', x, x)
        adj_logits = sim
        edge_scores = torch.sigmoid(adj_logits) * causal_mask

        # ---
        h = nodes
        for lin in self.gnn_layers:
            m = torch.einsum('bij,bjd->bid', adj_norm, h)  # [B,S,D]
            h2 = self.gnn_act(lin(m))
            h2 = self.gnn_dropout(h2)
            h = h + h2

        # ---
        E = int(causal_mask.sum().item())
        dbar = float((2 * E) / max(1, B * S)) if S > 0 else 0.0
        L = len(self.gnn_layers)

        flops_est = int(B * L * (S * S * D + S * D * D))

        self.last_info = {
            "V": int(B * S),
            "E": E,
            "dbar": dbar,
            "epsilon": eps,
            # ↓↓↓
            "causal_mask": causal_mask,   # 0/1
            "adj_logits": adj_logits,
            "edge_scores": edge_scores,
            "flops_est": flops_est,
        }
        return h, self.last_info
