import os
from Helpers.helper import py_torch_emb_col, w2v_emb_col, bert_emb_col
import pickle
import numpy as np
import pandas as pd
from gensim.models import Word2Vec
from sklearn.preprocessing import OneHotEncoder
from xgboost import XGBClassifier

curdir = '.'
source_path = os.path.abspath(os.curdir)
skill_cluster_path = "/../feature_data/competency_grp_data/skill_cluster_dict_40.pkl"
cluster_skill_path = "/../feature_data/competency_grp_data/cluster_skill_dict_40.pkl"
model_file_path = source_path + "/../model/3_class_model/"
test_path_folder = source_path + "/../dataset/3_class/"


def embedding_feature(model):
    skill_cluster_dict = pickle.load(open(source_path + skill_cluster_path, "rb"))
    rhs_embedding = pd.read_csv(source_path + "/../feature_data/check.csv")
    cluster_skill_dict = pickle.load(open(source_path + cluster_skill_path, "rb"))
    grp_40_df = pd.DataFrame(dict.keys(cluster_skill_dict), columns=["group_name"])
    # read cluster dictionary
    rhs_embedding['LABELS'] = rhs_embedding["1"].apply(lambda x: x.strip('[]'))
    if model == "pytorch":
        col_name = py_torch_emb_col
        embedding1 = rhs_embedding["LABELS"].values
        final_list = []
        for ele in embedding1:
            inter = ele.split()
            inter_1 = [float(i) for i in inter]
            final_list.append(inter_1)

    elif model == "bert":
        col_name = bert_emb_col
        print("source_path", source_path)
        bert_30 = pickle.load(
            open(source_path + "/../feature_data/SkillBERT_features/bert_pca_embeddings_128.pkl", "rb"))
        final_list = []
        for label in rhs_embedding["rhs"].values:
            label = label.replace("$", " ")
            if label == "qc 9.2":
                print("label..." + label, "===", bert_30[label])
            if label in bert_30:
                emb_30 = bert_30[label]
                emb_all = list(emb_30)
                final_list.append(emb_all)
            else:
                final_list.append(np.zeros(128))
    labels = rhs_embedding["rhs"].values
    dependent = []

    for i in range(len(labels)):
        dep = skill_cluster_dict.get(labels[i].replace("$", " "))
        if dep:
            dep_str = ",".join(dep)
        else:
            dep_str = None
        dependent.append(dep_str)

    embedding_df = pd.DataFrame(final_list, columns=col_name)
    training_dataframe = pd.DataFrame({"skill_name": labels, "label": dependent})
    training_dataframe = pd.concat([training_dataframe, embedding_df], axis=1)
    grp_40_df['key'] = 0
    training_dataframe['key'] = 0
    training_dataframe_new1 = pd.merge(training_dataframe, grp_40_df, on="key", how="outer")
    print("skill_count", len(training_dataframe_new1["skill_name"].drop_duplicates()))
    print("grp_count", len(training_dataframe_new1["label"].drop_duplicates()))
    return training_dataframe_new1.dropna().reset_index(drop=True)



def process_df(df):
    dataframe = df.dropna()
    label = dataframe[["label"]]
    enc = OneHotEncoder(handle_unknown='ignore')
    enc.fit(label)
    cat = enc.categories_[0]
    y_train_new = enc.transform(label).toarray()
    y_train_df = pd.DataFrame(data=y_train_new,
                              index=np.array(range(len(y_train_new))),
                              columns=np.array(range(40)))

    train_new_check = pd.concat([dataframe.reset_index(drop=True), y_train_df.reset_index(drop=True)], axis=1)
    label_drop_column = ['skill_name', "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13",
                         "d14", "d15", "d16", "d17", "d18",
                         "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30"]
    train_df_1 = train_new_check.groupby(label_drop_column).sum().reset_index()
    print("check")
    return train_df_1, cat

def process_cluster_data(cluster_data):
    enc = OneHotEncoder(handle_unknown='ignore')
    cluster_new = enc.fit_transform(cluster_data[['cluster']]).toarray()
    cluster_new_df = pd.DataFrame(data=cluster_new,
                                  index=np.array(range(len(cluster_new))),
                                  columns=np.array(range(40)))

    cluster_new_df_check = pd.concat([cluster_data.reset_index(drop=True), cluster_new_df.reset_index(drop=True)],
                                     axis=1)
    return cluster_new_df_check

