import pandas as pd
import numpy as np
from util.data_util import NetDataset, RealNetDataset
import torch
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
import os
import torch_geometric.transforms as transforms

from sklearn.utils import resample


class Credit:
    def __init__(self, args):
        self.args = args
        
    def _load_data(self, path):
        encoder = OneHotEncoder(sparse=False)
        scaler = StandardScaler()

        df = pd.read_csv(path)
        df = df.drop(['ID'], axis= 1)
        df.rename(columns={'default.payment.next.month':'Target'}, inplace= True)
        df['SEX'] = 2 - df['SEX']
        # df1 = df[df['SEX']==1]
        # df0 = df[df['SEX']==0]
        
        # df10_downsample = resample(df1[df1['Target'] == 0], n_samples=len(df1[df1['Target'] == 1]))
        # df1_downsample = pd.concat([df10_downsample, df1[df1['Target'] == 1]])
        # # print(df1_downsample)
        
        # df00_downsample = resample(df0[df0['Target'] == 0], n_samples=len(df0[df0['Target'] == 1]))
        # df0_downsample = pd.concat([df00_downsample, df0[df0['Target'] == 1]])
        # # print(df0_downsample)
        # df = pd.concat([df1_downsample, df0_downsample])
        # df = df.sample(frac=1)
        
        sa = df['SEX'].values
        df = df.drop(['SEX'], axis= 1)
        features = df.drop(['Target'], axis=1)
        categorical_list = ['EDUCATION', 'MARRIAGE']
        categorical_features = features[categorical_list].values
        
        categorical_features = encoder.fit_transform(categorical_features)

        
        numerical_features = features.drop(categorical_list, axis=1)
        numerical_features = scaler.fit_transform(numerical_features)
        
        node_features = np.concatenate([categorical_features, numerical_features], axis=1)
        
        z_features = df['AGE'].values
        node_labels = df['Target'].values
        
        z_features = scaler.fit_transform(z_features.reshape(-1, 1))
        sa = torch.FloatTensor(sa).view(-1, 1)
        node_features, z_features, node_labels = torch.FloatTensor(node_features), torch.FloatTensor(z_features), torch.FloatTensor(node_labels)
        return node_features, z_features, sa, node_labels
    
    def data_loaders(self, **kwargs):
        path = os.path.join(self.args.root, 'raw', self.args.filename)
        node_features, z_features, sa, node_labels = self._load_data(path)
        
        
        pre_transform = transforms.Compose([transforms.RemoveIsolatedNodes(),
                                transforms.RandomNodeSplit(split='train_rest', num_val=0, num_test=0.2)])
        
        graphdataset = NetDataset(self.args, node_features, z_features, sa, node_labels, pre_transform=pre_transform, **kwargs)
        
        
        return graphdataset[0]
    
        
        
        
        
        