import os
import math
import random
import logging
import copy
from typing import List, Dict, Tuple, Optional, Union

import joblib
import lightning as l
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, ConcatDataset, random_split, Subset
from collections import defaultdict
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedShuffleSplit

from data.data_import import load_data
from utils.path_utils import get_directory_path
from utils.sensor_config_loader import SensorConfigLoader
from utils.global_label_converter import GlobalLabelConverter

# Configure logging (optional but good practice)
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

# --- 1. Base Class: Handles Data Loading and Window Indexing ---

class BaseWindowDataset(Dataset):
    """
    Base dataset class for time-series data.

    Handles loading data from a specified source, validating its structure,
    and preparing indices for fixed-size sliding windows. Subclasses should
    override __getitem__ to define how raw window data is processed.
    """
    def __init__(self, config):
        """
        Initializes the dataset by loading data and preparing window indices.

        Args:
            config: Configuration object/dict containing dataset parameters like:
                - dataset_name (str): Identifier for the dataset to load.
                - window_size (int): The number of time steps in each window.
                - stride (int): The step size for the sliding window.
                - min_segment_length (int): Minimum length required for a data segment.
                - original_sample_rate (float): Original sample rate of the data.
                - target_sample_rate (float): Target sample rate for resampling.
        """
        super().__init__()
        self.config = config
        self.dataset_name = config.dataset_name
        self.window_size = config.window_size
        self.stride = config.stride

        if not all(hasattr(config, attr) for attr in ['dataset_name', 'window_size', 'stride', 'min_segment_length', 'original_sample_rate', 'target_sample_rate']):
             # Use % formatting for logger
             logger.warning("Config object might be missing expected attributes for BaseWindowDataset.")

        # Use % formatting for logger
        logger.info("[%s] Loading data...", self.dataset_name)
        dataset_bundle = self._load_data()
        self.data: pd.DataFrame = dataset_bundle["data"]
        self.activity_label_encoder = dataset_bundle["activity_encoder"]
        self.subject_label_encoder = dataset_bundle["subject_encoder"]

        self._validate_data()

        # Use % formatting for logger
        logger.info("[%s] Preparing window indices...", self.dataset_name)
        self.segments = self.data[("segment", "", "")].unique()
        self.indices: list[tuple[int, int]] = self._prepare_indices()
        # Use % formatting for logger
        logger.info("[%s] Found %d windows.", self.dataset_name, len(self.indices))

        if not self.indices:
            # Use % formatting for logger
            logger.warning("[%s] No valid windows found with window_size=%d and stride=%d.",
                           self.dataset_name, self.window_size, self.stride)


    def _load_data(self):
        """Loads the dataset using the configured parameters."""
        try:
            return load_data(
                dataset_name=self.config.dataset_name,
                min_segment_length=self.config.min_segment_length,
                original_sample_rate=self.config.original_sample_rate,
                target_sample_rate=self.config.target_sample_rate,
            )
        except Exception as e:
            # Use % formatting for logger
            logger.error("[%s] Failed to load data: %s", self.dataset_name, e, exc_info=True)
            raise


    def _validate_data(self):
        """Performs basic validation on the loaded DataFrame."""
        if self.data is None or self.data.empty:
            # Use % formatting for logger
            raise ValueError(f"[{self.dataset_name}] Data loading returned an empty or None DataFrame.") # f-string ok in Exception

        required_meta_cols = [("segment", "", ""), ("activity_id", "", ""), ("subject_id", "", "")]
        missing_cols = [col for col in required_meta_cols if col not in self.data.columns]
        if missing_cols:
            # Use % formatting for logger
            logger.error("[%s] Loaded data is missing required metadata columns: %s", self.dataset_name, missing_cols)
            logger.error("Available columns: %s", self.data.columns.tolist())
            raise ValueError(f"[{self.dataset_name}] Missing required metadata columns.") # f-string ok in Exception


    def _prepare_indices(self) -> list[tuple[int, int]]:
        """
        Generates a list of (start_index, end_index) tuples for valid windows.
        ... (rest of docstring) ...
        """
        indices = []
        segment_col = ("segment", "", "")
        for segment_id in self.segments:
            segment_data = self.data[self.data[segment_col] == segment_id]
            segment_len = len(segment_data)

            if segment_len >= self.window_size:
                num_windows = 1 + (segment_len - self.window_size) // self.stride
                segment_indices = segment_data.index

                for i in range(num_windows):
                    start_pos = i * self.stride
                    end_pos = start_pos + self.window_size - 1

                    if end_pos < segment_len:
                        start_idx = segment_indices[start_pos]
                        end_idx = segment_indices[end_pos]
                        indices.append((start_idx, end_idx))
        return indices

    def __len__(self) -> int:
        """Returns the total number of windows in the dataset."""
        return len(self.indices)

    def __getitem__(self, idx: int) -> tuple[pd.DataFrame, int, int]:
        """
        Retrieves raw data for a specific window index.
        ... (rest of docstring) ...
        """
        if not 0 <= idx < len(self.indices):
            raise IndexError(f"Index {idx} out of bounds for dataset with length {len(self.indices)}") # f-string ok in Exception

        start_idx, end_idx = self.indices[idx]
        meta_cols_to_drop = [("segment", "", ""), ("activity_id", "", ""), ("subject_id", "", "")]

        try:
            window_slice = self.data.loc[start_idx:end_idx]

            if len(window_slice) != self.window_size:
                 # Use % formatting for logger
                 logger.warning("[%s] Window slice at index %d (%s:%s) has unexpected length %d, expected %d. Check index monotonicity.",
                                self.dataset_name, idx, start_idx, end_idx, len(window_slice), self.window_size)

            x_df = window_slice.drop(columns=meta_cols_to_drop, errors='ignore')
            activity_id = self.data.at[start_idx, ("activity_id", "", "")]
            subject_id = self.data.at[start_idx, ("subject_id", "", "")]

        except KeyError as e:
             # Use % formatting for logger
             logger.error("[%s] KeyError accessing data/labels for window index %d (%s:%s): %s",
                          self.dataset_name, idx, start_idx, end_idx, e)
             raise
        except Exception as e:
             # Use % formatting for logger
             logger.error("[%s] Unexpected error getting item at index %d (%s:%s): %s",
                          self.dataset_name, idx, start_idx, end_idx, e, exc_info=True)
             raise

        return x_df, int(activity_id), int(subject_id)