def top_features(model):
    similarity_folder = "/../feature_data/SkillBERT_features/similarity_based_feature/"
    if model == "bert":
        top1 = pd.read_csv(source_path + similarity_folder + "bert_top1_dist.csv")
        top2 = pd.read_csv(source_path + similarity_folder + "bert_top2_dist.csv")
        top3 = pd.read_csv(source_path + similarity_folder + "bert_top3_dist.csv")

    top1_tr = top1.melt(id_vars=['skill_name'])
    top2_tr = top2.melt(id_vars=['skill_name'])
    top3_tr = top3.melt(id_vars=['skill_name'])

    top1_top2 = pd.merge(top1_tr, top2_tr, on=["skill_name", 'variable'])
    top1_top2_top3 = pd.merge(top1_top2, top3_tr, on=["skill_name", 'variable'])
    return top1_top2_top3


def dist_features(model):
    similarity_folder = "/../feature_data/SkillBERT_features/similarity_based_feature/"

    if model == "bert":
        group_similarity = pd.read_csv(source_path + similarity_folder + "bert_skill_grp_similarity.csv")
        group_similarity["skill_name"] = group_similarity["skill_name"].str.replace("$", " ")

    return group_similarity


def get_labels(n_classes):
    skill_cluster_dict = pickle.load(open(source_path + skill_cluster_path, "rb"))
    if n_classes == 2:
        skill_list = []
        group_list = []
        for index, key in enumerate(skill_cluster_dict):
            groups = skill_cluster_dict[key]
            for group in groups:
                skill_list.append(key.replace(" ", "$"))
                group_list.append(group)
        data_tuples = list(zip(skill_list, group_list))
        label_df = pd.DataFrame(data_tuples, columns=["skill_name", "group_name"])
        label_df["label"] = 1
        return label_df
    else :
        labels_df = pd.read_csv(source_path+"/../feature_data/3_class_label.csv")
        labels_df["skill_name"] = labels_df["skill_name"].str.replace(" ","$")
        return labels_df

def group_wise_prim_sec():
    competency_grp_path = "/../feature_data/competency_grp_data/"
    primary_skill_group = pickle.load(open(source_path + competency_grp_path + 'primary_skill.pkl', "rb"))
    primary_skill_count = {}
    secondary_skill_group = pickle.load(open(source_path + competency_grp_path + 'secondary_skill.pkl', "rb"))
    secondary_skill_count = {}
    for index, key in enumerate(primary_skill_group):
        primary_skill_count[key] = len(primary_skill_group[key])
    for index, key in enumerate(secondary_skill_group):
        secondary_skill_count[key] = len(secondary_skill_group[key])
    primary_skill_count_df = pd.DataFrame.from_dict(primary_skill_count, orient='index',
                                                    columns=['prim_count'])
    primary_skill_count_df["group_name"] = primary_skill_count_df.index
    secondary_skill_count_df = pd.DataFrame.from_dict(secondary_skill_count, orient='index',
                                                      columns=['sec_count'])
    secondary_skill_count_df["group_name"] = secondary_skill_count_df.index
    group_prim_sec = pd.merge(primary_skill_count_df, secondary_skill_count_df, on=["group_name"])
    return group_prim_sec


def accuracy_prediction_and_skill_classification(model, test_data, test_data_label, test, type):
    prediction = model.predict_proba(test_data)
    from sklearn.metrics import classification_report
    pred_check = np.argmax(prediction, axis=1)
    scores = classification_report(test_data_label, pred_check,digits=4)
    print(scores)


def training(final_df):
    feature_importance_path = "/../feature_data/feature_importance/"
    test_path_folder = source_path + "/../dataset/3_class/"
    '''
       TFIDF value is pre-calculated for train and test set to reduce data preparation time.
       To calculate TFIDF for test dataset using training dataset please refer to function - get_tf_idf_feature()
       inside Word2vec_only.py
       '''
    TFIDF_path = "/../feature_data/TFIDF/"
    tf_idf_train_py = pd.read_csv(source_path + TFIDF_path + "train_data_tf_idf.csv")
    tf_idf_test_py = pd.read_csv(source_path + TFIDF_path + "test_data_tf_idf.csv")
    train = pd.merge(final_df, tf_idf_train_py, on="skill_name")
    test = pd.merge(final_df, tf_idf_test_py, on="skill_name")
    train.to_csv(test_path_folder + "train_2_class.csv", index=False)
    test.to_csv(test_path_folder + "test_2_class.csv", index=False)
    train_new = train.drop(
        ["skill_name", "label", "variable", "group_name", "group_name_x", "group_name_y"], axis=1)
    train.to_csv("train_new1.csv",index=False)
    test_new = test.drop(
        ["skill_name", "label", "variable", "group_name", "group_name_x", "group_name_y"], axis=1)
    test.to_csv("test_new1.csv",index=False)
    X_train = train_new.drop(columns=["label_new"])
    Y_train = train_new[["label_new"]]
    X_test = test_new.drop(columns=["label_new"])
    Y_test = test_new[["label_new"]]
    xgb_model = XGBClassifier(n_jobs=-1, max_depth=5, n_estimators=800, objective='multi:softprob', num_class=3)
    file_name = "3_class_model_xgb.pkl"
    xgb_model.fit(X_train, Y_train)
    pickle.dump(xgb_model, open(model_file_path + file_name, "wb"))
    col_list = []
    score_list = []
    for col, score in zip(X_train.columns, xgb_model.feature_importances_):
        col_list.append(col)
        score_list.append(score)
        print(col, score)
    data_tuples = list(zip(col_list, score_list))
    label_df = pd.DataFrame(data_tuples, columns=["feature_name", "score"])
    label_df.to_csv(source_path + feature_importance_path + "feature_imp_s_class.csv", index=False)
    print("-----------------" + "Training Completed Successfully" + "-----------------")



