# -*- coding: utf-8 -*-
from typing import Dict, Optional, Union, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..common.features_and_fusion import FeatureExtractorAndFusion
from ..modules.dynamic_segmentation import DynamicSegmentationModule
from ..modules.causal_graph import CausalGraphModuleLite

def _try_import(path_mod, name):
    try:
        mod = __import__(path_mod, fromlist=[name])
        return getattr(mod, name)
    except Exception:
        return None

TrajHead = _try_import("layers.traj_head", "TrajHead") or _try_import("model.layers.traj_head", "TrajHead")
AEBlock  = _try_import("backbone.ae", "AEBlock") or _try_import("model.backbone.ae", "AEBlock")
MLPHead  = _try_import("backbone.head", "MLPHead") or _try_import("model.backbone.head", "MLPHead")


if MLPHead is None:
    class MLPHead(nn.Module):
        def __init__(self, in_dim, out_dim, hidden=256, dropout=0.1):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
                nn.Linear(hidden, out_dim)
            )
        def forward(self, x): return self.net(x)


class TrajectoryPredictor(nn.Module):

    def __init__(self, config: Dict):
        super().__init__()
        self.config = dict(config) # ，
        self.shared_dim = int(self.config.get("shared_dim", 128))

 # ---- （ ）----
        self.use_dynseg     = bool(self.config.get("use_dynseg", True))
        self.use_causal     = bool(self.config.get("use_causal", True))
        self.use_multimodal = bool(self.config.get("use_multimodal", True))
        self.use_nbc        = bool(self.config.get("use_nbc", True))
 # use_nbc （ ）
        self.config.setdefault("use_nbc", self.use_nbc)

 # （ use_dynseg=False ）
        self.seg_bins = int(self.config.get("seg_bins", 8))


        self.ff = FeatureExtractorAndFusion(self.config)

 # / （ ， state_dict ）
        self.seg_mod = DynamicSegmentationModule(shared_dim=self.shared_dim, config=self.config)
        self.cg_mod  = CausalGraphModuleLite(shared_dim=self.shared_dim, config=self.config)


        self.traj_K = int(self.config.get("traj_K", 5))
        self.traj_T = int(self.config.get("traj_T", 12))
        self.traj_hist_dim = int(self.config.get("traj_hist_dim", 64))
        self.traj_hist_enc = nn.GRU(input_size=2, hidden_size=self.traj_hist_dim, num_layers=1, batch_first=True)

 # ： （ ： logits； softmax）
        final_dim = self.shared_dim * 2  # fused_global + graph_pool
        if TrajHead is None:
 # ： MLP K*T*2 + K logits
            self.traj_head = nn.Sequential(
                nn.Linear(final_dim + self.traj_hist_dim, int(self.config.get("traj_hidden", 256))),
                nn.GELU(),
                nn.Dropout(float(self.config.get("traj_dropout", 0.1))),
                nn.Linear(int(self.config.get("traj_hidden", 256)), self.traj_K * self.traj_T * 2 + self.traj_K),
            )
            self._traj_head_is_fallback = True
        else:
            self.traj_head = TrajHead(in_dim=final_dim + self.traj_hist_dim,
                                      K=self.traj_K, T=self.traj_T,
                                      hidden=int(self.config.get("traj_hidden", 256)),
                                      dropout=float(self.config.get("traj_dropout", 0.1)))
            self._traj_head_is_fallback = False

    # ----------------- helpers -----------------
    def _encode_agent_histories(self, hist: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        B, M, H, _ = hist.shape
        x = hist.reshape(B*M, H, 2)
        L = lengths.reshape(B*M).clamp(min=1)
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths=L.cpu(), batch_first=True, enforce_sorted=False)
        self.traj_hist_enc.flatten_parameters()
        _, h = self.traj_hist_enc(packed)  # [1, BM, Hdim]
        enc = h.squeeze(0).reshape(B, M, -1)
        return enc

    def _uniform_segment(self, ts_feat: torch.Tensor, bins: int) -> Tuple[torch.Tensor, torch.Tensor, Dict]:

        assert ts_feat.dim() == 3, f"ts_feat expect [B,T,D], got {tuple(ts_feat.shape)}"
        B, T, D = ts_feat.shape
        S = max(1, int(bins))
 # （ ）
        boundaries: List[Tuple[int,int]] = []
        for i in range(S):
            st = int((i * T) / S)
            ed = int(((i + 1) * T) / S)
            boundaries.append((st, ed))

        tokens = []
        segs = []
        max_len = 1
        for (st, ed) in boundaries:
            if ed <= st:
 # ： 0
                tok = ts_feat.new_zeros(B, D)
                seg = ts_feat.new_zeros(B, 1, D)
                cur_len = 1
            else:
                seg = ts_feat[:, st:ed, :]               # [B, Li, D]
                cur_len = seg.size(1)
                tok = seg.mean(dim=1)                    # [B, D]
            tokens.append(tok)
            segs.append(seg)
            max_len = max(max_len, cur_len)

 # pad
        seg_seq = ts_feat.new_zeros(B, S, max_len, D)
        for i, seg in enumerate(segs):
            Li = seg.size(1)
            if Li > 0:
                seg_seq[:, i, :Li, :] = seg
        seg_tokens = torch.stack(tokens, dim=1)          # [B, S, D]
        info = {"type": "uniform", "bins": S, "boundaries": boundaries}
        return seg_tokens, seg_seq, info

    def _fallback_fuse(self, seg_tokens: torch.Tensor) -> torch.Tensor:
        return seg_tokens.mean(dim=1)

    # ----------------- forward -----------------
    def forward(self,
                images: torch.Tensor,
                text_item: Optional[Dict[str, torch.Tensor]],
                packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                static_data: torch.Tensor,
                traj: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        assert traj is not None and ("hist" in traj) and ("hist_mask" in traj), "traj dict missing 'hist'/'hist_mask'."

 # 1)
        img_tok, txt_tok, ts_feat, static_vec = self.ff.extract_features(images, text_item, packed_ts, static_data)
 # ts_lengths
        if isinstance(packed_ts, dict) and ("seq_lengths" in packed_ts) and packed_ts["seq_lengths"].dim() >= 1:
            ts_lengths = packed_ts["seq_lengths"]
            if ts_lengths.dim() == 2:
                ts_lengths = ts_lengths.min(dim=1)[0]
        else:
            ts_lengths = torch.full((ts_feat.size(0),), ts_feat.size(1), device=ts_feat.device, dtype=torch.long)

 # 2) （ or ）
        if self.use_dynseg:
            seg_tokens, seg_seq, eamc_info = self.seg_mod(ts_feat, ts_lengths)  # [B,S',D], [B,S',L',D]
        else:
            seg_tokens, seg_seq, uni_info = self._uniform_segment(ts_feat, self.seg_bins)
            eamc_info = {"bypass": True, "uniform": uni_info}

 # 3) （ ）
        if self.use_multimodal:
            fused_global = self.ff.cross_modal_fusion(img_tok, txt_tok, seg_tokens, static_vec)  # [B, D]
        else:
            fused_global = self._fallback_fuse(seg_tokens)  # [B, D]

 # 4) / （ ）
        if self.use_causal:
            graph_nodes, graph_info = self.cg_mod(seg_seq, img_tok, txt_tok)  # [B,S',D], dict
            graph_pool = graph_nodes.mean(dim=1) if graph_nodes.numel() > 0 else fused_global.new_zeros(fused_global.shape)
        else:
            graph_nodes = seg_tokens # token “ ”
            graph_pool  = seg_tokens.mean(dim=1)         # [B,D]
            graph_info  = {"bypass": True, "nodes": graph_nodes.shape[1]}

 # 5)
        scene = torch.cat([fused_global, graph_pool], dim=-1)  # [B, 2D]
        B, D2 = scene.shape

 # 6) Agent
        hist      = traj["hist"].to(scene.device).float()        # [B,M,H,2]
        hist_mask = traj["hist_mask"].to(scene.device).bool()    # [B,M,H]
        lengths   = hist_mask.long().sum(dim=-1)                 # [B,M]
        hist_enc  = self._encode_agent_histories(hist, lengths)  # [B,M,Hdim]
        M = hist_enc.size(1)

 # 7) （ logits； softmax）
        scene_exp    = scene.unsqueeze(1).expand(B, M, D2)       # [B,M,2D]
        head_in      = torch.cat([scene_exp, hist_enc], dim=-1)  # [B,M,2D+Hdim]
        head_in_flat = head_in.reshape(B*M, -1)

        if self._traj_head_is_fallback:
            raw = self.traj_head(head_in_flat)  # [B*M, K*T*2 + K]
            coords = raw[:, : self.traj_K * self.traj_T * 2].reshape(B, M, self.traj_K, self.traj_T, 2)
            logits = raw[:, self.traj_K * self.traj_T * 2 :].reshape(B, M, self.traj_K)
            traj_mean, traj_conf = coords, logits
        else:
            traj_mean_flat, traj_conf_flat = self.traj_head(head_in_flat) # logits
            traj_mean = traj_mean_flat.view(B, M, self.traj_K, self.traj_T, 2)
            traj_conf = traj_conf_flat.view(B, M, self.traj_K) # logits（ softmax）
        
 # ： info， /
        eamc_info["use_dynseg"]     = bool(self.use_dynseg)
        eamc_info["use_nbc"]        = bool(self.use_nbc)
        graph_info["use_causal"]    = bool(self.use_causal)
        eamc_info["seg_bins_if_uniform"] = int(self.seg_bins)

        out = {
            "traj_mean": traj_mean,          # [B,M,K,T,2]
            "traj_conf": traj_conf,          # [B,M,K] logits
            "eamc_info": eamc_info,          # dict
            "graph_info": graph_info,        # dict
        }
        return out
