from torchdiffeq import odeint
import torch
import torch.nn as nn
def inference(autoformer, cfm_model, history_data, future_news):
    """
    1. Autoformer 预测趋势
    2. OT-CFM 生成加速度 (Shock)
    3. 物理方程融合
    """
    
    # --- A. Autoformer 预测趋势 ---
    # h_trend: [1, Seq_Len, Dim]
    h_trend = autoformer.predict(history_data) 
    
    # --- B. OT-CFM 生成加速度 (Generating Force) ---
    # 定义 ODE 函数 wrapper
    class ODEFunc(nn.Module):
        def __init__(self, model, news):
            super().__init__()
            self.model = model
            self.news = news
        
        def forward(self, t, x):
            # t 是标量，需要广播
            t_vec = torch.ones(x.shape[0], 1).to(x.device) * t
            # 模型预测向量场 dx/dt
            return self.model(x, t_vec, self.news)

    # 准备初始噪声
    x0 = torch.randn_like(h_trend) # 形状与趋势一致
    
    # 使用 ODE Solver 从 t=0 积分到 t=1
    # 这一步生成的是 "Predicted Acceleration Sequence"
    ode_func = ODEFunc(cfm_model, future_news)
    
    # 你可以选择 'euler' (快) 或 'dopri5' (准)
    trajectory = odeint(ode_func, x0, torch.tensor([0., 1.]), method='euler')
    
    # 取 t=1 的结果作为生成的加速度
    acc_pred = trajectory[-1] # [1, Seq_Len, Dim]
    
    # --- C. 物理融合 (Physics Integration) ---
    # 现在我们有了 h_trend (惯性) 和 acc_pred (外力)
    # h_final = h_trend + DoubleIntegral(acc_pred)
    
    # 简单实现：假设 h_trend 是基准线，acc_pred 是对于基准线的偏离加速度
    # 我们需要数值积分两次得到位置偏离
    
    # 第一次积分: Acceleration -> Velocity Deviation
    # cumsum 近似积分
    vel_deviation = torch.cumsum(acc_pred, dim=1) 
    
    # 第二次积分: Velocity Deviation -> Position Deviation
    pos_deviation = torch.cumsum(vel_deviation, dim=1)
    
    # 最终结果
    h_final = h_trend + pos_deviation
    
    return h_final