import torch
from torch.utils.data import DataLoader, Dataset
import math
import time

import torch
import torch.nn as nn
import numpy as np;

from data.traffic.regr_data_utils import Data_utility;

class CustomDataset(Dataset):
    def __init__(self, X, y, transform=None):
        """
        Args:
            X (array-like): Features data (e.g., X_train or X_test).
            y (array-like): Target data (e.g., y_train or y_test).
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.X = torch.tensor(X, dtype=torch.float32)  # Convert X to torch tensor
        self.y = torch.tensor(y, dtype=torch.float32)     # Convert y to torch tensor (adjust dtype as needed)
        self.transform = transform

    def __len__(self):
        # Returns the total number of samples in the dataset
        return len(self.X)

    def __getitem__(self, idx):
        # Retrieves the input and target at index idx
        input_data = self.X[idx]
        target_data = self.y[idx]
        
        # Apply transform if available
        if self.transform:
            input_data = self.transform(input_data)
            
        return input_data, target_data
        
def get_traffic_dataset():
    data_address = './data/traffic/traffic.txt'
    horizon = 24
    window = 24 * 7
    normalize = 2
    cuda = True
    Data = Data_utility(data_address, 0.6, 0.2, cuda, horizon, window, normalize);
    X_train = Data.train[0]
    y_train = Data.train[1]
    X_test = Data.test[0]
    y_test = Data.test[1]

    train_dataset = CustomDataset(X_train, y_train)
    test_dataset = CustomDataset(X_test, y_test)
    return train_dataset, test_dataset

