from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from models import Informer, Autoformer, Transformer, DLinear, Linear, NLinear, PatchTST
from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
from utils.metrics import metric
from models.flow_matching.acc_model import TemporalForceNet
from models.flow_matching.acceleration import compute_target_acceleration
from models.flow_matching.conditional_flow_matching import AccelerationOTFlowMatcher
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler 
import torch.nn.functional as F

import os
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np

import sys
import torchcde
from models.flow_matching.cross_attn import UltraFastCrossAttentionFilter
from exp.exp_main_maml import estimate_initial_velocity_ls

warnings.filterwarnings('ignore')


# --- 定义损失函数 ---
def weighted_l1_loss(pred, target, weight_factor):
    # weight_factor = |a_target|
    # L1 Loss: |pred - target|
    
    # 避免除以零，并确保权重为正
    weights = weight_factor.abs() + 1e-6 
    
    # 加权 L1 损失: W * |pred - target|
    loss = (weights * torch.abs(pred - target)).mean()
    return loss

import numpy as np

import torch
import torch.nn as nn

class DynamicGate(nn.Module):
    def __init__(self, emb_dim, hidden_dim=32):
        super(DynamicGate, self).__init__()
        self.gate = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.LayerNorm(hidden_dim), # 保证训练稳定性
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid() # 核心：将输出映射为“开关”强度
        )

    def forward(self, news_emb):
        # news_emb: [Batch, Seq_len, Dim] 或 [Batch, Dim]
        # 如果是序列，通常取 mean 或最后一个时刻
        if news_emb.dim() == 3:
            news_emb = news_emb.mean(dim=1) 
            
        gate_score = self.gate(news_emb) # 输出 [Batch, 1]
        return gate_score

class EMAScaler:
    """
    基于指数滑动平均 (EMA) 的 StandardScaler，用于动态更新统计量。
    在持续学习中，这比静态 StandardScaler 更具适应性。
    """
    def __init__(self, momentum=0.01, epsilon=1e-5):
        # momentum (动量): 0.01 表示 99% 依赖旧值，1% 依赖新值
        self.momentum = momentum
        self.epsilon = epsilon
        self.mean = None
        self.var = None
        self.n = 0

    def partial_fit(self, X):
        """用新的 batch X 更新 EMA 统计量"""
        X = np.asarray(X)
        batch_mean = X.mean(axis=0)
        batch_var = X.var(axis=0)
        
        if self.n == 0:
            self.mean = batch_mean
            self.var = batch_var
        else:
            # EMA 更新规则: new_mean = old_mean * (1 - m) + batch_mean * m
            self.mean = (1 - self.momentum) * self.mean + self.momentum * batch_mean
            self.var = (1 - self.momentum) * self.var + self.momentum * batch_var
        
        self.n += 1

    def transform(self, X):
        """对 X 进行标准化 (Z-score)"""
        if self.mean is None:
            raise RuntimeError("Scaler must be partially fitted before transformation.")
        
        # 避免除以零
        std = np.sqrt(self.var + self.epsilon)
        return (np.asarray(X) - self.mean) / std

    
class GPhi(nn.Module):
    def __init__(self, hidden_dim, news_dim, hidden_hidden=128):
        super().__init__()
    #     self.net = nn.Sequential(
    #         nn.Linear(hidden_dim, hidden_hidden),
    #         nn.ReLU(),
    #         nn.Linear(hidden_hidden, news_dim * hidden_dim)
    #     )
    #     self.hidden_dim = hidden_dim
    #     self.news_dim = news_dim

    # def forward(self, t, z):
    #     """
    #     t: 时间步长或时间序列（例如，shape 为 [B, N] 或 [N]）
    #     z: 状态张量（例如，shape 为 [B, hidden_dim]）
    #     """
    #     # z: [B, hidden_dim]
    #     B = z.size(0)
        
    #     # 你可以根据 t 来计算其他时间相关的操作（例如时间衰减、控制梯度等）
    #     # 在此假设你需要结合 t 做一些计算
    #     # 比如，time-dependent scaling factor，或者直接使用 t 来改变网络行为

    #     # 如果你有时间依赖的因素，可以在这里加入基于 t 的计算
        
    #     out = self.net(z)  # [B, news_dim * hidden_dim]
        
    #     # 计算输出形状 [B, hidden_dim, news_dim]
    #     return out.view(B, self.hidden_dim, self.news_dim)  # reshape成 CDE 需要的形状
        self.net = nn.Sequential(
                nn.Linear(hidden_dim, hidden_hidden),
                nn.LayerNorm(hidden_hidden), # ✅ 新增：LayerNorm
                nn.SiLU(),                   # ✅ 替换：ReLU -> SiLU
                nn.Linear(hidden_hidden, (hidden_dim+1) * hidden_dim),
                nn.Tanh()                    # ✅ 新增：限制输出在 [-1, 1]
            )
        self.hidden_dim = hidden_dim
        self.news_dim = news_dim

    def forward(self, t, z):
        # z: [B, hidden_dim]
        B = z.size(0)
        
        # 约束状态 z 的幅度（可选，但推荐）
        z_normalized = F.layer_norm(z, normalized_shape=[self.hidden_dim])
        
        # **注意：t 没有被用到，但这是 CDE func 的标准签名**
        out = self.net(z_normalized) 
        
        # 计算输出形状 [B, hidden_dim, news_dim] (即 G(z) 的矩阵输出)
        return out.view(B, self.hidden_dim, self.hidden_dim+1)


