import os
import re
import logging
import zipfile
from typing import Any, Dict, Optional

import pandas as pd
import numpy as np
import scipy.io as sio
from scipy.signal import resample
from joblib import Parallel, delayed
from sklearn.preprocessing import LabelEncoder

from utils.path_utils import get_directory_path

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

def split_string(s, depth=2):
    parts = s.split("_")
    if len(parts) < depth + 1:
        parts.extend([""] * (depth + 1 - len(parts)))
    return tuple(parts[: depth + 1])

# ---------------------------
# Helper function: Resample data using polyphase filtering
# ---------------------------
def _resample_data(
    data: pd.DataFrame,
    original_sample_rate: float,
    target_sample_rate: float,
) -> pd.DataFrame:
    """
    Resample the data using polyphase filtering.

    :param data: DataFrame containing time series data with multi-index label columns (e.g., ("segment", "", ""), etc.)
    :param original_sample_rate: Original sample rate
    :param target_sample_rate: Target sample rate
    :return: DataFrame containing the resampled data
    """
    segments = data[("segment", "", "")].unique()
    resampled_segments = []
    for segment in segments:
        segment_data = data[data[("segment", "", "")] == segment]
        # Extract labels (assumes labels are consistent within a segment)
        segment_labels = segment_data[[("segment", "", ""), ("subject_id", "", ""), ("activity_id", "", "")]].iloc[0]
        # Drop label columns to keep only numerical data for resampling
        segment_values = segment_data.drop(columns=[("segment", "", ""), ("subject_id", "", ""), ("activity_id", "", "")])
        num_samples = int(len(segment_values) * (target_sample_rate / original_sample_rate))
        try:
            resampled_array = resample(segment_values, num_samples)
        except Exception as e:
            logger.error("Error resampling segment %s: %s", segment, e)
            continue
        resampled_df = pd.DataFrame(resampled_array, columns=segment_values.columns)
        # Reattach original labels to the resampled data
        resampled_df[("segment", "", "")] = segment_labels[("segment", "", "")]
        resampled_df[("subject_id", "", "")] = segment_labels[("subject_id", "", "")]
        resampled_df[("activity_id", "", "")] = segment_labels[("activity_id", "", "")]
        resampled_segments.append(resampled_df)
    if resampled_segments:
        return pd.concat(resampled_segments).reset_index(drop=True)
    else:
        logger.warning("No segments were successfully resampled.")
        # Return original data if no segments could be resampled
        return data.copy() # Return a copy to avoid modifying the original DataFrame
# ---------------------------
# Helper function: Resample a single segment (for parallel processing)
# ---------------------------
def _resample_segment(
    segment_id: Any,
    segment_data_group: pd.DataFrame,
    original_sample_rate: float,
    target_sample_rate: float,
) -> Optional[pd.DataFrame]:
    """
    Resamples a single data segment. Designed to be used with parallel processing.

    :param segment_id: The ID of the segment being processed.
    :param segment_data_group: DataFrame containing data for the specific segment.
    :param original_sample_rate: Original sample rate.
    :param target_sample_rate: Target sample rate.
    :return: DataFrame containing the resampled segment data, or None if resampling fails.
    """
    try:
        # Extract labels (assumes labels are consistent within a segment)
        # Use .iloc[0] safely as groupby guarantees at least one row per group
        segment_labels = segment_data_group[[("segment", "", ""), ("subject_id", "", ""), ("activity_id", "", "")]].iloc[0]

        # Drop label columns to keep only numerical data for resampling
        segment_values = segment_data_group.drop(columns=[("segment", "", ""), ("subject_id", "", ""), ("activity_id", "", "")])

        # Calculate the number of samples for the resampled data
        num_samples = int(len(segment_values) * (target_sample_rate / original_sample_rate))

        # Perform resampling
        resampled_array = resample(segment_values, num_samples)

        # Create a new DataFrame for the resampled data
        resampled_df = pd.DataFrame(resampled_array, columns=segment_values.columns)

        # Reattach original labels to the resampled data
        resampled_df[("segment", "", "")] = segment_labels[("segment", "", "")]
        resampled_df[("subject_id", "", "")] = segment_labels[("subject_id", "", "")]
        resampled_df[("activity_id", "", "")] = segment_labels[("activity_id", "", "")]

        return resampled_df
    except Exception as e:
        logger.error("Error resampling segment %s: %s", segment_id, e)
        return None