# --- 2. Standard Dataset: Converts Base data to Tensors ---

class WindowedTensorDataset(BaseWindowDataset):
    """
    A dataset that extends BaseWindowDataset by converting the
    windowed data and labels into PyTorch Tensors.
    """
    def __init__(self, config):
        """
        Initializes the dataset, loading data via the base class.

        Args:
            config: Configuration object passed to BaseWindowDataset.
        """
        super().__init__(config)
        # Use % formatting for logger
        logger.info("[%s] Initialized WindowedTensorDataset.", self.dataset_name)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Retrieves a window as Tensors.
        ... (rest of docstring) ...
        """
        x_df, activity_id, subject_id = super().__getitem__(idx)

        try:
            x_np = x_df.to_numpy(dtype=np.float32)
            x_tensor = torch.from_numpy(x_np)
            activity_label = torch.tensor(activity_id, dtype=torch.int64)
            subject_label = torch.tensor(subject_id, dtype=torch.int64)

        except Exception as e:
            # Use % formatting for logger
            logger.error("[%s] Error converting data/labels to tensor at index %d: %s",
                         self.dataset_name, idx, e, exc_info=True)
            raise

        return x_tensor, activity_label, subject_label


# --- 3. Specialized Dataset: Adds Sensor Embeddings and Filtering ---

class MoVQFormerDataset(WindowedTensorDataset):
    """
    A specialized dataset for the MoVQFormer model.
    ... (rest of docstring) ...
    """
    def __init__(self, config, allowed_activity_labels: list[int] | None = None):
        """
        Initializes the dataset, loads data, applies filtering, and loads embeddings.
        ... (rest of docstring) ...
        """
        super().__init__(config)
        self.allowed_activity_labels = allowed_activity_labels

        # Use % formatting for logger
        logger.info("[%s] Loading sensor embeddings...", self.dataset_name)
        self.sensor_embeddings: torch.Tensor = self._get_sensor_embeddings()

        if self.allowed_activity_labels is not None:
            # Use % formatting for logger
            logger.info("[%s] Filtering indices by allowed activity labels: %s",
                        self.dataset_name, self.allowed_activity_labels)
            original_indices_count = len(self.indices)
            self.indices = self._filter_indices_by_activity()
            # Use % formatting for logger
            logger.info("[%s] Filtered indices: %d -> %d",
                        self.dataset_name, original_indices_count, len(self.indices))
            if not self.indices:
                 # Use % formatting for logger
                 logger.warning("[%s] No windows remaining after filtering by activity labels.", self.dataset_name)

        # Use % formatting for logger
        logger.info("[%s] Initialized MoVQFormerDataset with %d samples.", self.dataset_name, len(self))


    def _get_sensor_embeddings(self) -> torch.Tensor:
        """Loads sensor embeddings for the specific dataset."""
        try:
            embeddings = SensorConfigLoader().embedded_vector[self.config.dataset_name]
            if not isinstance(embeddings, torch.Tensor):
                 embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32)
            else:
                 embeddings_tensor = embeddings.float()
            # Use % formatting for logger
            logger.info("[%s] Loaded sensor embeddings with shape: %s", self.dataset_name, embeddings_tensor.shape)
            return embeddings_tensor
        except KeyError:
            # Use % formatting for logger
            logger.error("[%s] Sensor embeddings not found for dataset name '%s' in SensorConfigLoader.",
                         self.dataset_name, self.config.dataset_name)
            raise
        except Exception as e:
            # Use % formatting for logger
            logger.error("[%s] Failed to load or process sensor embeddings: %s", self.dataset_name, e, exc_info=True)
            raise


    def _filter_indices_by_activity(self) -> list[tuple[int, int]]:
        """
        Filters the pre-calculated window indices based on allowed activity labels.
        ... (rest of docstring) ...
        """
        if self.allowed_activity_labels is None:
            return self.indices

        filtered_indices = []
        allowed_set = set(self.allowed_activity_labels)
        activity_col = ("activity_id", "", "")

        for start_idx, end_idx in self.indices:
            try:
                activity_id = int(self.data.at[start_idx, activity_col])
                if activity_id in allowed_set:
                    filtered_indices.append((start_idx, end_idx))
            except KeyError:
                 # Use % formatting for logger
                 logger.warning("[%s] Could not find activity label at index %s during filtering. Skipping index.",
                                self.dataset_name, start_idx)
                 continue
            except Exception as e:
                 # Use % formatting for logger
                 logger.warning("[%s] Error checking activity label at index %s during filtering: %s. Skipping index.",
                                self.dataset_name, start_idx, e)
                 continue

        return filtered_indices


    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Retrieves a processed window and sensor embeddings.
        ... (rest of docstring) ...
        """
        if not 0 <= idx < len(self.indices):
             raise IndexError(f"Index {idx} out of bounds for filtered dataset with length {len(self.indices)}") 

        x_tensor, activity_label, subject_label = super().__getitem__(idx)

        return (
            x_tensor,
            self.sensor_embeddings,
            activity_label,
            subject_label,
        )
    
