import numpy as np
import pandas as pd
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from copy import deepcopy


def load_data(source, target, batch_size, seed, train_drop_last=True, return_dset=False):
    
    print('[Info] Loading data')
    
    if not return_dset:
        allloader_source, _, all_source_categorical_id_set = TableData(source, seed, 'all', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        trainloader_source, all_weights, train_source_categorical_id_set = TableData(source, seed, 'train', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        valloader_source, _, val_source_categorical_id_set = TableData(source, seed, 'val', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        testloader_source, _, test_source_categorical_id_set = TableData(source, seed, 'test', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)

        allloader_target, _, all_target_categorical_id_set = TableData(target, seed, 'all', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        trainloader_target, all_weights, train_target_categorical_id_set = TableData(target, seed, 'train', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        valloader_target, _, val_target_categorical_id_set = TableData(target, seed, 'val', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        testloader_target, _, test_target_categorical_id_set = TableData(target, seed, 'test', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)

        if 'Dutch' in source:
            input_dim = 58
            
        source_loaders = (allloader_source, trainloader_source, valloader_source, testloader_source)
        target_loaders = (allloader_target, trainloader_target, valloader_target, testloader_target)
        source_categorical_id_sets = (all_source_categorical_id_set, train_source_categorical_id_set, val_source_categorical_id_set, test_source_categorical_id_set)
        target_categorical_id_sets = (all_target_categorical_id_set, train_target_categorical_id_set, val_target_categorical_id_set, test_target_categorical_id_set)
            
        return source_loaders, target_loaders, source_categorical_id_sets, target_categorical_id_sets, (input_dim, all_weights)

    else:
        all_source_features, all_source_labels = TableData(source, seed, 'all', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        train_source_features, train_source_labels = TableData(source, seed, 'train', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        val_source_features, val_source_labels = TableData(source, seed, 'val', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        test_source_features, test_source_labels = TableData(source, seed, 'test', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)

        all_target_features, all_target_labels = TableData(target, seed, 'all', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        train_target_features, train_target_labels = TableData(target, seed, 'train', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        val_target_features, val_target_labels = TableData(target, seed, 'val', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)
        test_target_features, test_target_labels = TableData(target, seed, 'test', batch_size, train_drop_last=train_drop_last, return_dset=return_dset)

        if 'Dutch' in source:
            input_dim = 58
        
        source_dsets = (all_source_features, all_source_labels, train_source_features, train_source_labels, val_source_features, val_source_labels, test_source_features, test_source_labels)
        target_dsets = (all_target_features, all_target_labels, train_target_features, train_target_labels, val_target_features, val_target_labels, test_target_features, test_target_labels)
        
        return source_dsets, target_dsets, input_dim


def TableData(name, seed, mode, batch_size, train_drop_last=True, return_dset=False):   
    
    dataset_name = name.split('_')[0]
    
    if dataset_name == 'Dutch': # Dutch
        np.random.seed(seed)        
        data_path = 'data/'
        raw_data = pd.read_csv(data_path + 'dutch.csv')
        
        categorical_features = ['age', 'household_position', 'household_size', 'citizenship', 'country_birth', 'edu_level',
                                'economic_status', 'cur_eco_activity', 'marital_status']
        new_categorical_data = {}
        for original_categorical_feature in categorical_features:
            oh_enc = OneHotEncoder(sparse=False)
            new_categorical_onehot_features = oh_enc.fit_transform(raw_data[[original_categorical_feature]])
            for i, value in enumerate(set(raw_data[original_categorical_feature])):
                new_categorical_data[f'{original_categorical_feature}={value}'] = new_categorical_onehot_features[:, i]
        new_categorical_data = pd.DataFrame(new_categorical_data)
        
        raw_data = pd.concat([raw_data, new_categorical_data], axis=1)
        for categorical_feature in categorical_features:
            del raw_data[categorical_feature]
        
        sensitive_id = np.argwhere(raw_data.keys() == 'sex').item()
        label_id = np.argwhere(raw_data.keys() == 'occupation').item()
        
        feature_data = deepcopy(raw_data)
        del feature_data['sex']
        del feature_data['occupation']
        feature_data['prev_residence_place'] -= 1
        
        all_features = np.array(feature_data).astype(float)
        all_sensitives = np.array(raw_data['sex'] == 'male').astype(float).flatten()
        all_labels = np.array(raw_data['occupation']).astype(float).flatten()
        
        feature_names = list(feature_data.keys())
        
        categorical_ids = {categorical_feature: [] for categorical_feature in categorical_features}
        for i, feature_name in enumerate(feature_names):  
            for categorical_feature in categorical_features:
                if feature_name.split('=')[0] == categorical_feature:
                    categorical_ids[categorical_feature].append(i)
        categorical_weights = {categorical_feature: 1.0 / len(categorical_ids[categorical_feature]) 
                               for categorical_feature in categorical_features}
        
        all_categorical_ids = []
        for categorical_id in categorical_ids.values():
            all_categorical_ids += categorical_id
        assert len(all_categorical_ids) == len(set(all_categorical_ids))
        all_categorical_ids = list(set(all_categorical_ids))
        
        all_continuous_ids = []
        for feature_id in range(all_features.shape[1]):
            if feature_id not in all_categorical_ids:
                all_continuous_ids.append(feature_id)
        
        assert all_features.shape[1] == len(all_continuous_ids + all_categorical_ids) 
        
        total_variable_num = len(all_continuous_ids) + len(categorical_ids.keys())
        
        all_weights = [1.0] * all_features.shape[1]
        all_weights = np.array(all_weights).astype(float)
        for categorical_feature in categorical_ids.keys():
            for sub_categorical_id in categorical_ids[categorical_feature]:
                all_weights[sub_categorical_id] *= categorical_weights[categorical_feature]
        
        assert len(all_weights) == all_features.shape[1]
                       
        random_ids = np.random.permutation(len(raw_data))
        train_ids = random_ids[:48336] 
        test_ids = random_ids[48336:]

        train_features, train_sensitives, train_labels = all_features[train_ids], all_sensitives[train_ids], all_labels[train_ids]
        test_features, test_sensitives, test_labels = all_features[test_ids], all_sensitives[test_ids], all_labels[test_ids]
        train_features, val_features, train_sensitives, val_sensitives, train_labels, val_labels = train_test_split(train_features, train_sensitives, train_labels, 
                                                                                                                    test_size=0.2, random_state=2022)
        
        scaler = MinMaxScaler()
        scaler.fit(train_features)
        train_features = scaler.transform(train_features)
        val_features = scaler.transform(val_features)
        test_features = scaler.transform(test_features)
        assert (round(train_features.max(), 4) == 1.0) and (round(train_features.min(), 4) == 0.0)
        
        train_features, train_sensitives, train_labels = torch.from_numpy(train_features).float(), torch.from_numpy(train_sensitives), torch.from_numpy(train_labels)
        val_features, val_sensitives, val_labels = torch.from_numpy(val_features).float(), torch.from_numpy(val_sensitives), torch.from_numpy(val_labels)
        test_features, test_sensitives, test_labels = torch.from_numpy(test_features).float(), torch.from_numpy(test_sensitives), torch.from_numpy(test_labels)
        
        assert train_features.size(0) == train_labels.size(0)
        

    # divide by sensitive attribute        
    sensitive_id = int(name.split('_')[1])
    if mode == 'all':
        features = torch.cat([
            train_features[train_sensitives == sensitive_id], 
            val_features[val_sensitives == sensitive_id], 
            test_features[test_sensitives == sensitive_id]
            ])
        labels = torch.cat([
            train_labels[train_sensitives == sensitive_id], 
            val_labels[val_sensitives == sensitive_id], 
            test_labels[test_sensitives == sensitive_id]
            ])        
    elif mode == 'train':
        features = train_features[train_sensitives == sensitive_id]
        labels = train_labels[train_sensitives == sensitive_id]
    elif mode == 'val':
        features = val_features[val_sensitives == sensitive_id]
        labels = val_labels[val_sensitives == sensitive_id]
    elif mode == 'test':
        features = test_features[test_sensitives == sensitive_id]
        labels = test_labels[test_sensitives == sensitive_id]
    dset = TensorDataset(features, labels)
    if train_drop_last:
        dloader = DataLoader(dset, batch_size, num_workers=4, 
                            shuffle=True if mode == 'train' else False, drop_last=True if mode == 'train' else False)
    else:
        dloader = DataLoader(dset, batch_size, num_workers=4, shuffle=True, drop_last=False)
    print(f'[Info] {name} {mode} data sample size: {len(dset)}')
        
    if not return_dset:
        return dloader, all_weights, list(categorical_ids.values())
    else:
        return features, labels