# ---------------------------
# Helper function: Resample data using polyphase filtering (Parallel Version)
# ---------------------------
def _resample_data_parallel(
    data: pd.DataFrame,
    original_sample_rate: float,
    target_sample_rate: float,
    n_jobs: int = 32, # Use all available CPU cores by default
) -> pd.DataFrame:
    """
    Resample the data using polyphase filtering in parallel across segments.

    :param data: DataFrame containing time series data with multi-index label columns.
    :param original_sample_rate: Original sample rate.
    :param target_sample_rate: Target sample rate.
    :param n_jobs: Number of parallel jobs to run.
    :return: DataFrame containing the resampled data.
    """
    if not isinstance(data.columns, pd.MultiIndex):
        logger.error("Input data must have a MultiIndex columns structure for resampling.")
        return data.copy()

    label_cols = [col for col in data.columns if col[0] in ["segment", "subject_id", "activity_id"]]
    if not label_cols:
         logger.error("Required label columns ('segment', 'subject_id', 'activity_id') not found in data.")
         return data.copy()

    # Group data by segment
    grouped_data = data.groupby(("segment", "", ""))

    # Use joblib for parallel processing
    # Pass only the necessary data (group name and group DataFrame) to the delayed function
    resampled_segments_list = Parallel(n_jobs=n_jobs)(
        delayed(_resample_segment)(
            segment_id,
            segment_data_group, # Pass the actual group DataFrame
            original_sample_rate,
            target_sample_rate,
        )
        for segment_id, segment_data_group in grouped_data # Iterate through groups
    )

    # Filter out None results (segments that failed resampling)
    successful_segments = [df for df in resampled_segments_list if df is not None]

    if successful_segments:
        # Concatenate the successfully resampled segments
        resampled_data = pd.concat(successful_segments).reset_index(drop=True)
        # Ensure the column order matches the original data (excluding index if reset)
        # Re-apply MultiIndex structure if lost during concat
        if not isinstance(resampled_data.columns, pd.MultiIndex):
             resampled_data.columns = pd.MultiIndex.from_tuples(resampled_data.columns)
        # Reorder columns to match original data structure as closely as possible
        try:
            # Get original column order, prioritize label columns first if needed
            original_columns = data.columns.tolist()
            # Filter columns present in the resampled data
            ordered_columns = [col for col in original_columns if col in resampled_data.columns]
            # Add any new columns created during resampling (shouldn't happen here)
            ordered_columns.extend([col for col in resampled_data.columns if col not in original_columns])
            resampled_data = resampled_data[ordered_columns]
        except Exception as e:
            logger.warning(f"Could not reorder columns perfectly: {e}")

        return resampled_data
    else:
        logger.warning("No segments were successfully resampled in parallel.")
        # Return original data if no segments could be resampled
        return data.copy() # Return a copy



