import pandas as pd
from util.data_util import NetDataset
import torch
from sklearn.preprocessing import StandardScaler
import os
import torch_geometric.transforms as transforms

class Adult:
    def __init__(self, args):
        self.args = args
        
    def _load_data(self, path):
        df = pd.read_csv(path, index_col=0)
        features = df.drop(['Target'], axis=1)
        node_features = features.drop(['Sex'], axis=1).values
        sa = df['Sex'].values
        z_features = df['age'].values
        node_labels = df['Target'].values
        
        sa = torch.FloatTensor(sa).view(-1, 1)
        node_features, z_features, node_labels = torch.FloatTensor(node_features), torch.FloatTensor(z_features), torch.FloatTensor(node_labels)
        return node_features, z_features, sa, node_labels
        
        
    def data_loaders(self, **kwargs):
        path = os.path.join(self.args.root, 'raw', self.args.filename)
        node_features, z_features, sa, node_labels = self._load_data(path)
        
        pre_transform = transforms.Compose([transforms.RemoveIsolatedNodes(),
                                transforms.RandomNodeSplit(split='train_rest', num_val=0, num_test=0.2)])
        graphdataset = NetDataset(self.args, node_features, z_features, sa, node_labels, pre_transform=pre_transform, **kwargs)
            
        return graphdataset[0]