import os
import json
import math
import numpy as np
from typing import Dict, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add


def build_edge_attr_from_nodes(x_nodes: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
    if edge_index.numel() == 0:
        return torch.zeros((0, 9), dtype=torch.float32, device=x_nodes.device)
    src, dst = edge_index[0], edge_index[1]
    F_needed = 9
    if x_nodes.size(1) < F_needed:
        pad = torch.zeros((x_nodes.size(0), F_needed - x_nodes.size(1)),
                          device=x_nodes.device, dtype=x_nodes.dtype)
        x_use = torch.cat([x_nodes, pad], dim=1)
    else:
        x_use = x_nodes[:, :F_needed]
    rel = x_use[src] - x_use[dst]
    return rel


def sample_to_data(sample: dict) -> Data:
    vk_vec=torch.from_numpy(sample['vk_feat']).float()
    future_traj=torch.from_numpy(sample['future_traj']).float()
    density_seq=torch.from_numpy(sample['trad_k_seq']).float()
    intent_target = torch.from_numpy(sample['hist_intent']).float()

    num_x_seq=torch.as_tensor(sample['num_x_seq'],dtype=torch.long)
    num_diff_seq=torch.as_tensor(sample['num_diff_seq'],dtype=torch.long)
    num_adv_seq=torch.as_tensor(sample['num_adv_seq'],dtype=torch.long)

    T=num_x_seq.numel()
    node_ptr=torch.zeros(T+1,dtype=torch.long)
    node_ptr[1:]=torch.cumsum(num_x_seq,dim=0)
    diff_ptr=torch.zeros(T+1,dtype=torch.long)
    diff_ptr[1:]=torch.cumsum(num_diff_seq,dim=0)
    adv_ptr=torch.zeros(T+1,dtype=torch.long)
    adv_ptr[1:]=torch.cumsum(num_adv_seq,dim=0)

    x_list=[torch.from_numpy(arr).float() for arr in sample['x_nodes_seq']]
    x_all=torch.cat(x_list,dim=0) if len(x_list)>0 else torch.zeros((0,10),dtype=torch.float32)

    ei_diff_list,ei_adv_seq_list=[],[]
    for t in range(T):
        node_off=node_ptr[t]
        ei_d_np=sample['edge_index_diff_seq'][t]
        if ei_d_np.size:
            ei_d=torch.from_numpy(ei_d_np).long()
            ei_d=ei_d+node_off
            ei_diff_list.append(ei_d)

        ei_a_np=sample['edge_index_adv_seq'][t]
        if ei_a_np.size:
            ei_a=torch.from_numpy(ei_a_np).long()
            ei_a=ei_a+node_off
            ei_adv_seq_list.append(ei_a)

    if len(ei_diff_list)>0:
        ei_diff=torch.cat(ei_diff_list,dim=1)
    else:
        ei_diff=torch.zeros((2,0),dtype=torch.long)

    if len(ei_adv_seq_list)>0:
        ei_adv=torch.cat(ei_adv_seq_list,dim=1)
    else:
        ei_adv=torch.zeros((2,0),dtype=torch.long)

    eattr_diff=build_edge_attr_from_nodes(x_all,ei_diff)
    eattr_adv=build_edge_attr_from_nodes(x_all,ei_adv)

    ego_index_seq=node_ptr[:-1].clone()

    data=Data(
        x_nodes=x_all,
        edge_index_diff=ei_diff,edge_attr_diff=eattr_diff,
        edge_index_adv=ei_adv,edge_attr_adv=eattr_adv,

        vk_feat=vk_vec,
        future_traj=future_traj,
        density_seq=density_seq,
        hist_intent=intent_target,

        T=torch.tensor(T,dtype=torch.long),
        num_x_seq=num_x_seq,
        num_diff_seq=num_diff_seq,
        num_adv_seq=num_adv_seq,
        node_ptr=node_ptr,
        diff_ptr=diff_ptr,
        adv_ptr=adv_ptr,
        ego_index_seq=ego_index_seq,
    )

    return data


class GKVKDataset(torch.utils.data.Dataset):
    def __init__(self, pt_path: str):
        super().__init__()
        self.raw = torch.load(pt_path)
        self.items = [sample_to_data(s) for s in self.raw]
        self.vk_dim = 10
        self.gk_dim = 10

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        return self.items[idx]


##################################
# VKEncoder
class VKEncoder(nn.Module):
    def __init__(self,in_dim:int,hidden:int=128,out:int=128,dropout:float=0.1):
        super().__init__()
        self.phi=nn.Sequential(
            nn.Linear(in_dim,2*hidden),
            nn.LayerNorm(2*hidden),
            nn.GELU(),
            nn.Dropout(dropout),

            nn.Linear(2*hidden,hidden),
            nn.LayerNorm(hidden),
            nn.GELU(),
            nn.Dropout(dropout),

            nn.Linear(hidden,out),
        )

    def forward(self,vk_seq:torch.Tensor):
        return self.phi(vk_seq)

# GK architecture
class EdgeFeatureNorm(nn.Module):
    def __init__(self, dim=9, eps=1e-6, trainable=False, keep_last=True):
        super().__init__()
        self.eps = eps
        self.keep_last = keep_last
        self.register_buffer('mean', torch.zeros(dim))
        self.register_buffer('std', torch.ones(dim))
        if trainable:
            self.mean = nn.Parameter(self.mean)
            self.std = nn.Parameter(self.std)

    def forward(self, edge_attr: torch.Tensor):
        if edge_attr.numel() == 0:
            return edge_attr
        x = edge_attr
        main, last = x[..., :-1], x[..., -1:]
        mean = main.mean(dim=0, keepdim=True)
        var = (main - mean).pow(2).mean(dim=0, keepdim=True)
        std = var.sqrt()
        std = torch.clamp(std, min=1e-2)
        norm = (main - mean) / std
        norm = torch.nan_to_num(norm, nan=0.0, posinf=0.0, neginf=0.0)
        return torch.cat([norm, last], dim=-1)

# Diffusion term
class DiffusionConvLinear(MessagePassing):
    def __init__(self, in_node: int, edge_dim: int, out: int):
        super().__init__()
        self.lin_node = nn.Linear(in_node, out, bias=False)
        self.edge_norm = EdgeFeatureNorm(dim=edge_dim, trainable=False, keep_last=True)
        self.alpha = nn.Parameter(torch.zeros(edge_dim))
        self.bias = nn.Parameter(torch.tensor(0.0))
        self.softplus = nn.Softplus()

        with torch.no_grad():
            init = torch.zeros(edge_dim)
            init[4] = 0.8
            init[5] = -0.4
            init[6] = 0.6
            init[7] = 0.3
            init[8] = 0.5
            self.alpha.fill_(0.2)
            self.alpha[-1] = 0.4
            self.bias.fill_(-0.2)

    def forward(self, x, edge_index, edge_attr, batch=None):

        if edge_index.numel() == 0:
            zeros = torch.zeros(x.size(0), self.lin_node.out_features, device=x.device)
            return zeros, torch.zeros((0,), device=x.device), torch.zeros((0, edge_attr.size(1)), device=x.device)
        xh = self.lin_node(x)
        phi = self.edge_norm(edge_attr)
        contrib = phi * self.alpha[None, :]
        logits = contrib.sum(dim=-1) + self.bias
        w = self.softplus(logits)
        out = self.propagate(edge_index, x=xh, w=w)
        return out, w, contrib

    def message(self, x_j, w):
        return w.unsqueeze(-1) * x_j

class AdvectionLineGraphConv(MessagePassing):
    def __init__(self, in_node: int, phi_dim: int, out: int):
        super().__init__(aggr="add")
        self.lin_node = nn.Linear(in_node, out, bias=False)
        self.edge_norm = EdgeFeatureNorm(dim=phi_dim, trainable=False, keep_last=True)
        self.pair_head = EdgePairFeature(phi_dim=phi_dim, out_dim=1)

        with torch.no_grad():
            nn.init.constant_(self.pair_head.beta.weight, 0.1)
            if hasattr(self.pair_head.beta, "bias") and self.pair_head.beta.bias is not None:
                nn.init.constant_(self.pair_head.beta.bias, 0.0)

    def forward(self, x, edge_index_adv, edge_attr_adv, batch=None):
        N = x.size(0)
        H = self.lin_node.out_features
        xh = self.lin_node(x)

        if edge_index_adv is None or edge_index_adv.numel() == 0:
            zero_feat_dim = edge_attr_adv.size(1) if (edge_attr_adv is not None and edge_attr_adv.numel() > 0) else 1
            return (torch.zeros((N, H), dtype=x.dtype, device=x.device),
                    torch.zeros((0,), dtype=x.dtype, device=x.device),
                    torch.zeros((0, zero_feat_dim), dtype=x.dtype, device=x.device))

        phi_e = self.edge_norm(edge_attr_adv) if (edge_attr_adv is not None and edge_attr_adv.numel() > 0) else edge_attr_adv
        src_idx, dst_idx = edge_index_adv[0], edge_index_adv[1]
        s_raw, contrib = self.pair_head(phi_e, src_idx, dst_idx)
        w = torch.tanh(s_raw)

        out = self.propagate(edge_index_adv, x=xh, w=w)
        return out, w, contrib

    def message(self, x_j, w):
        return w.unsqueeze(-1) * x_j


class EdgePairFeature(nn.Module):
    def __init__(self, phi_dim=8, out_dim=1):
        super().__init__()
        self.use_absdiff = True
        in_dim = phi_dim if self.use_absdiff else (2 * phi_dim)
        self.beta = nn.Linear(in_dim, out_dim, bias=True)

    def forward(self, phi_e: torch.Tensor, src_idx: torch.Tensor, dst_idx: torch.Tensor):
        if phi_e.size(0) == 0 or src_idx.numel() == 0:
            return (torch.zeros((0,), dtype=phi_e.dtype, device=phi_e.device),
                    torch.zeros((0, phi_e.size(1)), dtype=phi_e.dtype, device=phi_e.device))
        fe = phi_e[dst_idx]
        fep = phi_e[src_idx]
        if self.use_absdiff:
            psi = (fe - fep).abs()
            W = self.beta.weight.squeeze(0)
            contrib = psi * W[None, :]
            s = (contrib.sum(dim=-1) + self.beta.bias)
        else:
            psi = torch.cat([fe, fep], dim=-1)
            s = self.beta(psi).squeeze(-1)
            contrib = torch.zeros((psi.size(0), psi.size(1)), dtype=psi.dtype, device=psi.device)
        return s, contrib

class GKBlock(nn.Module):
    def __init__(self, in_node_dim: int, edge_phi_dim: int, hidden: int = 64, out_dim: int = 64, dropout: float = 0.1):
        super().__init__()
        self.diff = DiffusionConvLinear(in_node=in_node_dim, edge_dim=edge_phi_dim, out=hidden)
        self.adv = AdvectionLineGraphConv(in_node=in_node_dim, phi_dim=edge_phi_dim, out=hidden)

        self.fuse = nn.Sequential(
            nn.Linear(2 * hidden, 2 * hidden),
            nn.LayerNorm(2 * hidden),
            nn.Tanh(),
            nn.Dropout(dropout),

            nn.Linear(2 * hidden, hidden),
            nn.LayerNorm(hidden),
            nn.Tanh(),

            nn.Linear(hidden, out_dim)
        )

    def forward(self, x, edge_index_diff, edge_attr_diff, edge_index_adv, edge_attr_adv, batch=None):
        x_diff, w_diff, contrib_diff = self.diff(x, edge_index_diff, edge_attr_diff, batch)
        x_adv, w_adv, contrib_adv = self.adv(x, edge_index_adv, edge_attr_adv, batch)
        h = torch.cat([x_diff, x_adv], dim=-1)
        gk_out = self.fuse(h)
        aux = {
            "w_diff": w_diff,
            "contrib_diff": contrib_diff,
            "w_adv": w_adv,
            "contrib_adv": contrib_adv
        }
        return gk_out, aux

class GKEncoder(nn.Module):
    def __init__(self,in_node_dim:int,edge_phi_dim:int,gk_hidden:int,dropout:float=0.1,pool:str='mean'):
        super().__init__()
        assert pool in ['mean', 'ego']
        self.pool = pool
        self.gk=GKBlock(
            in_node_dim=in_node_dim,
            edge_phi_dim=edge_phi_dim,
            hidden=gk_hidden,
            out_dim=gk_hidden,
            dropout=dropout
        )

    @staticmethod
    def _pool_frame(X_t:torch.Tensor,method:str,gk_hidden:int):
        if method == 'mean':
            return X_t.mean(dim=0)
        return X_t[0]

    def forward(self,
                x_nodes:torch.Tensor,
                edge_index_diff:torch.Tensor,
                edge_attr_diff:torch.Tensor,
                edge_index_adv:torch.Tensor,
                edge_attr_adv:torch.Tensor,
                batch:torch.Tensor,
                num_x_seq:torch.Tensor,
                B:int,
                T:int):
        Z_all,aux=self.gk(x_nodes, edge_index_diff, edge_attr_diff, edge_index_adv, edge_attr_adv, batch)

        gk_hidden=Z_all.size(-1)
        device=Z_all.device

        G_tokens_per_batch=[]
        time_cursor=0
        for b in range(B):
            mask_b=(batch==b)
            Z_b=Z_all[mask_b]

            num_x_seq_b=num_x_seq[time_cursor:time_cursor+T].to(torch.long)
            time_cursor+=T

            node_ptr_b=torch.zeros(T+1,dtype=torch.long,device=device)
            if num_x_seq_b.numel()>0:
                node_ptr_b[1:]=torch.cumsum(num_x_seq_b,dim=0)

            tokens_b=[]
            for t in range(T):
                start=node_ptr_b[t].item()
                end=node_ptr_b[t+1].item()
                X_t=Z_b[start:end]
                tok_t=self._pool_frame(X_t,self.pool,gk_hidden)
                tokens_b.append(tok_t)
            G_tokens_per_batch.append(torch.stack(tokens_b,dim=0))

        G_tokens_seq=torch.stack(G_tokens_per_batch,dim=0)
        return G_tokens_seq,aux

class TokenProjector(nn.Module):
    def __init__(self, d_gk: int, d_vk: int, d_model: int):
        super().__init__()
        self.proj_gk = nn.Linear(d_gk, d_model, bias=False)
        self.proj_vk = nn.Linear(d_vk, d_model, bias=False)

    def forward(self, vk_out: torch.Tensor, gk_out: torch.Tensor):
        return self.proj_vk(vk_out), self.proj_gk(gk_out)

class CrossAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int = 4, p_drop: float = 0.1):
        super().__init__()
        self.attn_vk_q_gk = nn.MultiheadAttention(d_model, n_heads, dropout=p_drop, batch_first=True)
        self.attn_gk_q_vk = nn.MultiheadAttention(d_model, n_heads, dropout=p_drop, batch_first=True)
        self.norm_v1 = nn.LayerNorm(d_model)
        self.norm_g1 = nn.LayerNorm(d_model)

    def forward(self, G_list, v_list):
        v_ctx_list = []
        G_ctx_list = []
        attn_list = []

        for G, v in zip(G_list, v_list):
            v_q = v.unsqueeze(0) if v.dim() == 2 else v
            G_kv = G.unsqueeze(0)
            v_out, w_v = self.attn_vk_q_gk(v_q, G_kv, G_kv, need_weights=True, average_attn_weights=False)
            v_ctx = self.norm_v1(v + v_out.squeeze(0))

            G_q = G.unsqueeze(0)
            v_kv = v_ctx.unsqueeze(0)
            G_out, w_g = self.attn_gk_q_vk(G_q, v_kv, v_kv, need_weights=True, average_attn_weights=False)
            G_ctx = self.norm_g1(G + G_out.squeeze(0))

            v_ctx_list.append(v_ctx.squeeze(0))
            G_ctx_list.append(G_ctx)
            attn_list.append({
                "w_vk_q_gk": w_v.squeeze(0),
                "w_gk_q_vk": w_g.squeeze(0),
            })

        return v_ctx_list, G_ctx_list, attn_list