# ---------------------------
# Main function: Load dataset
# ---------------------------
def load_data(
    dataset_name: str = "opportunity",
    min_segment_length: int = 30,
    original_sample_rate: Optional[float] = None,
    target_sample_rate: Optional[float] = None,
    **kwargs: Any,
) -> Dict[str, Any]:
    """
    Load the specified dataset and perform resampling and segment filtering if needed.

    :param dataset_name: Name of the dataset; supports "opportunity", "mhealth", "pamap2", "dsads", "realworld2016"
    :param min_segment_length: Minimum number of samples per segment; segments that do not meet this requirement will be filtered out
    :param original_sample_rate: Original sample rate (used for resampling)
    :param target_sample_rate: Target sample rate (used for resampling)
    :param kwargs: Additional parameters passed to the specific dataset loading functions
    :return: Dictionary containing:
             "data": Processed DataFrame,
             "activity_encoder": LabelEncoder for activity_id,
             "subject_encoder": LabelEncoder for subject_id
    :raises ValueError: If the dataset_name is not supported
    """
    dataset_name_lower = dataset_name.lower()
    base_path = get_directory_path("Dataset")
    cache_path = get_directory_path("cache")

    # --- Cache Configuration ---
    # 目录结构: cache / dataset_name / processed /
    processed_cache_dir = os.path.join(cache_path, dataset_name_lower, "processed")

    # 确定是否需要重采样
    resampling_needed = (
        target_sample_rate is not None
        and original_sample_rate is not None
        and target_sample_rate != original_sample_rate
    )

    # 构建缓存文件名 (包含所有影响最终数据的参数)
    # 使用 'p' 代替 '.' 来处理浮点数采样率
    orig_sr_str = f"orig{str(original_sample_rate).replace('.', 'p')}" if original_sample_rate is not None else "origNA"
    target_sr_str = f"target{str(target_sample_rate).replace('.', 'p')}" if target_sample_rate is not None else "targetNA"
    min_len_str = f"minlen{min_segment_length}"

    # 最终处理后的数据文件名 (使用 .pkl 扩展名)
    cache_filename = f"data_{orig_sr_str}_{target_sr_str}_{min_len_str}.pkl" # <--- 修改点
    cache_filepath = os.path.join(processed_cache_dir, cache_filename)
    logger.info(f"Target processed data cache file: {cache_filepath}")
    # --- End Cache Configuration ---

    # --- 检查缓存 ---
    if os.path.exists(cache_filepath):
        logger.info(f"Found cached processed data. Loading from: {cache_filepath}")
        try:
            # 使用 pd.read_pickle 读取
            processed_data = pd.read_pickle(cache_filepath) # <--- 修改点
            logger.info("Successfully loaded data from Pickle cache.") # <--- 修改点

            # 重要：重新拟合 LabelEncoders 以匹配缓存数据中的编码值
            activity_col = ("activity_id", "", "") if isinstance(processed_data.columns, pd.MultiIndex) else "activity_id"
            subject_col = ("subject_id", "", "") if isinstance(processed_data.columns, pd.MultiIndex) else "subject_id"

            activity_id_le = LabelEncoder()
            activity_id_le.fit(processed_data[activity_col])

            subject_id_le = LabelEncoder()
            subject_id_le.fit(processed_data[subject_col])

            logger.info("Re-fitted label encoders based on cached data.")

            return {
                "data": processed_data,
                "activity_encoder": activity_id_le,
                "subject_encoder": subject_id_le,
            }
        except Exception as e:
            logger.error(f"Failed to load or process data from cache file {cache_filepath}: {e}. Proceeding with full data processing.")
            # 如果缓存读取失败，则继续执行下面的完整处理流程
    else:
        logger.info("Processed data cache not found. Starting full data processing...")
    # --- 结束缓存检查 ---


    # --- 如果缓存未命中，执行完整处理流程 ---
    # (这部分代码与 Parquet 版本相同，除了最后的保存步骤)
    dataset_loaders = {
         "opportunity": (_load_opportunity_dataset, os.path.join("OpportunityUCIDataset", "dataset")),
         "mhealth": (_load_mhealth_dataset, "MHEALTHDATASET"),
         "pamap2": (_load_pamap2_dataset, "PAMAP2"),
         "dsads": (_load_dsads_dataset, "DSADS"),
         "realworld2016": (_load_realworld2016_dataset, "realworld2016"),
         "uschad": (_load_uschad_dataset, "USC-HAD"),
         "ucihar": (_load_ucihar_dataset, "UCI HAR Dataset"),
    }
    if dataset_name_lower not in dataset_loaders:
        raise ValueError(f"Unknown dataset name: {dataset_name}")
    loader_func, relative_dataset_path = dataset_loaders[dataset_name_lower]
    dataset_path = os.path.abspath(os.path.join(base_path, relative_dataset_path))

    # 1. 加载原始数据
    raw_pickle_dir = os.path.join(cache_path, dataset_name_lower)
    raw_pickle_path = os.path.join(raw_pickle_dir, "data_tmp.pkl")
    try:
        data = loader_func(dataset_path, raw_pickle_path, **kwargs)
        logger.info(f"Loaded raw data for {dataset_name}.")
    except Exception as e:
        logger.error("Failed to load raw dataset %s: %s", dataset_name, e)
        raise

    # 2. 标签编码
    activity_id_le = LabelEncoder()
    activity_col = ("activity_id", "", "") if isinstance(data.columns, pd.MultiIndex) else "activity_id"
    data[activity_col] = activity_id_le.fit_transform(data[activity_col])
    subject_id_le = LabelEncoder()
    subject_col = ("subject_id", "", "") if isinstance(data.columns, pd.MultiIndex) else "subject_id"
    data[subject_col] = subject_id_le.fit_transform(data[subject_col])
    logger.info("Performed label encoding.")

    # 3. 重采样
    if resampling_needed:
        logger.info(f"Resampling data from {original_sample_rate} Hz to {target_sample_rate} Hz...")
        # --- 选择你的重采样函数 ---
        processed_data = _resample_data_parallel(data, original_sample_rate, target_sample_rate) # 或其他版本
        logger.info("Resampling complete.")
    else:
        logger.info("Resampling not required or rates are the same.")
        processed_data = data.copy()

    # 4. 按最小段长度过滤
    if processed_data is not None and min_segment_length > 0:
        logger.info(f"Filtering segments shorter than {min_segment_length} samples...")
        try:
            segment_col = ("segment", "", "") if isinstance(processed_data.columns, pd.MultiIndex) else "segment"
            if segment_col not in processed_data.columns:
                 logger.warning(f"Segment column '{segment_col}' not found. Skipping segment length filtering.")
            else:
                segment_lengths = processed_data.groupby(segment_col).size()
                valid_segments = segment_lengths[segment_lengths >= min_segment_length].index
                initial_rows = len(processed_data)
                processed_data = processed_data[processed_data[segment_col].isin(valid_segments)].reset_index(drop=True)
                final_rows = len(processed_data)
                logger.info(f"Filtering complete. Removed {initial_rows - final_rows} rows from {initial_rows} total.")
        except Exception as e:
            logger.error("Error filtering segments: %s", e)
            logger.warning("Returning data without segment length filtering due to error.")
    elif processed_data is None:
         logger.error("processed_data is None before filtering. This should not happen.")
         return {}

    # --- 5. 保存到缓存 (如果之前未命中缓存) ---
    try:
        logger.info(f"Saving processed data to cache: {cache_filepath}")
        os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
        # 使用 df.to_pickle 保存
        processed_data.to_pickle(cache_filepath) # <--- 修改点
        logger.info("Successfully saved processed data to Pickle cache.") # <--- 修改点
    except Exception as e:
        logger.error(f"Failed to save processed data to cache file {cache_filepath}: {e}")
        # 即使保存失败，仍然返回处理好的数据
    # --- 结束保存 ---

    # 返回新处理的数据和编码器
    return {
        "data": processed_data,
        "activity_encoder": activity_id_le,
        "subject_encoder": subject_id_le,
    }


