import pandas as pd
import numpy as np

from copy import deepcopy


def pre_process_dataset(random_seed, dataset, DATA_SIZE, training=True):
    if dataset == 'heloc':
        X, y, means, std, encoder = pre_process_heloc(DATA_SIZE, training)
    if dataset == 'adult':
        X, y, means, std, encoder = pre_process_adult(DATA_SIZE, training)
    if dataset == 'german_credit':
        X, y, means, std, encoder = pre_process_german_credit(DATA_SIZE, training)
    return X, y, means, std, encoder


class ReversibleOneHotEncoder:
    
    def __init__(self):
        self.columns_info = {}
        self.original_columns = None

    def fit(self, df, columns):
        """
        Fits the encoder to the DataFrame.
        
        :param df: Input DataFrame.
        :param columns: List of columns to one-hot encode.
        """
        self.original_columns = df.columns.tolist()
        for column in columns:
            self.columns_info[column] = df[column].unique().tolist()

    def transform(self, df):
        """
        Transforms the DataFrame by one-hot encoding specified columns.
        
        :param df: Input DataFrame.
        :return: DataFrame with one-hot encoded columns.
        """
        df_encoded = df.copy()
        for column, categories in self.columns_info.items():
            dummies = pd.get_dummies(df_encoded[column], prefix=column)
            # Ensure all categories are present in the dummies DataFrame
            for category in categories:
                dummy_col = f"{column}_{category}"
                if dummy_col not in dummies.columns:
                    dummies[dummy_col] = 0
            df_encoded = pd.concat([df_encoded, dummies], axis=1)
            df_encoded.drop(column, axis=1, inplace=True)
        
        return df_encoded

    def inverse_transform(self, df_encoded):
        """
        Reverses the one-hot encoding.
        
        :param df_encoded: One-hot encoded DataFrame.
        :return: Original DataFrame with categorical columns.
        """
        df_decoded = df_encoded.copy()
        
        for column, categories in self.columns_info.items():
            # Create a DataFrame with the one-hot encoded columns
            col_dummies = df_encoded[[f"{column}_{cat}" for cat in categories]]
            # Get the original category back by finding the column with value 1
            df_decoded[column] = col_dummies.idxmax(axis=1).str.replace(f"{column}_", "")
            df_decoded.drop([f"{column}_{cat}" for cat in categories], axis=1, inplace=True)
        
        return df_decoded[self.original_columns]




def normalize(data, dataset, means, std, encoder, reverse=False):
    """
    for normalizing (and opposite) the data according to baseline
    """
    x = deepcopy(data)
    
    if dataset == 'heloc':
        categorical_idxs = []
        col_idxs = [0,1,2,3]
    elif dataset == 'adult':
        categorical_idxs = [0, 2, 3, 6, 7]       
        col_idxs = [0,1,2,3,4,5,6,7]
    elif dataset == 'german_credit':
        categorical_idxs = [0,2,3]
        col_idxs = [0,1,2,3,4]
    else:
        raise TypeError('Not a valid dataset name')
        
    if reverse:
        x = encoder.inverse_transform(x)
        for col_idx in col_idxs:
            if col_idx not in categorical_idxs:
                x[ x.columns[col_idx] ] = x[ x.columns[col_idx] ].astype('float64')
                x[ x.columns[col_idx] ] = x.apply(lambda row: (row[x.columns[col_idx]] * std[col_idx]) + means[col_idx], axis=1)
    else:
        for col_idx in col_idxs:
            if col_idx not in categorical_idxs:
                x[ x.columns[col_idx] ] = x[ x.columns[col_idx] ].astype('float64')
                x[ x.columns[col_idx] ] = x.apply(lambda row: (row[x.columns[col_idx]] - means[col_idx]) / std[col_idx], axis=1)
        x = encoder.transform(x)
    return x


