from data_provider.data_factory_ele import data_provider
from exp.exp_basic import Exp_Basic
from models import Informer, Autoformer, Transformer, DLinear, Linear, NLinear, PatchTST,iTransformer,TimeXer
from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
from utils.metrics import metric
from models.flow_matching.acc_model import TemporalForceNet
from models.flow_matching.acceleration import compute_target_acceleration
from models.flow_matching.conditional_flow_matching import AccelerationOTFlowMatcher
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler 
import torch.nn.functional as F

import os
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np

import sys
import torchcde
from models.flow_matching.cross_attn import UltraFastCrossAttentionFilter
from exp.exp_main_maml import estimate_initial_velocity_ls

warnings.filterwarnings('ignore')


def compute_v_error(y_pred, y_true):
    """
    y_pred: [B, S, D] - 预测值
    y_true: [B, S, D] - 真实值
    """
    # 计算一阶导数（速度）
    # v.shape 为 [B, S-1, D]
    v_pred = y_pred[:, 1:, :] - y_pred[:, :-1, :]
    v_true = y_true[:, 1:, :] - y_true[:, :-1, :]

    v_error = torch.mean(torch.abs(v_pred - v_true))

    return v_error


def compute_batch_crps(final_samples, y_true):
    """
    final_samples: [100, B, S, D]
    y_true: [B, S, D]
    """
    import numpy as np
    from properscoring import crps_ensemble
    
    
    samples_np = final_samples.detach().cpu().numpy() # [100, B, S, D]
    true_np = y_true.detach().cpu().numpy()           # [B, S, D]
    
     
    samples_for_ps = samples_np.transpose(1, 2, 3, 0)
  
    crps_score = crps_ensemble(true_np, samples_for_ps)
    
    return crps_score.mean() # 返回全局平均值



# --- 定义损失函数 ---
def weighted_l1_loss(pred, target, weight_factor):
    # weight_factor = |a_target|
    # L1 Loss: |pred - target|
    
    # 避免除以零，并确保权重为正
    weights = weight_factor.abs() + 1e-6 
    
    # 加权 L1 损失: W * |pred - target|
    loss = (weights * torch.abs(pred - target)).mean()
    return loss


def extract_phase_data(sequence, smooth=False, window=3):
    """
    sequence: 单个样本的预测序列, 形状为 [pred_len] 或 [pred_len, 1]
    smooth: 是否对速度进行平滑处理
    """
    x = sequence.flatten()
    
    # 计算速度 V (一阶差分)
    # V[t] = X[t] - X[t-1]
    v = np.diff(x, prepend=x[0]) # prepend 保证长度一致，V[0] 初始为 0 或由 batch_x 末尾决定
    
    if smooth:
        # 使用滑动平均平滑速度，减少离散差分的噪声
        v = pd.Series(v).rolling(window=window, min_periods=1, center=True).mean().values
        
    return x, v


def extract_phase_data_v2(history_seq, pred_seq):
    """
    history_seq: 历史序列末尾, 至少包含最后两个点 [..., last_2, last_1]
    pred_seq: 预测序列 [p1, p2, ..., p_n]
    """
    # 获取历史末尾的最后一个值和最后的速度
    x_last_1 = history_seq[-1]
    x_last_2 = history_seq[-2]
    v_start = x_last_1 - x_last_2  # 物理初速度 V0
    
    # 拼接历史最后一个点到预测序列开头，用于计算连续的速度
    full_x = np.concatenate([[x_last_1], pred_seq])
    
    # 计算速度 (使用中心差分可以更平滑，但为了展示模型生成的原始动力学，
    # 建议直接使用前向差分，但要把历史初速度补上)
    # v[t] = x[t] - x[t-1]
    v = np.diff(full_x) 
    
    # 此时 v[0] 就是连接历史与预测的那一刻的速度
    # x 对应 [p1, p2, ...], v 对应 [v1, v2, ...]
    return pred_seq, v



class DynamicGate(nn.Module):
    def __init__(self, emb_dim, hidden_dim=32):
        super(DynamicGate, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.LayerNorm(hidden_dim), # 保证训练稳定性
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid() # 核心：将输出映射为“开关”强度
        )

    def forward(self, news_emb):
        # news_emb: [Batch, Seq_len, Dim] 或 [Batch, Dim]
        # 如果是序列，通常取 mean 或最后一个时刻
        if news_emb.dim() == 3:
            news_emb = news_emb.mean(dim=1) 
            
        gate_score = self.gate(news_emb) # 输出 [Batch, 1]
        return gate_score

