import copy
from Helpers.helper import bert_emb_col
import pickle
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from xgboost import XGBClassifier
import os

curdir = '.'
source_path = os.path.abspath(os.curdir)
skill_cluster_path = "/../feature_data/competency_grp_data/skill_cluster_dict_40.pkl"
cluster_skill_path = "/../feature_data/competency_grp_data/cluster_skill_dict_40.pkl"
model_file_path = source_path + "/../model/bert_pre_train_only/"
test_path_folder = source_path + "/../dataset/bert_pre_train_only/"


def embedding_feature(model):
    skill_cluster_dict = pickle.load(open(source_path + skill_cluster_path, "rb"))
    rhs_embedding = pd.read_csv(source_path + "/../feature_data/check.csv")
    cluster_skill_dict = pickle.load(open(source_path + cluster_skill_path, "rb"))
    grp_40_df = pd.DataFrame(dict.keys(cluster_skill_dict), columns=["group_name"])
    # read cluster dictionary
    rhs_embedding['LABELS'] = rhs_embedding["1"].apply(lambda x: x.strip('[]'))

    if model == "bert_pre_train":
        bert_128 = pd.read_csv(source_path + "/../feature_data/bert_pre_trained_features/pre_train_bert_128.csv")
    labels = rhs_embedding["rhs"].values
    dependent = []

    for i in range(len(labels)):
        dep = skill_cluster_dict.get(labels[i].replace("$", " "))
        if dep:
            dep_str = ",".join(dep)
        else:
            dep_str = None
        dependent.append(dep_str)

    training_dataframe = pd.DataFrame({"skill_name": labels, "label": dependent})
    bert_128["skill_name"] = bert_128["skill_name"].str.replace(" ", "$")
    training_dataframe = pd.merge(training_dataframe, bert_128, on="skill_name")
    grp_40_df['key'] = 0
    training_dataframe['key'] = 0
    training_dataframe_new1 = pd.merge(training_dataframe, grp_40_df, on="key", how="outer")
    print("skill_count", len(training_dataframe_new1["skill_name"].drop_duplicates()))
    print("grp_count", len(training_dataframe_new1["label"].drop_duplicates()))
    return training_dataframe_new1.dropna().reset_index(drop=True)


def process_df(df):
    dataframe = df.dropna()
    label = dataframe[["label"]]
    enc = OneHotEncoder(handle_unknown='ignore')
    enc.fit(label)
    cat = enc.categories_[0]
    y_train_new = enc.transform(label).toarray()
    y_train_df = pd.DataFrame(data=y_train_new,
                              index=np.array(range(len(y_train_new))),
                              columns=np.array(range(40)))

    train_new_check = pd.concat([dataframe.reset_index(drop=True), y_train_df.reset_index(drop=True)], axis=1)
    label_drop_column = ['skill_name', "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13",
                         "d14", "d15", "d16", "d17", "d18",
                         "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30"]
    train_df_1 = train_new_check.groupby(label_drop_column).sum().reset_index()
    print("check")
    return train_df_1, cat


def process_cluster_data(cluster_data):
    enc = OneHotEncoder(handle_unknown='ignore')
    cluster_new = enc.fit_transform(cluster_data[['cluster']]).toarray()
    cluster_new_df = pd.DataFrame(data=cluster_new,
                                  index=np.array(range(len(cluster_new))),
                                  columns=np.array(range(100)))

    cluster_new_df_check = pd.concat([cluster_data.reset_index(drop=True), cluster_new_df.reset_index(drop=True)],
                                     axis=1)
    return cluster_new_df_check


def top_features(model):
    similarity_folder = "/../feature_data/bert_pre_trained_features/similarity_based_feature/"

    if model == "bert_pre_train":
        top1 = pd.read_csv(source_path + similarity_folder + "bert_pre_train_top1_dist.csv")
        top2 = pd.read_csv(source_path + similarity_folder + "bert_pre_train_top2_dist.csv")
        top3 = pd.read_csv(source_path + similarity_folder + "bert_pre_train_top3_dist.csv")

    top1_tr = top1.melt(id_vars=['skill_name'])
    top2_tr = top2.melt(id_vars=['skill_name'])
    top3_tr = top3.melt(id_vars=['skill_name'])

    top1_top2 = pd.merge(top1_tr, top2_tr, on=["skill_name", 'variable'])
    top1_top2_top3 = pd.merge(top1_top2, top3_tr, on=["skill_name", 'variable'])
    return top1_top2_top3


