# -*- coding: utf-8 -*-
import os
import math
from typing import Dict, List, Tuple, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================

# =========================
def _try_import(path_mod, name):
    try:
        mod = __import__(path_mod, fromlist=[name])
        return getattr(mod, name)
    except Exception:
        return None

# Encoders
ImageEncoder = _try_import("layers.image_encoder", "ImageEncoder") or _try_import("model.layers.image_encoder", "ImageEncoder")
TextEncoder  = _try_import("layers.text_encoder",  "TextEncoder")  or _try_import("model.layers.text_encoder",  "TextEncoder")
TimeSeriesEncoder = _try_import("layers.ts_encoder", "TimeSeriesEncoder") or _try_import("model.layers.ts_encoder", "TimeSeriesEncoder")
StaticEncoder = _try_import("layers.st_encoder", "StaticEncoder") or _try_import("model.layers.st_encoder", "StaticEncoder")

# EAMC
DynamicSegmentationModule = _try_import("modules.dynamic_segmentation", "DynamicSegmentationModule") or _try_import("model.modules.dynamic_segmentation", "DynamicSegmentationModule")
GatedMultiscaleAttention = _try_import("layers.gma", "GatedMultiscaleAttention") or _try_import("model.layers.gma", "GatedMultiscaleAttention")
SegmentImportanceScorer  = _try_import("layers.segment_importance", "SegmentImportanceScorer") or _try_import("model.layers.segment_importance", "SegmentImportanceScorer")
DynamicComputeAllocator  = _try_import("layers.segment_importance", "DynamicComputeAllocator") or _try_import("model.layers.segment_importance", "DynamicComputeAllocator")

# Graph
CausalGraphModule = _try_import("layers.causal_graph", "CausalGraphModule") or _try_import("model.layers.causal_graph", "CausalGraphModule")

# Heads & utils
AEBlock = _try_import("backbone.ae", "AEBlock") or _try_import("model.backbone.ae", "AEBlock")
MLPHead = _try_import("backbone.head", "MLPHead") or _try_import("model.backbone.head", "MLPHead")
TrajHead = _try_import("layers.traj_head", "TrajHead") or _try_import("model.layers.traj_head", "TrajHead")
TemperatureScaledAttention = _try_import("layers.temperature_attention", "TemperatureScaledAttention") or _try_import("model.layers.temperature_attention", "TemperatureScaledAttention")
compute_irr_bcr = _try_import("backbone.eamc_utils", "compute_irr_bcr") or _try_import("model.backbone.eamc_utils", "compute_irr_bcr")

# Fallbacks（
if MLPHead is None:
    class MLPHead(nn.Module):
        def __init__(self, in_dim, out_dim, hidden=256, dropout=0.1):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
                nn.Linear(hidden, out_dim)
            )
        def forward(self, x): return self.net(x)

