from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from .dataset import TimeSeriesDataset

def load_and_split_data(data, timestamps, train_ratio=0.8, val_ratio=0.1):
    """Load and split time series data into train, validation and test sets.

    Args:
        data: numpy.ndarray of shape (num_samples, seq_len, input_dim),
            the input time series data sequences
        timestamps: numpy.ndarray of shape (num_samples, seq_len + 1),
            the timestamp sequences
        train_ratio: float, ratio of training data, defaults to 0.8
        val_ratio: float, ratio of validation data, defaults to 0.1

    Returns:
        tuple:
            - train_data: numpy.ndarray, training dataset
            - val_data: numpy.ndarray, validation dataset
            - test_data: numpy.ndarray, test dataset
            - train_timestamps: numpy.ndarray, timestamps for training data
            - val_timestamps: numpy.ndarray, timestamps for validation data
            - test_timestamps: numpy.ndarray, timestamps for test data
    """
    train_data, test_data, train_timestamps, test_timestamps = train_test_split(
        data, timestamps, test_size=1 - train_ratio, random_state=42
    )
    train_data, val_data, train_timestamps, val_timestamps = train_test_split(
        train_data, train_timestamps, test_size=val_ratio / (train_ratio + val_ratio), random_state=42
    )
    return (
        train_data, val_data, test_data,
        train_timestamps, val_timestamps, test_timestamps
    )

def create_data_loaders(train_data, val_data, test_data, 
                       train_timestamps, val_timestamps, test_timestamps,
                       input_len, output_len, d_inner, batch_size):
    """Create data loaders for train, validation and test datasets.

    Args:
        train_data: numpy.ndarray, training dataset
        val_data: numpy.ndarray, validation dataset
        test_data: numpy.ndarray, test dataset
        train_timestamps: numpy.ndarray, timestamps for training data
        val_timestamps: numpy.ndarray, timestamps for validation data
        test_timestamps: numpy.ndarray, timestamps for test data
        input_len: int, length of input sequence
        output_len: int, length of output sequence
        d_inner: int, dimension of inner features
        batch_size: int, size of each batch

    Returns:
        tuple:
            - train_loader: DataLoader, data loader for training set
            - val_loader: DataLoader, data loader for validation set
            - test_loader: DataLoader, data loader for test set
    """
    train_dataset = TimeSeriesDataset(train_data, train_timestamps,
                                    input_len, output_len, d_inner)
    val_dataset = TimeSeriesDataset(val_data, val_timestamps,
                                  input_len, output_len, d_inner)
    test_dataset = TimeSeriesDataset(test_data, test_timestamps,
                                   input_len, output_len, d_inner)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader