from paper import alfr_ds
import torch
import torch.utils.data as data_utils
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.ensemble import GradientBoostingClassifier
from aif360.algorithms.preprocessing import DisparateImpactRemover
from aif360.algorithms.preprocessing import LFR
from aif360.datasets.adult_dataset import AdultDataset
from aif360.datasets.bank_dataset import BankDataset
from aif360.datasets.compas_dataset import CompasDataset
from aif360.datasets.german_dataset import GermanDataset
from collections import defaultdict
from fairlearn.metrics import demographic_parity_difference
from sklearn.metrics import balanced_accuracy_score
import argparse


def debias_paper(train, test, protected_attribute_idx=0, min_max=True, perform_dampening=True,
                 stacking=3, adversary_hidden_size='auto', latent_dimensions='auto'):
    """
    Method which calls ALFR-DS(n) or ALFR-S(n) and returns a new training set.
    :param train: The train AIF360 set.
    :param test: The test AIF360 set.
    :param protected_attribute_idx: Index of protected variable from the AIF360 dataset.
    :param min_max: Whether to perform min_max scaling on inputs (always True).
    :param perform_dampening: Whether to perform dampening (ALFR-DS(n)) or not (ALFR-S(n)).
    :param stacking: How many stacks to perform.
    :param adversary_hidden_size: Hidden size of adversary. Kept on 'auto', and is is computed by heuristic below.
    :param latent_dimensions: Hidden size of latent space. Kept on 'auto', and is is computed by heuristic below.
    :return: New train/test set.
    """
    # Use input scaler to ensure that inputs are in range of tanh function
    if min_max:
        input_scaler = MinMaxScaler(feature_range=(-1, 1))
    else:
        input_scaler = StandardScaler()

    train_df = train.convert_to_dataframe()[0]
    test_df = test.convert_to_dataframe()[0]

    protected_attribute = train.protected_attribute_names[protected_attribute_idx]
    target = train.label_names[0]

    if adversary_hidden_size == 'auto':
        adversary_hidden_size = max((len(train_df.columns) - 1) // 4, 16)

    if latent_dimensions == 'auto':
        latent_dimensions = max((len(train_df.columns) - 1) // 4, 16)

    epochs = [max(10, 30 // stacking) for _ in range(stacking)]

    # Create X_train while scaling and dropping both the target and protected attribute!
    X_train = input_scaler.fit_transform(train_df.drop([protected_attribute, target], axis=1).values)
    y_train = train_df[[target]].values

    X_test = input_scaler.transform(test_df.drop([protected_attribute, target], axis=1).values)

    train_loader = data_utils.DataLoader(
        data_utils.TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train)),
        batch_size=32,
        shuffle=True
    )

    # Get encoder + decoder
    encoder, decoder, _ = alfr_ds(train_loader, adversary_hidden_size=adversary_hidden_size, epochs=epochs,
                                  alpha=None if perform_dampening else 1,
                                  latent_dimensions=latent_dimensions, output_activation='tanh' if min_max else 'none')

    # Get output (new input features)
    return pd.DataFrame(encoder(torch.Tensor(X_train)).detach().numpy(), index=train_df.index), pd.DataFrame(
        encoder(torch.Tensor(X_test)).detach().numpy(), index=test_df.index)


def test_fairness_model(train, test, preprocess_model, scale=False):
    """
    Method which tests an AIF360 fairness model.
    :param train: The train AIF360 set.
    :param test: The test AIF360 set.
    :param preprocess_model: The preprocessing AIF360 to use.
    :param scale: Whether we want to scale the data (Used for LFR).
    :return: Prediction np array
    """
    train = train.copy()
    test = test.copy()

    # Target is always the 0th index
    target = train.label_names[0]

    if scale:
        scaler = StandardScaler()
        train.features = scaler.fit_transform(train.features)
        test.features = scaler.transform(test.features)

    # Fit+transform preprocess and transform test
    train_transformed = preprocess_model.fit_transform(train)
    try:
        test_transformed = preprocess_model.transform(test)
    except NotImplementedError:
        # Re-fit assuming same distribution
        test_transformed = preprocess_model.fit_transform(test)

    # Get the dataframe
    df_train_transformed = train_transformed.convert_to_dataframe()[0]
    df_test_transformed = test_transformed.convert_to_dataframe()[0]

    # Extract X, y
    X_train_transformed = df_train_transformed.drop(target, axis=1)
    y_train_transformed = df_train_transformed[target]
    X_test_transformed = df_test_transformed.drop(target, axis=1)

    # Train a new gb model with default params
    try:
        # We have to make sure to use instance weights here since reweighing uses this!
        gb = GradientBoostingClassifier()
        gb.fit(X_train_transformed, y_train_transformed, sample_weight=train_transformed.instance_weights)

        # Return prediction
        return gb.predict(X_test_transformed)
    except:
        # This happens whenever y_transformed.loc[train_idx] contains only 1 label
        majority_label = y_train_transformed[0]
        return [majority_label for _ in range(len(X_test_transformed))]


def run_experiment(dataset_name, repeat=10) -> None:
    """
    Method that runs an AIF360 experiment as described in the paper "Adversarial Learned Fair Representations using
    Dampening and Stacking". It runs for a given dataset name (adult/bank/compas/german) repeat different runs for each
    protected variable. It saves these results to experiments/aif360

    :param dataset_name: Dataset name (adult/bank/compas/german).
    :param repeat: Number of times to repeat (default is 10)
    :return: None
    """
    if dataset_name == 'adult':
        dataset = AdultDataset()
    elif dataset_name == 'bank':
        dataset = BankDataset()
    elif dataset_name == 'compas':
        dataset = CompasDataset()
    elif dataset_name == 'german':
        dataset = GermanDataset()
    else:
        raise ValueError('Dataset name not supported in experiments.')

    # Protected attribute / target
    for protected_attribute_idx in range(len(dataset.protected_attribute_names)):
        for repeat in range(1, repeat+1):

            # Convert train/test
            train, test = dataset.split([0.8], shuffle=True)

            # Convert to dataframe (useful for all sorts of calculations)
            train_df = train.convert_to_dataframe()[0]
            test_df = test.convert_to_dataframe()[0]

            result = {}

            protected_attribute = dataset.protected_attribute_names[protected_attribute_idx]
            target = dataset.label_names[0]

            # Keep track of the idx
            result['idx'] = test_df.index.values
            result['protected'] = test_df[protected_attribute].values
            result['target'] = test_df[target].values

            print(f'Dataset: {dataset_name}, Protected attribute: {protected_attribute}, Target: {target}')

            # 1. Uncensored
            print('1. Uncensored')
            X_train = train_df.drop(target, axis=1)
            y_train = train_df[target]
            X_test = test_df.drop(target, axis=1)
            gb = GradientBoostingClassifier()
            gb.fit(X_train, y_train)
            result['uncensored'] = gb.predict(X_test)

            # 2. ALFR
            for stack in range(1, 4):
                print(f'2.{stack} ALFR_S({stack})')
                X_train, X_test = debias_paper(train, test, protected_attribute_idx, stacking=stack,
                                               perform_dampening=False)
                gb = GradientBoostingClassifier()
                gb.fit(X_train, y_train)
                result[f'alfrs{stack}'] = gb.predict(X_test)

            # 2. ALFR-DS
            for stack in range(1, 4):
                print(f'3.{stack}. ALFR_DS({stack})')
                X_train, X_test = debias_paper(train, test, protected_attribute_idx, stacking=stack,
                                               perform_dampening=True)
                gb = GradientBoostingClassifier()
                gb.fit(X_train, y_train)
                result[f'alfrds{stack}'] = gb.predict(X_test)

            # 3. LFR
            print('4.1 LFR')
            result['lfr'] = test_fairness_model(
                train, test,
                LFR(
                    privileged_groups=[
                        {protected_attribute: dataset.privileged_protected_attributes[protected_attribute_idx]}],
                    unprivileged_groups=[
                        {protected_attribute: dataset.unprivileged_protected_attributes[protected_attribute_idx]}]
                ),
                scale=True
            )

            # 4. DisparateImpactRemover
            print('5. DisparateImpactRemover')
            result['dir'] = test_fairness_model(
                train, test,
                DisparateImpactRemover(sensitive_attribute=protected_attribute),
            )

            print(f'Saving "experiments/aif360/{dataset_name}_{protected_attribute}_{repeat}.pkl"')
            pd.DataFrame(result).to_pickle(f'experiments/aif360/{dataset_name}_{protected_attribute}_{repeat}.pkl')


def collect_results_for_dataset(dataset, protected_variable, repeat=10):
    """
    Method that collects all results (of all runs) for a combination of dataset and protected variable.
    :param dataset: Dataset (bank/compas/adult/german)
    :param protected_variable: Protected variable.
    :param repeat: Which runs to collect
    :return: Combined dataframe
    """
    data = defaultdict(list)
    for run in range(1, repeat+1):
        result_df = pd.read_pickle(f'experiments/aif360/{dataset}_{protected_variable}_{run}.pkl')

        # result_df.drop(['reweigh'], axis=1, inplace=True)
        rename = {
            'lfr': 'LFR', 'alfrs1': 'ALFR-S(1)', 'alfrs2': 'ALFR-S(2)', 'alfrs3': 'ALFR-S(3)',
            'dir': 'DIR', 'uncensored': 'Uncensored',
            'alfrds1': 'ALFR-DS(1)', 'alfrds2': 'ALFR-DS(2)', 'alfrds3': 'ALFR-DS(3)'
        }
        to_keep = {'idx', 'protected', 'target'}

        # Determine algorithms
        result_df.drop(set(result_df.columns) - set(rename.keys()) - to_keep, inplace=True, axis=1)
        result_df.rename(rename, inplace=True, axis=1)
        algos = set(result_df.columns) - to_keep

        for algo in algos:
            data['algo'].append(algo)
            data['run'].append(run)

            data['% BA'].append(balanced_accuracy_score(result_df['target'], result_df[algo]))
            data['DP'].append(demographic_parity_difference(result_df['target'], result_df[algo],
                                                            sensitive_features=result_df['protected']))

    return pd.DataFrame(data)


def get_all_combined_results(all_datasets):
    """
    Method that combines all results into one big table that is shown in the paper.
    :param all_datasets: Dictionary containing all the datasets.
    :return: Dataframe with all results.
    """
    all_results = []
    for dataset in all_datasets:
        for protected_variable in all_datasets[dataset]:
            # Collect results
            results = collect_results_for_dataset(dataset, protected_variable)
            means = results.drop('run', axis=1).groupby('algo').mean()
            # stds = results.drop('run', axis=1).groupby('algo').std()
            # combined = means.round(2).astype(str) + ' ± ' + stds.round(2).astype(str)
            combined = (means * 100).round(0).astype(int)

            all_results.append(combined.transpose())

    # Compute new results and set index
    all_results = pd.concat(all_results).transpose()
    tuples = [(k, i, d) for k in all_datasets for i in all_datasets[k] for d in ['BA', 'BP']]
    new_columns = pd.MultiIndex.from_tuples(tuples, names=['Dataset', 'Protected Variable', 'Metric'])
    all_results.columns = new_columns
    return all_results


if __name__ == "__main__":
    # Create argument parser
    parser = argparse.ArgumentParser(description='Run AIF360 study.')
    parser.add_argument('--run_adult', default=True, type=bool, help='Whether to test on Adult dataset.')
    parser.add_argument('--run_german', default=True, type=bool, help='Whether to test on German dataset.')
    parser.add_argument('--run_bank', default=True, type=bool, help='Whether to test on Bank dataset.')
    parser.add_argument('--run_compas', default=True, type=bool, help='Whether to test on COMPAS dataset.')

    args = parser.parse_args()

    all_datasets = dict()
    if args.run_adult:
        all_datasets['adult'] = {'race', 'sex'}
    if args.run_german:
        all_datasets['german'] = {'age', 'sex'}
    if args.run_bank:
        all_datasets['bank'] = {'age'}
    if args.run_compas:
        all_datasets['compas'] = {'race', 'sex'}

    for dataset in all_datasets:
        print(f'Running experiment {dataset}')
        run_experiment(dataset)

    all_results = get_all_combined_results(all_datasets)
    print('Saving results!')
    all_results.to_pickle('experiments/aif360/results.pkl')
    print('Saving LateX table for paper')
    with open('experiments/aif360/results_latex.txt', 'w') as f:
        f.write(all_results.to_latex())