class CommonDataModule(l.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.config.split_ratio = config.split_ratio
        self.config.batch_size = config.batch_size
        self.config.num_workers = config.num_workers

        self.data = None
        self.activity_label_encoder = None
        self.subject_label_encoder = None
        self.dataset = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        pass

    def setup(self, stage=None):
        # 加载数据
        dataset_bundle = load_data(
            dataset_name=self.config.dataset_name,
            min_segment_length=self.config.min_segment_length,
            original_sample_rate=self.config.original_sample_rate,
            target_sample_rate=self.config.target_sample_rate,
        )

        self.data = dataset_bundle["data"]
        self.activity_label_encoder = dataset_bundle["activity_encoder"]
        self.subject_label_encoder = dataset_bundle["subject_encoder"]

        # 创建数据集实例
        self.dataset = CommonDataset(self.config)

        # 拆分数据集
        train_size = int(self.config.split_ratio[0] * len(self.dataset))
        val_size = int(self.config.split_ratio[1] * len(self.dataset))
        test_size = len(self.dataset) - train_size - val_size

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            self.dataset, [train_size, val_size, test_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=self.config.train_shuffle,
            num_workers=self.config.num_workers,
            persistent_workers=self.config.persistent_workers,
            pin_memory=self.config.pin_memory,
            prefetch_factor=self.config.prefetch_factor,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            persistent_workers=self.config.persistent_workers,
            pin_memory=self.config.pin_memory,
            prefetch_factor=self.config.prefetch_factor,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            persistent_workers=self.config.persistent_workers,
            pin_memory=self.config.pin_memory,
            prefetch_factor=self.config.prefetch_factor,
        )

# --- DataModule ---

class ProportionalSubsetSampler(Sampler[list[int]]):
    """
    Samples batches proportionally from multiple original datasets within a Subset.

    Ensures each batch contains indices (from the Subset's index list)
    that correspond to only one original underlying dataset. Mixes batches
    from different original datasets according to their relative sizes *within the subset*.

    Args:
        subset_indices (list[int]): The list of indices defining the Subset
                                    (e.g., `train_subset.indices`). These are indices
                                    relative to the original ConcatDataset.
        original_cumulative_lengths (list[int]): Cumulative lengths of the
                                                 original datasets in the ConcatDataset
                                                 (e.g., [0, len(ds0), len(ds0)+len(ds1), ...]).
        batch_size (int): Size of batches.
        shuffle (bool): If True, shuffle indices within each original dataset group
                        and the order of batches.
        drop_last (bool): If True, drop the last incomplete batch from each group.
    """
    def __init__(self,
                 subset_indices: list[int],
                 original_cumulative_lengths: list[int],
                 batch_size: int,
                 shuffle: bool = True,
                 drop_last: bool = False):

        # --- Input Validation ---
        if not isinstance(subset_indices, list):
            raise TypeError("subset_indices should be a list of integers")
        if not isinstance(original_cumulative_lengths, list) or not all(isinstance(l, int) for l in original_cumulative_lengths):
             raise TypeError("original_cumulative_lengths should be a list of integers")
        if not original_cumulative_lengths or original_cumulative_lengths[0] != 0:
             raise ValueError("original_cumulative_lengths must start with 0")
        if batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value")
        if not isinstance(shuffle, bool):
            raise TypeError("shuffle should be a boolean value")
        if not isinstance(drop_last, bool):
            raise TypeError("drop_last should be a boolean value")

        self.subset_indices_global = subset_indices # Keep original global indices for mapping
        self.original_cumulative_lengths = original_cumulative_lengths
        self.num_original_datasets = len(original_cumulative_lengths) - 1
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        # --- Create mapping from global index to local index (0 to len(subset)-1) ---
        self.global_to_local_map = {global_idx: local_idx for local_idx, global_idx in enumerate(self.subset_indices_global)}
        self.subset_size = len(self.subset_indices_global)

       # --- Grouping LOCAL Indices by Source ---
        self.indices_by_source: dict[int, list[int]] = defaultdict(list) # Stores LOCAL indices
        unmapped_count = 0
        for global_idx in self.subset_indices_global: # Iterate through global indices
            source_ds_idx = -1
            for i in range(self.num_original_datasets):
                if self.original_cumulative_lengths[i] <= global_idx < self.original_cumulative_lengths[i+1]:
                    source_ds_idx = i
                    break

            if source_ds_idx != -1:
                local_idx = self.global_to_local_map[global_idx] # Find the corresponding local index
                self.indices_by_source[source_ds_idx].append(local_idx) # Group the LOCAL index by source
            else:
                unmapped_count += 1

        if unmapped_count > 0:
            logger.warning("%d indices from the subset could not be mapped to an original dataset source. "
                           "Check if subset indices are valid w.r.t original_cumulative_lengths (max index: %d).",
                           unmapped_count, self.original_cumulative_lengths[-1] - 1 if self.original_cumulative_lengths else -1)


        # --- Calculate Lengths and Batches Per Source (within this subset) ---
        self.lengths_by_source = [len(self.indices_by_source.get(i, [])) for i in range(self.num_original_datasets)]
        self.num_batches_per_source = []
        for i, length in enumerate(self.lengths_by_source):
            if length == 0:
                num_batches = 0
            elif self.drop_last:
                num_batches = length // self.batch_size
            else:
                # Use math.ceil correctly
                num_batches = math.ceil(length / self.batch_size)
            self.num_batches_per_source.append(num_batches)
            # Log if a source dataset contributes 0 batches despite having samples (due to drop_last)
            if self.drop_last and length > 0 and num_batches == 0:
                 logger.debug("Source dataset %d has %d samples but yields 0 batches due to drop_last=True and batch_size=%d.",
                              i, length, self.batch_size)


        # --- Calculate Total Batches and Log Warnings ---
        self.total_batches = sum(self.num_batches_per_source)


    def __iter__(self):
        """
        Generates an iterator yielding lists of LOCAL indices (0 to N-1) for each batch.
        """
        all_batches = [] # List to store batches of LOCAL indices

        # Iterate through each original dataset source
        for source_ds_idx in range(self.num_original_datasets):
            # Get the LOCAL indices belonging to this source dataset *within the subset*
            source_local_indices = self.indices_by_source.get(source_ds_idx, []) # These are LOCAL indices
            current_len = len(source_local_indices)

            if current_len == 0:
                continue

            # Shuffle local indices within this source group if needed
            if self.shuffle:
                indices_to_batch = random.sample(source_local_indices, current_len)
            else:
                indices_to_batch = source_local_indices

            # Create batches for the current source group
            for i in range(0, current_len, self.batch_size):
                batch_local_indices = indices_to_batch[i : i + self.batch_size] # Batch of LOCAL indices

                # Apply drop_last logic
                if len(batch_local_indices) < self.batch_size and self.drop_last:
                    continue

                # Add the valid batch (list of LOCAL indices) to our list
                all_batches.append(batch_local_indices)

        # Shuffle the order of all generated batches (from all sources) if needed
        if self.shuffle:
            random.shuffle(all_batches)

        # Yield the batches (of LOCAL indices) one by one
        for batch in all_batches:
             # print(f"DEBUG Sampler yielding batch of len {len(batch)}: {batch[:5]}...") # Optional debug print
             yield batch

    def __len__(self) -> int:
        """
        Returns the total number of batches this sampler will yield.
        Accurately reflects the effect of drop_last.
        """
        # This calculation should be precise as it mirrors the logic in __iter__
        # and the calculation done in __init__. Let's rely on the __init__ calculation.
        # If issues arise, recalculate here based on self.lengths_by_source.
        return self.total_batches

class MoVQFormerDataModule(l.LightningDataModule):
    """
    A LightningDataModule for handling multiple MoVQFormer datasets with
    flexible splitting strategies.

    Supports:
    - Loading multiple datasets specified in `config.dataset_name`.
    - Splitting strategies controlled by `config.split_strategy`:
        - 'concatenate_split': Concatenate all datasets, then perform a random
          train/val/test split based on `config.split_ratio`.
        - 'by_dataset': Assign entire datasets to train/val/test splits based
          on `config.split_dataset_assignments`.
    - Ensures each batch comes from a single original dataset using
      `ProportionalSubsetSampler`.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Attributes initialized in setup
        self.train_dataset: Optional[Union[ConcatDataset, Subset]] = None
        self.val_dataset: Optional[Union[ConcatDataset, Subset]] = None
        self.test_dataset: Optional[Union[ConcatDataset, Subset]] = None

        self.train_sampler: Optional[ProportionalSubsetSampler] = None
        self.val_sampler: Optional[ProportionalSubsetSampler] = None
        self.test_sampler: Optional[ProportionalSubsetSampler] = None

        # Store loaded datasets and their structure based on split strategy
        self.all_datasets_dict: Dict[str, Dataset] = {} # Store by name for 'by_dataset'

        # --- Strategy-specific attributes ---
        # For 'concatenate_split'
        self.concat_dataset_all: Optional[ConcatDataset] = None
        self.global_cumulative_lengths: List[int] = []

        # For 'by_dataset'
        self.train_indices: Optional[List[int]] = None
        self.val_indices: Optional[List[int]] = None
        self.test_indices: Optional[List[int]] = None
        self.train_cumulative_lengths: List[int] = []
        self.val_cumulative_lengths: List[int] = []
        self.test_cumulative_lengths: List[int] = []
        # ---

        self._validate_config()

    def _validate_config(self):
        """Validates necessary configuration parameters."""
        required_top_level = ['dataset_name', 'batch_size', 'train_shuffle', 'split_strategy']
        for attr in required_top_level:
            if not hasattr(self.config, attr):
                raise ValueError(f"Config missing required attribute: {attr}")

        if not isinstance(self.config.dataset_name, list):
             raise ValueError("config.dataset_name must be a list.")
        if not self.config.dataset_name:
             raise ValueError("config.dataset_name cannot be empty.")

        # Validate split_strategy
        valid_strategies = ["concatenate_split", "by_dataset"]
        if self.config.split_strategy not in valid_strategies:
             raise ValueError(f"Invalid config.split_strategy '{self.config.split_strategy}'. "
                              f"Must be one of: {valid_strategies}")

        # Validate strategy-specific parameters
        if self.config.split_strategy == "concatenate_split":
            self.split_ratio = getattr(self.config, "split_ratio", [0.8, 0.1, 0.1]) # Example default
            if not isinstance(self.split_ratio, (list, tuple)) or len(self.split_ratio) != 3:
                raise ValueError("config.split_ratio must be a list/tuple of length 3 for 'concatenate_split'")
            if not math.isclose(sum(self.split_ratio), 1.0, abs_tol=1e-6):
                raise ValueError(f"config.split_ratio must sum to 1.0, got {sum(self.split_ratio)}")
            if any(r < 0 for r in self.split_ratio):
                raise ValueError("config.split_ratio elements cannot be negative.")
            self.split_seed = getattr(self.config, 'split_seed', 42)

        elif self.config.split_strategy == "by_dataset":
            if not hasattr(self.config, "split_dataset_assignments"):
                raise ValueError("config.split_dataset_assignments is required for 'by_dataset' strategy.")
            assignments = self.config.split_dataset_assignments
            if not isinstance(assignments, dict) or not all(k in ["train", "val", "test"] for k in assignments):
                 raise ValueError("config.split_dataset_assignments must be a dict with keys 'train', 'val', 'test'.")
            if not all(isinstance(v, list) for v in assignments.values()):
                 raise ValueError("Values in config.split_dataset_assignments must be lists of dataset names.")
             # Optional: Check if all dataset names in assignments are unique and cover config.dataset_name

        # Seed for reproducibility (used by concatenate_split)
        # self.split_seed is set above if needed

    def prepare_data(self):
        """Placeholder for data downloading or initial prep."""
        logger.info("MoVQFormerDataModule: prepare_data called.")
        # Example: Download datasets if they don't exist
        # for ds_name in self.config.dataset_name:
        #     maybe_download(ds_name)
        pass

    def _get_single_dataset_config(self, index, dataset_name):
        """Helper to create config for a single dataset from the main config."""
        temp_config = copy.deepcopy(self.config) # Start with a copy of the main config
        temp_config.dataset_name = dataset_name # Set the specific dataset name

        # Identify parameters that might be lists in the main config
        # These should correspond to parameters used by the Dataset classes
        list_params = [
            "original_sample_rate", "num_channels", "min_segment_length",
            "target_sample_rate", "window_size", "stride"
            # Add any other parameters that are expected to be lists matching dataset_name
        ]

        num_total_datasets = len(self.config.dataset_name)

        for param in list_params:
            config_value = getattr(self.config, param, None) # Get value safely

            if isinstance(config_value, list):
                # If it's a list, try to get the element at the current index
                if len(config_value) == num_total_datasets:
                    try:
                        setattr(temp_config, param, config_value[index])
                    except IndexError:
                        # This should theoretically not happen if lengths match, but safety first
                         raise ValueError(f"Internal Error: Index {index} out of bounds for config parameter '{param}' list.")
                else:
                    # Length mismatch - raise an error
                     raise ValueError(
                         f"Config parameter '{param}' is a list ({config_value}) but its length ({len(config_value)}) "
                         f"does not match the number of dataset names ({num_total_datasets}). Cannot determine value for '{dataset_name}'."
                     )
            elif config_value is not None:
                 # If it's not a list but exists, assign it directly (applies to all datasets)
                 # This assumes the Dataset class can handle non-list types for these params
                 setattr(temp_config, param, config_value)
            # If param is not in the main config at all (config_value is None), do nothing.
            # Rely on defaults within the Dataset class constructor or raise errors there if mandatory.

        return temp_config


    def setup(self, stage: str | None = None):
        """Loads datasets and prepares splits based on the chosen strategy."""
        logger.info("MoVQFormerDataModule: setup called for stage '%s' with strategy '%s'.",
                    stage, self.config.split_strategy)

        # Prevent redundant setup if datasets are already prepared for the strategy
        # Check if the primary dataset for the stage exists based on strategy
        already_setup = False
        if self.config.split_strategy == "concatenate_split" and self.concat_dataset_all:
             already_setup = True
        elif self.config.split_strategy == "by_dataset" and self.train_dataset: # Check one split dataset
             already_setup = True

        if already_setup and stage not in ('fit', None): # Allow re-run for fit/None
             logger.info("Data appears to be already loaded and split. Creating samplers if needed.")
             self._create_samplers_if_needed(stage)
             return

        # --- 1. Load all datasets specified in config ---
        self.all_datasets_dict = {}
        allowed_labels = getattr(self.config, "allowed_activity_labels", None)

        logger.info("Loading datasets: %s", self.config.dataset_name)
        for i, ds_name in enumerate(self.config.dataset_name):
             logger.info("  Loading dataset: %s (%d/%d)", ds_name, i + 1, len(self.config.dataset_name))
             single_ds_config = self._get_single_dataset_config(i, ds_name)
             try:
                 # *** Instantiate YOUR specific dataset class ***
                 dataset = MoVQFormerDataset(single_ds_config, allowed_activity_labels=allowed_labels)

                 if len(dataset) > 0:
                     # Use the dataset_name attribute set within the dataset class
                     actual_ds_name = dataset.dataset_name
                     if actual_ds_name in self.all_datasets_dict:
                          logger.warning("Duplicate dataset name '%s' detected during loading. Overwriting previous entry.", actual_ds_name)
                     self.all_datasets_dict[actual_ds_name] = dataset
                     logger.info("  Loaded '%s' with %d samples.", actual_ds_name, len(dataset))
                 else:
                     logger.warning("  Loaded dataset %s is empty, skipping.", ds_name)
             except Exception as e:
                 logger.error("  ERROR: Failed to load dataset %s. Error: %s", ds_name, e, exc_info=True)
                 logger.warning("  Skipping dataset %s due to loading error.", ds_name)
                 # Decide if you want to raise the error or continue: raise e

        if not self.all_datasets_dict:
             raise RuntimeError("Failed to load any valid datasets specified in config.dataset_name.")

        # --- 2. Apply Splitting Strategy ---
        logger.info("Applying split strategy: %s", self.config.split_strategy)

        if self.config.split_strategy == "concatenate_split":
            # --- Strategy: Concatenate then Split ---
            all_datasets_list = list(self.all_datasets_dict.values())
            self.concat_dataset_all = ConcatDataset(all_datasets_list)
            dataset_lengths = [len(ds) for ds in all_datasets_list]
            self.global_cumulative_lengths = [0] + [int(x) for x in np.cumsum(dataset_lengths)]
            total_samples = len(self.concat_dataset_all)
            logger.info("Concatenated dataset created. Total samples: %d, from %d source(s). Global cumulative lengths: %s",
                        total_samples, len(all_datasets_list), self.global_cumulative_lengths)

            # Perform Train/Val/Test Split using random_split
            train_ratio, val_ratio, test_ratio = self.split_ratio
            train_len = int(np.floor(train_ratio * total_samples))
            val_len = int(np.floor(val_ratio * total_samples))
            # Ensure test_len accounts for potential rounding errors
            test_len = total_samples - train_len - val_len
            split_lengths = [train_len, val_len, test_len]

            logger.info("Splitting concatenated data (%d samples) with ratios %s -> lengths %s using seed %d",
                        total_samples, self.split_ratio, split_lengths, self.split_seed)

            if any(l < 0 for l in split_lengths):
                 raise ValueError(f"Calculated split lengths resulted in negative value: {split_lengths}")

            generator = torch.Generator().manual_seed(self.split_seed)
            if total_samples > 0:
                # random_split returns Subsets; indices are relative to concat_dataset_all
                self.train_dataset, self.val_dataset, self.test_dataset = random_split(
                    self.concat_dataset_all, split_lengths, generator=generator
                )
            else:
                logger.warning("Total samples is 0, creating empty Subset objects.")
                empty_indices = []
                self.train_dataset = Subset(self.concat_dataset_all, empty_indices)
                self.val_dataset = Subset(self.concat_dataset_all, empty_indices)
                self.test_dataset = Subset(self.concat_dataset_all, empty_indices)

            logger.info("Split complete: Train=%d, Val=%d, Test=%d",
                        len(self.train_dataset), len(self.val_dataset), len(self.test_dataset))

        elif self.config.split_strategy == "by_dataset":
            # --- Strategy: Assign Datasets to Splits ---
            assignments = self.config.split_dataset_assignments
            train_ds_names = assignments.get("train", [])
            val_ds_names = assignments.get("val", [])
            test_ds_names = assignments.get("test", [])

            # Validate assignments against loaded datasets
            assigned_names = set(train_ds_names + val_ds_names + test_ds_names)
            loaded_names = set(self.all_datasets_dict.keys())
            missing_in_assignments = loaded_names - assigned_names
            unknown_in_assignments = assigned_names - loaded_names

            if missing_in_assignments:
                logger.warning("The following loaded datasets were not assigned to any split ('train', 'val', 'test') "
                               "in config.split_dataset_assignments: %s. They will be ignored.",
                               list(missing_in_assignments))
            if unknown_in_assignments:
                 logger.warning("The following dataset names in config.split_dataset_assignments were not found "
                                "among the successfully loaded datasets: %s. They will be ignored.",
                                list(unknown_in_assignments))

            # Create lists of dataset objects for each split
            train_datasets = [self.all_datasets_dict[name] for name in train_ds_names if name in loaded_names]
            val_datasets = [self.all_datasets_dict[name] for name in val_ds_names if name in loaded_names]
            test_datasets = [self.all_datasets_dict[name] for name in test_ds_names if name in loaded_names]

            # Create ConcatDataset for each split (even if only one dataset)
            self.train_dataset = ConcatDataset(train_datasets) if train_datasets else None
            self.val_dataset = ConcatDataset(val_datasets) if val_datasets else None
            self.test_dataset = ConcatDataset(test_datasets) if test_datasets else None

            # Calculate cumulative lengths *within each split's ConcatDataset*
            train_lengths = [len(ds) for ds in train_datasets] if self.train_dataset else []
            val_lengths = [len(ds) for ds in val_datasets] if self.val_dataset else []
            test_lengths = [len(ds) for ds in test_datasets] if self.test_dataset else []

            self.train_cumulative_lengths = [0] + [int(x) for x in np.cumsum(train_lengths)]
            self.val_cumulative_lengths = [0] + [int(x) for x in np.cumsum(val_lengths)]
            self.test_cumulative_lengths = [0] + [int(x) for x in np.cumsum(test_lengths)]

            # Generate simple range indices for each split's dataset
            self.train_indices = list(range(len(self.train_dataset))) if self.train_dataset else []
            self.val_indices = list(range(len(self.val_dataset))) if self.val_dataset else []
            self.test_indices = list(range(len(self.test_dataset))) if self.test_dataset else []

            logger.info("Split by dataset complete:")
            logger.info("  Train: %d samples from datasets %s. Cumulative lengths: %s",
                        len(self.train_dataset) if self.train_dataset else 0, train_ds_names, self.train_cumulative_lengths)
            logger.info("  Val:   %d samples from datasets %s. Cumulative lengths: %s",
                        len(self.val_dataset) if self.val_dataset else 0, val_ds_names, self.val_cumulative_lengths)
            logger.info("  Test:  %d samples from datasets %s. Cumulative lengths: %s",
                        len(self.test_dataset) if self.test_dataset else 0, test_ds_names, self.test_cumulative_lengths)

        else:
            # Should be caught by validation, but defense in depth
            raise NotImplementedError(f"Split strategy '{self.config.split_strategy}' is not implemented.")


        # --- 3. Create Samplers ---
        # Samplers are created based on the structure defined by the strategy
        self._create_samplers_if_needed(stage)


    def _create_samplers_if_needed(self, stage: str | None):
        """Creates ProportionalSubsetSamplers using the correct indices and lengths."""
        if not any([self.train_dataset, self.val_dataset, self.test_dataset]):
             logger.warning("Datasets not fully created, cannot reliably create samplers.")
             return

        # Determine batch sizes
        batch_size = self.config.batch_size
        val_batch_size = getattr(self.config, 'val_batch_size', batch_size)
        test_batch_size = getattr(self.config, 'test_batch_size', batch_size)
        drop_last_eval = getattr(self.config, "drop_last_eval", False)
        # Use train_shuffle only for training sampler
        train_shuffle = self.config.train_shuffle

        # --- Create Train Sampler ---
        if (stage in ('fit', None)) and self.train_sampler is None and self.train_dataset and len(self.train_dataset) > 0:
            logger.info("Creating train sampler...")
            if self.config.split_strategy == "concatenate_split":
                indices = self.train_dataset.indices # Indices from random_split (global)
                cum_lengths = self.global_cumulative_lengths
            elif self.config.split_strategy == "by_dataset":
                indices = self.train_indices # Range indices relative to train_dataset
                cum_lengths = self.train_cumulative_lengths
            else: raise RuntimeError("Invalid state") # Should not happen

            self.train_sampler = ProportionalSubsetSampler(
                subset_indices=indices,
                original_cumulative_lengths=cum_lengths,
                batch_size=batch_size,
                shuffle=train_shuffle,
                drop_last=getattr(self.config, "drop_last_train", False) # Allow specific train drop_last
            )
            logger.info("  Train sampler created: %d batches.", len(self.train_sampler))
        elif self.train_dataset is not None and len(self.train_dataset) == 0:
             logger.info("Train subset is empty, skipping train sampler creation.")

        # --- Create Validation Sampler ---
        if (stage in ('fit', 'validate', None)) and self.val_sampler is None and self.val_dataset and len(self.val_dataset) > 0:
            logger.info("Creating validation sampler...")
            if self.config.split_strategy == "concatenate_split":
                indices = self.val_dataset.indices # Indices from random_split (global)
                cum_lengths = self.global_cumulative_lengths
            elif self.config.split_strategy == "by_dataset":
                indices = self.val_indices # Range indices relative to val_dataset
                cum_lengths = self.val_cumulative_lengths
            else: raise RuntimeError("Invalid state")

            self.val_sampler = ProportionalSubsetSampler(
                subset_indices=indices,
                original_cumulative_lengths=cum_lengths,
                batch_size=val_batch_size,
                shuffle=False, # No shuffle for validation
                drop_last=drop_last_eval
            )
            logger.info("  Validation sampler created: %d batches.", len(self.val_sampler))
        elif self.val_dataset is not None and len(self.val_dataset) == 0:
             logger.info("Validation subset is empty, skipping validation sampler creation.")

        # --- Create Test Sampler ---
        if (stage in ('test', None)) and self.test_sampler is None and self.test_dataset and len(self.test_dataset) > 0:
             logger.info("Creating test sampler...")
             if self.config.split_strategy == "concatenate_split":
                 indices = self.test_dataset.indices # Indices from random_split (global)
                 cum_lengths = self.global_cumulative_lengths
             elif self.config.split_strategy == "by_dataset":
                 indices = self.test_indices # Range indices relative to test_dataset
                 cum_lengths = self.test_cumulative_lengths
             else: raise RuntimeError("Invalid state")

             self.test_sampler = ProportionalSubsetSampler(
                 subset_indices=indices,
                 original_cumulative_lengths=cum_lengths,
                 batch_size=test_batch_size,
                 shuffle=False, # No shuffle for test
                 drop_last=drop_last_eval
             )
             logger.info("  Test sampler created: %d batches.", len(self.test_sampler))
        elif self.test_dataset is not None and len(self.test_dataset) == 0:
             logger.info("Test subset is empty, skipping test sampler creation.")


    def _get_dataloader(self, dataset: Optional[Union[ConcatDataset, Subset]],
                        sampler: Optional[ProportionalSubsetSampler]) -> DataLoader:
        """Helper to create a DataLoader for a given dataset and sampler."""
        # Return an empty DataLoader if dataset/sampler is missing or sampler yields no batches
        if dataset is None or sampler is None or len(sampler) == 0:
            logger.debug("Dataset or sampler is None or sampler length is 0, returning empty DataLoader.")
            return DataLoader([]) # Return loader for empty list

        num_workers = getattr(self.config, "num_workers", 0)
        pin_memory = getattr(self.config, "pin_memory", torch.cuda.is_available()) # Default True if CUDA available
        persistent_workers = getattr(self.config, "persistent_workers", num_workers > 0) # Default True if workers > 0
        prefetch_factor = getattr(self.config, 'prefetch_factor', 2 if num_workers > 0 else None)

        # Ensure prefetch_factor is None if num_workers is 0
        effective_prefetch_factor = prefetch_factor if num_workers > 0 else None

        # Crucially, use batch_sampler=sampler, not sampler=sampler
        return DataLoader(
            dataset, # Use the specific Subset (train/val/test) or ConcatDataset
            batch_sampler=sampler, # Use the custom sampler that yields batches of indices
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
            prefetch_factor=effective_prefetch_factor,
        )

    # --- DataLoader methods ---
    def train_dataloader(self) -> DataLoader:
        self._create_samplers_if_needed('fit') # Ensure sampler exists for fit stage
        return self._get_dataloader(self.train_dataset, self.train_sampler)

    def val_dataloader(self) -> DataLoader:
        self._create_samplers_if_needed('validate') # Ensure sampler exists for validate stage
        return self._get_dataloader(self.val_dataset, self.val_sampler)

    def test_dataloader(self) -> DataLoader:
        self._create_samplers_if_needed('test') # Ensure sampler exists for test stage
        return self._get_dataloader(self.test_dataset, self.test_sampler)

    
if __name__ == "__main__":
    import os
    import sys
    from pathlib import Path

    srcpath = os.path.abspath(Path(os.path.dirname(__file__)) / "..")
    sys.path.insert(0, srcpath)

    from src.data.data_import import load_data

    a = load_data("pamap2")
    print(len(a["subject_id_le"].classes_))