def _load_opportunity_dataset(dataset_path: str, pickle_path: str) -> pd.DataFrame:
    """
    从指定路径加载Opportunity数据集，返回一个pandas DataFrame对象。该函数会对原始数据进行
    预处理，包括删除不必要的列、插值、填充缺失值、添加subject_id等，以使得数据更易于处理。

    Args:
        dataset_path (str): 数据集文件路径。

    Returns:
        pd.DataFrame: 包含处理后的Opportunity数据集的pandas DataFrame对象。
    """

    file_pattern = re.compile(r"\S*.dat$")
    files = [
        os.path.join(dataset_path, f)
        for f in os.listdir(dataset_path)
        if file_pattern.match(f)
    ]
    # pickle_path = os.path.join(dataset_path, "data_tmp.pkl")

    # label_seq = {
    #     406516: ['Open Door 1', 0],  # 文件中类别对应编号: [ 类别名, 预处理后label ]
    #     406517: ['Open Door 2', 1],
    #     404516: ['Close Door 1', 2],
    #     404517: ['Close Door 2', 3],
    #     406520: ['Open Fridge', 4],
    #     404520: ['Close Fridge', 5],
    #     406505: ['Open Dishwasher', 6],
    #     404505: ['Close Dishwasher', 7],
    #     406519: ['Open Drawer 1', 8],
    #     404519: ['Close Drawer 1', 9],
    #     406511: ['Open Drawer 2', 10],
    #     404511: ['Close Drawer 2', 11],
    #     406508: ['Open Drawer 3', 12],
    #     404508: ['Close Drawer 3', 13],
    #     408512: ['Clean Table', 14],
    #     407521: ['Drink from Cup', 15],
    #     405506: ['Toggle Switch', 16]
    # }

    columns_index = (
        [*range(37, 46)]
        + [*range(50, 59)]
        + [*range(63, 72)]
        + [*range(76, 85)]
        + [*range(89, 98)]
        + [*range(108, 114)]
        + [*range(124, 130)]
        + [249]
    )

    column_names = [
        # 'RKN_up_acc_x', 'RKN_up_acc_y', 'RKN_up_acc_z',
        # 'HIP_acc_x','HIP_acc_y', 'HIP_acc_z',
        # 'LUA_up_acc_x', 'LUA_up_acc_y','LUA_up_acc_z',
        # 'RUA_down_acc_x', 'RUA_down_acc_y', 'RUA_down_acc_z',
        # 'LH_acc_x', 'LH_acc_y', 'LH_acc_z',
        # 'BACK_acc_x', 'BACK_acc_y','BACK_acc_z',
        # 'RKN_down_acc_x', 'RKN_down_acc_y', 'RKN_down_acc_z',
        # 'RWR_acc_x', 'RWR_acc_y', 'RWR_acc_z',
        # 'RUA_up_acc_x', 'RUA_up_acc_y','RUA_up_acc_z',
        # 'LUA_down_acc_x', 'LUA_down_acc_y', 'LUA_down_acc_z',
        # 'LWR_acc_x', 'LWR_acc_y', 'LWR_acc_z',
        # 'RH_acc_x', 'RH_acc_y','RH_acc_z',
        "BACK_acc_x",
        "BACK_acc_y",
        "BACK_acc_z",
        "BACK_gyro_x",
        "BACK_gyro_y",
        "BACK_gyro_z",
        "BACK_magne_x",
        "BACK_magne_y",
        "BACK_magne_z",
        "RUA_acc_x",
        "RUA_acc_y",
        "RUA_acc_z",
        "RUA_gyro_x",
        "RUA_gyro_y",
        "RUA_gyro_z",
        "RUA_magne_x",
        "RUA_magne_y",
        "RUA_magne_z",
        "RLA_acc_x",
        "RLA_acc_y",
        "RLA_acc_z",
        "RLA_gyro_x",
        "RLA_gyro_y",
        "RLA_gyro_z",
        "RLA_magne_x",
        "RLA_magne_y",
        "RLA_magne_z",
        "LUA_acc_x",
        "LUA_acc_y",
        "LUA_acc_z",
        "LUA_gyro_x",
        "LUA_gyro_y",
        "LUA_gyro_z",
        "LUA_magne_x",
        "LUA_magne_y",
        "LUA_magne_z",
        "LLA_acc_x",
        "LLA_acc_y",
        "LLA_acc_z",
        "LLA_gyro_x",
        "LLA_gyro_y",
        "LLA_gyro_z",
        "LLA_magne_x",
        "LLA_magne_y",
        "LLA_magne_z",
        "L-SHOE_acc_x",
        "L-SHOE_acc_y",
        "L-SHOE_acc_z",
        "L-SHOE_gyro_x",
        "L-SHOE_gyro_y",
        "L-SHOE_gyro_z",
        "R-SHOE_acc_x",
        "R-SHOE_acc_y",
        "R-SHOE_acc_z",
        "R-SHOE_gyro_x",
        "R-SHOE_gyro_y",
        "R-SHOE_gyro_z",
    ]
    column_names = [split_string(name) for name in column_names] + [
        ("activity_id", "", "")
    ]
    column_names = pd.MultiIndex.from_tuples(
        column_names, names=["body_part", "sensor_type", "axis"]
    )

    if os.path.exists(pickle_path):
        data = pd.read_pickle(pickle_path)
    else:
        # 定义读取数据的函数
        def _read_data(file):
            data = pd.read_table(file, header=None, sep="\s+", usecols=columns_index)
            data.columns = column_names
            data = data.interpolate().fillna(0)
            data[("subject_id", "", "")] = int(
                re.search(r"S(\d+)-\S*.dat", file).group(1)
            )
            data[("segment", "", "")] = (
                data[("activity_id", "", "")]
                .ne(data[("activity_id", "", "")].shift())
                .cumsum()
            )
            # .where(lambda x: data[('activity_id','','')].ne(0), None))
            return data

        # data = pd.concat(
        #     [pd.read_table(
        #         f, header=None, sep='\s+', usecols=columns_index).
        #      interpolate().fillna(0).assign(
        #          subject_id=int(re.search(r'S(\d+)-\S*.dat', f).group(1)))
        #      for f in files],
        #     ignore_index=True)

        data = pd.concat([_read_data(file) for file in files], ignore_index=True)
        data[("segment", "", "")] = (
            data[("segment", "", "")].ne(data[("segment", "", "")].shift()).cumsum()
        )
        data = data[data[("activity_id", "", "")] != 0].reset_index(drop=True)
        data[("segment", "", "")] = (
            data[("segment", "", "")].ne(data[("segment", "", "")].shift()).cumsum()
        )
        data.to_pickle(pickle_path)

    # data.drop(["L-SHOE", "R-SHOE"], level=0, axis=1, inplace=True)
    return data


