import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from utils.timefeatures import time_features
import warnings
from utils.constants_LamaH import Constants
from sklearn.preprocessing import StandardScaler as Scaler
import math
from joblib import Parallel, delayed
from tqdm import tqdm

warnings.filterwarnings('ignore')


class CamelsDataModule:
    def __init__(self, args, target_station='01', station_list=None, size=None,
                 features='M', target='Value', scaler_dict=None, timeenc=0, freq='d', batch_flag='mini_batch',
                 is_GNN=False):
        self.args = args
        self.target_station = target_station
        self.features = features
        self.target = target
        self.scaler_dict = scaler_dict
        self.scaler = Scaler()
        self.timeenc = timeenc
        self.freq = freq
        self.cycle = self.args.cycle
        self.batch_flag = batch_flag
        self.is_GNN = is_GNN
        # 在初始化时检查DI参数合法性
        if self.args.use_di:
            assert self.args.di_window > 0, "DI window must be positive integer"
            assert self.args.di_window < self.args.seq_len, "DI window cannot exceed sequence length"

        if size is None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len, self.label_len, self.pred_len = size

        # 初始化路径
        self.root_path = args.dataset_path
        self.data_path = args.data_root_path
        self.other_features_path = args.other_features_path
        self.basin_feature_path = args.basin_feature_path
        self.all_stations = Constants(args).all_stations
        self.basin_dict = Constants(args).basin_dict
        self.station = target_station
        self.__check_split_time()

        # 确定要处理的站点列表
        if station_list is not None:
            self.station_list = station_list
            if self.station not in self.all_stations:
                raise ValueError(f"Station {self.station} not in station list")
        else:
            if self.args.run_wise == 'all':
                if self.station == 'all':
                    self.station_list = self.all_stations
                else:
                    raise ValueError(f"Station only support 'all' when run_wise is 'all', but get station {self.station}")
            elif self.args.run_wise in ['basin', 'basin_station']:
                if self.station not in self.basin_dict.keys():
                    raise ValueError(f"Basin {self.station} not in {self.basin_dict.keys()}")
                self.station_list = self.basin_dict[self.station]
            elif self.args.run_wise == 'station':
                if self.station not in self.all_stations:
                    raise ValueError(f"Station {self.station} not in station list")
                self.station_list = [self.station]

        self.station_ids = list(range(len(self.station_list)))

        # 读取一次数据
        self.__read_all_data()

    def __check_split_time(self):
        assert pd.Timestamp(self.args.train_start_time) >= pd.Timestamp('1985-01-01'), \
            'train start time must be greater than 1985-01-01'
        assert pd.Timestamp(self.args.train_end_time) >= pd.Timestamp(self.args.train_start_time), \
            'train end time must be greater than train start time'
        assert pd.Timestamp(self.args.vali_end_time) >= pd.Timestamp(self.args.vali_start_time), \
            'vali end time must be greater than vali start time'
        assert pd.Timestamp(self.args.test_end_time) >= pd.Timestamp(self.args.test_start_time), \
            'test end time must be greater than test start time'
        assert pd.Timestamp(self.args.vali_start_time) >= pd.Timestamp(self.args.train_start_time), \
            'vali start time must be greater than train start time'
        assert pd.Timestamp(self.args.test_start_time) >= pd.Timestamp(self.args.vali_start_time), \
            'test start time must be greater than vali start time'
        assert pd.Timestamp(self.args.test_end_time) <= pd.Timestamp('2014-12-31'), \
            'test end time must be smaller than 2014-12-31'
        assert pd.Timestamp(self.args.vali_end_time) <= pd.Timestamp(self.args.test_start_time), \
            'vali end time must be smaller than test start time'
        assert pd.Timestamp(self.args.train_end_time) <= pd.Timestamp(self.args.vali_start_time), \
            'train end time must be smaller than vali start time'

    def __read_static_data(self):
        """高效读取静态流域属性"""
        attr_list = [
            'p_mean', 'pet_mean', 'p_seasonality', 'frac_snow', 'aridity', 'high_prec_freq', 'high_prec_dur',
            'low_prec_freq', 'low_prec_dur', 'elev_mean', 'slope_mean', 'area_gages2', 'frac_forest',
            'lai_max', 'lai_diff', 'gvf_max', 'gvf_diff', 'dom_land_cover_frac', 'dom_land_cover',
            'root_depth_50', 'soil_depth_pelletier', 'soil_depth_statsgo', 'soil_porosity',
            'soil_conductivity', 'max_water_content', 'sand_frac', 'silt_frac', 'clay_frac', 'geol_1st_class',
            'glim_1st_class_frac', 'geol_2nd_class', 'glim_2nd_class_frac', 'carbonate_rocks_frac',
            'geol_porostiy', 'geol_permeability'
        ]
        attr_file_list = ['clim', 'geol', 'hydro', 'soil', 'topo', 'vege']

        # 存放结果的 dict {attr: Series}
        data_dict = {}
        f_dict = {}  # 存储字符串映射

        for attr_file in attr_file_list:
            file_path = os.path.join(
                str(self.root_path), str(self.basin_feature_path), f"camels_{attr_file}.txt"
            )
            df = pd.read_csv(file_path, sep=";")
            df["gauge_id"] = df["gauge_id"].astype(str).str.zfill(8)
            df = df.set_index("gauge_id")

            for attr in attr_list:
                if attr in df.columns:
                    col = df[attr]
                    if col.dtype == "object":
                        values, ref = pd.factorize(col, sort=True)
                        data_dict[attr] = pd.Series(values, index=df.index)
                        f_dict[attr] = ref.tolist()
                    else:
                        data_dict[attr] = col

        # 一次性组装 DataFrame
        df_attr = pd.DataFrame(data_dict).fillna(0)

        # ⚡ 转 NumPy
        static_tensor = df_attr.values.astype(float)

        return df_attr, static_tensor, f_dict

    def __preprocess_steamflow(self, station, t_index):
        """高效读取并对齐单个站点的流量数据"""
        basin_id = [id for id in self.basin_dict.keys() if station in self.basin_dict[id]][0]
        data_path = os.path.join(str(self.root_path), str(self.data_path), basin_id,
                                 f"{station}_streamflow_qc.txt")

        df = pd.read_csv(
            data_path,
            sep=r"\s+",
            header=None,
            names=['id', 'year', 'month', 'day', 'qobs', 'q_tag'],
            engine="python"
        )

        # 转日期 + 转 NumPy
        date = pd.to_datetime(df[['year', 'month', 'day']]).values.astype("datetime64[D]")
        obs = df['qobs'].values.astype(float)

        # 缺测值处理：标记为 M 的流量置 0，其它小于 0 的置 nan
        obs[df['q_tag'].values == 'M'] = 0
        obs[obs < 0] = np.nan

        # 输出数组（对齐全局时间）
        out = np.full(len(t_index), np.nan, dtype=float)
        _, ind1, ind2 = np.intersect1d(date, t_index, return_indices=True)
        out[ind2] = obs[ind1]

        return out  # shape = [T]

    def __preprocess_other_features(self, station, t_index):
        """高效读取并对齐单个站点的气象特征"""
        attr_list = ['dayl(s)', 'prcp(mm/day)', 'srad(W/m2)', 'tmax(C)', 'vp(Pa)']
        basin_id = [id for id in self.basin_dict.keys() if station in self.basin_dict[id]][0]
        data_path = os.path.join(str(self.root_path), str(self.other_features_path), basin_id,
                                 f"{station}_lump_cida_forcing_leap.txt")

        df = pd.read_csv(
            data_path, sep=r"\s+", engine="python", skiprows=3, header=0
        )

        date = pd.to_datetime(df[['Year', 'Mnth', 'Day']].rename(
            columns={'Year': 'year', 'Mnth': 'month', 'Day': 'day'}
        )).values.astype("datetime64[D]")

        data = df[attr_list].values.astype(float)

        # 输出数组（对齐全局时间）
        out = np.full((len(t_index), len(attr_list)), np.nan, dtype=float)
        _, ind1, ind2 = np.intersect1d(date, t_index, return_indices=True)
        out[ind2, :] = data[ind1, :]

        return out  # shape = [T, C]

    def __read_all_data(self, n_jobs=8):
        """一次性读取所有数据"""
        if self.args.verbose:
            print("loading dataset...")

        # ===== 全局时间范围 =====
        target_start = pd.Timestamp(self.args.train_start_time)
        target_end = pd.Timestamp(self.args.test_end_time)

        # ====== 缓存路径 ======
        cache_path = os.path.join(self.root_path, "data_tensor.npy")
        static_cache_path = os.path.join(self.root_path, "static_tensor.npy")

        if os.path.exists(cache_path) and os.path.exists(static_cache_path):
            if self.args.verbose:
                print(f"Loading cached data from {cache_path}")
            data_tensor = np.load(cache_path)
            static_tensor = np.load(static_cache_path)
            # 仍然需要 basin_norm_data 和 df_example
            df_attr, _, _ = self.__read_static_data()
            df_attr = df_attr.loc[df_attr.index.astype(str).isin(map(str, self.all_stations))]
            basin_norm_data = df_attr[['area_gages2', 'p_mean']].values
            t_index = pd.date_range(
                start=pd.Timestamp(self.args.train_start_time),
                end=pd.Timestamp(self.args.test_end_time),
                freq=self.freq
            ).values.astype("datetime64[D]")
            df_example = pd.DataFrame({"Timestamp": t_index})
            df_example["example_flow"] = data_tensor[:, 0, -1]
        else:
            t_index = pd.date_range(start=target_start, end=target_end, freq=self.freq).values.astype("datetime64[D]")

            # ===== 静态特征 =====
            df_attr, static_tensor, fDict = self.__read_static_data()
            df_attr = df_attr.loc[df_attr.index.astype(str).isin(map(str, self.all_stations))]
            static_tensor = df_attr.values.astype(float)
            basin_norm_data = df_attr[['area_gages2', 'p_mean']].values

            # ===== 并行读取动态特征 =====
            def process_station(station):
                qobs = self.__preprocess_steamflow(station, t_index).reshape(-1, 1)
                feats = self.__preprocess_other_features(station, t_index)
                return np.concatenate([feats, qobs], axis=1)  # [T, C]

            results = Parallel(n_jobs=n_jobs)(
                delayed(process_station)(station) for station in tqdm(self.all_stations)
            )
            data_tensor = np.stack(results, axis=1)  # [T, N, C]
            # ===== 保存缓存 =====
            np.save(cache_path, data_tensor)
            np.save(static_cache_path, static_tensor)

        # ===== 数据集划分 =====
        t_train_end = pd.Timestamp(self.args.train_end_time)
        t_vali_start = pd.Timestamp(self.args.vali_start_time)
        t_vali_end = pd.Timestamp(self.args.vali_end_time)
        t_test_start = pd.Timestamp(self.args.test_start_time)
        t_test_end = pd.Timestamp(self.args.test_end_time)

        mask_train = (t_index >= target_start) & (t_index <= t_train_end)
        mask_vali = (t_index >= t_vali_start) & (t_index <= t_vali_end)
        mask_test = (t_index >= t_test_start) & (t_index <= t_test_end)

        self.num_train = mask_train.sum()
        self.num_vali = mask_vali.sum()
        self.num_test = mask_test.sum()

        # ===== 示例 df =====
        df_example = pd.DataFrame({"Timestamp": t_index})
        df_example["example_flow"] = data_tensor[:, 0, -1]

        # norm from paper: Meta-LSTM in hydrology: Advancing runoff predictions through model-agnostic meta-learning
        # https://github.com/AXuaner/MetaLSTM/blob/main/hydroDL/data/camels.py
        for i in range(data_tensor.shape[1]):
            # basin norm
            temp_area = basin_norm_data[i, 0]
            temp_prep = basin_norm_data[i, 1]
            flow = data_tensor[:, i, -1]
            flow = (flow * 0.0283168 * 3600 * 24) / (
                    (temp_area * (10 ** 6)) * (temp_prep * 10 ** (-3))
            )  # (m^3/day)/(m^3/day)
            # tran norm
            flow = np.log10(np.sqrt(flow) + 0.1)
            data_tensor[:, i, -1] = flow

        # scaler：only use train part to fit
        train_data = data_tensor[:self.num_train]
        C = data_tensor.shape[-1]
        self.target_idx = C - 1  # qobs 是最后一个特征

        scaler_list = []
        # 非目标特征（每个通道一个 scaler，跨所有站点）
        for i in range(C):
            if i == self.target_idx:
                scaler_list.append(None)  # 占位，qobs 特殊处理
                continue
            scaler_i = Scaler()
            feat = train_data[:, :, i].reshape(-1, 1)
            mask = ~np.isnan(feat[:, 0])
            scaler_i.fit(feat[mask].reshape(-1, 1))
            scaler_list.append(scaler_i)

        # 目标特征 qobs：每个站点一个 scaler
        qobs_scaler_list = []
        for n in range(train_data.shape[1]):
            scaler_q = Scaler()
            feat = train_data[:, n, self.target_idx].reshape(-1, 1)
            mask = ~np.isnan(feat[:, 0])
            scaler_q.fit(feat[mask].reshape(-1, 1))
            qobs_scaler_list.append(scaler_q)

        # ========== DI模式预处理 ==========
        if self.args.use_di:
            # 复制当前径流数据并滞后平移
            di_data = np.zeros_like(data_tensor[..., -1:])  # 只处理径流列

            # 沿时间轴滞后平移（滞后di_window天）
            di_data[self.args.di_window:] = data_tensor[:-self.args.di_window, :, -1:]

            # 拼接滞后观测作为新特征, 特征维度：[气象特征, 滞后观测, 原始径流]
            data_tensor = np.concatenate([
                data_tensor[..., :-1],  # 气象特征
                di_data,  # 滞后观测
                data_tensor[..., -1:]  # 原始径流（保持在末尾）
            ], axis=-1)
            # 更新归一化器列表
            scaler_list.append(scaler_list[-1])  # 复用径流归一化器

        # 保存
        self.data_tensor = data_tensor
        self.static_tensor = static_tensor
        self.basin_norm_data = basin_norm_data
        self.df_example = df_example
        self.scaler_list = scaler_list
        self.qobs_scaler_list = qobs_scaler_list

    def get_dataset(self, flag):
        """return train/val/test sub dataset"""
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]

        # dataset splitting
        len_df_raw = self.data_tensor.shape[0]
        assert len_df_raw == self.num_train + self.num_vali + self.num_test, \
            f"len_df_raw {len_df_raw} != num_train + num_vali + num_test {self.num_train + self.num_vali + self.num_test}"
        border1s = [0, self.num_train - self.seq_len, len_df_raw - self.num_test - self.seq_len]
        border2s = [self.num_train, self.num_train + self.num_vali, len_df_raw]
        border1, border2 = border1s[self.set_type], border2s[self.set_type]

        # scale process
        data = np.empty_like(self.data_tensor, dtype=float)

        C = self.data_tensor.shape[-1]
        N = self.data_tensor.shape[1]
        for i in range(C):
            if i == self.target_idx:
                # qobs: 每个站点单独缩放
                for n in range(N):
                    feat = self.data_tensor[:, n, i].reshape(-1, 1)
                    mask = ~np.isnan(feat[:, 0])
                    scaled = feat.copy()
                    scaled[mask] = self.qobs_scaler_list[n].transform(feat[mask])
                    data[:, n, i] = scaled.reshape(-1)
            else:
                # 其他特征: 全局缩放
                feat = self.data_tensor[:, :, i].reshape(-1, 1)
                mask = ~np.isnan(feat[:, 0])
                scaled = feat.copy()
                scaled[mask] = self.scaler_list[i].transform(feat[mask])
                data[:, :, i] = scaled.reshape(self.data_tensor[:, :, i].shape)

        # 训练前对缺测做最终填充（缺测=0）
        data = np.nan_to_num(data, nan=0.0)

        # Timestamp process
        df_stamp = self.df_example[['Timestamp']][border1:border2]
        df_stamp['Timestamp'] = pd.to_datetime(df_stamp['Timestamp'])
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp['Timestamp'].apply(lambda row: row.month, 1)
            df_stamp['day'] = df_stamp['Timestamp'].apply(lambda row: row.day, 1)
            df_stamp['weekday'] = df_stamp['Timestamp'].apply(lambda row: row.weekday(), 1)
            df_stamp['hour'] = df_stamp['Timestamp'].apply(lambda row: row.hour, 1)
            data_stamp = df_stamp.drop(['Timestamp'], 1).values
        elif self.timeenc == 1:
            data_stamp = time_features(pd.to_datetime(df_stamp['Timestamp'].values), freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)
        else:
            raise NotImplementedError

        # 取出 scaled 的 data_x (train/val/test)
        data_x = data[border1:border2]
        data_y = data[border1:border2]
        data_stamp = data_stamp

        # add cycle
        cycle_index = (np.arange(len(data)) % self.cycle)[border1:border2]

        if self.args.run_wise == 'station':
            # 使用 station wise 进行建模，只取出与目标站点相关的站点数据
            station_indices = [self.all_stations.index(s) for s in self.station_list]
            station_indices = np.array(station_indices)
            data_x = data_x[:, station_indices, :]
            data_y = data_y[:, station_indices, :]
        if self.is_GNN:
            # GNN 模型只使用flow data
            data_x = data_x[:, :, -1:]
            data_y = data_y[:, :, -1:]
            self.static_tensor = None

        if self.batch_flag == 'mini_batch':
            return MiniBatchSampler(self.args, data_x, data_y, data_stamp,
                                    static_tensor=self.static_tensor,
                                    basin_norm_data=self.basin_norm_data,
                                    seq_len=self.seq_len,
                                    label_len=self.label_len,
                                    pred_len=self.pred_len,
                                    station_ids=self.station_ids,
                                    qobs_scaler_list=self.qobs_scaler_list,
                                    data_flag=flag,
                                    cycle_index=cycle_index)
        elif self.batch_flag == 'full_batch':
            return FullBatchSampler(self.args, data_x, data_y, data_stamp,
                                    static_tensor=self.static_tensor,
                                    basin_norm_data=self.basin_norm_data,
                                    seq_len=self.seq_len,
                                    label_len=self.label_len,
                                    pred_len=self.pred_len,
                                    station_ids=self.station_ids,
                                    qobs_scaler_list=self.qobs_scaler_list,
                                    data_flag=flag,
                                    cycle_index=cycle_index,
                                    target_station=self.station,
                                    all_stations=self.all_stations)
        else:
            raise NotImplementedError(f"Unknown batch flag {self.batch_flag}")