def prepare_data():
    # ---------- Embedding features-------------------------
    bert_embedding_df = embedding_feature("bert")
    # ---------------Spectral clustering features
    cluster_data = pd.read_csv(source_path + "/../feature_data/skill_group_spectral_clustering.csv")
    # cluster_data = pd.read_csv(source_path + "/../cluster_analysis_graph.csv")
    cluster_df = process_cluster_data(cluster_data)
    cluster_df["skill_name"] = cluster_df["skill_name"].str.replace(" ", "$")
    # ----------- Merge spectral clustering and embedding features----------
    emb_sc_df = pd.merge(bert_embedding_df, cluster_df, on="skill_name", how="left")
    # -----------Drop irreleven columns----------------
    emb_sc_df.drop(['cluster', 'key', 'label'], axis=1, inplace=True)
    # -----------Merge Top1,Top2,Top3 related features
    top_features_bert = top_features("bert")
    top_features_bert["skill_name"] = top_features_bert["skill_name"].str.replace(" ", "$")
    top_features_embedding_bert = pd.merge(emb_sc_df, top_features_bert, left_on=["skill_name", "group_name"],
                                           right_on=["skill_name", "variable"])
    # -----------Merge distance related features
    dist_features_bert = dist_features("bert")
    dist_features_bert["skill_name"] = dist_features_bert["skill_name"].str.replace(" ", "$")
    bert_all_feature = pd.merge(top_features_embedding_bert, dist_features_bert, on=["skill_name", "variable"])
    # -----------prim and secn group count related features
    grp_prim_sec_df = group_wise_prim_sec()
    # -----------Merge distance, top1-3 features, prim and sec-------------
    dist_top_prim_sec_feature = pd.merge(bert_all_feature, grp_prim_sec_df, left_on=["variable"],
                                         right_on=["group_name"])
    dist_top_prim_sec_feature["skill_name"] = dist_top_prim_sec_feature["skill_name"].str.replace(" ", "$")

    # -------------------Attach Bert embedding result--------------------
    bert_result = pd.read_csv(source_path + "/../feature_data/SkillBERT_features/bert_results.csv")
    dist_top_prim_sec_feature["group_name"] = dist_top_prim_sec_feature["group_name_x"]
    final_features1 = pd.merge(dist_top_prim_sec_feature, bert_result, on=["skill_name", "group_name"])
    # ----------- Attach label ----------------------------
    # get_labels(2) for binary classification and get_labels(3) for prim and sec classification
    labels = get_labels(3)

    final_df = pd.merge(final_features1, labels, on=["skill_name", "group_name"], how="left")
    final_df["label_new"] = np.where(final_df["label"] == "prim", 2, np.where(final_df["label"] == "sec", 1, 0))
    return final_df
def testing(test_path):
    test_data = pd.read_csv(test_path_folder + test_path)
    test_new = test_data.drop(["skill_name", "label", "variable", "group_name", "group_name_x", "group_name_y"], axis=1)
    X_test = test_new.drop(columns=["label_new"])
    Y_test = test_new[["label_new"]]
    # load
    xgb_model = pickle.load(open(model_file_path + "3_class_model_xgb.pkl", "rb"))

    results = xgb_model.predict(X_test)
    print("-----------------" + "Testing Completed Successfully" + "-----------------")

if __name__ == "__main__":
    '''
    Below function is used to prepare training data
    '''
    data_df = prepare_data()
    '''
    Below function is used for model training
    '''
    training(data_df)
    '''
    Below function is used to load the model trained in previous step on test data
    you can replace "test_2_class.csv" with your own test data.
    It also takes an parameter to specify which model you want to use.
    1. xgb - Loads XGBOOST
    2. rf - Loads Random Forest
    3. bi-lstm - Loads Bi-LSTM
    '''
    testing("test_2_class.csv")