def _load_mhealth_dataset(dataset_path: str, pickle_path: str) -> pd.DataFrame:
    file_pattern = re.compile(r"mHealth_subject\d+.log")
    files = [
        os.path.join(dataset_path, f)
        for f in os.listdir(dataset_path)
        if file_pattern.match(f)
    ]
    # pickle_path = os.path.join(dataset_path, "data_tmp.pkl")

    column_names = [
        "chest_acc_x",
        "chest_acc_y",
        "chest_acc_z",
        "electrocardiogram_1",
        "electrocardiogram_2",
        "ankle_acc_x",
        "ankle_acc_y",
        "ankle_acc_z",
        "ankle_gyro_x",
        "ankle_gyro_y",
        "ankle_gyro_z",
        "ankle_magne_x",
        "ankle_magne_y",
        "ankle_magne_z",
        "arm_acc_x",
        "arm_acc_y",
        "arm_acc_z",
        "arm_gyro_x",
        "arm_gyro_y",
        "arm_gyro_z",
        "arm_magne_x",
        "arm_magne_y",
        "arm_magne_z",
    ]

    column_names = [split_string(name) for name in column_names] + [
        ("activity_id", "", "")
    ]
    column_names = pd.MultiIndex.from_tuples(
        column_names, names=["body_part", "sensor_type", "axis"]
    )

    if os.path.exists(pickle_path):
        data = pd.read_pickle(pickle_path)
    else:
        # data = pd.concat(
        #     [pd.read_table(
        #         f, header=None, sep='\s+').
        #      interpolate().fillna(0).assign(
        #          subject_id=int(
        #              re.search(r'mHealth_subject(\d+).log', f).group(1)))
        #      for f in files],
        #     ignore_index=True)

        def _read_data(file):
            data = pd.read_table(file, header=None, sep="\s+")
            data.columns = column_names
            data = data.interpolate().fillna(0)
            data[("subject_id", "", "")] = int(
                re.search(r"mHealth_subject(\d+).log", file).group(1)
            )
            data[("segment", "", "")] = (
                data[("activity_id", "", "")]
                .ne(data[("activity_id", "", "")].shift())
                .cumsum()
            )
            return data

        data = pd.concat([_read_data(f) for f in files], ignore_index=True)

        # data.columns = column_names
        data = data[data["activity_id"] != 0].reset_index(drop=True)
        data[("segment", "", "")] = (
            data[("segment", "", "")].ne(data[("segment", "", "")].shift()).cumsum()
        )
        data.to_pickle(pickle_path)

    data.drop(["electrocardiogram"], level=0, axis=1, inplace=True)

    return data