def pad_feature_middle(batch):
    """
    batch: Tensor [B, L, D], 最后一列是target
    在倒数第二列插入一列全 0
    """
    B, L, D = batch.shape
    if D < 6:  # 这里假设目标总特征数是6，可以改成 max_feat-1
        pad = torch.zeros(B, L, 1, device=batch.device, dtype=batch.dtype)
        # 拼接: [:, :, :-1] + pad + [:, :, -1:]
        batch = torch.cat([batch[:, :, :-1], pad, batch[:, :, -1:]], dim=2)
    return batch

class Exp_Main(Exp_Basic):
    def __init__(self, args):
        super(Exp_Main, self).__init__(args)

         

    def _build_model(self):
        model_dict = {
            'Autoformer': Autoformer,
            'Transformer': Transformer,
            'Informer': Informer,
            'DLinear': DLinear,
            'NLinear': NLinear,
            'Linear': Linear,
            'PatchTST': PatchTST,
        }
        model = model_dict[self.args.model].Model(self.args).float()

        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        # data_set, data_loader,data_loader1,data_loader2 = data_provider(self.args, flag)
        # return data_set, data_loader,data_loader1,data_loader2
        data_set, data_loader,data_loader1 = data_provider(self.args, flag)
        return data_set, data_loader,data_loader1

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    # def vali(self, vali_data, vali_loader_a_share,vali_loader_btc,vali_loader_us, criterion):
    def vali(self, vali_data, vali_loader_a_share,vali_loader_us, criterion):
    
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            # for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
            for (us_batch, a_share_batch) in zip(vali_loader_us, vali_loader_a_share):
                # batch_x = batch_x.float().to(self.device)
                # batch_y = batch_y.float()

                # batch_x_mark = batch_x_mark.float().to(self.device)
                # batch_y_mark = batch_y_mark.float().to(self.device)

                # crypto_x = pad_feature_middle(crypto_batch[0])
                # crypto_y = pad_feature_middle(crypto_batch[1])
                # crypto_time_x = crypto_batch[2]
                # crypto_time_y = crypto_batch[3]

                us_x = us_batch[0]
                us_y = us_batch[1]
                us_time_x = us_batch[2]
                us_time_y = us_batch[3]
                 
                a_share_x = a_share_batch[0]
                a_share_y = a_share_batch[1]
                a_share_time_x = a_share_batch[2]
                a_share_time_y = a_share_batch[3]
                 
                batch_x = torch.cat([us_x, a_share_x], dim=0)
                batch_y = torch.cat([us_y, a_share_y], dim=0)
                batch_x_mark = torch.cat([us_time_x, a_share_time_x], dim=0)
                batch_y_mark = torch.cat([us_time_y, a_share_time_y], dim=0)

                batch_x = batch_x.float().to(self.device)

                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model:
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                f_dim = -1 if self.args.features == 'MS' else 0
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)
                
                # print(loss,"1111", file=sys.stderr, flush=True)
                # exit()
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self, setting):
        # train_data, train_loader_a_share,train_loader_btc,train_loader_us = self._get_data(flag='train')
        # vali_data,  vali_loader_a_share,vali_loader_btc,vali_loader_us= self._get_data(flag='val')
        # test_data,  test_loader_a_share,test_loader_btc,test_loader_us = self._get_data(flag='test')

        train_data, train_loader_a_share,train_loader_us = self._get_data(flag='train')
        vali_data,  vali_loader_a_share,vali_loader_us= self._get_data(flag='val')
        test_data,  test_loader_a_share,test_loader_us = self._get_data(flag='test')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        train_steps = len(train_loader_a_share)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
            
        scheduler = lr_scheduler.OneCycleLR(optimizer = model_optim,
                                            steps_per_epoch = train_steps,
                                            pct_start = self.args.pct_start,
                                            epochs = self.args.train_epochs,
                                            max_lr = self.args.learning_rate)

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            i=0
            # for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader_a_share):
            for (us_batch, a_share_batch) in zip(train_loader_us, train_loader_a_share):
                # crypto for zero-shot
                # crypto_x = pad_feature_middle(crypto_batch[0])
                # crypto_y = pad_feature_middle(crypto_batch[1])
                # crypto_time_x = crypto_batch[2]
                # crypto_time_y = crypto_batch[3]

                us_x = us_batch[0]
                us_y = us_batch[1]
                us_time_x = us_batch[2]
                us_time_y = us_batch[3]
                 
                a_share_x = a_share_batch[0]
                a_share_y = a_share_batch[1]
                a_share_time_x = a_share_batch[2]
                a_share_time_y = a_share_batch[3]
                 

                batch_x = torch.cat([us_x, a_share_x], dim=0)
                batch_y = torch.cat([us_y, a_share_y], dim=0)
                batch_x_mark = torch.cat([us_time_x, a_share_time_x], dim=0)
                batch_y_mark = torch.cat([us_time_y, a_share_time_y], dim=0)
                # import sys
                # print(crypto_batch[0].shape, file=sys.stderr, flush=True)
                # print("222222", file=sys.stderr, flush=True)
                
                iter_count += 1
                model_optim.zero_grad()
                batch_x = batch_x.float().to(self.device)

                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)

                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

                        f_dim = -1 if self.args.features == 'MS' else 0
                        outputs = outputs[:, -self.args.pred_len:, f_dim:]
                        batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            
                        else:
                            outputs,_,_ = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y)
                    # print(outputs.shape,batch_y.shape)
                    f_dim = -1 if self.args.features == 'MS' else 0
                    outputs = outputs[:, -self.args.pred_len:, f_dim:]
                    batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    # print(loss,"1111", file=sys.stderr, flush=True)
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())\
                          , file=sys.stderr, flush=True)
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)\
                          , file=sys.stderr, flush=True)
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    model_optim.step()
                    
                if self.args.lradj == 'TST':
                    adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
                    scheduler.step()

                i+=1

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)\
                  , file=sys.stderr, flush=True)
            train_loss = np.average(train_loss)
            # vali_loss = self.vali(vali_data, vali_loader_a_share,vali_loader_btc,vali_loader_us, criterion)
            # test_loss = self.vali(test_data, test_loader_a_share,test_loader_btc,test_loader_us, criterion)
            vali_loss = self.vali(vali_data, vali_loader_a_share,vali_loader_us, criterion)
            test_loss = self.vali(test_data, test_loader_a_share,test_loader_us, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss)\
                    , file=sys.stderr, flush=True)
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping", file=sys.stderr, flush=True)
                break

            if self.args.lradj != 'TST':
                adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args)
            else:
                print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0])\
                      , file=sys.stderr, flush=True)

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
    

    def news_cond(self,us_news,a_share_news,H):
        # news cross attention
        N_max = max(us_news.shape[1],a_share_news.shape[1])
        news_all = torch.zeros(H.shape[0], N_max, us_news.shape[-1]).to(H.device)
        news_mask = torch.zeros(H.shape[0], N_max, dtype=torch.bool).to(H.device)
        
        news_all[:H.shape[0]//2, :us_news.shape[1], :] = us_news
        news_mask[:H.shape[0]//2, :us_news.shape[1]] = True

        news_all[H.shape[0]//2:, :a_share_news.shape[1], :] = a_share_news
        news_mask[H.shape[0]//2:, :a_share_news.shape[1]] = True

        e_prime, gates = self.cross_attn(H,news_all,news_mask)

        B, N, D = e_prime.shape
        
        t = torch.linspace(0, 1, N, device=e_prime.device, dtype=e_prime.dtype) 
        t_expanded = t.unsqueeze(0).expand(B, N).unsqueeze(-1) # [B, N, 1]
        X = torch.cat([t_expanded, e_prime], dim=2)
        Xcde = torchcde.LinearInterpolation(X) 
        hidden_dim = 128
            
        input_dim = D + 1 
        e_prime_mean = e_prime.mean(dim=1)  # [B, D_]
        z0 = self.z0_initializer(e_prime_mean) # [B, hidden_dim]
        z_T = torchcde.cdeint(X=Xcde, func=self.g_phi, z0=z0, t=t)
        news_emb = z_T[:, -1, :]
        cond_emb = news_emb
        uncond_emb = self.null_news.unsqueeze(0).expand(B, -1)   # [B, news_dim]

        return cond_emb,uncond_emb


    def ode_preocess(self,x_t,v_t,uncond_emb,cond_emb,B,w,D_orig):
        steps = 20 
        dt = 1.0 / steps
        
        
        for i in range(steps):
            # CFG
            t_value = i / steps
            t_batch = torch.full((x_t.shape[0],), t_value, device=self.device)
            x_dup = torch.cat([x_t, x_t], dim=0)            # [2B, ...]
            t_dup = torch.cat([t_batch, t_batch], dim=0)    # [2B]
            # cond_dup = torch.cat([uncond_emb, cond_emb], dim=0)  # [2B, D_news]

            cond_dup = torch.cat([uncond_emb, uncond_emb], dim=0)  # [2B, D_news]

            S_t = torch.cat([x_t, v_t], dim=-1)
            S_dup = torch.cat([S_t, S_t], dim=0)
            
            # S_t = torch.cat([x_t, v_t], dim=-1)
            # S_dup = torch.cat([S_t, S_t], dim=0)
            # ut_dup = self.flow_network(x_dup, t_dup, cond_dup)
            ut_dup = self.flow_network(S_dup, t_dup, cond_dup)
            ut_uncond, ut_cond = ut_dup.split(B, dim=0)         # each [B, ...]

            # compute difference & optional clip
            v_diff = ut_cond - ut_uncond
        
            U_guided = ut_uncond + w * (ut_cond - ut_uncond)

            U_x_guided, U_v_guided = U_guided.split(D_orig, dim=-1) # Each [B, S, D]
            # v_t = v_t + guided_acceleration * dt
            # x_t = x_t + v_t * dt

            
            dx = v_t * dt
            dv = U_v_guided * dt
            x_t = x_t + dx
            v_t = v_t + dv

            # x_t = x_t + U_x_guided * dt
            # v_t = v_t + U_v_guided * dt

        return x_t
        
    def flow_predict(self,us_x,us_y,us_time_x,us_time_y,us_news,a_share_x,a_share_y,a_share_time_x,a_share_time_y,a_share_news):
        batch_x = torch.cat([us_x, a_share_x], dim=0)
        batch_y = torch.cat([us_y, a_share_y], dim=0)
        batch_x_mark = torch.cat([us_time_x, a_share_time_x], dim=0)
        batch_y_mark = torch.cat([us_time_y, a_share_time_y], dim=0)

        batch_x = batch_x.float().to(self.device)
        batch_y = batch_y.float().to(self.device)
        batch_x_mark = batch_x_mark.float().to(self.device)
        batch_y_mark = batch_y_mark.float().to(self.device)
        
        # decoder input
        dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
        dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
        # encoder - decoder
        
        if self.args.output_attention:
            output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
        else:
            output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        
        trend_part = trend_part[:, -self.args.pred_len:, :]
        batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)

        true = batch_y.detach().cpu()

        # news cross attention
        cond_emb,uncond_emb = self.news_cond(us_news,a_share_news,H)

        flow_target = batch_y - trend_part
        return flow_target, true, cond_emb, uncond_emb
    
    
    def flow_inversion(self, flow_target, S_max, D_orig, cond_emb, uncond_emb, w, num_iterations=20, learning_rate=0.05):
        """
        通过优化找到最佳的条件潜变量序列 (x_0, v_0)。
        
        Args:
            flow_target: 观测到的历史残差序列 x_obs_residual (B, L_obs, D_orig)。
            S_max, D_orig: 序列和维度信息。
            cond_emb, uncond_emb, w, clip_norm: 用于 ode_process 的 CFG 参数。
            num_iterations, learning_rate: 优化参数。
            
        Returns:
            mu_z: 估计出的条件潜变量序列均值 S_0 (B, S_max, 2*D_orig)。
        """
        B, L_obs, D_data = flow_target.shape
        
        # 假设 flow_target 就是 x_obs_residual
        x_obs_residual = flow_target 
        
        # 1. 初始化潜变量 x_0 和 v_0：[B, S_max, D_orig]
        # 必须分别设置 requires_grad=True
        sigma = 1e-2
        x_0 = torch.randn(B, S_max, D_orig, device=self.device) * sigma
        v_0 = torch.randn(B, S_max, D_orig, device=self.device) * sigma
        
        x_0.requires_grad_(True)
        v_0.requires_grad_(True)
        
        # 将两个潜变量都放入优化器中
        optimizer = optim.Adam([x_0, v_0], lr=learning_rate)
        
        
        # **【核心优化循环】**：寻找最能生成 x_obs_residual 的 (x_0, v_0)
        for i in range(num_iterations):
            optimizer.zero_grad()
            
            
            X_pred = self.ode_preocess(x_0, v_0, uncond_emb,cond_emb, B,w,D_orig) 
            
            # 3. 计算损失：只匹配观测的 L_obs 长度
            X_pred_obs_part = X_pred

            V_obs_residual = x_obs_residual[:, 1:, :] - x_obs_residual[:, :-1, :] 
            V_pred_obs_part = X_pred_obs_part[:, 1:, :] - X_pred_obs_part[:, :-1, :]
            lambda_v = 1.0 # 速度项权重，可能需要调参
            L_v = lambda_v * torch.mean((V_pred_obs_part - V_obs_residual)**2)
            
            # 最小化重构误差 (L2 Loss)
            lambda_reg = 0.001 # 正则化系数
            reg_loss = lambda_reg * (torch.mean(x_0**2) + torch.mean(v_0**2))

            mu_z_opt = torch.cat([x_0, v_0], dim=-1)

            # 计算相邻时间步的差值 (在 S_max 维度上)
            # diff: [B, S_max-1, 2*D_orig]
            diff = mu_z_opt[:, 1:, :] - mu_z_opt[:, :-1, :]
            
            # 计算平滑损失 (差值的 L2 范数)
            loss_smooth = 0.01 * torch.mean(diff**2)

            loss = torch.mean((X_pred_obs_part - x_obs_residual)**2) + reg_loss # +loss_smooth +L_v
            
            # 4. 反向传播和参数更新 (更新 x_0 和 v_0)
            loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_([x_0, v_0], max_norm=1.0) # 假设 clip_norm 是一个参数
            optimizer.step()

        
        # mu_z 是优化后的潜变量序列 S_0，需要将 x_0 和 v_0 拼接回去
        mu_z = torch.cat([x_0.detach(), v_0.detach()], dim=-1)
        
        return mu_z

    def vali_inf(self, vali_data, vali_loader_a_share,vali_loader_us, criterion,path):
        total_loss = []
        loss_original = []
        loss_trend = []
        for param in self.model.parameters():
            param.requires_grad = False
        
        best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
        self.flow_network.load_state_dict(torch.load(best_model_path_flow))
        best_model_path_cross = path + '/' + 'checkpoint_cross_s0.pth'
        self.cross_attn.load_state_dict(torch.load(best_model_path_cross))
        best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
        self.null_news.data = torch.load(best_model_path_null)
        self.null_news = self.null_news.to(self.device)
        # best_model_path_gphi = path + '/' + 'checkpoint_gphi_s0.pth'
        # self.g_phi.load_state_dict(torch.load(best_model_path_gphi))
        
        self.model.eval()
        self.flow_network.eval()
        self.cross_attn.eval()
        self.g_phi.eval()

        ix = 200
        ix_i = 0
        
         
        # for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
        for (us_batch, a_share_batch) in zip(vali_loader_us, vali_loader_a_share):
            if ix_i>ix:
                break
            ix_i+=1
            
            us_x = us_batch[0]
            us_y = us_batch[1]
            us_time_x = us_batch[2]
            us_time_y = us_batch[3]
            us_news = us_batch[4]

            us_x_flow = us_batch[5]
            us_y_flow = us_batch[6]
            us_time_x_flow = us_batch[7]
            us_time_y_flow = us_batch[8]
            us_news_flow = us_batch[9]

            a_share_x = a_share_batch[0]
            a_share_y = a_share_batch[1]
            a_share_time_x = a_share_batch[2]
            a_share_time_y = a_share_batch[3]
            a_share_news = a_share_batch[4]

            a_share_x_flow = a_share_batch[5]
            a_share_y_flow = a_share_batch[6]
            a_share_time_x_flow = a_share_batch[7]
            a_share_time_y_flow = a_share_batch[8]
            a_share_news_flow = a_share_batch[9]

            # print(us_x.shape,us_y.shape,us_time_x.shape,us_time_y.shape,us_news.shape, file=sys.stderr, flush=True)
            # print(us_x_flow.shape,us_y_flow.shape,us_time_x_flow.shape,us_time_y_flow.shape,us_news_flow.shape, file=sys.stderr, flush=True)
            # print(a_share_x.shape,a_share_y.shape,a_share_time_x.shape,a_share_time_y.shape,a_share_news.shape, file=sys.stderr, flush=True)
            # print(a_share_x_flow.shape,a_share_y_flow.shape,a_share_time_x_flow.shape,a_share_time_y_flow.shape,a_share_news_flow.shape, file=sys.stderr, flush=True)
            
            # exit()

            batch_x = torch.cat([us_x, a_share_x], dim=0)
            batch_y = torch.cat([us_y, a_share_y], dim=0)
            batch_x_mark = torch.cat([us_time_x, a_share_time_x], dim=0)
            batch_y_mark = torch.cat([us_time_y, a_share_time_y], dim=0)

            batch_x = batch_x.float().to(self.device)
            batch_y = batch_y.float().to(self.device)
            batch_x_mark = batch_x_mark.float().to(self.device)
            batch_y_mark = batch_y_mark.float().to(self.device)
            
            # decoder input
            dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
            dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
            # encoder - decoder
            
            if self.args.output_attention:
                output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
            else:
                output, trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            

            output_1 = output[:, -self.args.pred_len:, :].detach().cpu()
            batch_y_1 = batch_y[:, -self.args.pred_len:, :].detach().cpu() 
            loss_1 = criterion(output_1, batch_y_1)
            loss_original.append(loss_1)
            output_2 = trend_part[:, -self.args.pred_len:, :].detach().cpu()
            loss_2 = criterion(output_2, batch_y_1)
            loss_trend.append(loss_2)

            
            trend_part = trend_part[:, -self.args.pred_len:, :]
            output_ = output[:, -self.args.pred_len:, :]
            batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)

            true = batch_y.detach().cpu()

            # news cross attention
            cond_emb,uncond_emb = self.news_cond(us_news,a_share_news,H)

            w = 1.0 #0.1 # 1.0,1.5,2.0
            
            sigma0 = 0.0 
            x_t = output_  # 严格从 Autoformer 输出开始 [B, S, D]
            v_t = torch.zeros_like(x_t)  # 初始速度设为 0 (对应训练时的 V0=0)

            mean = x_t.mean(dim=1, keepdim=True)
            std = torch.sqrt(torch.var(x_t, dim=1, keepdim=True) + 1e-5)
            
            # 2. 归一化起点
            x_t = (x_t - mean) / std
            v_t = torch.zeros_like(x_t)
            
            B, S, D = output_.shape
            D_orig = D
            D_total = 2 * D

            # sigma = 1e-2
            # x_t = torch.randn(B, S, D, device=self.device) * sigma
            # v_t = torch.randn(B, S, D, device=self.device) * sigma

            flow = False 
            if flow:
                flow_target_flow, true_flow, cond_emb_flow, uncond_emb_flow = \
                    self.flow_predict(us_x_flow,us_y_flow,us_time_x_flow,us_time_y_flow,us_news_flow,\
                                        a_share_x_flow,a_share_y_flow,a_share_time_x_flow,a_share_time_y_flow,a_share_news_flow)

                mu_z = self.flow_inversion(flow_target_flow, S, D_orig, cond_emb_flow, uncond_emb_flow, w)
                S_t = mu_z.clone().detach() # [B, S_max, 2*D_orig]
                x_t, v_t = S_t.split(D_orig, dim=-1) # [B, S_max, D_orig]

            with torch.no_grad():
                x_t_final = self.ode_preocess(x_t,v_t,uncond_emb,cond_emb,B,w,D_orig)
            # final_prediction =  x_t_final
            final_prediction = x_t_final * std + mean

            pred = final_prediction.detach().cpu()
            loss = criterion(pred, true)

            
            total_loss.append(loss.item())
        total_loss = np.mean(total_loss)
        original_loss = np.mean(loss_original)
        trend_loss = np.mean(loss_trend)

        self.model.train()
    
        self.flow_network.train()
        self.cross_attn.train()
        self.g_phi.train()
        print('total_loss', 'original_loss', 'trend_loss',total_loss, original_loss, trend_loss,"1111", file=sys.stderr, flush=True)
        return total_loss, original_loss, trend_loss

    def inference(self, setting):
        """
        1. Autoformer 预测趋势 冻结
        2. OT-CFM 生成加速度 (Shock)
        3. 物理方程融合
        """
        # train_data, train_loader_a_share,train_loader_btc,train_loader_us = self.(flag='train')
        # vali_data,  vali_loader_a_share,vali_loader_btc,vali_loader_us= self._get_data(flag='val')
        # test_data,  test_loader_a_share,test_loader_btc,test_loader_us = self._get_data(flag='test')
        train_data, train_loader_a_share,train_loader_us = self._get_data(flag='train')
        vali_data,  vali_loader_a_share,vali_loader_us= self._get_data(flag='val')
        test_data,  test_loader_a_share,test_loader_us = self._get_data(flag='test')
        criterion = self._select_criterion()
        
        path_pre = os.path.join(self.args.checkpoints, setting)
        best_model_path = path_pre + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        
        
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)
        
        time_now = time.time()

        train_steps = len(train_loader_a_share)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        self.news_dim = 768
        hidden_dim = 128
        self.flow_network = TemporalForceNet(state_dim=self.args.enc_in,news_dim=hidden_dim).to(self.device)
        self.ot_matcher = AccelerationOTFlowMatcher()# .to(self.device)
        self.cross_attn = UltraFastCrossAttentionFilter(self.args.d_model,self.news_dim).to(self.device)
        self.g_phi = GPhi(hidden_dim, self.news_dim+1).to(self.device)

        self.null_news = nn.Parameter(torch.randn(hidden_dim, device=self.device))
        
        # best_model_path_flow = path_pre + '/' + 'checkpoint_flow_s0.pth'
        # self.flow_network.load_state_dict(torch.load(best_model_path_flow))
        # best_model_path_cross = path_pre + '/' + 'checkpoint_cross_s0.pth'
        # self.cross_attn.load_state_dict(torch.load(best_model_path_cross))
        # best_model_path_null = path_pre + '/' + 'checkpoint_null_s0.pth'
        # self.null_news.data = torch.load(best_model_path_null)
        # self.null_news = self.null_news.to(self.device)

        self.z0_initializer = nn.Linear(hidden_dim, hidden_dim).to(self.device)

        self.p_drop = 0.1
        model_params = list(self.flow_network.parameters()) + list(self.cross_attn.parameters())\
            + [self.null_news] + list(self.g_phi.parameters()) + list(self.z0_initializer.parameters())

        WEIGHT_DECAY = 1e-3 
        model_optim = optim.AdamW(
            model_params, 
            lr=self.args.learning_rate, 
            weight_decay=WEIGHT_DECAY 
        )

        for param in self.model.parameters():
            param.requires_grad = False

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
        
        iters_per_epoch = min(len(train_loader_us), len(train_loader_a_share))
        self.model.eval()
        scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim, steps_per_epoch=iters_per_epoch, pct_start=self.args.pct_start, epochs=self.args.train_epochs, max_lr=self.args.learning_rate)
        # print("1111", file=sys.stderr, flush=True)
        # vali_loss,original_loss,trend_loss = self.vali_inf(vali_data, vali_loader_a_share,vali_loader_us, criterion,path)
        # print("22222", file=sys.stderr, flush=True)
        # test_loss,original_loss,trend_loss = self.vali_inf(test_data, test_loader_a_share,test_loader_us, criterion,path)
        # print("3333", file=sys.stderr, flush=True)
        # exit()
        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            epoch_time = time.time()
            i=0

            ix_i = 0
            ix = 20
            count=0
            # for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader_a_share):
            for (us_batch, a_share_batch) in zip(train_loader_us, train_loader_a_share):
                print(ix_i,"11111", file=sys.stderr, flush=True)
                if ix_i%ix==0 and ix_i!=0:
                    print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)\
                    , file=sys.stderr, flush=True)
                    train_loss_ = np.average(train_loss)
                    best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
                    torch.save(self.flow_network.state_dict(), best_model_path_flow)
                    best_model_path_cross = path + '/' + 'checkpoint_cross_s0.pth'
                    torch.save(self.cross_attn.state_dict(), best_model_path_cross)
                    best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
                    torch.save(self.null_news.data,best_model_path_null)
                    # best_model_path_gphi = path + '/' + 'checkpoint_gphi_s0.pth'
                    # torch.save(self.g_phi.state_dict(), best_model_path_gphi)
                    
                    vali_loss,original_loss,trend_loss = self.vali_inf(vali_data, vali_loader_a_share,vali_loader_us, criterion,path)
                    test_loss,original_loss,trend_loss = self.vali_inf(test_data, test_loader_a_share,test_loader_us, criterion,path)

                    
                    print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                        epoch + 1, train_steps, train_loss_, vali_loss, test_loss)\
                            , file=sys.stderr, flush=True)
                    # break
                    count+=1
                ix_i+=1
                if count>5:
                    break
                # crypto for zero-shot
                # crypto_x = pad_feature_middle(crypto_batch[0])
                # crypto_y = pad_feature_middle(crypto_batch[1])
                # crypto_time_x = crypto_batch[2]
                # crypto_time_y = crypto_batch[3]

                us_x = us_batch[0]
                us_y = us_batch[1]
                us_time_x = us_batch[2]
                us_time_y = us_batch[3]
                us_news = us_batch[4]
                a_share_x = a_share_batch[0]
                a_share_y = a_share_batch[1]
                a_share_time_x = a_share_batch[2]
                a_share_time_y = a_share_batch[3]
                a_share_news = a_share_batch[4]

                
                batch_x = torch.cat([us_x, a_share_x], dim=0)
                batch_y = torch.cat([us_y, a_share_y], dim=0)
                batch_x_mark = torch.cat([us_time_x, a_share_time_x], dim=0)
                batch_y_mark = torch.cat([us_time_y, a_share_time_y], dim=0)
                # import sys
                # print(crypto_batch[0].shape, file=sys.stderr, flush=True)
                # print("222222", file=sys.stderr, flush=True)
                
                iter_count += 1
                model_optim.zero_grad()
                batch_x = batch_x.float().to(self.device)
                
                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # sample_stds = batch_y.std(dim=(1, 2)) # [B]
                # print(f"Batch Std Variation: {sample_stds.max() / (sample_stds.min() + 1e-6):.2f}",file=sys.stderr, flush=True)

                
                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)

                if self.args.output_attention:
                    output,trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                else:
                    output,trend_part,H = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y)
                # torch.Size([32, 7198, 768]) 1111
                # torch.Size([32, 2045, 768]) 1111
                output_1 = output[:, -self.args.pred_len:, :]
                batch_y_1 = batch_y[:, -self.args.pred_len:, :].to(self.device)
                loss_1 = criterion(output_1, batch_y_1)
                # print("original_loss",loss_1.item(),file=sys.stderr, flush=True)
                output_2 = trend_part[:, -self.args.pred_len:, :]
                loss_2 = criterion(output_2, batch_y_1)
                # print("trend_loss",loss_2.item(),file=sys.stderr, flush=True)

                # news cross attention
                N_max = max(us_news.shape[1],a_share_news.shape[1])
                news_all = torch.zeros(H.shape[0], N_max, us_news.shape[-1]).to(H.device)
                news_mask = torch.zeros(H.shape[0], N_max, dtype=torch.bool).to(H.device)

                news_all[:H.shape[0]//2, :us_news.shape[1], :] = us_news
                news_mask[:H.shape[0]//2, :us_news.shape[1]] = True

                news_all[H.shape[0]//2:, :a_share_news.shape[1], :] = a_share_news
                news_mask[H.shape[0]//2:, :a_share_news.shape[1]] = True
                # # TODO:长时间学习的时候考虑记住新闻
                e_prime, gates = self.cross_attn(H,news_all,news_mask)
                
                B_, N_, D_ = e_prime.shape
                t = torch.linspace(0, 1, N_, device=e_prime.device, dtype=e_prime.dtype) 
                t_expanded = t.unsqueeze(0).expand(B_, N_).unsqueeze(-1) # [B, N, 1]
                 
                X = torch.cat([t_expanded, e_prime], dim=2)
                Xcde = torchcde.LinearInterpolation(X) 
                hidden_dim = 128
                input_dim = D_ + 1 

                e_prime_mean = e_prime.mean(dim=1)  # [B, D_]
                z0 = self.z0_initializer(e_prime_mean) # [B, hidden_dim]
                z_T = torchcde.cdeint(X=Xcde, func=self.g_phi, z0=z0, t=t)

                news_emb = z_T[:, -1, :]
                
                # news_emb = F.normalize(news_emb, p=2, dim=1)
                
                trend_part = trend_part[:, -self.args.pred_len:, :]
                batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)
                flow_target = batch_y - trend_part
                
                output_ = output[:, -self.args.pred_len:, :]
                B, Seq, D = flow_target.shape
                # --- 训练阶段 (Training) ---

                # # 1. 基础物理量定义
                X0 = output_ 
                X1 = batch_y
                V0 = torch.zeros_like(X0)                 # 初始速度设为0
                
                a_target = 2 * (X1 - X0 - V0) 
                # residual_ratio = a_target.abs().mean() / X0.abs().mean()
                # print(f"Residual/Signal Ratio: {residual_ratio.item():.6f}",file=sys.stderr, flush=True)
                # print(a_target.abs().mean(),a_target.abs().min(),a_target.abs().max(),file=sys.stderr, flush=True)
                # exit()
                # 3. 计算终点速度 V1，构造 S1
                V1 = V0 + a_target
                S0 = torch.cat([X0, V0], dim=-1)         # [B, S, 2D]
                S1_target = torch.cat([X1, V1], dim=-1)   # [B, S, 2D]

                # 4. 插值得到训练样本 S_t 和加速度预测
                t_ = torch.rand(B, device=self.device)
                t = t_.view(B, 1, 1)
                S_t = (1 - t) * S0 + t * S1_target

                S_t_dup = torch.cat([S_t, S_t], dim=0)         
                t_dup = torch.cat([t_, t_], dim=0) 
                drop_mask = (torch.rand(news_emb.shape[0], device=self.device) < self.p_drop)   
                 
                cond_emb = news_emb.clone()             
                if drop_mask.any():
                    cond_emb[drop_mask] = self.null_news.unsqueeze(0)

                uncond_emb = self.null_news.unsqueeze(0).expand(B_, -1)    
                # cond_dup = torch.cat([uncond_emb, cond_emb], dim=0)   
                cond_dup = torch.cat([uncond_emb, uncond_emb], dim=0)   
                
                ut_dup = self.flow_network(S_t_dup, t_dup, cond_dup)  
                 
                ut_uncond, ut_cond = ut_dup.split(B_, dim=0)          # each [B, S, D] (预测的 A_impact)

                D_orig = D
                
                ut_x_pred, ut_v_pred = ut_cond.split(D_orig, dim=-1)
                ut_x_pred_un, ut_v_pred_un = ut_uncond.split(D_orig, dim=-1)

                # 辅助 Loss: 加速度的大小应与门控的激活程度正相关
                gate_intensity = gates.mean(dim=1) # [B]
                acc_intensity = torch.norm(ut_v_pred, dim=(1,2)) # [B]
                loss_corr = -torch.corrcoef(torch.stack([gate_intensity, acc_intensity]))[0,1]

                
                pred_acceleration = ut_v_pred     # 你的新闻分支提供的力

                # 计算余弦相似度损失
                # 我们希望 cos(theta) 趋近于 1，所以 loss = 1 - cos
                cos_sim = torch.nn.functional.cosine_similarity(pred_acceleration, flow_target, dim=-1)
                loss_direction = (1.0 - cos_sim).mean()
                
                E_v= criterion(ut_v_pred, a_target) # 位置/速度残差拟合
                E_v_un = criterion(ut_v_pred_un, a_target)
                E_x= criterion(ut_x_pred, S_t[:,:,:D]) # 位置/速度残差拟合
                E_x_un = criterion(ut_x_pred_un, S_t[:,:,:D])

                
                lambda_v = 1.0 # 建议从 1.0 开始
                loss_1 = 0.1*E_x + lambda_v * E_v
                loss_2 = 0.1* E_x_un + lambda_v * E_v_un

                l1_penalty = 1e-4 * gates.mean()
                loss = 0.5*loss_1 + 0.5*loss_2 + 0.1*loss_corr +l1_penalty + loss_direction
                
                train_loss.append(loss.item())
                

                if (i + 1) % 100 == 0:
                    # print(loss,"1111", file=sys.stderr, flush=True)
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())\
                          , file=sys.stderr, flush=True)
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)\
                          , file=sys.stderr, flush=True)
                    iter_count = 0
                    time_now = time.time()
                
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(self.g_phi.parameters(), max_norm=0.1) 
                    torch.nn.utils.clip_grad_norm_(self.flow_network.parameters(), max_norm=1.0) 
                    model_optim.step()
                    
                # if self.args.lradj == 'TST':
                adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
                scheduler.step()

                i+=1

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)\
                  , file=sys.stderr, flush=True)
            train_loss = np.average(train_loss)
            best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
            torch.save(self.flow_network.state_dict(), best_model_path_flow)
            best_model_path_cross = path + '/' + 'checkpoint_cross_s0.pth'
            torch.save(self.cross_attn.state_dict(), best_model_path_cross)
            best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
            torch.save(self.null_news.data,best_model_path_null)
            # best_model_path_gphi= path + '/' + 'checkpoint_gphi_s0.pth'
            # torch.save(self.g_phi,best_model_path_gphi)

            
            vali_loss,original_loss,trend_loss = self.vali_inf(vali_data, vali_loader_a_share,vali_loader_us, criterion,path)
            test_loss,original_loss,trend_loss = self.vali_inf(test_data, test_loader_a_share,test_loader_us, criterion,path)

            
            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss)\
                    , file=sys.stderr, flush=True)
            # early_stopping(vali_loss, self.flow_network, path)
            # if early_stopping.early_stop:
            #     print("Early stopping", file=sys.stderr, flush=True)
            #     break
            
            if self.args.lradj != 'TST':
                adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args)
            else:
                print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0])\
                      , file=sys.stderr, flush=True)


        best_model_path_flow = path + '/' + 'checkpoint_flow_s0.pth'
        torch.save(self.flow_network.state_dict(), best_model_path_flow)
        best_model_path_cross = path + '/' + 'checkpoint_cross_s0.pth'
        torch.save(self.cross_attn.state_dict(), best_model_path_cross)
        best_model_path_null = path + '/' + 'checkpoint_null_s0.pth'
        torch.save(self.null_news.data,best_model_path_null)
        # best_model_path_gphi= path + '/' + 'checkpoint_gphi_s0.pth'
        # torch.save(self.g_phi,best_model_path_gphi)

        
        return self.flow_network
    

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        
        if test:
            print('loading model')
            self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))

        preds = []
        trues = []
        inputx = []
        folder_path = './test_results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]

                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

                f_dim = -1 if self.args.features == 'MS' else 0
                # print(outputs.shape,batch_y.shape)
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
                outputs = outputs.detach().cpu().numpy()
                batch_y = batch_y.detach().cpu().numpy()

                pred = outputs  # outputs.detach().cpu().numpy()  # .squeeze()
                true = batch_y  # batch_y.detach().cpu().numpy()  # .squeeze()

                preds.append(pred)
                trues.append(true)
                inputx.append(batch_x.detach().cpu().numpy())
                if i % 20 == 0:
                    input = batch_x.detach().cpu().numpy()
                    gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
                    pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
                    visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))

        if self.args.test_flop:
            test_params_flop((batch_x.shape[1],batch_x.shape[2]))
            exit()
        preds = np.array(preds)
        trues = np.array(trues)
        inputx = np.array(inputx)

        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
        print('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse))
        f = open("result.txt", 'a')
        f.write(setting + "  \n")
        f.write('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse))
        f.write('\n')
        f.write('\n')
        f.close()

        # np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe,rse, corr]))
        np.save(folder_path + 'pred.npy', preds)
        # np.save(folder_path + 'true.npy', trues)
        # np.save(folder_path + 'x.npy', inputx)
        return

    def predict(self, setting, load=False):
        pred_data, pred_loader = self._get_data(flag='pred')

        if load:
            path = os.path.join(self.args.checkpoints, setting)
            best_model_path = path + '/' + 'checkpoint.pth'
            self.model.load_state_dict(torch.load(best_model_path))

        preds = []

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to(batch_y.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model:
                        outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                pred = outputs.detach().cpu().numpy()  # .squeeze()
                preds.append(pred)

        preds = np.array(preds)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])

        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        np.save(folder_path + 'real_prediction.npy', preds)

        return





