import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from utils.timefeatures import time_features
import warnings
from utils.constants_LamaH import Constants
from sklearn.preprocessing import StandardScaler as Scaler
import torch
from joblib import Parallel, delayed
from tqdm import tqdm

warnings.filterwarnings('ignore')


class LamaHDataModule:
    def __init__(self, args, target_station='01', station_list=None, size=None,
                 features='M', target='Value', scaler_dict=None, timeenc=0, freq='d',
                 cycle=None, batch_flag='mini_batch', is_GNN=False):
        self.args = args
        self.target_station = target_station
        self.features = features
        self.target = target
        self.scaler_dict = scaler_dict
        self.scaler = Scaler()
        self.timeenc = timeenc
        self.freq = freq
        self.batch_flag = batch_flag
        self.is_GNN = is_GNN
        self.cycle = self.args.cycle

        if size is None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len, self.label_len, self.pred_len = size

        # 初始化路径
        self.root_path = args.dataset_path
        self.data_path = args.data_root_path
        self.other_features_path = args.other_features_path
        self.all_stations = Constants(args).all_stations
        self.basin_dict = Constants(args).basin_dict
        self.station = target_station
        self.__check_split_time()

        # 确定要处理的站点列表
        if station_list is not None:
            self.station_list = station_list
            if self.station not in self.all_stations:
                raise ValueError(f"Station {self.station} not in station list")
        else:
            if self.args.run_wise == 'all':
                if self.station == 'all':
                    self.station_list = self.all_stations
                else:
                    if self.station not in self.all_stations:
                        raise ValueError(f"Station {self.station} not in station list")
                    self.station_list = [self.station]
            elif self.args.run_wise in ['basin', 'basin_station']:
                if self.station not in self.basin_dict.keys():
                    raise ValueError(f"Basin {self.station} not in {self.basin_dict.keys()}")
                self.station_list = self.basin_dict[self.station]
            elif self.args.run_wise == 'station':
                if self.station not in self.all_stations:
                    raise ValueError(f"Station {self.station} not in station list")
                self.station_list = [self.station]

        self.station_ids = list(range(len(self.station_list)))

        # 读取一次数据
        self.__read_all_data()

    def __check_split_time(self):
        if self.args.data == 'LamaH_daily':
            assert pd.Timestamp(self.args.train_start_time) >= pd.Timestamp('2000-01-01'), \
                'train start time must be greater than 2000-01-01'
            assert pd.Timestamp(self.args.test_end_time) <= pd.Timestamp('2017-12-31'), \
                'test end time must be smaller than 2017-12-31'
        elif self.args.data == 'LamaH_hourly':
            assert pd.Timestamp(self.args.train_start_time) >= pd.Timestamp('2000-01-01-00'), \
                'train start time must be greater than 2000-01-01-00'
            assert pd.Timestamp(self.args.test_end_time) <= pd.Timestamp('2017-12-31-23'), \
                'test end time must be smaller than 2017-12-31-23'

        assert pd.Timestamp(self.args.train_end_time) >= pd.Timestamp(self.args.train_start_time), \
            'train end time must be greater than train start time'
        assert pd.Timestamp(self.args.vali_end_time) >= pd.Timestamp(self.args.vali_start_time), \
            'vali end time must be greater than vali start time'
        assert pd.Timestamp(self.args.test_end_time) >= pd.Timestamp(self.args.test_start_time), \
            'test end time must be greater than test start time'
        assert pd.Timestamp(self.args.vali_start_time) >= pd.Timestamp(self.args.train_start_time), \
            'vali start time must be greater than train start time'
        assert pd.Timestamp(self.args.test_start_time) >= pd.Timestamp(self.args.vali_start_time), \
            'test start time must be greater than vali start time'
        assert pd.Timestamp(self.args.vali_end_time) <= pd.Timestamp(self.args.test_start_time), \
            'vali end time must be smaller than test start time'
        assert pd.Timestamp(self.args.train_end_time) <= pd.Timestamp(self.args.vali_start_time), \
            'train end time must be smaller than vali start time'

    def __preprocess_steamflow(self, station, t_index):
        """高效读取并对齐单个站点的流量数据"""
        station_data_path = 'ID_' + str(station) + '.csv'
        df = pd.read_csv(os.path.join(str(self.root_path), str(self.data_path), station_data_path), sep=';')
        if self.args.data == 'LamaH_daily':
            df = df[['YYYY'] + ['MM'] + ['DD'] + [self.target]]
            # 将数值列转换为字符串并补零
            df['MM'] = df['MM'].astype(str).str.zfill(2)  # 例如将1转换为'01'
            df['DD'] = df['DD'].astype(str).str.zfill(2)
            # 合并为字符串后转换
            df['Timestamp'] = pd.to_datetime(df['YYYY'].astype(str) + '-' + df['MM'] + '-' + df['DD'])
            df.drop(['YYYY', 'MM', 'DD'], axis=1, inplace=True)
        elif self.args.data == 'LamaH_hourly':
            df = df[['YYYY'] + ['MM'] + ['DD'] + ['hh'] + [self.target]]
            # 将数值列转换为字符串并补零
            df['MM'] = df['MM'].astype(str).str.zfill(2)  # 例如将1转换为'01'
            df['DD'] = df['DD'].astype(str).str.zfill(2)
            df['hh'] = df['hh'].astype(str).str.zfill(2)
            # 合并为字符串后转换
            df['Timestamp'] = pd.to_datetime(df['YYYY'].astype(str) + '-' + df['MM'] + '-' + df['DD'] + '-' + df['hh'])
            df.drop(['YYYY', 'MM', 'DD', 'hh'], axis=1, inplace=True)
        else:
            raise NotImplementedError

        if self.args.data == "LamaH_daily":
            date = pd.to_datetime(df['Timestamp']).values.astype("datetime64[D]")
        elif self.args.data == "LamaH_hourly":
            date = pd.to_datetime(df['Timestamp']).values.astype("datetime64[h]")
        obs = df[self.target].values.astype(float)

        # 输出数组（对齐全局时间）
        out = np.full(len(t_index), np.nan, dtype=float)
        _, ind1, ind2 = np.intersect1d(date, t_index, return_indices=True)
        out[ind2] = obs[ind1]
        return out  # [T]

    def __preprocess_other_features(self, station, t_index):
        """高效读取并对齐单个站点的气象特征"""
        data_path = os.path.join(str(self.root_path), str(self.other_features_path), 'ID_' + str(station) + '.csv')
        df = pd.read_csv(os.path.join(data_path), sep=';')
        if self.args.data == 'LamaH_daily':
            attr_list = [
                "prec",  # precipitation
                "volsw_123",  # topsoil moisture
                "2m_temp_max",  # air temperature
                "surf_press",  # surface pressure
            ]
            df = df[['YYYY'] + ['MM'] + ['DD'] + attr_list]
            # 将数值列转换为字符串并补零
            df['MM'] = df['MM'].astype(str).str.zfill(2)  # 例如将1转换为'01'
            df['DD'] = df['DD'].astype(str).str.zfill(2)
            # 合并为字符串后转换
            df['Timestamp'] = pd.to_datetime(df['YYYY'].astype(str) + '-' + df['MM'] + '-' + df['DD'])
            df.drop(['YYYY', 'MM', 'DD'], axis=1, inplace=True)
        elif self.args.data == 'LamaH_hourly':
            attr_list = [
                "prec",  # precipitation
                "volsw_123",  # topsoil moisture
                "2m_temp",  # air temperature
                "surf_press",  # surface pressure
            ]
            df = df[['YYYY'] + ['MM'] + ['DD'] + ['hh'] + attr_list]
            # 将数值列转换为字符串并补零
            df['MM'] = df['MM'].astype(str).str.zfill(2)  # 例如将1转换为'01'
            df['DD'] = df['DD'].astype(str).str.zfill(2)
            df['hh'] = df['hh'].astype(str).str.zfill(2)
            # 合并为字符串后转换
            df['Timestamp'] = pd.to_datetime(df['YYYY'].astype(str) + '-' + df['MM'] + '-' + df['DD'] + '-' + df['hh'])
            df.drop(['YYYY', 'MM', 'DD', 'hh'], axis=1, inplace=True)
        else:
            raise NotImplementedError

        if self.args.data == "LamaH_daily":
            date = pd.to_datetime(df['Timestamp']).values.astype("datetime64[D]")
        elif self.args.data == "LamaH_hourly":
            date = pd.to_datetime(df['Timestamp']).values.astype("datetime64[h]")
        data = df[attr_list].values.astype(float)

        # 输出数组（对齐全局时间）
        out = np.full((len(t_index), len(attr_list)), np.nan, dtype=float)
        _, ind1, ind2 = np.intersect1d(date, t_index, return_indices=True)
        out[ind2, :] = data[ind1, :]
        return out  # [T, C]

    def __read_all_data(self, n_jobs=8):
        if self.args.verbose:
            print('loading dataset...')

        # ===== 全局时间范围 =====
        target_start = pd.Timestamp(self.args.train_start_time)
        target_end = pd.Timestamp(self.args.test_end_time)

        if self.args.data == "LamaH_daily":
            t_index = pd.date_range(start=target_start, end=target_end, freq=self.freq).values.astype("datetime64[D]")
        elif self.args.data == "LamaH_hourly":
            t_index = pd.date_range(start=target_start, end=target_end, freq=self.freq).values.astype("datetime64[h]")

        # ====== 缓存路径 ======
        cache_path = os.path.join(self.root_path, "data_tensor.npy")
        if self.args.data_tensor is not None:
            data_tensor = self.args.data_tensor
            df_example = self.args.df_example
        else:
            if os.path.exists(cache_path):
                if self.args.verbose:
                    print(f"Loading cached data from {cache_path}")
                data_tensor = np.load(cache_path)
            else:
                # ===== 并行读取动态特征 =====
                def process_station(station):
                    qobs = self.__preprocess_steamflow(station, t_index).reshape(-1, 1)
                    feats = self.__preprocess_other_features(station, t_index)
                    return np.concatenate([feats, qobs], axis=1)  # [T, C]

                results = Parallel(n_jobs=n_jobs)(
                    delayed(process_station)(station) for station in tqdm(self.all_stations)
                )
                data_tensor = np.stack(results, axis=1)  # [T, N, C]
                # ===== 保存缓存 =====
                np.save(cache_path, data_tensor)
            # ===== 示例 df =====
            df_example = pd.DataFrame({"Timestamp": t_index})
            df_example["example_flow"] = data_tensor[:, 0, -1]

        # ===== 数据集划分 =====
        t_train_end = pd.Timestamp(self.args.train_end_time)
        t_vali_start = pd.Timestamp(self.args.vali_start_time)
        t_vali_end = pd.Timestamp(self.args.vali_end_time)
        t_test_start = pd.Timestamp(self.args.test_start_time)
        t_test_end = pd.Timestamp(self.args.test_end_time)

        mask_train = (t_index >= target_start) & (t_index <= t_train_end)
        mask_vali = (t_index >= t_vali_start) & (t_index <= t_vali_end)
        mask_test = (t_index >= t_test_start) & (t_index <= t_test_end)

        self.num_train = mask_train.sum()
        self.num_vali = mask_vali.sum()
        self.num_test = mask_test.sum()

        # scaler：only use train part to fit
        train_data = data_tensor[:self.num_train]
        C = data_tensor.shape[-1]
        self.target_idx = C - 1  # qobs 是最后一个特征

        scaler_list = []
        # 非目标特征（每个通道一个 scaler，跨所有站点）
        for i in range(C):
            if i == self.target_idx:
                scaler_list.append(None)  # 占位，qobs 特殊处理
                continue
            scaler_i = Scaler()
            feat = train_data[:, :, i].reshape(-1, 1)
            mask = ~np.isnan(feat[:, 0])
            scaler_i.fit(feat[mask].reshape(-1, 1))
            scaler_list.append(scaler_i)

        # 目标特征 qobs：每个站点一个 scaler
        qobs_scaler_list = []
        for n in range(train_data.shape[1]):
            scaler_q = Scaler()
            feat = train_data[:, n, self.target_idx].reshape(-1, 1)
            mask = ~np.isnan(feat[:, 0])
            scaler_q.fit(feat[mask].reshape(-1, 1))
            qobs_scaler_list.append(scaler_q)

        # 保存
        self.data_tensor = data_tensor
        self.df_example = df_example
        self.scaler_list = scaler_list
        self.qobs_scaler_list = qobs_scaler_list

    def get_dataset(self, flag):
        """return train/val/test sub dataset"""
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        # dataset splitting
        len_df_raw = self.data_tensor.shape[0]
        assert len_df_raw == self.num_train + self.num_vali + self.num_test, \
            f"len_df_raw {len_df_raw} != num_train + num_vali + num_test {self.num_train + self.num_vali + self.num_test}"
        border1s = [0, self.num_train - self.seq_len, len_df_raw - self.num_test - self.seq_len]
        border2s = [self.num_train, self.num_train + self.num_vali, len_df_raw]
        border1, border2 = border1s[self.set_type], border2s[self.set_type]

        # scale process
        data = np.empty_like(self.data_tensor, dtype=float)

        C = self.data_tensor.shape[-1]
        N = self.data_tensor.shape[1]
        for i in range(C):
            if i == self.target_idx:
                # qobs: 每个站点单独缩放
                for n in range(N):
                    feat = self.data_tensor[:, n, i].reshape(-1, 1)
                    mask = ~np.isnan(feat[:, 0])
                    scaled = feat.copy()
                    scaled[mask] = self.qobs_scaler_list[n].transform(feat[mask])
                    data[:, n, i] = scaled.reshape(-1)
            else:
                # 其他特征: 全局缩放
                feat = self.data_tensor[:, :, i].reshape(-1, 1)
                mask = ~np.isnan(feat[:, 0])
                scaled = feat.copy()
                scaled[mask] = self.scaler_list[i].transform(feat[mask])
                data[:, :, i] = scaled.reshape(self.data_tensor[:, :, i].shape)

        # 训练前对缺测做最终填充（缺测=0）
        data = np.nan_to_num(data, nan=0.0)

        # Timestamp process
        df_stamp = self.df_example[['Timestamp']][border1:border2]
        df_stamp['Timestamp'] = pd.to_datetime(df_stamp['Timestamp'])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp['Timestamp'].apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp['Timestamp'].apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp['Timestamp'].apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp['Timestamp'].apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['Timestamp'], 1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['Timestamp'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
        else:
            raise NotImplementedError

        # 取出 scaled 的 data_x (train/val/test)
        data_x = data[border1:border2]
        data_y = data[border1:border2]
        data_stamp = data_stamp

        # add cycle
        cycle_index = (np.arange(len(data)) % self.cycle)[border1:border2]

        if self.args.run_wise == 'station':
            # 使用 station wise 进行建模，只取出与目标站点相关的站点数据
            station_indices = [self.all_stations.index(s) for s in self.station_list]
            station_indices = np.array(station_indices)
            data_x = data_x[:, station_indices, :]
            data_y = data_y[:, station_indices, :]
        if self.is_GNN:
            # GNN 模型只使用flow data
            data_x = data_x[:, :, -1:]
            data_y = data_y[:, :, -1:]
            self.static_tensor = None

        if self.batch_flag == 'mini_batch':
            return MiniBatchSampler(self.args, data_x, data_y, data_stamp,
                                    seq_len=self.seq_len,
                                    label_len=self.label_len,
                                    pred_len=self.pred_len,
                                    station_ids=self.station_ids,
                                    qobs_scaler_list=self.qobs_scaler_list,
                                    data_flag=flag,
                                    cycle_index=cycle_index)
        elif self.batch_flag == 'full_batch':
            return FullBatchSampler(self.args, data_x, data_y, data_stamp,
                                    seq_len=self.seq_len,
                                    label_len=self.label_len,
                                    pred_len=self.pred_len,
                                    station_ids=self.station_ids,
                                    qobs_scaler_list=self.qobs_scaler_list,
                                    data_flag=flag,
                                    cycle_index=cycle_index,
                                    target_station=self.station,
                                    all_stations=self.all_stations)
        else:
            raise NotImplementedError(f"Unknown batch flag {self.batch_flag}")


class MiniBatchSampler(Dataset):
    def __init__(self, args, data_x, data_y, data_stamp,
                 seq_len=365, label_len=0, pred_len=1,
                 station_ids=None, qobs_scaler_list=None, data_flag='train',
                 cycle_index=None):
        """
        args
        data_x: np.array [T, N, C_in]
        data_y: np.array [T, N, C_out]
        data_stamp: np.array [T, D]   # 时间戳特征，比如 [year, month, day]
        static_tensor: np.array [N, C_static] 或 None
        """
        super().__init__()
        self.args = args
        self.data_x = torch.tensor(data_x, dtype=torch.float32)
        self.data_y = torch.tensor(data_y, dtype=torch.float32)
        self.data_stamp = torch.tensor(data_stamp, dtype=torch.float32)
        self.cycle_index = torch.tensor(cycle_index, dtype=torch.long)

        self.T, self.N, self.C_in = self.data_x.shape
        _, _, self.C_out = self.data_y.shape
        self.D = self.data_stamp.shape[-1]

        self.seq_len = seq_len
        self.label_len = label_len
        self.pred_len = pred_len
        self.batch_per_epoch = self.args.batch_size
        self.station_ids = station_ids
        self.qobs_scaler_list = qobs_scaler_list
        self.data_flag = data_flag  # 'train' / 'val' / 'test'

        self.time_indices = np.arange(data_x.shape[0] - seq_len - pred_len + 1)
        self.station_ids = np.arange(self.data_x.shape[1])

    def __len__(self):
        return len(self.time_indices) * len(self.station_ids)

    def __getitem__(self, idx):
        if self.data_flag == "test":
            # 按顺序遍历 station + time
            n = idx // len(self.time_indices)
            s_begin = self.time_indices[idx % len(self.time_indices)]
        else:
            # 训练 / 验证 随机采样
            n = torch.randint(0, len(self.station_ids), (1,)).item()
            s_begin = torch.randint(0, len(self.time_indices), (1,)).item()

        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end, n, :]  # [L, C_in]
        seq_y = self.data_y[r_begin:r_end, n, :]  # [H, C_out]

        seq_x_mark = self.data_stamp[s_begin:s_end]  # [L, ...]
        seq_y_mark = self.data_stamp[r_begin:r_end]  # [H, ...]

        cycle_index = self.cycle_index[s_end]

        # # sanity check
        # L, H, C_in, C_out = self.args.seq_len, self.args.pred_len, self.args.enc_in, self.args.c_out
        #
        # assert seq_x.shape == (L, C_in), \
        #     f"Sample index {idx} seq_x shape {tuple(seq_x.shape)} != {(L, C_in)}"
        #
        # # assert seq_y.shape == (H, C_out), \
        # #     f"Sample index {idx} seq_y shape {tuple(seq_y.shape)} != {(H, C_out)}"
        #
        # assert seq_x_mark.shape == (L, 3), \
        #     f"Sample index {idx} seq_x_mark shape {tuple(seq_x_mark.shape)} != {(L, 3)}"
        #
        # assert seq_y_mark.shape == (H, 3), \
        #     f"Sample index {idx} seq_y_mark shape {tuple(seq_y_mark.shape)} != {(H, 3)}"
        #
        # # 检查 NaN
        # def _assert_no_nan(data, name):
        #     if torch.is_floating_point(data):
        #         nan_mask = torch.isnan(data)
        #         if nan_mask.any():
        #             nan_count = nan_mask.sum().item()
        #             nan_indices = nan_mask.nonzero(as_tuple=False)  # 返回二维坐标
        #             print(f"[NaN WARNING] {name}: Found {nan_count} NaN values")
        #             print(f"Indices of NaN in {name}: {nan_indices.tolist()[:50]} ...")  # 只打印前50个，避免爆屏
        #             raise ValueError(f"Found NaN in {name}")
        #
        # _assert_no_nan(seq_x, 'seq_x')
        # _assert_no_nan(seq_y, 'seq_y')
        # _assert_no_nan(seq_x_mark, 'seq_x_mark')
        # _assert_no_nan(seq_y_mark, 'seq_y_mark')

        return seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index, n, s_begin

    def inverse_transform(self, data, input_station_ids=None):
        """
        data: [batch, pred_len, C_out]
        input_station_ids: [batch] (each sample's station index)
        """
        B, H, C = data.shape
        out = data.copy()

        if C == 1:
            # 只有 qobs 输出
            for b in range(B):
                sid = input_station_ids[b]
                out[b, :, 0:1] = self.qobs_scaler_list[sid].inverse_transform(out[b, :, 0:1])
        else:
            # 多通道，假设最后一维是 qobs
            for b in range(B):
                sid = input_station_ids[b]
                out[b, :, -1:] = self.qobs_scaler_list[sid].inverse_transform(out[b, :, -1:])
        return out