def pre_process_heloc(DATA_SIZE, training):
    
    # Load the dataset
    heloc_df = pd.read_csv('datasets/heloc_dataset_v1.csv')    
    
    X = heloc_df[['MSinceMostRecentInqexcl7days', 'NumRevolvingTradesWBalance', 'NumTradesOpeninLast12M', 'NumInqLast6M']]    
    
    # Favourable outcome is Good
    y = heloc_df.RiskPerformance
    y = y.replace('Bad', 0.)
    y = y.replace('Good', 1.).values
    
    for col in X.columns:
        X[col] = X[col].astype('int')
    # Remove negative values
    X['label'] = y
    X = X[(X >= 0).all(axis=1)]
    
    if training:
        X = X.iloc[:DATA_SIZE]
    else:
        X = X.iloc[DATA_SIZE:DATA_SIZE*2]
    
    y = X.label
    del X['label']

    columns = X.columns
    means = [0 for i in range(X.shape[-1])]
    std = [1 for i in range(X.shape[-1])]
        
    # Standarize continuous features
    heloc_categorical_names = []
    for col_idx, col in enumerate(X.columns):
        if col not in heloc_categorical_names:
            means[col_idx] = X[col].mean(axis=0)
            std[col_idx] = X[col].std(axis=0)
            X[col] = X[col].astype('float')
            
    encoder = ReversibleOneHotEncoder()
    encoder.fit(X, columns=heloc_categorical_names)  
            
    return X, y, means, std, encoder


def pre_process_adult(DATA_SIZE, training):
    
    adult_df = pd.read_csv('datasets/adult.csv').reset_index(drop=True)
    adult_df.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation',
                        'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week',
                        'native-country', 'label']  # proper name of each of the features
    
    adult_df = adult_df.dropna()
    
    #  We use the variables in the causal graph of Nabi & Shpitser, 2018
    adult_df = adult_df.drop(['fnlwgt', 'education', 'occupation', 'relationship', 'capital-gain',
                              'capital-loss'], axis=1)
            
    # Filter dataset to binary choice for these features
    adult_df = adult_df[adult_df['workclass'].isin([' Private', ' Self-emp-not-inc'])]
    adult_df = adult_df[adult_df['race'].isin([' White', ' Black'])]
        
    # Only want 1000 rows
    if training:
        adult_df = adult_df.iloc[:DATA_SIZE]
    else:
        adult_df = adult_df.iloc[DATA_SIZE:DATA_SIZE*2]
        
    adult_df['native-country-United-States'] = adult_df.apply(lambda row: 'yes' if 'United-States' in row['native-country'] else 'no', axis=1)
    adult_df['marital-status-Married'] = adult_df.apply(lambda row: 'yes' if 'Married' in row['marital-status'] else 'no', axis=1)
    adult_df['isMale'] = adult_df.apply(lambda row: 'yes' if 'Male' in row['sex'] else 'no', axis=1)
    adult_df['workclass-Private'] = adult_df.apply(lambda row: 'yes' if ' Private' in row['workclass'] else 'no', axis=1)
    adult_df['isWhite'] = adult_df.apply(lambda row: 'yes' if ' White' in row['race'] else 'no', axis=1)
    
    adult_df = adult_df.drop(['native-country', 'marital-status', 'sex', 'workclass', 'race'], axis=1)
    X = adult_df.drop('label', axis=1)

    # Target is whether the individual has a yearly income greater than 50k
    y = adult_df['label'].replace(' <=50K', 0.0)
    y = y.replace(' >50K', 1.0).values

    # Re-arange to follow the causal graph
    columns = ['isMale', 'age', 'native-country-United-States', 'marital-status-Married', 'education-num', 'hours-per-week', 'workclass-Private', 'isWhite']
    X = X[columns]    
        
    # Standarize continuous features
    means = [0 for i in range(X.shape[-1])]
    std = [1 for i in range(X.shape[-1])]
    adult_categorical_names = ['isMale', 'native-country-United-States', 'marital-status-Married', 'workclass-Private', 'isWhite']
    for col_idx, col in enumerate(X.columns):
        if col not in adult_categorical_names:
            means[col_idx] = X[col].mean(axis=0)
            std[col_idx] = X[col].std(axis=0)
            X[col] = X[col].astype('float')
    encoder = ReversibleOneHotEncoder()
    encoder.fit(X, columns=adult_categorical_names)
        
    return X, y, means, std, encoder