class ControlHead(nn.Module):
    def __init__(self, d_model: int, d_control: int, pool: str = "mean"):
        super().__init__()
        self.pool = pool
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model), nn.Tanh(),
            nn.Linear(d_model, d_control),nn.Tanh(),
        )

    def forward(self, x):
        if isinstance(x, list):
            out = []
            for x_g in x:
                if x_g.dim() == 2 and x_g.size(0) == 1:
                    out.append(self.mlp(x_g.squeeze(0)))
                elif x_g.dim() == 2:
                    if self.pool == "mean":
                        g = x_g.mean(dim=0, keepdim=False)
                    elif self.pool == "ego":
                        g = x_g[0]
                    else:
                        g = x_g.mean(dim=0, keepdim=False)
                    out.append(self.mlp(g))
                elif x_g.dim() == 1:
                    out.append(self.mlp(x_g))
            return torch.stack(out) if len(out) > 0 else torch.empty(0, self.mlp[-1].out_features)
        else:
            if x.ndim == 2 and x.size(0) == 1:
                return self.mlp(x.squeeze(0))
            elif x.ndim == 2:
                if self.pool == "mean":
                    g = x.mean(dim=0, keepdim=False)
                elif self.pool == "ego":
                    g = x[0]
                else:
                    g = x.mean(dim=0, keepdim=False)
                return self.mlp(g)
            else:
                raise ValueError("Unexpected shape for control head input.")