class MiniBatchSampler(Dataset):
    def __init__(self, args, data_x, data_y, data_stamp,
                 static_tensor=None, basin_norm_data=None,
                 seq_len=365, label_len=0, pred_len=1,
                 station_ids=None, qobs_scaler_list=None, data_flag='train',
                 cycle_index=None):
        """
        args
        data_x: np.array [T, N, C_in]
        data_y: np.array [T, N, C_out]
        data_stamp: np.array [T, D]   # 时间戳特征，比如 [year, month, day]
        static_tensor: np.array [N, C_static] 或 None
        """
        super().__init__()
        self.args = args
        self.data_x = torch.tensor(data_x, dtype=torch.float32)
        self.data_y = torch.tensor(data_y, dtype=torch.float32)
        self.data_stamp = torch.tensor(data_stamp, dtype=torch.float32)
        self.static_tensor = None if static_tensor is None else torch.tensor(static_tensor, dtype=torch.float32)
        self.basin_norm_data = basin_norm_data
        self.cycle_index = cycle_index

        self.T, self.N, self.C_in = self.data_x.shape
        _, _, self.C_out = self.data_y.shape
        self.D = self.data_stamp.shape[-1]

        self.seq_len = seq_len
        self.label_len = label_len
        self.pred_len = pred_len
        self.batch_per_epoch = self.args.batch_size
        self.station_ids = station_ids
        self.qobs_scaler_list = qobs_scaler_list
        self.data_flag = data_flag  # 'train' / 'val' / 'test'

        self.time_indices = np.arange(data_x.shape[0] - seq_len - pred_len + 1)
        self.station_ids = np.arange(self.data_x.shape[1])

    def __len__(self):
        return len(self.time_indices) * len(self.station_ids)

    def __getitem__(self, idx):
        if self.data_flag == "test":
            # 按顺序遍历 station + time
            n = idx // len(self.time_indices)
            s_begin = self.time_indices[idx % len(self.time_indices)]
        else:
            # 训练 / 验证 随机采样
            n = torch.randint(0, len(self.station_ids), (1,)).item()
            s_begin = np.random.randint(0, len(self.time_indices))

        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end, n, :]  # [L, C_in]
        seq_y = self.data_y[r_begin:r_end, n, :]  # [H, C_out]

        if self.static_tensor is not None:
            c_feat = self.static_tensor[n].expand(self.seq_len, -1)
            seq_x = torch.cat([c_feat, seq_x], dim=-1)

        seq_y = torch.as_tensor(seq_y, dtype=torch.float32)

        seq_x_mark = torch.as_tensor(self.data_stamp[s_begin:s_end], dtype=torch.float32)  # [L, ...]
        seq_y_mark = torch.as_tensor(self.data_stamp[r_begin:r_end], dtype=torch.float32)  # [H, ...]

        cycle_index = self.cycle_index[s_end]

        # # sanity check
        # L, H, C_in, C_out = self.args.seq_len, self.args.pred_len, self.args.enc_in, self.args.c_out
        #
        # assert seq_x.shape == (L, C_in), \
        #     f"Sample index {idx} seq_x shape {tuple(seq_x.shape)} != {(L, C_in)}"
        #
        # # assert seq_y.shape == (H, C_out), \
        # #     f"Sample index {idx} seq_y shape {tuple(seq_y.shape)} != {(H, C_out)}"
        #
        # assert seq_x_mark.shape == (L, 3), \
        #     f"Sample index {idx} seq_x_mark shape {tuple(seq_x_mark.shape)} != {(L, 3)}"
        #
        # assert seq_y_mark.shape == (H, 3), \
        #     f"Sample index {idx} seq_y_mark shape {tuple(seq_y_mark.shape)} != {(H, 3)}"
        #
        # # 检查 NaN
        # def _assert_no_nan(data, name):
        #     if torch.is_floating_point(data):
        #         nan_mask = torch.isnan(data)
        #         if nan_mask.any():
        #             nan_count = nan_mask.sum().item()
        #             nan_indices = nan_mask.nonzero(as_tuple=False)  # 返回二维坐标
        #             print(f"[NaN WARNING] {name}: Found {nan_count} NaN values")
        #             print(f"Indices of NaN in {name}: {nan_indices.tolist()[:50]} ...")  # 只打印前50个，避免爆屏
        #             raise ValueError(f"Found NaN in {name}")
        #
        # _assert_no_nan(seq_x, 'seq_x')
        # _assert_no_nan(seq_y, 'seq_y')
        # _assert_no_nan(seq_x_mark, 'seq_x_mark')
        # _assert_no_nan(seq_y_mark, 'seq_y_mark')

        return seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index, n, s_begin

    def inverse_transform(self, data, input_station_ids=None):
        """
        data: [batch, pred_len, C_out]
        input_station_ids: [batch] (each sample's station index)
        """
        B, H, C = data.shape

        # de-norm from paper: Meta-LSTM in hydrology: Advancing runoff predictions through model-agnostic meta-learning
        # https://github.com/AXuaner/MetaLSTM/blob/main/hydroDL/data/camels.py
        out = []
        for b in range(B):
            sid = input_station_ids[b].item()
            data[b, :, -1:] = self.qobs_scaler_list[sid].inverse_transform(data[b, :, -1:])
            temp_area = self.basin_norm_data[sid, 0]
            temp_prep = self.basin_norm_data[sid, 1]

            d = np.power(10, data[b]) - 0.1
            d[d < 0] = 0  # set negative as zero
            d = d ** 2
            d = d * ((temp_area * (10 ** 6)) * (temp_prep * 10 ** (-3))) / (0.0283168 * 3600 * 24)
            out.append(d)
        return np.stack(out, axis=0)  # [batch, pred_len, C_out]