def dist_features(model):
    similarity_folder = "/../feature_data/bert_pre_trained_features/similarity_based_feature/"
    if model == "bert_pre_train":
        group_similarity = pd.read_csv(source_path + similarity_folder  + "bert_pre_train_skill_grp_similarity.csv")
        group_similarity["skill_name"] = group_similarity["skill_name"].str.replace("$", " ")

    return group_similarity


def get_labels():
    skill_cluster_dict = pickle.load(open(source_path + skill_cluster_path, "rb"))
    skill_list = []
    group_list = []
    for index, key in enumerate(skill_cluster_dict):
        groups = skill_cluster_dict[key]
        for group in groups:
            skill_list.append(key.replace(" ", "$"))
            group_list.append(group)
    data_tuples = list(zip(skill_list, group_list))
    label_df = pd.DataFrame(data_tuples, columns=["skill_name", "group_name"])
    label_df["label"] = 1
    return label_df


def group_wise_prim_sec():
    competency_grp_path = "/../feature_data/competency_grp_data/"
    primary_skill_group = pickle.load(open(source_path + competency_grp_path + 'primary_skill.pkl', "rb"))
    primary_skill_count = {}
    secondary_skill_group = pickle.load(open(source_path + competency_grp_path + 'secondary_skill.pkl', "rb"))
    secondary_skill_count = {}
    for index, key in enumerate(primary_skill_group):
        primary_skill_count[key] = len(primary_skill_group[key])
    for index, key in enumerate(secondary_skill_group):
        secondary_skill_count[key] = len(secondary_skill_group[key])
    primary_skill_count_df = pd.DataFrame.from_dict(primary_skill_count, orient='index',
                                                    columns=['prim_count'])
    primary_skill_count_df["group_name"] = primary_skill_count_df.index
    secondary_skill_count_df = pd.DataFrame.from_dict(secondary_skill_count, orient='index',
                                                      columns=['sec_count'])
    secondary_skill_count_df["group_name"] = secondary_skill_count_df.index
    group_prim_sec = pd.merge(primary_skill_count_df, secondary_skill_count_df, on=["group_name"])
    return group_prim_sec


def accuracy_prediction_and_skill_classification(model, test_data, test_data_label, test, type):
    prediction = model.predict_proba(test_data)
    from sklearn.metrics import classification_report
    cutoffs = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
    for cutoff in cutoffs:
        result_list_inter = []
        result_list_inter.append(cutoffs)
        if type == "dnn":
            pred_check = copy.deepcopy(prediction)
        else:
            pred_check = copy.deepcopy(prediction[:, 1])

        print("------cut off is--------", cutoff)
        pred_check[pred_check >= cutoff] = 1
        pred_check[pred_check < cutoff] = 0
        scores = classification_report(test_data_label, pred_check, digits=4)
        print(scores)

