import pandas as pd
import numpy as np
from util.data_util import NetDataset, RealNetDataset
import torch
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
import os
import torch_geometric.transforms as transforms

from sklearn.utils import resample 


class GermanReal:
    def __init__(self, args):
        self.args = args
        
    def _feature_norm(self, features):
        min_values = features.min(axis=0)[0]
        max_values = features.max(axis=0)[0]
        return 2 * (features - min_values) / (max_values - min_values) - 1
        
    def _load_data(self, path):
        scaler = StandardScaler()
        # z_features = ['HasCoapplicant','HasGuarantor','OwnsHouse','RentsHouse','Unemployed','YearsAtCurrentJob_lt_1','YearsAtCurrentJob_geq_4','JobClassIsSkilled']
        z_features = ['YearsAtCurrentJob_lt_1','YearsAtCurrentJob_geq_4','JobClassIsSkilled']
        
        df = pd.read_csv(path)
        features = df.drop(['OtherLoansAtStore', 'PurposeOfLoan', 'Gender'], axis=1)
        node_features = features.drop(z_features, axis=1).values
        z_features = df[z_features].values
        node_labels = df['GoodCustomer'].values
        node_labels[node_labels == -1] = 0
        df.loc[df['Gender'] == 'Female', 'Gender'] = 0.0
        df.loc[df['Gender'] == 'Male', 'Gender'] = 1.0
        
        sa = df['Gender'].values
        sa = sa.astype(float)
        node_features = scaler.fit_transform(node_features)
        z_features = scaler.fit_transform(z_features)
        
        node_features, z_features, sa, node_labels = torch.FloatTensor(node_features), torch.FloatTensor(z_features), torch.FloatTensor(sa).unsqueeze(1), torch.FloatTensor(node_labels)
        
        
        return node_features, z_features, sa, node_labels
    
    def data_loaders(self, **kwargs):
        path = os.path.join(self.args.root, 'raw', self.args.filename)
        node_features, z_features, sa, node_labels = self._load_data(path)
        
        
        pre_transform = transforms.Compose([transforms.RemoveIsolatedNodes(),
                                transforms.RandomNodeSplit(split='train_rest', num_val=0, num_test=0.2)])
        
        graphdataset = RealNetDataset(self.args, node_features, z_features, sa, node_labels, pre_transform=pre_transform, **kwargs)
        
        
        return graphdataset[0]
    
        
        
        
        
        