class EMAScaler:
    """
    基于指数滑动平均 (EMA) 的 StandardScaler，用于动态更新统计量。
    在持续学习中，这比静态 StandardScaler 更具适应性。
    """
    def __init__(self, momentum=0.01, epsilon=1e-5):
        # momentum (动量): 0.01 表示 99% 依赖旧值，1% 依赖新值
        self.momentum = momentum
        self.epsilon = epsilon
        self.mean = None
        self.var = None
        self.n = 0

    def partial_fit(self, X):
        """用新的 batch X 更新 EMA 统计量"""
        X = np.asarray(X)
        batch_mean = X.mean(axis=0)
        batch_var = X.var(axis=0)
        
        if self.n == 0:
            self.mean = batch_mean
            self.var = batch_var
        else:
            # EMA 更新规则: new_mean = old_mean * (1 - m) + batch_mean * m
            self.mean = (1 - self.momentum) * self.mean + self.momentum * batch_mean
            self.var = (1 - self.momentum) * self.var + self.momentum * batch_var
        
        self.n += 1

    def transform(self, X):
        """对 X 进行标准化 (Z-score)"""
        if self.mean is None:
            raise RuntimeError("Scaler must be partially fitted before transformation.")
        
        # 避免除以零
        std = np.sqrt(self.var + self.epsilon)
        return (np.asarray(X) - self.mean) / std

    
class GPhi(nn.Module):
    def __init__(self, hidden_dim, news_dim, hidden_hidden=128):
        super().__init__()
    
        self.net = nn.Sequential(
                nn.Linear(hidden_dim, hidden_hidden),
                nn.LayerNorm(hidden_hidden), 
                nn.SiLU(),                    
                nn.Linear(hidden_hidden, (hidden_dim+1) * hidden_dim),
                nn.Tanh()                     
            )
        self.hidden_dim = hidden_dim
        self.news_dim = news_dim

    def forward(self, t, z):
        # z: [B, hidden_dim]
        B = z.size(0)
        
        z_normalized = F.layer_norm(z, normalized_shape=[self.hidden_dim])
        
        out = self.net(z_normalized) 
        
     
        return out.view(B, self.hidden_dim, self.hidden_dim+1)


def pad_feature_middle(batch):
    """
    batch: Tensor [B, L, D], 最后一列是target
    在倒数第二列插入一列全 0
    """
    B, L, D = batch.shape
    if D < 6:  # 这里假设目标总特征数是6，可以改成 max_feat-1
        pad = torch.zeros(B, L, 1, device=batch.device, dtype=batch.dtype)
        # 拼接: [:, :, :-1] + pad + [:, :, -1:]
        batch = torch.cat([batch[:, :, :-1], pad, batch[:, :, -1:]], dim=2)
    return batch