def training(final_df):
    feature_importance_path = "/../feature_data/feature_importance/"
    test_path_folder = source_path + "/../dataset/bert_pre_train_only/"
    '''
       TFIDF value is pre-calculated for train and test set to reduce data preparation time.
       To calculate TFIDF for test dataset using training dataset please refer to function - get_tf_idf_feature()
       inside Word2vec_only.py
       '''
    TFIDF_path = "/../feature_data/TFIDF/"
    tf_idf_train_py = pd.read_csv(source_path + TFIDF_path + "train_data_tf_idf.csv")
    tf_idf_test_py = pd.read_csv(source_path + TFIDF_path + "test_data_tf_idf.csv")

    train = pd.merge(final_df, tf_idf_train_py, on="skill_name")
    test = pd.merge(final_df, tf_idf_test_py, on="skill_name")
    train.to_csv(test_path_folder + "train_2_class.csv", index=False)
    test.to_csv(test_path_folder + "test_2_class.csv", index=False)

    train_new = train.drop(
        ["skill_name", "label_x", "label_y", "variable", "group_name", "key", "group_name_x", "group_name_y"], axis=1)
    test_new = test.drop(
        ["skill_name", "label_x", "label_y", "variable", "group_name", "key", "group_name_x", "group_name_y"], axis=1)
    X_train = train_new.drop(columns=["label_new"])
    Y_train = train_new[["label_new"]]
    X_test = test_new.drop(columns=["label_new"])
    Y_test = test_new[["label_new"]]
    xgb_model = XGBClassifier(n_jobs=-1, max_depth=1, n_estimators=100)
    file_name = "bert_pre_train_only_xgb.pkl"

    # save
    xgb_model.fit(X_train, Y_train)
    pickle.dump(xgb_model, open(model_file_path + file_name, "wb"))
    col_list = []
    score_list = []
    for col, score in zip(X_train.columns, xgb_model.feature_importances_):
        col_list.append(col)
        score_list.append(score)
        print(col, score)
    data_tuples = list(zip(col_list, score_list))
    label_df = pd.DataFrame(data_tuples, columns=["feature_name", "score"])
    label_df.to_csv(source_path+feature_importance_path+"feature_imp_only_bre_pre_train.csv", index=False)
    print("-----------------" + "Training Completed Successfully" + "-----------------")

def prepare_data():
    # ---------- Embedding features-------------------------
    bert_embedding_df = embedding_feature("bert_pre_train")
    # -----------Merge Top1,Top2,Top3 related features
    top_features_bert = top_features("bert_pre_train")
    top_features_bert["skill_name"] = top_features_bert["skill_name"].str.replace(" ","$")
    top_features_embedding_bert = pd.merge(bert_embedding_df, top_features_bert, left_on=["skill_name", "group_name"],right_on=["skill_name","variable"])
    # -----------Merge distance related features
    dist_features_bert = dist_features("bert_pre_train")
    dist_features_bert["skill_name"] = dist_features_bert["skill_name"].str.replace(" ","$")
    bert_all_feature = pd.merge(top_features_embedding_bert, dist_features_bert, on=["skill_name", "group_name"])
    # -----------prim and secn group count related features
    grp_prim_sec_df = group_wise_prim_sec()
    # -----------Merge distance, top1-3 features, prim and sec-------------
    dist_top_prim_sec_feature = pd.merge(bert_all_feature, grp_prim_sec_df, left_on=["variable"],
                                         right_on=["group_name"])
    dist_top_prim_sec_feature["skill_name"] = dist_top_prim_sec_feature["skill_name"].str.replace(" ", "$")

    #-------------------Attach Bert embedding result--------------------
    bert_result = pd.read_csv(source_path+"/../feature_data/bert_pre_trained_features/bert_results_exp_bert_base_1.csv")
    dist_top_prim_sec_feature["group_name"] = dist_top_prim_sec_feature["group_name_x"]
    final_features1 = pd.merge(dist_top_prim_sec_feature, bert_result, on=["skill_name", "group_name"])
    # ----------- Attach label ----------------------------
    # get_labels(2) for binary classification and get_labels(3) for prim and sec classification
    labels = get_labels()

    final_df = pd.merge(final_features1, labels, on=["skill_name", "group_name"], how="left")
    final_df["label_new"] = np.where(final_df["label_y"] == 1, 1,0)
    # ------------ Run classifier----------
    return final_df

def testing(test_path):
    test_data = pd.read_csv(test_path_folder + test_path)
    test_new = test_data.drop( ["skill_name", "label_x", "label_y", "variable", "group_name", "key", "group_name_x", "group_name_y"], axis=1)
    X_test = test_new.drop(columns=["label_new"])
    # load
    xgb_model = pickle.load(open(model_file_path + "bert_pre_train_only_xgb.pkl", "rb"))
    results = xgb_model.predict_proba(X_test)
    print("-----------------" + "Testing Completed Successfully" + "-----------------")


if __name__ == "__main__":
    '''
    Below function is used to prepare training data
    '''
    data_df = prepare_data()
    '''
    Below function is used for model training
    '''
    training(data_df)
    '''
    Below function is used to load the model trained in previous step on test data
    you can replace "test_2_class.csv" with your own test data
    '''
    testing("test_2_class.csv")
