import torch
from torch.utils.data import Dataset

class UnsupervisedDatasets(Dataset):
    def __init__(self, data, window_size, stride):
        """
        Initialize the dataset
        :param data: DataFrame containing time series data
        :param window_size: Size of the sliding window
        :param stride: Stride of the sliding window
        """
        self.window_size = window_size
        self.stride = stride
        self.segments = data[("segment", "", "")].unique()  # Get all unique segments
        self.dimensions = [col for col in data.columns if col not in [("segment", "", ""), ("activity_id", "", ""), ("subject_id", "", "")]]
        
        # Reset the index to ensure continuous indices
        data = data.reset_index(drop=True)

        # Pre-convert all data into tensors by dimensions
        self.tensor_data = {dim: torch.tensor(data[dim].values, dtype=torch.float32) for dim in self.dimensions}
        
        self.indices = self._prepare_indices(data)
    
    def _prepare_indices(self, data):
        """
        Prepare indices for all sliding windows, treating each dimension independently
        """
        indices = []
        for segment in self.segments:
            segment_indices = data[data[("segment", "", "")] == segment].index
            for dimension in self.dimensions:
                if len(segment_indices) >= self.window_size:  # Make sure there is at least one window
                    num_windows = 1 + (len(segment_indices) - self.window_size) // self.stride
                    for i in range(num_windows):
                        start_idx = segment_indices[i * self.stride]
                        end_idx = start_idx + self.window_size - 1
                        indices.append((start_idx, end_idx, dimension))
        return indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        start_idx, end_idx, dimension = self.indices[idx]
        # Access the pre-converted tensor slice
        x_tensor = self.tensor_data[dimension][start_idx:end_idx+1]
        return x_tensor.unsqueeze(0)


if __name__ == "__main__":
    import os, sys
    from pathlib import Path

    srcpath = os.path.abspath(Path(os.path.dirname(__file__)) / "..")
    sys.path.insert(0, srcpath)

    from src.data.data_import import load_data
    from torch.utils.data import DataLoader

    data, b, c = load_data("pamap2")
    print(len(b.classes_))

    dataset = UnsupervisedDatasets(data, window_size=50, stride=25)

    # 设置 DataLoader 的参数
    batch_size = 32  # 每批数据的大小
    shuffle = True   # 是否在每个epoch开始时随机打乱数据
    num_workers = 4  # 使用的子进程数量
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