class FullBatchSampler(Dataset):
    def __init__(self, args, data_x, data_y, data_stamp,
                 seq_len=365, label_len=0, pred_len=1,
                 station_ids=None, qobs_scaler_list=None, data_flag='train',
                 cycle_index=None,
                 target_station=None, all_stations=None):
        """
        args
        data_x: np.array [T, N, C_in]
        data_y: np.array [T, N, C_out]
        data_stamp: np.array [T, D]   # 时间戳特征，比如 [year, month, day]
        static_tensor: np.array [N, C_static] 或 None
        """
        super().__init__()
        self.args = args
        self.data_x = torch.tensor(data_x, dtype=torch.float32)
        self.data_y = torch.tensor(data_y, dtype=torch.float32)
        self.data_stamp = torch.tensor(data_stamp, dtype=torch.float32)
        self.cycle_index = torch.tensor(cycle_index, dtype=torch.long)

        self.T, self.N, self.C_in = self.data_x.shape
        _, _, self.C_out = self.data_y.shape
        self.D = self.data_stamp.shape[-1]

        self.seq_len = seq_len
        self.label_len = label_len
        self.pred_len = pred_len
        self.batch_per_epoch = self.args.batch_size
        self.station_ids = station_ids
        self.qobs_scaler_list = qobs_scaler_list
        self.data_flag = data_flag  # 'train' / 'val' / 'test'
        self.target_station = target_station
        self.all_stations = all_stations

        self.time_indices = np.arange(data_x.shape[0] - seq_len - pred_len + 1)

    def __len__(self):
        return len(self.time_indices)

    def __getitem__(self, idx):
        if self.data_flag == "test":
            s_begin = self.time_indices[idx]
        else:
            s_begin = np.random.choice(self.time_indices)

        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end, :, :]  # [L, N, C_in]
        seq_y = self.data_y[r_begin:r_end, :, :]  # [H, N, C_out]

        seq_x_mark = torch.as_tensor(self.data_stamp[s_begin:s_end], dtype=torch.float32)  # [L, ...]
        seq_y_mark = torch.as_tensor(self.data_stamp[r_begin:r_end], dtype=torch.float32)  # [H, ...]

        cycle_index = self.cycle_index[s_end]

        # # sanity check
        # L, H, C_in, C_out = self.args.seq_len, self.args.pred_len, self.args.enc_in, self.args.c_out
        #
        # assert seq_x.shape == (L, C_in), \
        #     f"Sample index {idx} seq_x shape {tuple(seq_x.shape)} != {(L, C_in)}"
        #
        # # assert seq_y.shape == (H, C_out), \
        # #     f"Sample index {idx} seq_y shape {tuple(seq_y.shape)} != {(H, C_out)}"
        #
        # assert seq_x_mark.shape == (L, 3), \
        #     f"Sample index {idx} seq_x_mark shape {tuple(seq_x_mark.shape)} != {(L, 3)}"
        #
        # assert seq_y_mark.shape == (H, 3), \
        #     f"Sample index {idx} seq_y_mark shape {tuple(seq_y_mark.shape)} != {(H, 3)}"
        #
        # # 检查 NaN
        # def _assert_no_nan(data, name):
        #     if torch.is_floating_point(data):
        #         nan_mask = torch.isnan(data)
        #         if nan_mask.any():
        #             nan_count = nan_mask.sum().item()
        #             nan_indices = nan_mask.nonzero(as_tuple=False)  # 返回二维坐标
        #             print(f"[NaN WARNING] {name}: Found {nan_count} NaN values")
        #             print(f"Indices of NaN in {name}: {nan_indices.tolist()[:50]} ...")  # 只打印前50个，避免爆屏
        #             raise ValueError(f"Found NaN in {name}")
        #
        # _assert_no_nan(seq_x, 'seq_x')
        # _assert_no_nan(seq_y, 'seq_y')
        # _assert_no_nan(seq_x_mark, 'seq_x_mark')
        # _assert_no_nan(seq_y_mark, 'seq_y_mark')

        return seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index, 0, s_begin

    def inverse_transform(self, data, input_station_ids=None, is_GNN=False):
        """
        data: [batch, pred_len, C_out] of target station
        """
        B, H, C = data.shape
        out = data.copy()

        if is_GNN:
            out = out.reshape(-1, C)
            for i in range(C):
                out[:, i] = self.qobs_scaler_list[i].inverse_transform(out[:, i].reshape(-1, 1)).squeeze(-1)
            out = out.reshape(B, H, C)
        else:
            # [batch size, time steps, num features], here num features=1: only water flow
            out = out.reshape(-1, C)
            sid = self.all_stations.index(self.target_station)
            out = self.qobs_scaler_list[sid].inverse_transform(out)
            out = out.reshape(B, H, C)
        return out