def _load_pamap2_dataset(
    dataset_path: str,
    pickle_path: str,
    dataset_type: str = "protocol",
) -> pd.DataFrame:
    dataset_type_lower = dataset_type.lower()

    if dataset_type_lower not in ["protocol", "optional", "all"]:
        raise ValueError(
            "Invalid dataset_type. Choose from 'protocol', 'optional', or 'all'."
        )

    protocol_path = os.path.join(dataset_path, "Protocol")
    optional_path = os.path.join(dataset_path, "Optional")

    # 生成完整的路径
    pickle_dir = os.path.join(os.path.dirname(pickle_path), dataset_type)
    # 确保中间路径存在
    os.makedirs(pickle_dir, exist_ok=True)

    pickle_path = os.path.join(pickle_dir, os.path.basename(pickle_path))
    if os.path.exists(pickle_path):
        data = pd.read_pickle(pickle_path)
    else:
        # Load and concatenate data
        protocol_files = [
            os.path.join(protocol_path, f) for f in os.listdir(protocol_path)
        ]
        optional_files = [
            os.path.join(optional_path, f) for f in os.listdir(optional_path)
        ]

        if dataset_type_lower == "protocol":
            list_of_files = protocol_files
        elif dataset_type_lower == "optional":
            list_of_files = optional_files
        else:
            list_of_files = protocol_files + optional_files

        colNames = ["timestamp", "activity_id", "heartrate"]

        IMUhand = [
            "handTemperature",
            "hand_acc_x",
            "hand_acc_y",
            "hand_acc_z",
            "hand_acc-6_x",
            "hand_acc-6_y",
            "hand_acc-6_z",
            "hand_gyro_x",
            "hand_gyro_y",
            "hand_gyro_z",
            "hand_magne_x",
            "hand_magne_y",
            "hand_magne_z",
            "handOrientation_1",
            "handOrientation_2",
            "handOrientation_3",
            "handOrientation_4",
        ]

        IMUchest = [
            "chestTemperature",
            "chest_acc_x",
            "chest_acc_y",
            "chest_acc_z",
            "chest_acc-6_x",
            "chest_acc-6_y",
            "chest_acc-6_z",
            "chest_gyro_x",
            "chest_gyro_y",
            "chest_gyro_z",
            "chest_magne_x",
            "chest_magne_y",
            "chest_magne_z",
            "chestOrientation_1",
            "chestOrientation_2",
            "chestOrientation_3",
            "chestOrientation_4",
        ]

        IMUankle = [
            "ankleTemperature",
            "ankle_acc_x",
            "ankle_acc_y",
            "ankle_acc_z",
            "ankle_acc-6_x",
            "ankle_acc-6_y",
            "ankle_acc-6_z",
            "ankle_gyro_x",
            "ankle_gyro_y",
            "ankle_gyro_z",
            "ankle_magne_x",
            "ankle_magne_y",
            "ankle_magne_z",
            "ankleOrientation_1",
            "ankleOrientation_2",
            "ankleOrientation_3",
            "ankleOrientation_4",
        ]

        column_names = IMUhand + IMUchest + IMUankle
        column_names = [tuple([name, "", ""]) for name in colNames] + [
            split_string(name) for name in column_names
        ]
        column_names = pd.MultiIndex.from_tuples(
            column_names, names=["body_part", "sensor_type", "axis"]
        )

        subject_id_regex = re.compile(r"subject10(\d+)\.dat$")

        def _read_data(file):
            data = pd.read_table(file, header=None, sep="\s+")
            data.columns = column_names
            data = data.interpolate().fillna(0)
            data[("subject_id", "", "")] = int(subject_id_regex.search(file).group(1))
            data[("segment", "", "")] = (
                data[("activity_id", "", "")]
                .ne(data[("activity_id", "", "")].shift())
                .cumsum()
            )
            return data

        # data = pd.concat([
        #     pd.read_table(file, header=None, sep='\s+')
        #     .assign(subject_id=int(subject_id_regex.search(file).group(1)))
        #     # .rename(columns=column_names)
        #     for file in list_of_files
        # ], ignore_index=True)
        data = pd.concat([_read_data(f) for f in list_of_files], ignore_index=True)
        # data.columns = column_names
        # Clean the data
        data.loc[data.index[0:4], "heartrate"] = data.loc[4, "heartrate"].values[0]
        data.drop(
            [
                "handOrientation",
                "chestOrientation",
                "ankleOrientation",
                "handTemperature",
                "chestTemperature",
                "ankleTemperature",
                "timestamp",
                "heartrate",
            ],
            axis=1,
            level=0,
            inplace=True,
        )

        data = data[data["activity_id"] != 0].reset_index(drop=True)
        data = data.apply(pd.to_numeric, errors="coerce").interpolate()

        # 第四个人的地第五个动作有问题，只有一个值，所以要去掉
        data = data.loc[~((data["subject_id"] == 4) & (data["activity_id"] == 5))]
        # Save the processed data to a pickle file
        # 生成一个布尔值Series
        bool_series = (data["activity_id"] != data["activity_id"].shift()) | (
            data["subject_id"] != data["subject_id"].shift()
        )

        # 对布尔值Series进行累加操作，生成一个新的segment列
        data["segment"] = bool_series.cumsum()
        data.to_pickle(pickle_path)

    return data


def _load_dsads_dataset(dataset_path: str, pickle_path: str) -> pd.DataFrame:
    """
    读取DSADS数据集，返回DataFrame格式的数据集，数据已经插值和填充
    """
    # pickle_path = os.path.join(dataset_path, "data_tmp.pkl")
    if os.path.exists(pickle_path):
        data = pd.read_pickle(pickle_path)
    else:
        file_pattern = re.compile(r"\S*.txt$")
        list_of_files = [
            os.path.join(root, file)
            for root, _, files in os.walk(dataset_path, topdown=False)
            for file in files
            if file_pattern.match(file)
        ]
        list_of_files = sorted(list_of_files)
        subject_activity_regex = re.compile(
            r"a(?P<activity_id>\d+)(.*)p(?P<subject_id>\d+)\S*$"
        )
        data_list = []

        sensor_types = [
            "acc_x",
            "acc_y",
            "acc_z",
            "gyro_x",
            "gyro_y",
            "gyro_z",
            "magne_x",
            "magne_y",
            "magne_z",
        ]
        units = ["T", "RA", "LA", "RL", "LL"]
        column_names = [f"{unit}_{sensor}" for unit in units for sensor in sensor_types]
        column_names = (
            [split_string(name) for name in column_names]
            + [("activity_id", "", "")]
            + [("subject_id", "", "")]
        )
        column_names = pd.MultiIndex.from_tuples(
            column_names, names=["body_part", "sensor_type", "axis"]
        )

        for file in list_of_files:
            tempdata = pd.read_table(file, header=None, sep=",")

            # 使用apply()和interpolate()插值，填充缺失值
            tempdata = tempdata.apply(lambda col: col.interpolate().fillna(0))

            matches = subject_activity_regex.search(file).groupdict()
            tempdata = tempdata.assign(
                activity_id=int(matches["activity_id"]),
                subject_id=int(matches["subject_id"]),
            )
            data_list.append(tempdata)

        # 使用pd.concat()将多个DataFrame合并为一个
        data = pd.concat(data_list, axis=0, ignore_index=True)
        data.columns = column_names
        # 生成一个布尔值Series
        bool_series = (
            data[("activity_id", "", "")] != data[("activity_id", "", "")].shift()
        ) | (data[("subject_id", "", "")] != data[("subject_id", "", "")].shift())

        # 对布尔值Series进行累加操作，生成一个新的segment列
        data[("segment", "", "")] = bool_series.cumsum()

        # # 使用sort_values()对数据进行排序
        # data['tmp'] = data.index
        # # data.sort_values(by=['subject_id', 'tmp'],
        # #                  inplace=True, ascending=True)

        # # 过滤activity_id不等于0的数据，并重置索引
        # data = data[data['activity_id'] != 0].reset_index(drop=True)
        # data.drop(['tmp'], axis=1, level=0, inplace=True)

        # 保存处理后的数据到pickle文件
        data.to_pickle(pickle_path)
    return data