class FullBatchSampler(Dataset):
    def __init__(self, args, data_x, data_y, data_stamp,
                 static_tensor=None, basin_norm_data=None,
                 seq_len=365, label_len=0, pred_len=1,
                 station_ids=None, qobs_scaler_list=None, data_flag='train',
                 cycle_index=None, target_station=None, all_stations=None):
        """
        args
        data_x: np.array [T, N, C_in]
        data_y: np.array [T, N, C_out]
        data_stamp: np.array [T, D]   # 时间戳特征，比如 [year, month, day]
        static_tensor: np.array [N, C_static] 或 None
        """
        super().__init__()
        self.args = args
        self.data_x = torch.tensor(data_x, dtype=torch.float32)
        self.data_y = torch.tensor(data_y, dtype=torch.float32)
        self.data_stamp = torch.tensor(data_stamp, dtype=torch.float32)
        self.static_tensor = None if static_tensor is None else torch.tensor(static_tensor, dtype=torch.float32)
        self.basin_norm_data = basin_norm_data
        self.cycle_index = cycle_index

        self.T, self.N, self.C_in = self.data_x.shape
        _, _, self.C_out = self.data_y.shape
        self.D = self.data_stamp.shape[-1]

        self.seq_len = seq_len
        self.label_len = label_len
        self.pred_len = pred_len
        self.batch_per_epoch = self.args.batch_size
        self.station_ids = station_ids
        self.qobs_scaler_list = qobs_scaler_list
        self.data_flag = data_flag  # 'train' / 'val' / 'test'
        self.target_station = target_station
        self.all_stations = all_stations

        self.time_indices = np.arange(data_x.shape[0] - seq_len - pred_len + 1)

    def __len__(self):
        return len(self.time_indices)

    def __getitem__(self, idx):
        if self.data_flag == "test":
            s_begin = self.time_indices[idx]
        else:
            s_begin = np.random.choice(self.time_indices)

        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len

        seq_x = self.data_x[s_begin:s_end, :, :]  # [L, N, C_in]
        seq_y = self.data_y[r_begin:r_end, :, :]  # [H, N, C_out]

        seq_x_list = []
        for n in range(seq_x.shape[1]):
            x_n = torch.as_tensor(seq_x[:, n, :], dtype=torch.float32)
            if self.static_tensor is not None:
                c_feat = self.static_tensor[n].expand(self.seq_len, -1)
                x_n = torch.cat([c_feat, x_n], dim=-1)
            seq_x_list.append(x_n)
        seq_x = torch.stack(seq_x_list, dim=1)  # [L, N, C_in+C_static]
        seq_y = torch.as_tensor(seq_y, dtype=torch.float32)

        seq_x_mark = torch.as_tensor(self.data_stamp[s_begin:s_end], dtype=torch.float32)  # [L, ...]
        seq_y_mark = torch.as_tensor(self.data_stamp[r_begin:r_end], dtype=torch.float32)  # [H, ...]

        cycle_index = self.cycle_index[s_end]

        # # sanity check
        # L, H, C_in, C_out = self.args.seq_len, self.args.pred_len, self.args.enc_in, self.args.c_out
        #
        # assert seq_x.shape == (L, C_in), \
        #     f"Sample index {idx} seq_x shape {tuple(seq_x.shape)} != {(L, C_in)}"
        #
        # # assert seq_y.shape == (H, C_out), \
        # #     f"Sample index {idx} seq_y shape {tuple(seq_y.shape)} != {(H, C_out)}"
        #
        # assert seq_x_mark.shape == (L, 3), \
        #     f"Sample index {idx} seq_x_mark shape {tuple(seq_x_mark.shape)} != {(L, 3)}"
        #
        # assert seq_y_mark.shape == (H, 3), \
        #     f"Sample index {idx} seq_y_mark shape {tuple(seq_y_mark.shape)} != {(H, 3)}"
        #
        # # 检查 NaN
        # def _assert_no_nan(data, name):
        #     if torch.is_floating_point(data):
        #         nan_mask = torch.isnan(data)
        #         if nan_mask.any():
        #             nan_count = nan_mask.sum().item()
        #             nan_indices = nan_mask.nonzero(as_tuple=False)  # 返回二维坐标
        #             print(f"[NaN WARNING] {name}: Found {nan_count} NaN values")
        #             print(f"Indices of NaN in {name}: {nan_indices.tolist()[:50]} ...")  # 只打印前50个，避免爆屏
        #             raise ValueError(f"Found NaN in {name}")
        #
        # _assert_no_nan(seq_x, 'seq_x')
        # _assert_no_nan(seq_y, 'seq_y')
        # _assert_no_nan(seq_x_mark, 'seq_x_mark')
        # _assert_no_nan(seq_y_mark, 'seq_y_mark')

        return seq_x, seq_y, seq_x_mark, seq_y_mark, cycle_index, 0, s_begin

    def inverse_transform(self, data, input_station_ids=None, is_GNN=False):
        """
        data: [batch, pred_len, C_out] of target station
        """
        B, H, C = data.shape

        # de-norm from paper: Meta-LSTM in hydrology: Advancing runoff predictions through model-agnostic meta-learning
        # https://github.com/AXuaner/MetaLSTM/blob/main/hydroDL/data/camels.py
        if is_GNN:
            # [batch size, time steps, num stations]
            out = []
            for c in range(data.shape[-1]):
                s_data = data[:, :, c]
                s_data = s_data.reshape(-1, 1)
                s_data = self.qobs_scaler_list[c].inverse_transform(s_data).squeeze(-1)
                s_data = s_data.reshape(B, H)

                temp_area = self.basin_norm_data[c, 0]
                temp_prep = self.basin_norm_data[c, 1]

                d = np.power(10, s_data) - 0.1
                d[d < 0] = 0  # set negative as zero
                d = d ** 2
                d = d * ((temp_area * (10 ** 6)) * (temp_prep * 10 ** (-3))) / (0.0283168 * 3600 * 24)
                out.append(d)
            out = np.stack(out, axis=-1)  # [batch, pred_len, num_stations]
        else:
            # [batch size, time steps, num features], here num features=1: only water flow
            data = data.reshape(-1, 1)
            sid = self.all_stations.index(self.target_station)
            data = self.qobs_scaler_list[sid].inverse_transform(data)
            data = data.reshape(B, H, C)

            temp_area = self.basin_norm_data[sid, 0]
            temp_prep = self.basin_norm_data[sid, 1]
            d = np.power(10, data) - 0.1
            d[d < 0] = 0  # set negative as zero
            d = d ** 2
            out = d * ((temp_area * (10 ** 6)) * (temp_prep * 10 ** (-3))) / (0.0283168 * 3600 * 24)
        return out
