import pandas as pd
import numpy as np
from util.data_util import NetDataset, RealNetDataset
import torch
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
import os
import torch_geometric.transforms as transforms

from sklearn.utils import resample


class CreditReal:
    def __init__(self, args):
        self.args = args
        
    def _load_data(self, path):
        encoder = OneHotEncoder(sparse=False)
        scaler = StandardScaler()

        df = pd.read_csv(path)
        df = df.drop(['ID'], axis= 1)
        df.rename(columns={'default.payment.next.month':'Target'}, inplace= True)
        df['SEX'] = 2 - df['SEX']
        
        sa = df['SEX'].values
        df = df.drop(['SEX'], axis= 1)
        features = df.drop(['Target', 'AGE'], axis=1)
        categorical_list = ['EDUCATION', 'MARRIAGE']
        categorical_features = features[categorical_list].values
        
        categorical_features = encoder.fit_transform(categorical_features)

        
        numerical_features = features.drop(categorical_list, axis=1)
        numerical_features = scaler.fit_transform(numerical_features)
        node_features = numerical_features
       
        
        z_features = df['AGE'].values
        node_labels = df['Target'].values
        
        z_features = scaler.fit_transform(z_features.reshape(-1, 1))
        z_features = np.concatenate([categorical_features, z_features], axis=1)
        
        
        sa = torch.FloatTensor(sa).view(-1, 1)
        node_features, z_features, node_labels = torch.FloatTensor(node_features), torch.FloatTensor(z_features), torch.FloatTensor(node_labels)
        
        return node_features, z_features, sa, node_labels
    
    def data_loaders(self, **kwargs):
        path = os.path.join(self.args.root, 'raw', self.args.filename)
        node_features, z_features, sa, node_labels = self._load_data(path)
        
        
        pre_transform = transforms.Compose([transforms.RemoveIsolatedNodes(),
                                transforms.RandomNodeSplit(split='train_rest', num_val=0, num_test=0.2)])
        
        graphdataset = RealNetDataset(self.args, node_features, z_features, sa, node_labels, pre_transform=pre_transform, **kwargs)
        
        
        return graphdataset[0]
    
        
        
        
        
        