def pre_process_german_credit(DATA_SIZE, training):
    
    df = pd.read_csv('datasets/SouthGermanCredit.asc', sep=" ")
    
    df.status = df.status.replace({1: 'no checking account',
                                     2: '< 0 DM',
                                     3: '0 <= ... <= 200 DM', 
                                     4: '>= 200 DM / salary for at least 1 year'
                                  })

    df.credit_history = df.credit_history.replace({
                               0:'Delay in paying off in the past',
                               1:'Critical account/other credits elsewhere',
                               2:'No credits taken/all credits paid back duly',
                               3:'Existing credits paid back duly till now', 
                               4:'All credits at this bank paid back duly'})

    df.purpose = df.purpose.replace({0:'Others',
                                     1:'Car (new)',
                                     2:'Car (Used)',
                                     3:'Furniture/equipment',
                                     4:'Radio/television',
                                     5:'Domestic Applicances',
                                     6:'Repairs',
                                     7:'Education',
                                     8:'Vacation',
                                     9:'Retraining',
                                     10:'Business'
                                    })

    df.savings = df.savings.replace({1:'unknown/no savings account',
                                   2:'... <  100 DM',
                                   3:'100 <= ... <  500 DM',
                                   4:'500 <= ... < 1000 DM',
                                   5: '... >= 1000 DM'})

    df.employment_duration = df.employment_duration.replace({1:'unemployed',
                                   2:'< 1 yr',
                                   3:'1 <= ... < 4 yrs',
                                   4:'4 <= ... < 7 yrs',
                                   5: '>= 7 yrs '})

    df.installment_rate = df.installment_rate.replace({1:'>= 35',
                                   2:'25 <= ... < 35',
                                   3:'20 <= ... < 25',
                                   4:'< 20'})

    df.personal_status_sex = df.personal_status_sex.replace({1:'divorced/separated',
                                   2:'non-single or male : single',
                                   3:'married/widowed',
                                   4:'single'})

    df.other_debtors = df.other_debtors.replace({1:'none',
                                   2:'co-applicant',
                                   3:'guarantor',
                                })

    df.present_residence = df.present_residence.replace({1:'< 1 yr',
                                   2:'1 <= ... < 4 yrs',
                                   3:'4 <= ... < 7 yrs',
                                   4:'>= 7 yrs'})

    df.property = df.property.replace({1:'unknown / no property',
                                   2:'car or other',
                                   3:'building soc. savings agr./life insurance',
                                   4:'real estate'})

    df.other_installment_plans = df.other_installment_plans.replace({1:'bank',
                                   2:'stores',
                                   3:'none',
                                })

    df.housing = df.housing.replace({1:'for free',
                                   2:'rent',
                                   3:'own',
                                })

    df.number_credits = df.number_credits.replace({1:'1',
                                   2:'2-3',
                                   3:'4-5',
                                   4:'>= 6'})

    df.job = df.job.replace({1:'unemployed/unskilled - non-resident',
                                   2:'unskilled - resident',
                                   3:'skilled employee/official',
                                   4:'manager/self-empl./highly qualif. employee'})

    df.people_liable = df.people_liable.replace({
                                   1:'3 or more',
                                   2:'0 to 2'})

    df.telephone = df.telephone.replace({1:'no',
                                   2:'yes (under customer name)'})

    df.foreign_worker = df.foreign_worker.replace({1:'yes',
                                   2:'no'})
    
    # Only want 800 rows training
    if training:
        df = df.iloc[:500]
    else:
        df = df.iloc[500:]
    
    y = df.credit_risk.values
        
    # Re-arange to follow the causal graph
    columns = ['status', 'duration', 'credit_history', 'purpose', 'amount']
    X = df[columns]

    # Standarize continuous features#
    means = [0 for i in range(X.shape[-1])]
    std = [1 for i in range(X.shape[-1])]
    
    german_categorical_names = ['status', 'credit_history', 'purpose']
    for col_idx, col in enumerate(X.columns):
        if col not in german_categorical_names:
            means[col_idx] = X[col].mean(axis=0)
            std[col_idx] = X[col].std(axis=0)
            X[col] = X[col].astype('float')
    encoder = ReversibleOneHotEncoder()
    encoder.fit(X, columns=german_categorical_names)
    
    # Round age to be int but  
    for col in ['amount', 'duration']:
        X[col] = X[col].astype('int')   
                    
    return X, y, means, std, encoder








