import torch
from torch.utils.data import Dataset
import pandas as pd
from ast import literal_eval


'''Create data loader'''

class TrafficData(Dataset):
    def __init__(self, filename):
        self.df = pd.read_csv(filename, converters={'z': literal_eval})

    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        ## load the input features
        input_z = self.df.loc[index, 'z']
        shape_label = self.df.loc[index, 'shape_label']
        color_label = self.df.loc[index, 'color_label']
        
        return torch.tensor(input_z), torch.tensor(shape_label,dtype=torch.float), torch.tensor(color_label,dtype=torch.float)
        