class SinusoidalPE(nn.Module):
    def __init__(self,d_model:int,max_len:int=8192):
        super().__init__()
        pe=torch.zeros(max_len,d_model)
        pos=torch.arange(0,max_len,dtype=torch.float32).unsqueeze(1)
        div=torch.exp(torch.arange(0,d_model,2).float()* (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor, pos_start: int = 0) -> torch.Tensor:
        # x: [B,T,D]
        T, D = x.size(1), x.size(2)
        return x + self.pe[pos_start:pos_start + T, :D].unsqueeze(0)

def _causal_mask(T: int, device) -> torch.Tensor:
    m = torch.full((T, T), float('-inf'), device=device)
    return torch.triu(m, diagonal=1)

def _inv_sigmoid(y: torch.Tensor, eps=1e-6):
    y = y.clamp(eps, 1 - eps)
    return torch.log(y / (1 - y))

def _inv_tanh(y: torch.Tensor, eps=1e-6):
    y = y.clamp(-1 + eps, 1 - eps)
    return 0.5 * torch.log((1 + y) / (1 - y))

class IntentDiscriminator(nn.Module):
    def __init__(self,feat_dim:int,u_dim:int,n_modes:int,tf_dim:int=128,n_heads:int=8,n_layers:int=2,dim_ff:int=256,dropout:float=0.1,window:int=25,temperature:float=0.5):
        super().__init__()
        self.feat_dim = feat_dim
        self.u_dim = u_dim
        self.n_modes = n_modes
        self.tf_dim = tf_dim
        self.window = window
        self.temperature=temperature

        self.embed_raw=nn.Sequential(
            nn.Linear(feat_dim,tf_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(tf_dim,tf_dim),
        )
        self.embed_u=nn.Sequential(
            nn.Linear(u_dim,tf_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(tf_dim,tf_dim),
        )
        enc_layer=nn.TransformerEncoderLayer(d_model=tf_dim,nhead=n_heads,dim_feedforward=dim_ff,dropout=dropout,batch_first=True,norm_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer,num_layers=n_layers)
        self.pe=SinusoidalPE(d_model=tf_dim,max_len=8192)
        self.cls=nn.Sequential(
            nn.Linear(2*tf_dim,tf_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(tf_dim,n_modes),
        )

    def init_state(self, B: int, device):
        return {"raw_tokens": torch.zeros(B, 0, self.tf_dim, device=device),
                "abs_pos": torch.zeros((), dtype=torch.long, device=device),
                "time_since_switch": torch.zeros(B, device=device)}

    def step(self, raw_vk_t: torch.Tensor | None, u_t: torch.Tensor,
             state: dict, dt: float):
        if state is None:
            tok=self.embed_raw(raw_vk_t).unsqueeze(1)
            h=self.pe(tok,pos_start=0)
            H=self.encoder(h)
            h_last=H[:,-1,:]
            u_emb=self.embed_u(u_t)
            logits=self.cls(torch.cat([h_last,u_emb],-1))
            probs=F.softmax(logits/self.temperature,dim=-1)
            return probs,h_last,None,logits
        else:
            B = u_t.size(0)

            if raw_vk_t is not None:
                tok = self.embed_raw(raw_vk_t).unsqueeze(1)  # [B,1,tf_dim]
                tokens = torch.cat([state["raw_tokens"], tok], dim=1)
                if tokens.size(1) > self.window:
                    tokens = tokens[:, -self.window:, :]
                state["raw_tokens"] = tokens.detach()
                state["abs_pos"] = torch.tensor(int(state["abs_pos"].item()) + 1,
                                                device=tokens.device, dtype=torch.long)
            else:
                tokens = state["raw_tokens"]

            if tokens.numel() == 0:

                tokens = torch.zeros(B, 1, self.tf_dim, device=u_t.device)
            pos_start = max(0, int(state["abs_pos"].item()) - tokens.size(1))
            h = self.pe(tokens, pos_start=pos_start)  # [B,S,tf]
            S = h.size(1)
            mask = _causal_mask(S, h.device)
            H = self.encoder(h, mask=mask)  # [B,S,tf]
            h_last = H[:, -1, :]  # [B,tf]

            u_emb = self.embed_u(u_t)  # [B,tf]
            logits = self.cls(torch.cat([h_last, u_emb], dim=-1))  # [B,n_modes]
            probs = F.softmax(logits / self.temperature, dim=-1)

            state["time_since_switch"] = state["time_since_switch"] + dt
        return probs, h_last, state

class VKEvolverBlock(nn.Module):
    def __init__(self,d_model:int,d_control:int,rho_max:float=0.90,B_max:float=0.80):
        super().__init__()
        self.Nc=d_model//2
        self.Nr=d_model-2*self.Nc
        self.d_model=d_model
        self.d_control=d_control
        self.rho_max=rho_max
        self.B_max=B_max

        if self.Nc>0:
            self._rad_logits=nn.Parameter(torch.zeros(self.Nc))
            self._theta=nn.Parameter(torch.zeros(self.Nc))
        else:
            self.register_parameter('_rad_logits',None)
            self.register_parameter('_theta',None)

        if self.Nr>0:
            self._real_logits=nn.Parameter(torch.zeros(self.Nr))
        else:
            self.register_parameter('_real_logits',None)

        self.B=nn.Linear(d_control,d_model,bias=False)
        self.B_gate=nn.Parameter(torch.zeros(d_model))

        # self._idx2=torch.arange(0,2*self.Nc).reshape(self.Nc,2) if self.Nc>0 else None
        # self._idx1=torch.arange(2*self.Nc,2*self.Nc+self.Nr) if self.Nr>0 else None

    def _build_K(self, device):
        D = self.d_model
        K = torch.zeros(D, D, device=device)

        # complex 2x2 blocks
        if self.Nc > 0:
            R = self.rho_max * torch.sigmoid(self._rad_logits)   # [Nc] in (0, rho_max)
            c = torch.cos(self._theta); s = torch.sin(self._theta)
            Kc = torch.zeros(self.Nc, 2, 2, device=device)
            Kc[:, 0, 0] = R * c
            Kc[:, 0, 1] = -R * s
            Kc[:, 1, 0] = R * s
            Kc[:, 1, 1] = R * c
            for k in range(self.Nc):
                i = 2 * k
                K[i:i+2, i:i+2] = Kc[k]

        # real 1x1 blocks
        if self.Nr > 0:
            Sr = self.rho_max * torch.tanh(self._real_logits)     # [Nr] in (-rho_max, rho_max)
            for j in range(self.Nr):
                i = 2 * self.Nc + j
                K[i, i] = Sr[j]
        return K

    def forward(self, z_t: torch.Tensor, u_t: torch.Tensor):
        B, D = z_t.shape
        device = z_t.device

        K = self._build_K(device)
        z_lin = z_t @ K.T

        Bu = self.B(u_t)
        gate = torch.tanh(self.B_gate).unsqueeze(0)
        Bu = Bu * gate * self.B_max

        z_next = z_lin + Bu

        with torch.no_grad():
            vals = []
            if self.Nc > 0:
                R = self.rho_max * torch.sigmoid(self._rad_logits)
                vals.append(R)
            if self.Nr > 0:
                Sr = self.rho_max * torch.tanh(self._real_logits).abs()
                vals.append(Sr)
            if len(vals) > 0:
                rho_est = torch.max(torch.cat(vals))
            else:
                rho_est = torch.tensor(0., device=device)
            K_diag_abs_max = K.diag().abs().max()

        aux = {
            "rho_max": torch.tensor(self.rho_max, device=device),
            "rho_est": rho_est.repeat(B),               # [B]
            "K_diag_abs_max": K_diag_abs_max.repeat(B)  # [B]
        }
        return z_next, aux

DEFAULT_MODE_NAMES = ['free_flow','car_following','lane_changing','merging','emergency']

DEFAULT_MODE_PRIORS = {
    'free_flow':    dict(rho_max=0.95, B_max=0.20, target_R=0.90, theta_std=0.01,  target_Sr=0.90),
    'car_following':dict(rho_max=0.95, B_max=0.20, target_R=0.90, theta_std=0.01,  target_Sr=0.90),
    'lane_changing':dict(rho_max=0.95, B_max=0.20, target_R=0.90, theta_std=0.01,  target_Sr=0.90),
    'merging':      dict(rho_max=0.95, B_max=0.20, target_R=0.90, theta_std=0.01,  target_Sr=0.90),
    'emergency':    dict(rho_max=0.95, B_max=0.20, target_R=0.90, theta_std=0.01,  target_Sr=0.90),
}

class VKEvolver(nn.Module):
    def __init__(self,n_modes:int,d_model:int,d_control:int,rho_max:float=0.90,B_max:float=0.80):
        super().__init__()
        self.n_modes=n_modes
        self.experts=nn.ModuleList([
            VKEvolverBlock(d_model=d_model,d_control=d_control,rho_max=rho_max,B_max=B_max)
        for _ in range(n_modes)])

    def forward(self,z_t:torch.Tensor,u_t:torch.Tensor,mode_probs:torch.Tensor,eps_greedy:float=0.05):
        outs, rho_list, kdiag_list = [], [], []
        for k in range(self.n_modes):
            z_k, aux_k = self.experts[k](z_t, u_t)  # [B,D]
            outs.append(z_k)
            rho_list.append(aux_k['rho_est'].unsqueeze(-1))  # [B,1]
            kdiag_list.append(aux_k['K_diag_abs_max'].unsqueeze(-1))  # [B,1]

        z_stack = torch.stack(outs, dim=-1)  # [B,D,M]
        rho_stack = torch.cat(rho_list, dim=-1)  # [B,M]
        kdiag_stack = torch.cat(kdiag_list, dim=-1)  # [B,M]

        B, M = mode_probs.shape
        if self.training and eps_greedy > 0:
            rand_idx = torch.randint(0, M, (B,), device=z_t.device)
            take_rand = (torch.rand(B, device=z_t.device) < eps_greedy)
            top1 = mode_probs.argmax(dim=-1)
            sel = torch.where(take_rand, rand_idx, top1)
        else:
            sel = mode_probs.argmax(dim=-1)
        one_hot = F.one_hot(sel, M).float()  # [B,M]

        w = (one_hot + (mode_probs - mode_probs.detach()))  # [B,M]
        w = w.unsqueeze(1)  # [B,1,M]

        z_next = (z_stack * w).sum(dim=-1)  # [B,D]
        aux = {
            'rho_mm': (rho_stack * mode_probs).sum(dim=-1),  # [B]
            'Kdiag_mm': (kdiag_stack * mode_probs).sum(dim=-1),  # [B]
        }
        return z_next, aux

class ModeAwareVKEvolver(VKEvolver):
    def __init__(self, n_modes:int, d_model:int, d_control:int,
                 mode_names=None, mode_priors=None,
                 rho_max:float=0.90, B_max:float=0.80):
        super().__init__(n_modes=n_modes, d_model=d_model, d_control=d_control,
                         rho_max=rho_max, B_max=B_max)
        self.mode_names = mode_names or DEFAULT_MODE_NAMES[:n_modes]
        priors = mode_priors or DEFAULT_MODE_PRIORS

        for k, name in enumerate(self.mode_names):
            prior = priors.get(name, {})
            self._apply_prior_to_expert(k, prior)

    @torch.no_grad()
    def _apply_prior_to_expert(self, k:int, prior:dict):
        blk: VKEvolverBlock = self.experts[k]

        if 'rho_max' in prior: blk.rho_max = float(prior['rho_max'])
        if 'B_max'  in prior: blk.B_max  = float(prior['B_max'])

        Nc, Nr = blk.Nc, blk.Nr
        device = blk.B.weight.device

        if Nc > 0:
            target_R = float(prior.get('target_R', 0.85))
            theta_mean = float(prior.get('theta_mean', 0.0))
            theta_std  = float(prior.get('theta_std',  0.02))
            rot_blocks = int(prior.get('rot_blocks',  min(2, Nc)))

            R_frac = max(1e-3, min(0.999, target_R / blk.rho_max))
            rad_logits = _inv_sigmoid(torch.tensor(R_frac, device=device)).expand(Nc).clone()
            blk._rad_logits.copy_(rad_logits)

            theta = torch.randn(Nc, device=device) * theta_std + 0.0
            if rot_blocks > 0:
                theta[:rot_blocks] = torch.randn(rot_blocks, device=device) * theta_std + theta_mean
            blk._theta.copy_(theta)

        if Nr > 0:
            target_Sr = float(prior.get('target_Sr', 0.85))
            Sr_frac = max(1e-3, min(0.999, target_Sr / blk.rho_max))
            real_logits = _inv_tanh(torch.tensor(Sr_frac, device=device)).expand(Nr).clone()
            blk._real_logits.copy_(real_logits)

        gate_level = {
            'free_flow':   -0.1,
            'car_following': 0.0,
            'lane_changing':  0.1,
            'merging':       0.25,
            'emergency':    -0.15,
        }.get(self.mode_names[k], 0.0)
        blk.B_gate.copy_(torch.full_like(blk.B_gate, gate_level))

class GKEvolverBlock(nn.Module):
    def __init__(self, d_model:int, rho_max:float=0.95):
        super().__init__()
        assert d_model % 2 == 0, "d_model is even number"
        self.d_model = d_model

        self.Nc = d_model // 2
        self.Nr = d_model - 2 * self.Nc
        self.rho_max = float(rho_max)

        if self.Nc > 0:
            self._rad_logits = nn.Parameter(torch.zeros(self.Nc))
            self._theta      = nn.Parameter(torch.zeros(self.Nc))
        else:
            self.register_parameter("_rad_logits", None)
            self.register_parameter("_theta", None)

        if self.Nr > 0:
            self._real_logits = nn.Parameter(torch.zeros(self.Nr))
        else:
            self.register_parameter("_real_logits", None)

    def _build_K(self, device):
        D = self.d_model
        K = torch.zeros(D, D, device=device)

        # complex 2x2 BLOCk
        if self.Nc > 0:
            R = self.rho_max * torch.sigmoid(self._rad_logits)  # (0, rho_max)
            c = torch.cos(self._theta); s = torch.sin(self._theta)
            for k in range(self.Nc):
                i = 2 * k
                K[i,   i  ] = R[k] * c[k]
                K[i,   i+1] = -R[k] * s[k]
                K[i+1, i  ] = R[k] * s[k]
                K[i+1, i+1] = R[k] * c[k]

        # real 1x1 block
        if self.Nr > 0:
            Sr = self.rho_max * torch.tanh(self._real_logits)
            for j in range(self.Nr):
                idx = 2 * self.Nc + j
                K[idx, idx] = Sr[j]
        return K

    def forward(self, Z_t: torch.Tensor):
        if Z_t.dim() == 1:
            Z_t = Z_t.unsqueeze(0)     # [1,D]
        B, D = Z_t.shape
        device = Z_t.device

        K = self._build_K(device)      # [D,D]
        Z_next = Z_t @ K.T             # [B,D]

        with torch.no_grad():
            vals = []
            if self.Nc > 0:
                R = self.rho_max * torch.sigmoid(self._rad_logits)     # [Nc]
                vals.append(R)
            if self.Nr > 0:
                Sr = self.rho_max * torch.tanh(self._real_logits).abs()# [Nr]
                vals.append(Sr)
            rho_est = torch.max(torch.cat(vals)) if len(vals) else torch.tensor(0., device=device)
            K_diag_abs_max = K.diag().abs().max()

        aux = {
            "rho_est": rho_est.repeat(B),                 # [B]
            "K_diag_abs_max": K_diag_abs_max.repeat(B),   # [B]
            "K": K.unsqueeze(0).expand(B, D, D),
        }
        return Z_next, aux


class GKEvolver(nn.Module):
    def __init__(self, d_model:int, rho_max:float=0.95, **_):
        super().__init__()
        self.block = GKEvolverBlock(d_model=d_model, rho_max=rho_max)

    def forward(self, Z_in: torch.Tensor):
        return self.block(Z_in)

# Decoder
class ResidualDecoder(nn.Module):
    def __init__(self, d_model: int, hidden: int = 256, out_dim: int = 2, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc_out = nn.Linear(hidden, out_dim)

        self.norm1 = nn.LayerNorm(hidden)
        self.norm2 = nn.LayerNorm(hidden)
        self.dropout = nn.Dropout(dropout)

        self.res_proj = nn.Linear(d_model, hidden) if d_model != hidden else nn.Identity()

        nn.init.zeros_(self.fc_out.weight)
        nn.init.zeros_(self.fc_out.bias)

    def forward(self, x: torch.Tensor):
        if isinstance(x, list):
            return torch.stack([self._forward_single(xi) for xi in x])
        if x.dim()==1:
            return self._forward_single(x)
        if x.dim()==2:
            return self._forward_single(x)
        if x.dim()==3:
            B,T,D=x.shape
            y=self._forward_single(x.view(B*T,D))
            return y.view(B,T,-1)
        raise ValueError(f'error decoder input shape: {tuple(x.shape)}')

    def _forward_single(self, x):
        res = self.res_proj(x)
        x = F.relu(self.norm1(self.fc1(x)))
        x = self.dropout(x)
        x = F.tanh(self.norm2(self.fc2(x) + res))
        x = self.dropout(x)
        x = F.tanh(self.fc3(x))
        return self.fc_out(x)


class TemporalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 125):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x, timestep: int):
        if isinstance(x, list):
            return [xi + self.pe[timestep] for xi in x]
        if x.dim() == 1:
            return x + self.pe[timestep]
        else:
            return x + self.pe[timestep].unsqueeze(0)

class GraphKoopmanModel(nn.Module):
    def __init__(self,cfg:Dict,
                 gk_node_dim:int,gk_edge_dim:int,vk_dim:int,
                 gk_hidden=128,vk_hidden=128,d_model=128,
                 d_control_vk=128,
                 out_dim_gk=1,out_dim_vk=2,n_modes=5,
                 n_heads=8,stride=25):
        super().__init__()

        self.history_steps=int(cfg.get('history_steps',75))+1
        self.future_steps=int(cfg.get('future_steps',125))
        self.stride=stride
        self.dt_hist=float(cfg.get('dt',0.1))
        #self.dt_future=1.0
        self.dt_future=0.2
        #self.fut_iter_steps=int(self.future_steps/self.stride)
        # self.dt_future = 0.4
        self.fut_iter_steps = 12
        #print('fut_iter_steps:',self.fut_iter_steps)


        self.phys_params=nn.ParameterDict(
            {
                'c':nn.Parameter(torch.ones(1)),
                'nu':nn.Parameter(torch.ones(1)),
            }
        )

        self.gk_encoder=GKEncoder(in_node_dim=gk_node_dim,edge_phi_dim=gk_edge_dim,gk_hidden=gk_hidden,pool='ego')
        self.vk_encoder=VKEncoder(in_dim=vk_dim,hidden=vk_hidden,out=vk_hidden)

        self.proj=TokenProjector(d_gk=gk_hidden,d_vk=vk_hidden,d_model=d_model)

        self.intent=IntentDiscriminator(feat_dim=vk_dim,u_dim=d_control_vk,n_modes=n_modes,tf_dim=d_model,n_heads=n_heads,n_layers=2,dim_ff=2*d_model,dropout=0.1,window=32,temperature=0.5)

        self.xattn=CrossAttentionBlock(d_model=d_model,n_heads=n_heads,p_drop=0.1)
        self.ctrl_vk=ControlHead(d_model=d_model,d_control=d_control_vk,pool='ego')

        #self.evo_gk=GKEvolver(d_model=d_model,rank=8,use_diag_gate=True,K_base_init='eye',K_max=0.90,pi_iters=3,softness=6.0)
        #self.evo_vk=VKEvolver(n_modes=n_modes,d_model=d_model,d_control=d_control_vk,rho_max=0.90,B_max=0.0)
        self.evo_vk=ModeAwareVKEvolver(n_modes=n_modes,d_model=d_model,d_control=d_control_vk,mode_names=['free_flow','car_following','lane_changing','merging','emergency'])
        # self.evo_vk=KoopmanEvolver(d_model=d_model,d_control=d_control_vk,n_cplx=8,n_real=None,dt=0.04,b_max=0.6)
        self.evo_gk=GKEvolver(d_model=d_model,rho_max=0.90)

        self.dec_gk=ResidualDecoder(d_model=d_model,hidden=gk_hidden,out_dim=out_dim_gk,dropout=0.1)
        self.dec_vk=ResidualDecoder(d_model=d_model,hidden=vk_hidden,out_dim=out_dim_vk,dropout=0.1)

        self.temporal_pe_vk=TemporalPositionalEncoding(d_model=d_model,max_len=self.future_steps)
        self.temporal_pe_gk=TemporalPositionalEncoding(d_model=d_model,max_len=self.future_steps)

        self.pre_ln_vk=nn.LayerNorm(d_model,eps=1e-6)
        self.pre_ln_gk=nn.LayerNorm(d_model,eps=1e-6)

        self.intent_online_beta   = 0.30
        self.intent_switch_thresh = 0.55
        self.intent_min_stay_sec      = 1.0
        self.tbptt_steps          = 32

        #False:open-loop, True:close-loop
        self.latent_closed_loop   = True

    @staticmethod
    def _pack_lists_for_xattn(v_b,G_b):
        return [v_b.unsqueeze(0)], [G_b.unsqueeze(0)]


    def forward(self,data:Data):
        B=data.num_graphs
        T_h=self.history_steps
        device=data.x_nodes.device
        L=self.stride

        gk_obs,aux_gk=self.gk_encoder(
            x_nodes=data.x_nodes,
            edge_index_diff=data.edge_index_diff, edge_attr_diff=data.edge_attr_diff,
            edge_index_adv=data.edge_index_adv,   edge_attr_adv=data.edge_attr_adv,
            batch=data.batch,
            num_x_seq=data.num_x_seq,
            B=B, T=T_h
        )
        vk_raw=data.vk_feat.view(B,T_h,-1)
        vk_obs=self.vk_encoder(vk_raw)

        vk_obs,gk_obs=self.proj(vk_obs,gk_obs)#[B,T_h,d_model]

        state=self.intent.init_state(B,device)
        step_cnt=0
        tbptt=self.tbptt_steps

        #history latent (no decoder)
        z_hist_pred,Z_hist_pred=[],[]
        y_hist_vk,y_hist_gk=[],[]
        u_hist_seq,p_hist_seq=[],[]
        rho_hist,kdiag_hist=[],[]
        K_hist=[]
        top1_hist=[]
        for t in range(T_h-1):
            z_in=vk_obs[:,t,:]
            Z_in=gk_obs[:,t,:]

            # z_in=self.pre_ln_vk(z_t)
            # Z_in=self.pre_ln_gk(Z_t)

            z_pe=self.temporal_pe_vk(z_in,t)
            Z_pe=self.temporal_pe_gk(Z_in,t)

            v_list,g_list=[],[]
            for b in range(B):
                vl,gl=self._pack_lists_for_xattn(z_pe[b],Z_pe[b])
                v_list+=vl
                g_list+=gl
            v_ctx_list,_,_=self.xattn(g_list,v_list)
            u_t=self.ctrl_vk(v_ctx_list)
            u_hist_seq.append(u_t)

            p_t,_feat_t,state=self.intent.step(raw_vk_t=vk_raw[:,t,:],u_t=u_t,state=state,dt=self.dt_hist)
            p_hist_seq.append(p_t)
            top1 = p_t.argmax(dim=-1)
            top1_hist.append(top1)

            z_next,aux_vk_h=self.evo_vk(z_in,u_t,p_t,0.05)
            z_hist_pred.append(z_next)
            Z_next,aux_gk_h=self.evo_gk(Z_in)
            Z_hist_pred.append(Z_next)

            y_hist_vk.append(self.dec_vk(z_next))
            y_hist_gk.append(self.dec_gk(Z_next))

            rho_hist.append(aux_vk_h.get('rho_mm', torch.zeros(B, device=device)))
            kdiag_hist.append(aux_vk_h.get('Kdiag_mm', torch.zeros(B, device=device)))
            K_hist.append(aux_gk_h.get('K'))

            step_cnt+=1
            if tbptt and (step_cnt % tbptt == 0):
                state['raw_tokens']=state['raw_tokens'].detach()

        z_hist_pred=torch.stack(z_hist_pred,dim=1)
        #z_hist_raw=torch.stack(z_hist_raw,dim=1)
        Z_hist_pred=torch.stack(Z_hist_pred,dim=1)
        y_hist_vk=torch.stack(y_hist_vk,dim=1)
        #y_hist_gk=torch.stack(y_hist_gk,dim=1)
        u_hist_seq=torch.stack(u_hist_seq,dim=1)
        p_hist_seq=torch.stack(p_hist_seq,dim=1)
        rho_hist=torch.stack(rho_hist,dim=1)
        kdiag_hist=torch.stack(kdiag_hist,dim=1)
        K_hist=torch.stack(K_hist,dim=1)
        top1_hist = torch.stack(top1_hist, dim=1)

        idx_src_hist = torch.arange(0, (T_h) - L, device=device, dtype=torch.long)  # t
        idx_tgt_hist = idx_src_hist + L  # t+1s
        # print(idx_src_hist)
        # print(idx_tgt_hist)
        y_hist_delta_1s = y_hist_vk[:, idx_src_hist, :]  # [B,N_hist,2]
        z_hist_pred_1s = z_hist_pred[:, idx_src_hist, :]  # [B,N_hist,D]
        Z_hist_pred_1s = Z_hist_pred[:, idx_src_hist, :]  # [B,N_hist,D]
        z_hist_tgt_1s = vk_obs[:, idx_tgt_hist, :]  # [B,N_hist,D]
        Z_hist_tgt_1s = gk_obs[:, idx_tgt_hist, :]

        # z_hist_tgt=vk_obs[:,1:,:]
        # Z_hist_tgt=gk_obs[:,1:,:]

        #future
        z_c=self.pre_ln_vk(vk_obs[:,-1,:])
        Z_c=self.pre_ln_gk(gk_obs[:,-1,:])

        p_c=p_hist_seq[:,-1,:].detach() if p_hist_seq.numel() > 0 else torch.full((B, self.evo_vk.n_modes), 1.0 / self.evo_vk.n_modes, device=device)
        #last_switch=torch.zeros(B,dtype=torch.long,device=device)

        rho_fut,kdiag_fut=[],[]
        p_future_seq=[]
        vk_135=[]
        gk_135=[]
        K_fut=[]

        raw_vk_t=vk_raw[:,-1,:]
        state["time_since_switch"] = torch.zeros(B, device=device)

        y_vk_acc=torch.zeros(B,2,device=device)
        collect_at_steps = [int(1.0 / self.dt_future) - 1, int(3.0 / self.dt_future) - 1,
                            int(5.0 / self.dt_future) - 1]  # [24, 74, 124]
        collect_at_steps=[]
        # collect_at_steps = [int(0.8 / self.dt_future) - 1, int(1.6 / self.dt_future) - 1,
        #                     int(2.4 / self.dt_future) - 1,int(3.2 / self.dt_future) - 1,int(4 / self.dt_future) - 1,int(4.8 / self.dt_future) - 1]  # [24, 74, 124]
        #print('collect at steps:',collect_at_steps)

        #for j in range(1, 6):
        for j in range(self.fut_iter_steps):
            z_pe = self.temporal_pe_vk(z_c, j - 1)
            Z_pe = self.temporal_pe_gk(Z_c, j - 1)

            v_list, g_list = [], []
            for b in range(B):
                vl, gl = self._pack_lists_for_xattn(z_pe[b], Z_pe[b])
                v_list += vl
                g_list += gl
            v_ctx_list, _, _ = self.xattn(g_list, v_list)
            u_j = self.ctrl_vk(v_ctx_list)

            p_hat, _ff, state = self.intent.step(raw_vk_t=raw_vk_t, u_t=u_j, state=state, dt=self.dt_future)

            maxprob, _ = p_hat.max(dim=-1)
            can_switch = (maxprob >= self.intent_switch_thresh) & \
                         (state["time_since_switch"] >= self.intent_min_stay_sec)
            p_blend = (1.0 - self.intent_online_beta) * p_c + self.intent_online_beta * p_hat
            p_c = torch.where(can_switch.unsqueeze(-1), p_blend, p_c)
            state["time_since_switch"] = torch.where(
                can_switch, torch.zeros_like(state["time_since_switch"]), state["time_since_switch"]
            )
            p_future_seq.append(p_c)

            z_c, aux_vk_f = self.evo_vk(z_c, u_j, p_c, eps_greedy=0.0)
            Z_c, aux_gk_f = self.evo_gk(Z_c)
            delta_1s=self.dec_vk(z_c)
            y_vk_acc=y_vk_acc + delta_1s

            rho_fut.append(aux_vk_f.get('rho_mm', torch.zeros(B, device=device)))
            kdiag_fut.append(aux_vk_f.get('Kdiag_mm', torch.zeros(B, device=device)))
            K_fut.append(aux_gk_f.get('K'))

            if j in collect_at_steps:
                vk_135.append(y_vk_acc.clone())
                gk_135.append(self.dec_gk(Z_c))

        y_vk_135=torch.stack(vk_135,dim=1)
        y_gk_135=torch.stack(gk_135,dim=1)
        rho_fut=torch.stack(rho_fut,dim=1)
        kdiag_fut=torch.stack(kdiag_fut,dim=1)
        p_future_seq=torch.stack(p_future_seq,dim=1)
        K_fut=torch.stack(K_fut,dim=1)

        aux = {
            "intent": {
                "hist": p_hist_seq,  # [B,T_h-1,M]
                "future": p_future_seq,  # [B,5,M]
                "top1": top1_hist,
                "thresh": self.intent_switch_thresh,
                "beta": self.intent_online_beta,
                "minstay_sec": self.intent_min_stay_sec,
            },
            "recon": {
                "lag1s": {
                    "y_delta_seq": y_hist_delta_1s,  # [B,N_hist,2]  历史滑动 1s 的 Δxy 预测
                    "src_idx": idx_src_hist,  # [N_hist]
                    "tgt_idx": idx_tgt_hist,  # [N_hist]
                    "lag_steps": L  # (=25)
                },
                "latent_1s": {
                    "z_pred": z_hist_pred_1s, "z_tgt": z_hist_tgt_1s,  # [B,N_hist,D]
                    "Z_pred": Z_hist_pred_1s, "Z_tgt": Z_hist_tgt_1s,  # [B,N_hist,D]
                },
                "micro_out_hist": y_hist_vk,  # [B,T_h-1,2] —— 每细粒度 t 的 “+1s Δxy”
                #"macro_out_hist": y_hist_gk,  # [B,T_h-1,*]
                "u_hist": u_hist_seq,  # [B,T_h-1,U]
            },
            "stability": {
                "rho_hist": rho_hist, "kdiag_hist": kdiag_hist,
                "rho_fut": rho_fut, "kdiag_fut": kdiag_fut,
                "K_hist":K_hist,"K_fut":K_fut,
            },
            "gk_aux": aux_gk,
            "future_eval": {
                "vk_delta_135": y_vk_135,  # [B,3,2]
                "gk_out_135": y_gk_135,  # [B,3,*]
                "gt_indices_25hz": torch.tensor([24, 74, 124], device=device, dtype=torch.long)
            }
        }

        return (y_gk_135,y_vk_135),aux