def _load_realworld2016_dataset(dataset_path: str, pickle_path: str) -> pd.DataFrame:
    if os.path.exists(pickle_path):
        data = pd.read_pickle(pickle_path)
    else:
        file_pattern = re.compile(r".*(acc|gyr|mag).*_csv\.zip$")
        list_of_files = [
            os.path.join(root, file)
            for root, _, files in os.walk(dataset_path, topdown=False)
            for file in files
            if file_pattern.match(file)
        ]
        list_of_files = sorted(list_of_files)
        subject_regex = re.compile(r"proband(?P<subject_id>\d+)")
        for file in list_of_files:
            subject_id = subject_regex.search(file).group("subject_id")
            if subject_id in ["2", "4", "6", "7", "14"]:
                continue
            extract_path = os.path.join(
                dataset_path, "extract", f"subject_{subject_id}"
            )
            with zipfile.ZipFile(file, "r") as zip_ref:
                if not os.path.exists(extract_path):
                    os.makedirs(extract_path)
                zip_ref.extractall(extract_path)

        file_pattern = re.compile(r".*.csv$")
        list_of_subject = os.listdir(os.path.join(dataset_path, "extract"))
        match = re.compile(
            r"^(?P<sensor_type>[a-zA-Z]+)_(?P<activity>[a-zA-Z]+).*_(?P<body_part>[a-zA-Z]+).csv"
        )
        activitys = [
            match.search(file).group("activity")
            for root, _, files in os.walk(
                os.path.join(dataset_path, "extract"), topdown=False
            )
            for file in files
            if file_pattern.match(file)
        ]
        activitys = list(set(activitys))

        def merge_data(data_dict):
            tmp1 = ["id", "time", "x", "y", "z"]
            tmp2 = {"acc": "acc", "Gyroscope": "gyro", "MagneticField": "magne"}
            tmp3 = []
            for key, value in data_dict.items():
                columns = [
                    (
                        match.search(key).group("body_part"),
                        tmp2[match.search(key).group("sensor_type")],
                        i,
                    )
                    for i in tmp1
                ]
                columns = pd.MultiIndex.from_tuples(
                    columns, names=["body_part", "sensor_type", "axis"]
                )
                value.columns = columns
                tmp3.append(value)
            tmp3 = pd.concat(tmp3, axis=1)
            first_time = max(tmp3.loc[0, (slice(None), slice(None), "time")])
            tmp4 = []
            for body_part in tmp3.columns.levels[0]:
                for sensor_type in tmp3.columns.levels[1]:
                    tmp5 = tmp3.loc[:, (body_part, sensor_type, slice(None))].dropna()
                    tmp5 = tmp5.loc[
                        tmp5.loc[:, (body_part, sensor_type, "time")] >= first_time
                    ].reset_index(drop=True)
                    tmp4.append(tmp5)
            length = min([i.shape[0] for i in tmp4])
            tmp6 = pd.concat([i.iloc[:length, :] for i in tmp4], axis=1)
            return tmp6

        segment = 0
        subjects_data = []
        for subject in list_of_subject:
            list_of_files = os.listdir(os.path.join(dataset_path, "extract", subject))
            list_of_files = [file for file in list_of_files if file_pattern.match(file)]
            for activity in activitys:
                activity_datas = {
                    file: pd.read_csv(
                        os.path.join(dataset_path, "extract", subject, file)
                    )
                    for file in list_of_files
                    if activity in file
                }
                activity_data = merge_data(activity_datas)
                activity_data.loc[:, ("activity_id", "", "")] = activity
                activity_data.loc[:, ("subject_id", "", "")] = subject
                activity_data.loc[:, ("segment", "", "")] = segment
                segment += 1
                subjects_data.append(activity_data)
        data = pd.concat(subjects_data, axis=0)
        data.to_pickle(pickle_path)
    data.drop(["id", "time"], axis=1, level=2, inplace=True)
    return data


