import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from utils import read_data, check_data, adjust_dataset_size, split_labels, ColInfo
import numpy as np
from sklearn.preprocessing import LabelEncoder


def data_preprocess_global(dst, selected_labels, y_name):
    if selected_labels is not None:
        dst = dst[(dst[y_name] == selected_labels[0]) | (dst[y_name] == selected_labels[1])]
    dst.replace('unknown', np.NaN, inplace=True)
    dst.dropna(inplace=True)
    dst.drop(['default'], axis=1, inplace=True)
    dst.drop(['duration'], axis=1, inplace=True)
    dst.drop(['emp.var.rate', 'nr.employed'], axis=1, inplace=True)
    # dst["campaign"] = dst["campaign"].apply(lambda x: 8 if x > 8 else x)
    # dst["previous"] = dst["previous"].apply(lambda x: 2 if x >= 2 else x)
    dst.drop(['campaign'], axis=1, inplace=True)
    dst.drop(['previous', 'pdays'], axis=1, inplace=True)
    return dst


def get_bank_marketing_data(file_path):
    target_col_name = 'y'
    selected_labels = ['no', 'yes']
    var_numerical = ["age", "cons.price.idx", "cons.conf.idx", "euribor3m"]

    dst = read_data(file_path, delimiter=';')

    # ================= preprocessing ===========================
    dst = data_preprocess_global(dst, selected_labels, y_name='y')

    categorical_columns = ['contact', 'poutcome', 'job', 'month', 'marital', 'day_of_week', 'housing', 'loan',
                           'education']
    categorical_dims = {}
    for col in categorical_columns:
        # print(col, dst[col].nunique())
        l_enc = LabelEncoder()
        dst[col] = l_enc.fit_transform(dst[col].values)
        categorical_dims[col] = len(l_enc.classes_)

    unused_feat = []
    features = [col for col in dst.columns if col not in unused_feat + [target_col_name]]
    cat_idxs = [i for i, f in enumerate(features) if f in categorical_columns]
    cat_dims = [categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]
    col_info = ColInfo(cate_idxs=cat_idxs, cate_dims=cat_dims, cont_name=var_numerical, cate_name=categorical_columns)
    check_data(dst)

    dst = adjust_dataset_size(dst, action_type=1, y_name=target_col_name, sample_rate=0.1)
    dst_x, dst_y = split_labels(dst, y_name=target_col_name)
    dst_y = dst_y.map(lambda x: 0 if x == selected_labels[0] else 1)
    return dst_x, dst_y, col_info