class Exp_Main(Exp_Basic):
    def __init__(self, args):
        super(Exp_Main, self).__init__(args)

         

    def _build_model(self):
        model_dict = {
            'Autoformer': Autoformer,
            'Transformer': Transformer,
            'Informer': Informer,
            'DLinear': DLinear,
            'NLinear': NLinear,
            'Linear': Linear,
            'PatchTST': PatchTST,
            'iTransformer': iTransformer,
            'TimeXer': TimeXer
        }
        model = model_dict[self.args.model].Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
    
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark,batch_x_flow,batch_y_flow) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                batch_x_flow = batch_x_flow.float().to(self.device)
                batch_y_flow = batch_y_flow.float().to(self.device)

                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model:
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                f_dim = -1 # if self.args.features == 'MS' else 0
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)
               
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self, setting):
        
        train_data, train_loader  = self._get_data(flag='train')
        vali_data,  vali_loader = self._get_data(flag='val')
        test_data,  test_loader = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
            
        scheduler = lr_scheduler.OneCycleLR(optimizer = model_optim,
                                            steps_per_epoch = train_steps,
                                            pct_start = self.args.pct_start,
                                            epochs = self.args.train_epochs,
                                            max_lr = self.args.learning_rate)
         
        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark,batch_x_flow,batch_y_flow) in enumerate(train_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                batch_x_flow = batch_x_flow.float().to(self.device)
                batch_y_flow = batch_y_flow.float().to(self.device)

                iter_count += 1
                model_optim.zero_grad()
                
                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)

                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

                        f_dim = -1 #if self.args.features == 'MS' else 0
                        outputs = outputs[:, -self.args.pred_len:, f_dim:]
                        batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            
                        else:
                            outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y)
                    # print(outputs.shape,batch_y.shape)
                    f_dim = -1 #if self.args.features == 'MS' else 0
                    outputs = outputs[:, -self.args.pred_len:, f_dim:]
                    batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())
                
                if (i + 1) % 100 == 0:
                    # print(loss,"1111", file=sys.stderr, flush=True)
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())\
                          , file=sys.stderr, flush=True)
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)\
                          , file=sys.stderr, flush=True)
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()
                    
                if self.args.lradj == 'TST':
                    adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
                    scheduler.step()

                i+=1

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)\
                  , file=sys.stderr, flush=True)
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss)\
                    , file=sys.stderr, flush=True)
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping", file=sys.stderr, flush=True)
                break

            if self.args.lradj != 'TST':
                adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args)
            else:
                print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0])\
                      , file=sys.stderr, flush=True)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        print('saving model ',best_model_path, file=sys.stderr, flush=True)

        return self.model
    

    def news_cond(self,us_news,a_share_news,H):
        # news cross attention
        N_max = max(us_news.shape[1],a_share_news.shape[1])
        news_all = torch.zeros(H.shape[0], N_max, us_news.shape[-1]).to(H.device)
        news_mask = torch.zeros(H.shape[0], N_max, dtype=torch.bool).to(H.device)
        
        news_all[:H.shape[0]//2, :us_news.shape[1], :] = us_news
        news_mask[:H.shape[0]//2, :us_news.shape[1]] = True

        news_all[H.shape[0]//2:, :a_share_news.shape[1], :] = a_share_news
        news_mask[H.shape[0]//2:, :a_share_news.shape[1]] = True

        e_prime, gates = self.cross_attn(H,news_all,news_mask)

        B, N, D = e_prime.shape
        
        t = torch.linspace(0, 1, N, device=e_prime.device, dtype=e_prime.dtype) 
        t_expanded = t.unsqueeze(0).expand(B, N).unsqueeze(-1) # [B, N, 1]
        X = torch.cat([t_expanded, e_prime], dim=2)
        Xcde = torchcde.LinearInterpolation(X) 
        hidden_dim = 128
            
        input_dim = D + 1 
        e_prime_mean = e_prime.mean(dim=1)  # [B, D_]
        z0 = self.z0_initializer(e_prime_mean) # [B, hidden_dim]
        z_T = torchcde.cdeint(X=Xcde, func=self.g_phi, z0=z0, t=t)
        news_emb = z_T[:, -1, :]
        cond_emb = news_emb
        uncond_emb = self.null_news.unsqueeze(0).expand(B, -1)   # [B, news_dim]

        return cond_emb,uncond_emb


    def ode_preocess(self,x_t,v_t,uncond_emb,cond_emb,B,w,D_orig):
        steps = 20 
        dt = 1.0 / steps
        
        all_steps_accel = []
        all_diffs = []
        
        for i in range(steps):
            # CFG
            t_value = i / steps
            t_batch = torch.full((x_t.shape[0],), t_value, device=self.device)
            x_dup = torch.cat([x_t, x_t], dim=0)             
            t_dup = torch.cat([t_batch, t_batch], dim=0)    
             
            cond_dup = torch.cat([uncond_emb, cond_emb], dim=0)  

            S_t = torch.cat([x_t, v_t], dim=-1)
            S_dup = torch.cat([S_t, S_t], dim=0)
            
            ut_dup = self.flow_network(S_dup, t_dup, cond_dup)
            ut_uncond, ut_cond = ut_dup.split(B, dim=0)       

            v_diff = ut_cond - ut_uncond
            guidance_strength = torch.norm(w * (ut_cond - ut_uncond), p=2, dim=-1)
            all_diffs.append(guidance_strength)
        
            U_guided = ut_uncond + w * (ut_cond - ut_uncond)

            U_x_guided, U_v_guided = U_guided.split(D_orig, dim=-1) # Each [B, S, D]
           
            if i==0:
                first_accel = U_v_guided
            all_steps_accel.append(U_v_guided)
            
            dx = v_t * dt
            dv = U_v_guided * dt
            x_t = x_t + dx
            v_t = v_t + dv

        accel_to_analyze = torch.stack(all_steps_accel).mean(dim=0)
        avg_guidance = torch.stack(all_diffs).mean(dim=0)
        return x_t,accel_to_analyze,first_accel,avg_guidance
        
    def flow_predict(self,batch_x,batch_y,batch_x_mark,batch_y_mark):
         
        batch_x = batch_x.float().to(self.device)
        batch_y = batch_y.float().to(self.device)
        batch_x_mark = batch_x_mark.float().to(self.device)
        batch_y_mark = batch_y_mark.float().to(self.device)
        
        dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
        dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
         
        if self.args.output_attention:
            output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
        else:
            output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        
        trend_part = trend_part[:, -self.args.pred_len:, :]
        batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)

        true = batch_y.detach().cpu()
 
        flow_target = batch_y - trend_part
        return flow_target, true 
    
    
    def vali_inf(self, vali_data, vali_loader, criterion,path):
        total_loss = []
        loss_original = []
        loss_trend = []
        for param in self.model.parameters():
            param.requires_grad = False
        
        best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
        self.flow_network.load_state_dict(torch.load(best_model_path_flow))
        
        best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
        self.null_news.data = torch.load(best_model_path_null)
        self.null_news = self.null_news.to(self.device)
        

      
        self.model.eval()
        self.flow_network.eval()
         
        self.g_phi.eval()

        ix = 200
        ix_i = 0
        
         
        all_accels = []
        all_envs = []
        batch_guidance_steps = []
        true_batch = []
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark,batch_x_flow,batch_y_flow) in enumerate(vali_loader):
        
            batch_x = batch_x.float().to(self.device)
            batch_y = batch_y.float().to(self.device)
            batch_x_mark = batch_x_mark.float().to(self.device)
            batch_y_mark = batch_y_mark.float().to(self.device)

            batch_x_flow = batch_x_flow.float().to(self.device)
            batch_y_flow = batch_y_flow.float().to(self.device)
            
            # decoder input
            dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
            dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
            # encoder - decoder
            self.model.train()
            n_samples = 100
            auto_samples = []
            # CRPS
            with torch.no_grad():
                for _ in range(n_samples):
        
                    if self.args.output_attention:
                        output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                    else:
                        output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                    pred_auto = output[:, -self.args.pred_len:, :]
                    auto_samples.append(pred_auto.unsqueeze(0))
            self.model.eval()
            
            if self.args.output_attention:
                output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
            else:
                output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            
            output_1 = output[:, -self.args.pred_len:, :].detach().cpu()
            batch_y_1 = batch_y[:, -self.args.pred_len:, :].detach().cpu() 
            loss_1 = criterion(output_1 , batch_y_1)
            # loss_1 = compute_v_error(output_1, batch_y_1)
            auto_samples_tensor = torch.cat(auto_samples, dim=0)
            # loss_1 = compute_batch_crps(auto_samples_tensor, batch_y_1)

            # fig

            samples = auto_samples_tensor.cpu().numpy()

            # 1. 选择我们要展示的那条数据 (比如 batch 中的第 0 条) 和 特征 (第 0 个特征)
            # 得到形状: [20, pred_len]
            plot_samples = samples[:, :, :, -1] 

            # 2. 计算统计量
            mean_pred = np.mean(plot_samples, axis=0)  # 均值线
            std_pred = np.std(plot_samples, axis=0)    # 标准差 (用于阴影)
            # 或者使用分位数 (更鲁棒，比如 95% 置信区间)
            lower_bound = np.percentile(plot_samples, 5, axis=0) 
            upper_bound = np.percentile(plot_samples, 95, axis=0)
             
            loss_original.append(loss_1)
            output_2 = trend_part[:, -self.args.pred_len:, :].detach().cpu()
            loss_2 = criterion(output_2 , batch_y_1 )
            loss_trend.append(loss_2)
        
            self.model.eval()
            trend_part = trend_part[:, -self.args.pred_len:, :]
            output_ = output[:, -self.args.pred_len:, :]
            batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)

            true = batch_y.detach().cpu()
            B, Seq, D = output_.shape
            # news cross attention
            # cond_emb,uncond_emb = self.news_cond(us_news,a_share_news,H)
            uncond_emb = self.null_news.view(1, 1, -1).expand(B, Seq, -1)   
            # cond_dup = torch.cat([uncond_emb, cond_emb], dim=0)   
            # cond_emb =  batch_x_flow[:,-Seq:,:]   

            if batch_x_flow.shape[1] >= Seq:
                cond_emb = batch_x_flow[:, -Seq:, :]
            else:
                pad_len = Seq - batch_x_flow.shape[1]
                pad = torch.zeros(batch_x_flow.shape[0], pad_len, batch_x_flow.shape[-1], device=batch_x_flow.device, dtype=batch_x_flow.dtype)
                cond_emb = torch.cat([pad, batch_x_flow], dim=1)
                
            
                
            w = 1.0 #0.1 # 1.0,1.5,2.0
            
            sigma0 = 0.0 
            x_t = output_  # 严格从 Autoformer 输出开始 [B, S, D]
            # v_t = torch.zeros_like(x_t)  # 初始速度设为 0 (对应训练时的 V0=0)
            v_start = (batch_x[:, -1, :] - batch_x[:, -2, :]) # [B, D]
            v_t = v_start.unsqueeze(1).repeat(1, Seq, 1)      # [B, S, D]
           
            B, S, D = output_.shape
            D_orig = D
            D_total = 2 * D

            
            with torch.no_grad():
                x_t_final,accel_to_analyze,first_accel,avg_guidance \
                    = self.ode_preocess(x_t,v_t,uncond_emb,cond_emb,B,w,D_orig)
            batch_guidance_steps.append(avg_guidance.cpu().numpy())
            true_diff = torch.diff(true, dim=1)
            padding = torch.zeros((true.shape[0], 1, true.shape[2]), device=true.device)
            true_diff_24 = torch.cat([padding, true_diff], dim=1)
            true_batch.append(true_diff_24[:,:,-1].cpu().numpy())

             
            final_prediction = x_t_final# * std + mean
            pred = final_prediction.detach().cpu()
            loss = criterion(pred , true)
 
            all_accels.append(accel_to_analyze.cpu().numpy()) # [B, Pred_Len, D]
            
            all_envs.append(cond_emb.cpu().numpy())   # [B, Pred_Len, Env_D]
          
            total_loss.append(loss.item())
       

        total_loss = np.mean(total_loss)
        original_loss = np.mean(loss_original)
        trend_loss = np.mean(loss_trend)

        
        self.flow_network.train()
        self.g_phi.train()
        print('total_loss', 'original_loss', 'trend_loss',total_loss, original_loss, trend_loss,"1111", file=sys.stderr, flush=True)
        return total_loss, original_loss, trend_loss

    def inference(self, setting):
        """
        1. Autoformer 预测趋势 冻结
        2. OT-CFM 生成加速度 (Shock)
        3. 物理方程融合
        """
        train_data, train_loader = self._get_data(flag='train')
        vali_data,  vali_loader= self._get_data(flag='val')
        test_data,  test_loader = self._get_data(flag='test')
        criterion = self._select_criterion()
        
        path_pre = os.path.join(self.args.checkpoints, setting)
        best_model_path = path_pre + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

         
        
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)
        
        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
        
        self.news_dim = 768
        hidden_dim = 128
        
        input_phys_dim=6
        self.flow_network = TemporalForceNet(state_dim=self.args.enc_in,phys_dim=input_phys_dim).to(self.device)
        self.ot_matcher = AccelerationOTFlowMatcher()# .to(self.device)
        self.g_phi = GPhi(hidden_dim, self.news_dim+1).to(self.device)
        
        self.null_news = nn.Parameter(torch.randn(input_phys_dim, device=self.device))
        
        self.v0_predictor = nn.Sequential(
            nn.Linear(self.args.d_model, self.args.d_model // 2),
            nn.ReLU(),
            nn.Linear(self.args.d_model // 2, self.args.enc_in) # 输出维度与 OT 一致
        ).to(self.device)

        self.z0_initializer = nn.Linear(hidden_dim, hidden_dim).to(self.device)

        self.p_drop = 0.1
        model_params = list(self.flow_network.parameters()) \
            + [self.null_news] + list(self.g_phi.parameters()) + list(self.z0_initializer.parameters())

        WEIGHT_DECAY = 1e-3 
        model_optim = optim.AdamW(
            model_params, 
            lr=self.args.learning_rate, 
            weight_decay=WEIGHT_DECAY 
        )

        for param in self.model.parameters():
            param.requires_grad = False

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
        
        iters_per_epoch =len(train_loader) 
        self.model.eval()
        scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim, steps_per_epoch=iters_per_epoch, pct_start=self.args.pct_start, epochs=self.args.train_epochs, max_lr=self.args.learning_rate)
        
        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            epoch_time = time.time()
            i=0

            ix_i = 0
            ix = 100
            count=0
             
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, batch_x_flow,batch_y_flow) in enumerate(train_loader):
            # for (us_batch, a_share_batch) in zip(train_loader_us, train_loader_a_share):
                print(ix_i,"11111", file=sys.stderr, flush=True)
                if ix_i%ix==0 and ix_i!=0:
                    print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)\
                    , file=sys.stderr, flush=True)
                    train_loss_ = np.average(train_loss)
                    best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
                    torch.save(self.flow_network.state_dict(), best_model_path_flow)
                    # best_model_path_cross = path + '/' + 'checkpoint_cross_s0.pth'
                    # torch.save(self.cross_attn.state_dict(), best_model_path_cross)
                    best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
                    torch.save(self.null_news.data,best_model_path_null)
                    # best_model_path_gphi = path + '/' + 'checkpoint_gphi_s0.pth'
                    # torch.save(self.g_phi.state_dict(), best_model_path_gphi)
                    
                    vali_loss,original_loss,trend_loss = self.vali_inf(vali_data, vali_loader, criterion,path)
                    test_loss,original_loss,trend_loss = self.vali_inf(test_data, test_loader, criterion,path)

                    
                    print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                        epoch + 1, train_steps, train_loss_, vali_loss, test_loss)\
                            , file=sys.stderr, flush=True)
                    # break
                    count+=1
                ix_i+=1
                if count>15:
                    break
                
                iter_count += 1
                model_optim.zero_grad()
                batch_x = batch_x.float().to(self.device)
                
                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                batch_x_flow = batch_x_flow.float().to(self.device)
                batch_y_flow = batch_y_flow.float().to(self.device)
                 
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)

                if self.args.output_attention:
                    output,trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                else:
                    output,trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y)
                # torch.Size([32, 7198, 768]) 1111
                # torch.Size([32, 2045, 768]) 1111
                output_1 = output[:, -self.args.pred_len:, :]
                batch_y_1 = batch_y[:, -self.args.pred_len:, :].to(self.device)
                loss_1 = criterion(output_1, batch_y_1)
                # print("original_loss",loss_1.item(),file=sys.stderr, flush=True)
                output_2 = trend_part[:, -self.args.pred_len:, :]
                loss_2 = criterion(output_2, batch_y_1)
                
                # (49089, 11) (49089, 11)

                trend_part = trend_part[:, -self.args.pred_len:, :]
                batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)
                flow_target = batch_y - trend_part
                
                output_ = output[:, -self.args.pred_len:, :]
                B, Seq, D = output_.shape
                 
                X0 = output_ 
                X1 = batch_y
                 
                v_start = (batch_x[:, -1, :] - batch_x[:, -2, :]) # [B, D]
                V0 = v_start.unsqueeze(1).repeat(1, self.args.pred_len, 1) # 广播到整个预测长度 [B, S, D]
               
                a_target = 2 * (X1 - X0 - V0) 
                
                 
                V1 = V0 + a_target
                 
                S0 = torch.cat([X0, V0], dim=-1)         # [B, S, 2D]
                S1_target = torch.cat([X1, V1], dim=-1)   # [B, S, 2D]

              
                t_ = torch.rand(B, device=self.device)
                t = t_.view(B, 1, 1)
                S_t = (1 - t) * S0 + t * S1_target
 

                S_t_dup = torch.cat([S_t, S_t], dim=0)         
                t_dup = torch.cat([t_, t_], dim=0) 
                
                 
                uncond_emb = self.null_news.view(1, 1, -1).expand(B, Seq, -1)   
                
                if batch_x_flow.shape[1] >= Seq:
                    x_seq = batch_x_flow[:, -Seq:, :]
                else:
                    pad_len = Seq - batch_x_flow.shape[1]
                    pad = torch.zeros(batch_x_flow.shape[0], pad_len, batch_x_flow.shape[-1], device=batch_x_flow.device, dtype=batch_x_flow.dtype)
                    x_seq = torch.cat([pad, batch_x_flow], dim=1)
                
                
                cond_dup = torch.cat([uncond_emb, x_seq], dim=0)   
                
                
                ut_dup = self.flow_network(S_t_dup, t_dup, cond_dup)  
                
                ut_uncond, ut_cond = ut_dup.split(B, dim=0)

                
                ut_x_pred, ut_v_pred = ut_cond.split(D, dim=-1)
                ut_x_pred_un, ut_v_pred_un = ut_uncond.split(D, dim=-1)
 
                loss_v_cond = criterion(ut_v_pred, a_target)
                loss_v_uncond = criterion(ut_v_pred_un, a_target)

                
                loss_x_cond = criterion(ut_x_pred, S_t[:, :, D:])  
                loss_x_uncond = criterion(ut_x_pred_un, S_t[:, :, D:])
 
                cos_sim = torch.nn.functional.cosine_similarity(ut_v_pred, a_target, dim=-1)
                loss_direction = 0.0#(1.0 - cos_sim).mean()

                 
                loss_direction = (1.0 - cos_sim).mean()

                
                lambda_x = 0.1
                loss_cond = loss_v_cond + lambda_x * loss_x_cond
                loss_uncond = loss_v_uncond + lambda_x * loss_x_uncond

                
                loss = 0.5 * (loss_cond + loss_uncond) + 0.1 * loss_direction 
                
                train_loss.append(loss.item())
                 

                if (i + 1) % 100 == 0:
                    # print(loss,"1111", file=sys.stderr, flush=True)
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())\
                          , file=sys.stderr, flush=True)
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)\
                          , file=sys.stderr, flush=True)
                    iter_count = 0
                    time_now = time.time()
                
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(self.g_phi.parameters(), max_norm=0.1) 
                    torch.nn.utils.clip_grad_norm_(self.flow_network.parameters(), max_norm=1.0) 
                    model_optim.step()
                    
                # if self.args.lradj == 'TST':
                adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
                scheduler.step()

                i+=1

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)\
                  , file=sys.stderr, flush=True)
            train_loss = np.average(train_loss)
            best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
            torch.save(self.flow_network.state_dict(), best_model_path_flow)
              
            best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
            torch.save(self.null_news.data,best_model_path_null)
            
            
            vali_loss,original_loss,trend_loss = self.vali_inf(vali_data, vali_loader, criterion,path)
            test_loss,original_loss,trend_loss = self.vali_inf(test_data, test_loader, criterion,path)

            
            print("Epoch111: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss)\
                    , file=sys.stderr, flush=True)
            # early_stopping(vali_loss, self.flow_network, path)
            # if early_stopping.early_stop:
            #     print("Early stopping", file=sys.stderr, flush=True)
            #     break
            
            if self.args.lradj != 'TST':
                adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args)
            else:
                print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0])\
                      , file=sys.stderr, flush=True)


        best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
        torch.save(self.flow_network.state_dict(), best_model_path_flow)
        # best_model_path_cross = path + '/' + 'checkpoint_cross_s0.pth'
        # torch.save(self.cross_attn.state_dict(), best_model_path_cross)
        best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
        torch.save(self.null_news.data,best_model_path_null)
        # best_model_path_gphi= path + '/' + 'checkpoint_gphi_s0.pth'
        # torch.save(self.g_phi,best_model_path_gphi)
        
        return self.flow_network
    
    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        criterion = self._select_criterion()
        
        path = os.path.join(self.args.checkpoints, setting)
        vali_loss,original_loss,trend_loss = self.vali_inf(test_data, test_loader, criterion,path)
        
        return

    def predict(self, setting, load=False):
        pred_data, pred_loader = self._get_data(flag='pred')

        if load:
            path = os.path.join(self.args.checkpoints, setting)
            best_model_path = path + '/' + 'checkpoint.pth'
            self.model.load_state_dict(torch.load(best_model_path))

        preds = []

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to(batch_y.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model:
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                pred = outputs.detach().cpu().numpy()  # .squeeze()
                preds.append(pred)

        preds = np.array(preds)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        np.save(folder_path + 'real_prediction.npy', preds)

        return