def _load_uschad_dataset(dataset_path: str, pickle_path: str) -> pd.DataFrame:
    if os.path.exists(pickle_path):
        data = pd.read_pickle(pickle_path)
    else:
        file_pattern = re.compile(r".*\.mat$")
        list_of_files = [
            os.path.join(root, file)
            for root, _, files in os.walk(dataset_path, topdown=False)
            for file in files
            if file_pattern.match(file)
        ]
        list_of_files = sorted(list_of_files)
        subject_activity_regex = re.compile(
            r"Subject(?P<subject_id>\d+).*a(?P<activity_id>\d+)t"
        )
        data_list = []

        column_names = [
            "front-right-hip_acc_x",
            "front-right-hip_acc_y",
            "front-right-hip_acc_z",
            "front-right-hip_gyro_x",
            "front-right-hip_gyro_y",
            "front-right-hip_gyro_z",
        ]
        column_names = (
            [split_string(name) for name in column_names]
            + [("activity_id", "", "")]
            + [("subject_id", "", "")]
            + [("segment", "", "")]
        )
        column_names = pd.MultiIndex.from_tuples(
            column_names, names=["body_part", "sensor_type", "axis"]
        )

        for i, file in enumerate(list_of_files):
            mat = sio.loadmat(file)
            tempdata = pd.DataFrame(mat["sensor_readings"])
            matches = subject_activity_regex.search(file).groupdict()
            tempdata = tempdata.assign(
                activity_id=int(matches["activity_id"]),
                subject_id=int(matches["subject_id"]),
                segment=int(i) + 1,
            )

            data_list.append(tempdata)

        # 使用pd.concat()将多个DataFrame合并为一个
        data = pd.concat(data_list, axis=0, ignore_index=True)
        data.columns = column_names

        data.to_pickle(pickle_path)
    return data


def _load_ucihar_dataset(dataset_path: str, pickle_path: str) -> pd.DataFrame:
    if os.path.exists(pickle_path):
        return pd.read_pickle(pickle_path)
    
    # 定义传感器类型到列命名的映射
    SENSOR_MAP = {
        "total_acc": ("Waist", "acc"),
        "body_gyro": ("Waist", "gyro")
    }
    
    # 直接构造文件路径列表
    def build_file_list(split_type):
        return [
            (os.path.join(dataset_path, f"{split_type}/Inertial Signals/{sensor}_{axis}_{split_type}.txt"),
             SENSOR_MAP[sensor][0],  # 设备位置
             SENSOR_MAP[sensor][1],  # 传感器类型
             axis)                   # 坐标轴
            for sensor in ["total_acc", "body_gyro"]
            for axis in ["x", "y", "z"]
        ]
    
    # 加载数据集的通用函数
    def load_split_data(split_type):
        data = pd.DataFrame()
        for file_path, loc, sensor, axis in build_file_list(split_type):
            try:
                # 读取数据并展平为一维数组
                values = pd.read_csv(file_path, header=None, sep=r"\s+").values.flatten()
                data[(loc, sensor, axis)] = values
            except FileNotFoundError:
                raise RuntimeError(f"Missing required file: {file_path}")
        return data

    # 加载标签和受试者信息
    def load_labels_and_subjects(split_type):
        base = os.path.join(dataset_path, split_type)
        y = pd.read_csv(f"{base}/y_{split_type}.txt", header=None, sep=r"\s+").squeeze()
        subject = pd.read_csv(f"{base}/subject_{split_type}.txt", header=None, sep=r"\s+").squeeze()
        return y.repeat(64).reset_index(drop=True), subject.repeat(64).reset_index(drop=True)

    # 处理训练集和测试集
    train_data = load_split_data("train")
    test_data = load_split_data("test")

    # 移除原始数据集滑窗后的重叠部分
    def slice_data(df):
        return pd.concat(
            [df.iloc[i*128 : i*128+64] 
             for i in range(len(df) // 128)],
            ignore_index=True
        )
    
    train_data = slice_data(train_data)
    test_data = slice_data(test_data)

    # 添加标签和受试者信息
    for split_type, df in [("train", train_data), ("test", test_data)]:
        y, subject = load_labels_and_subjects(split_type)
        df[("activity_id", "", "")] = y
        df[("subject_id", "", "")] = subject

    # 生成分段标识
    def add_segment_column(df, start=0):
        changed = (df[("activity_id", "", "")].ne(df[("activity_id", "", "")].shift())) | \
                  (df[("subject_id", "", "")].ne(df[("subject_id", "", "")].shift()))
        df[("segment", "", "")] = changed.cumsum() + start
        return df

    train_data = add_segment_column(train_data)
    test_data = add_segment_column(test_data, start=train_data[("segment", "", "")].max() + 1)

    # 合并数据集并保存
    combined_data = pd.concat([train_data, test_data], ignore_index=True)
    combined_data.columns = pd.MultiIndex.from_tuples(combined_data.columns)
    combined_data.to_pickle(pickle_path)
    
    return combined_data

if __name__ == "__main__":

    import os, sys
    from pathlib import Path

    srcpath = os.path.abspath(Path(os.path.dirname(__file__)) / "..")
    sys.path.insert(0, srcpath)

    opportunity = load_data(
        "UCIHAR",
        min_segment_length=1,
        original_sample_rate=50,
        target_sample_rate=50,
    )