if TemperatureScaledAttention is None:
    class TemperatureScaledAttention(nn.Module):
        def __init__(self, d, nhead=8, dropout=0.1, batch_first=True, init_tau=1.0, learnable=True):
            super().__init__()
            self.attn = nn.MultiheadAttention(d, nhead, dropout=dropout, batch_first=batch_first)
            self.tau = nn.Parameter(torch.tensor(float(init_tau))) if learnable else torch.tensor(float(init_tau))
        def forward(self, q, k, v, key_padding_mask=None, attn_mask=None):
            out, _ = self.attn(q/self.tau, k/self.tau, v/self.tau, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
            return out

if compute_irr_bcr is None:

    def compute_irr_bcr(x: torch.Tensor, seg_mask: torch.Tensor, assign: torch.Tensor) -> Dict[str, float]:
        # x:[B,T,D]; seg_mask:[B,S]; assign:[B,T]∈{0..S-1}
        with torch.no_grad():
            B, T, D = x.shape
            irr_list, bcr_list = [], []
            for b in range(B):
                s_idx = assign[b]  # [T]

                rec = torch.zeros_like(x[b])
                for s in torch.unique(s_idx):
                    s = int(s.item())
                    t_mask = (s_idx == s)
                    if t_mask.any():
                        rec[t_mask] = x[b, t_mask].mean(dim=0, keepdim=True)
                nrmse = torch.sqrt(((x[b]-rec)**2).mean()) / (x[b].std() + 1e-6)
                irr_list.append(float(1.0 - nrmse.item()))

                cut = (s_idx[1:] != s_idx[:-1]).float().mean().item() if T > 1 else 0.0
                bcr_list.append(float(cut))
            return {"IRR": sum(irr_list)/len(irr_list), "BCR": sum(bcr_list)/len(bcr_list)}

# =========================

# =========================
class RestructuredMultiModalNet(nn.Module):
    """
    EAMC + C^2SG + Multi-Modal Fusion.
    Supports:
      - Trajectory prediction: task in {'traj','trajectory'} or config['task_type']=='traj'
      - Default tasks: classification / forecasting (maintains compatibility)
    """
    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
        self.shared_dim = self.config.get("shared_dim", 128)

        # caches (for logging/plots)
        self.last_eamc_info: Dict[str, Union[float, torch.Tensor, List[int]]] = {}
        self.last_graph_info: Dict[str, Union[float, int]] = {}

        # ---- Ablation Control Flags ----
        self.use_multimodal = self.config.get("use_multimodal", True)
        self.use_image = self.config.get("use_image", True)
        self.use_text = self.config.get("use_text", True)
        self.use_static = self.config.get("use_static", True)
        self.use_dynseg = self.config.get("use_dynseg", True)
        self.use_causal = self.config.get("use_causal", True)
        self.use_nbc = self.config.get("use_nbc", True)

        # ---- Encoders ----
        # Image encoder (conditional on use_image and use_multimodal)
        if self.use_multimodal and self.use_image and ImageEncoder:
            self.image_encoder = ImageEncoder(out_dim=self.config.get("image_dim", 128))
            self.img_norm = nn.LayerNorm(self.config.get("image_dim", 128))
        else:
            self.image_encoder = nn.Identity()
            self.img_norm = nn.Identity()

        # Text encoder (conditional on use_text and use_multimodal)
        if self.use_multimodal and self.use_text and TextEncoder:
            self.text_encoder = TextEncoder()
            self.text_proj = nn.Linear(768, self.config.get("text_dim", 128))
            self.txt_norm = nn.LayerNorm(self.config.get("text_dim", 128))
        else:
            self.text_encoder = nn.Identity()
            self.text_proj = nn.Linear(self.shared_dim, self.config.get("text_dim", 128))
            self.txt_norm = nn.LayerNorm(self.config.get("text_dim", 128))

        self.ts_encoder = TimeSeriesEncoder(out_dim=self.config.get("ts_dim", 64)) if TimeSeriesEncoder else nn.Identity()
        self.ts_proj = nn.Linear(self.config.get("ts_dim", 64) if TimeSeriesEncoder else self.shared_dim, self.shared_dim)
        self.ts_norm = nn.LayerNorm(self.shared_dim)

        # Static encoder (conditional on use_static)
        if self.use_static and StaticEncoder:
            self.static_encoder = StaticEncoder(in_dim=self.config.get("static_dim", 16),
                                                out_dim=self.config.get("static_dim", 16))
            self.static_norm = nn.LayerNorm(self.config.get("static_dim", 16))
        else:
            self.static_encoder = nn.Identity()
            self.static_norm = nn.Identity()

        # ---- EAMC ----
        if self.use_dynseg and DynamicSegmentationModule:
            seg_cfg = dict(
                d_model=self.shared_dim, d_pe=self.shared_dim,
                max_len=self.config.get("max_seq_len", 704),
                desired_threshold=self.config.get("desired_threshold", 0.8),
                fixed_max_segments=self.config.get("max_segments", 8),
                fixed_max_len=self.config.get("segment_len", 64),
                segment_mask_top_k_ratio=self.config.get("segment_mask_top_k_ratio", 0.3),
            )
            self.dynamic_segmenter = DynamicSegmentationModule(self.shared_dim, seg_cfg)
        else:
            self.dynamic_segmenter = None
        gma_cfg = {
            "scales": [8,16,32,64], "num_heads": 8, "head_dim": self.shared_dim//8,
            "time_branch": {"enabled": True, "filter_orders": [3,5,7,9]},
            "frequency_branch": {"enabled": self.config.get("frequency_branch_enabled", False),
                                 "use_fft": self.config.get("use_fft", False),
                                 "use_wavelet": self.config.get("use_wavelet", False),
                                 "wavelet_levels": self.config.get("wavelet_levels", 2),
                                 "fft_bins": self.config.get("fft_bins", 32)},
            "filter_bank": {"num_filters": 8, "filter_sizes": [3,5,7,9], "dilation_rates": [1,2,4,8]},
            "gating": {"depth": 3, "hidden_dim": self.config.get("eamc_hidden_dim", 128)}
        }
        self.gma_fusion = GatedMultiscaleAttention(self.shared_dim, gma_cfg) if GatedMultiscaleAttention else nn.Identity()
        self.segment_importance_scorer = SegmentImportanceScorer(
            feature_dim=self.shared_dim,
            scoring_method=self.config.get("eamc_scoring_method", "energy"),
            hidden_dim=self.config.get("eamc_hidden_dim", 128),
            num_heads=self.config.get("eamc_num_heads", 8),
        ) if SegmentImportanceScorer else None

        # ---- Cross-modal ----
        self.img_proj = nn.Linear(self.config.get("image_dim", 128), self.shared_dim)
        self.txt_proj_unified = nn.Linear(self.config.get("text_dim", 128), self.shared_dim)
        self.static_proj_unified = nn.Linear(self.config.get("static_dim", 16), self.shared_dim)
        self.img_txt_attn = TemperatureScaledAttention(self.shared_dim, 8, 0.1, True,
                                                       init_tau=self.config.get("tau_img_txt", 1.0),
                                                       learnable=self.config.get("tau_img_txt_learnable", True))
        self.img_ts_attn  = TemperatureScaledAttention(self.shared_dim, 8, 0.1, True,
                                                       init_tau=self.config.get("tau_img_ts", 0.1),
                                                       learnable=self.config.get("tau_img_ts_learnable", False))
        self.txt_ts_attn  = TemperatureScaledAttention(self.shared_dim, 8, 0.1, True,
                                                       init_tau=self.config.get("tau_txt_ts", 0.5),
                                                       learnable=self.config.get("tau_txt_ts_learnable", True))
        self.fusion_attn = nn.MultiheadAttention(self.shared_dim, 8, 0.1, batch_first=True)
        self.fusion_out  = nn.Sequential(nn.Linear(self.shared_dim, self.shared_dim),
                                         nn.GELU(), nn.Dropout(0.1),
                                         nn.Linear(self.shared_dim, self.shared_dim))

        # ---- Graph ----
        self.epsilon = float(self.config.get("causal_window_epsilon", 3.0))
        self.gnn_layers = nn.ModuleList([nn.Linear(self.shared_dim, self.shared_dim)
                                         for _ in range(self.config.get("gnn_layers", 2))])
        self.gnn_dropout = nn.Dropout(self.config.get("gnn_dropout", 0.1))
        self.gnn_act = nn.GELU()

        # ---- Heads ----
        final_dim = self.shared_dim * 2  # fused_global + graph_pool
        self.classification_head = MLPHead(final_dim, self.config.get("out_dim", 2) * self.config.get("out_len", 1))
        self.forecast_head       = MLPHead(final_dim, self.config.get("out_dim", 2) * self.config.get("out_len", 1))

        # Trajectory
        self.traj_K = int(self.config.get("traj_K", 6))
        self.traj_T = int(self.config.get("traj_T", 12))
        self.traj_hist_dim = int(self.config.get("traj_hist_dim", 64))
        self.traj_hist_enc = nn.GRU(input_size=2, hidden_size=self.traj_hist_dim, num_layers=1, batch_first=True)
        self.traj_head = TrajHead(in_dim=final_dim + self.traj_hist_dim,
                                  K=self.traj_K, T=self.traj_T,
                                  hidden=int(self.config.get("traj_hidden", 256)),
                                  dropout=float(self.config.get("traj_dropout", 0.1))) if TrajHead else None

        # ---- Aux ----
        self.enable_ae   = bool(self.config.get("enable_ae", False))
        self.enable_clip = bool(self.config.get("enable_clip", False))
        if self.enable_ae and AEBlock:
            self.ae_img = AEBlock(self.shared_dim, self.config.get("ae_loss_type", "mse"))
            self.ae_txt = AEBlock(self.shared_dim, self.config.get("ae_loss_type", "mse"))
            self.ae_ts  = AEBlock(self.shared_dim, self.config.get("ae_loss_type", "mse"))

    # =========================

    # =========================
    def extract_features(self,
                         images: torch.Tensor,
                         text_item: Optional[Dict[str, torch.Tensor]],
                         packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                         static_data: torch.Tensor):
        device = images.device if torch.is_tensor(images) else next(self.parameters()).device
        B = images.size(0)

        # ---- image ----
        if self.use_multimodal and self.use_image:
            if images.ndim == 5:  # [B,N,3,H,W]
                B, N, C, H, W = images.shape
                img_flat = images.view(B*N, C, H, W)
                img_feat = self.image_encoder(img_flat) if not isinstance(self.image_encoder, nn.Identity) else img_flat.mean(dim=[2,3])
                if img_feat.dim() == 2:
                    img_feat = img_feat.view(B, N, -1)
                else:  # [B*N, D] -> [B,N,D]
                    img_feat = img_feat.view(B, N, -1)
            else:  # [B,3,H,W]
                img_feat = self.image_encoder(images) if not isinstance(self.image_encoder, nn.Identity) else images.mean(dim=[2,3])
                if img_feat.dim() == 2:
                    img_feat = img_feat.unsqueeze(1)   # [B,1,D]
            img_feat = self.img_norm(img_feat)
            img_tok = self.img_proj(img_feat)         # [B,Ni,D]
        else:

            img_tok = torch.zeros(B, 1, self.shared_dim, device=device)

        # ---- text ----
        if self.use_multimodal and self.use_text and (text_item is not None) and not isinstance(self.text_encoder, nn.Identity):
            try:
                txt_raw = self.text_encoder(text_item=text_item, preprocessed=True, return_dict=False)  # [B, 768] or [B,N,768]
            except TypeError:
                txt_raw = self.text_encoder(text_item)
            if txt_raw.dim() == 2:
                txt_raw = txt_raw.unsqueeze(1)
            txt_feat = self.text_proj(txt_raw)   # [B,Nt,D_txt]
            txt_feat = self.txt_norm(txt_feat)
            txt_tok = self.txt_proj_unified(txt_feat) # [B,Nt,D]
        else:

            txt_tok = torch.zeros(B, 1, self.shared_dim, device=device)

        # ---- time-series ----
        if isinstance(packed_ts, dict) and "ts_data" in packed_ts:
            ts_raw = packed_ts["ts_data"].float()      # [B, N_med, T, F]
            B, N_med, T, F = ts_raw.shape
            ts_flat = ts_raw.view(B*N_med, T, F)
            ts_lat  = self.ts_encoder(ts_flat) if not isinstance(self.ts_encoder, nn.Identity) else ts_flat
            if ts_lat.dim() == 3:
                ts_lat = ts_lat.view(B, N_med, T, -1).mean(dim=1)   # [B,T,ts_dim]
            else:
                ts_lat = ts_flat.view(B, T, -1)
        else:

            ts_lat = self.ts_encoder(packed_ts) if not isinstance(self.ts_encoder, nn.Identity) else packed_ts
        ts_feat = self.ts_norm(self.ts_proj(ts_lat))  # [B,T,D]
        ts_tok  = ts_feat                             

        # ---- static ----
        if self.use_static:
            st = self.static_encoder(static_data) if not isinstance(self.static_encoder, nn.Identity) else static_data
            st = self.static_norm(st)
            st_tok = self.static_proj_unified(st.unsqueeze(1))  # [B,1,D]
        else:

            st_tok = torch.zeros(B, self.shared_dim, device=device)

        return img_tok, txt_tok, ts_tok, st_tok.squeeze(1)

    # =========================
    # EAMC（
    # =========================
    @torch.no_grad()
    def _keep_topB_merge(self, seg_tokens: torch.Tensor, scores: torch.Tensor, B: int):
        """
        Top-Down: Keep top-B segments, merge remaining segments to nearest high-scoring segment (based on index distance).
        seg_tokens: [B,S,D], scores: [B,S]
        return:
          new_tokens:[B,S',D], assign:[B,S] (old->new), merge_ops(list), old2new(list of lists)
        """
        Bsz, S, D = seg_tokens.shape
        topk = torch.topk(scores, k=min(B, S), dim=1, largest=True, sorted=True)  # indices:[B,B]
        keep_idx = topk.indices  # [B,B']

        assign = torch.zeros(Bsz, S, dtype=torch.long, device=seg_tokens.device)
        merge_ops, old2new = [], []

        for b in range(Bsz):
            keep = keep_idx[b].tolist()
            keep_sorted = sorted(keep)
            # nearest assign by index distance
            mapping = {}
            for s in range(S):
                if s in keep:
                    mapping[s] = keep.index(s)  # new id
                else:

                    nearest = min(range(len(keep_sorted)), key=lambda i: abs(keep_sorted[i]-s))
                    mapping[s] = nearest

            m_ops_b = [(s, mapping[s]) for s in range(S) if s not in keep]
            inv_map = [[] for _ in range(len(keep))]
            for s in range(S):
                inv_map[mapping[s]].append(s)
            merge_ops.append(m_ops_b)
            old2new.append(inv_map)
            assign[b] = torch.tensor([mapping[s] for s in range(S)], device=seg_tokens.device)

        new_tokens = []
        for b in range(Bsz):
            groups = old2new[b]
            agg = []
            for g in groups:
                agg.append(seg_tokens[b, g].mean(dim=0))
            new_tokens.append(torch.stack(agg, 0))
        new_tokens = torch.stack(new_tokens, 0)  # [B,B',D]
        return new_tokens, assign, merge_ops, old2new

    def apply_eamc(self, ts_feat: torch.Tensor, ts_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Input ts_feat:[B,T,D] → segmentation + intra-segment GMA → segment-level tokens [B,S',D] and segment sequences [B,S',L',D]
        - Budget-aware: if S > B (B from config['node_budget'] or env NBC_BUDGET), perform Top-Down merging
        - Record IRR/BCR, segment length distribution and merge mapping
        """
        B, T, D = ts_feat.shape
        device = ts_feat.device

        # 1)
        if not self.use_dynseg or self.dynamic_segmenter is None:

            max_segments = int(self.config.get("max_segments", 8))
            L = max(1, math.ceil(T / max_segments))
            pads = L*max_segments - T
            x = F.pad(ts_feat, (0,0,0,pads))  # [B, L*max_segments, D]
            seg_seq = x.view(B, max_segments, L, D)
            seg_tokens = self.gma_fusion(seg_seq) if hasattr(self.gma_fusion, "__call__") else seg_seq.mean(dim=2)
            if seg_tokens.dim() == 4: seg_tokens = seg_tokens.squeeze(2)
            S = max_segments
            self.last_eamc_info = {"S": torch.tensor([S]*B, device=device), "len_hist": [L]*B, "merge_ops": [], "IRR": None, "BCR": None}
            return seg_tokens, seg_seq

        try:


            ts_lengths = torch.full((B,), T, device=device, dtype=torch.long)
            seg_tokens, seg_seq, segment_info = self.dynamic_segmenter(ts_feat, ts_lengths)

            seg_tokens = self.gma_fusion(seg_seq)  # [B,S,1,D]
            if seg_tokens.dim() == 4: seg_tokens = seg_tokens.squeeze(2)
            S = seg_tokens.size(1)
        except Exception as e:
            print(f)
            # fallback：
            max_segments = int(self.config.get("max_segments", 8))
            L = max(1, math.ceil(T / max_segments))
            pads = L*max_segments - T
            x = F.pad(ts_feat, (0,0,0,pads))  # [B, L*max_segments, D]
            seg_seq = x.view(B, max_segments, L, D)
            seg_tokens = self.gma_fusion(seg_seq) if hasattr(self.gma_fusion, "__call__") else seg_seq.mean(dim=2)
            if seg_tokens.dim() == 4: seg_tokens = seg_tokens.squeeze(2)
            S = max_segments

            segment_info = {"IRR": None, "BCR": None}

        # 2)
        merge_ops = []; old2new = None; assign = None
        if self.use_nbc:
            budget = int(os.environ.get("NBC_BUDGET", self.config.get("node_budget", S)))
            if budget is not None and S > budget:

                scores = self.segment_importance_scorer(seg_tokens) if self.segment_importance_scorer else seg_tokens.norm(p=2, dim=-1)
                new_tokens, assign, merge_ops, old2new = self._keep_topB_merge(seg_tokens, scores, budget)  # [B,S',D]
                seg_tokens = new_tokens

                seg_seq = seg_tokens.unsqueeze(2)  # [B,S',1,D]
                S = seg_tokens.size(1)

        # 3)

        if assign is None:

            B_, S_, L_ = seg_seq.shape[:3]

            assign = torch.zeros(B, T, dtype=torch.long, device=device)
            edges = torch.linspace(0, T, steps=S_+1, device=device).long()
            for b in range(B):
                for s in range(S_):
                    assign[b, edges[s]:edges[s+1]] = s
        # IRR/BCR
        irr_bcr = compute_irr_bcr(ts_feat, seg_tokens.new_ones(B, S).bool(), assign) if callable(compute_irr_bcr) else {}

        len_hist = []
        for b in range(B):
            cnt = torch.bincount(assign[b], minlength=S).tolist()
            len_hist.append(cnt)


        irr_value = segment_info.get("IRR", None) if isinstance(segment_info, dict) else None
        bcr_value = segment_info.get("BCR", None) if isinstance(segment_info, dict) else None
        
        self.last_eamc_info = {
            "S": torch.tensor([S]*B, device=device),
            "merge_ops": merge_ops,
            "len_hist": len_hist,
            "IRR": irr_value,
            "BCR": bcr_value,
        }
        return seg_tokens, seg_seq

    # =========================

    # =========================
    def cross_modal_fusion(self, img_tok, txt_tok, ts_seg_tokens, static_vec):
        """
        Simplified global fusion: each modality does mean-pool → concatenate attention → residual
        img_tok:[B,Ni,D], txt_tok:[B,Nt,D], ts_seg_tokens:[B,S,D], static_vec:[B,Ds] -> project to D
        return fused_global:[B,D]
        """
        B = ts_seg_tokens.size(0)
        d = self.shared_dim
        st_tok = static_vec.unsqueeze(1)  # [B,1,Ds]
        st_tok = self.static_proj_unified(st_tok)  # [B,1,D]

        # concat tokens
        tokens = torch.cat([img_tok, txt_tok, ts_seg_tokens, st_tok], dim=1)  # [B, Nall, D]
        # self-attn
        fused, _ = self.fusion_attn(tokens, tokens, tokens)
        fused = fused + self.fusion_out(fused)
        # mean-pool
        fused_global = fused.mean(dim=1)  # [B,D]
        return fused_global

    # =========================

    # =========================
    def apply_causal_graph(self, seg_seq, img_tok=None, txt_tok=None):
        """
        Build graph on segment-level nodes: strict lower triangular + epsilon window
        seg_seq: [B,S,L,D] or [B,S,1,D]
        return graph_nodes: [B,S,D] (GNN output)
        """
        if seg_seq.dim() == 4:
            nodes = seg_seq.mean(dim=2)  # [B,S,D]
        else:
            nodes = seg_seq  # [B,S,D]
        B, S, D = nodes.shape
        device = nodes.device
        
        if not self.use_causal:

            self.last_graph_info = {"V": int(B*S), "E": 0, "dbar": 0.0}
            return nodes
        
        # adjacency:
        eps = int(self.epsilon)
        adj = torch.zeros(B, S, S, device=device)
        for i in range(S):
            j0 = max(0, i - eps)
            adj[:, i, j0:i] = 1.0

        deg = adj.sum(dim=-1, keepdim=True).clamp_min(1.0)
        adj_norm = adj / deg


        h = nodes
        for lin in self.gnn_layers:
            m = torch.einsum('bij,bjd->bid', adj_norm, h)
            h2 = lin(m)
            h2 = self.gnn_act(h2)
            h2 = self.gnn_dropout(h2)
            h = h + h2


        E = int(adj.sum().item())
        dbar = float((2*E) / max(1, B*S)) if S > 0 else 0.0
        self.last_graph_info = {"V": int(B*S), "E": E, "dbar": dbar}
        return h  # [B,S,D]

    # =========================
    # Trajectory
    # =========================
    def _encode_agent_histories(self, hist: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        hist: [B,M,H,2], lengths:[B,M]
        return: [B,M,traj_hist_dim]
        """
        B, M, H, _ = hist.shape
        x = hist.reshape(B*M, H, 2)
        L = lengths.reshape(B*M).clamp(min=1)
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths=L.cpu(), batch_first=True, enforce_sorted=False)
        self.traj_hist_enc.flatten_parameters()
        _, h = self.traj_hist_enc(packed)  # [1, BM, Hdim]
        enc = h.squeeze(0).reshape(B, M, -1)
        return enc

    def _forward_traj(self,
                      images: torch.Tensor,
                      text_item: Optional[Dict[str, torch.Tensor]],
                      packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                      static_data: torch.Tensor,
                      traj: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        End-to-end: encoding → EAMC → fusion → causal graph → scene features → Agent history encoding → TrajHead
        """
        assert traj is not None and ("hist" in traj) and ("hist_mask" in traj), "traj dict missing 'hist'/'hist_mask'."

        # 1)
        img_tok, txt_tok, ts_feat, static_vec = self.extract_features(images, text_item, packed_ts)

        # 2) EAMC（
        if isinstance(packed_ts, dict) and ("seq_lengths" in packed_ts) and packed_ts["seq_lengths"].dim() >= 1:
            ts_lengths = packed_ts["seq_lengths"]
            if ts_lengths.dim() == 2:  # [B,N_med]
                ts_lengths = ts_lengths.min(dim=1)[0]
        else:
            ts_lengths = torch.full((ts_feat.size(0),), ts_feat.size(1), device=ts_feat.device, dtype=torch.long)

        seg_tokens, seg_seq = self.apply_eamc(ts_feat, ts_lengths)  # [B,S',D], [B,S',L',D]

        # 3)
        fused_global = self.cross_modal_fusion(img_tok, txt_tok, seg_tokens, static_vec)  # [B,D]

        # 4)
        graph_nodes = self.apply_causal_graph(seg_seq, img_tok, txt_tok)  # [B,S',D]
        graph_pool  = graph_nodes.mean(dim=1)                              # [B,D]

        # 5)
        scene = torch.cat([fused_global, graph_pool], dim=-1)  # [B, 2D]
        B, D2 = scene.shape

        # 6) Agent
        hist      = traj["hist"].to(scene.device).float()        # [B,M,H,2]
        hist_mask = traj["hist_mask"].to(scene.device).bool()    # [B,M,H]
        lengths   = hist_mask.long().sum(dim=-1)                 # [B,M]
        hist_enc  = self._encode_agent_histories(hist, lengths)  # [B,M,Hdim]
        M = hist_enc.size(1)

        # 7) TrajHead
        scene_exp = scene.unsqueeze(1).expand(B, M, D2)          # [B,M,2D]
        head_in   = torch.cat([scene_exp, hist_enc], dim=-1)     # [B,M,2D+Hdim]
        traj_mean, traj_conf = self.traj_head(head_in)           # [B,M,K,T,2], [B,M,K]
        traj_conf = torch.softmax(traj_conf, dim=-1)

        out = {"traj_mean": traj_mean, "traj_conf": traj_conf}

        out["eamc_info"]  = self.last_eamc_info
        out["graph_info"] = self.last_graph_info
        return out

    # =========================

    # =========================
    def _forward_default(self,
                         images: torch.Tensor,
                         text_item: Optional[Dict[str, torch.Tensor]],
                         packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                         static_data: torch.Tensor) -> Dict[str, torch.Tensor]:
        img_tok, txt_tok, ts_feat, static_vec = self.extract_features(images, text_item, packed_ts)

        # EAMC
        if isinstance(packed_ts, dict) and ("seq_lengths" in packed_ts) and packed_ts["seq_lengths"].dim() >= 1:
            ts_lengths = packed_ts["seq_lengths"]
            if ts_lengths.dim() == 2:
                ts_lengths = ts_lengths.min(dim=1)[0]
        else:
            ts_lengths = torch.full((ts_feat.size(0),), ts_feat.size(1), device=ts_feat.device, dtype=torch.long)
        seg_tokens, seg_seq = self.apply_eamc(ts_feat, ts_lengths)

        fused_global = self.cross_modal_fusion(img_tok, txt_tok, seg_tokens, static_vec)
        graph_nodes  = self.apply_causal_graph(seg_seq, img_tok, txt_tok)
        graph_pool   = graph_nodes.mean(dim=1)
        feat = torch.cat([fused_global, graph_pool], dim=-1)

        out = {
            "logits_cls": self.classification_head(feat),
            "forecast":   self.forecast_head(feat),
            "eamc_info":  self.last_eamc_info,
            "graph_info": self.last_graph_info,
        }
        return out

    # =========================

    # =========================
    def forward(self,
                images: torch.Tensor,
                text_item: Optional[Dict[str, torch.Tensor]],
                packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                static_data: torch.Tensor,
                task: Optional[str] = None,
                traj: Optional[Dict[str, torch.Tensor]] = None):
        t = (task or self.config.get("task_type", "classification")).lower()
        if t in ("traj", "trajectory"):
            return self._forward_traj(images, text_item, packed_ts, static_data, traj)
        else:
            return self._forward_default(images, text_item, packed_ts, static_data)


# =========================

# =========================
def create_restructured_model(config: Dict) -> RestructuredMultiModalNet:
    return RestructuredMultiModalNet(config)
