import pandas as pd
import numpy as np
import pickle
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import r2_score
from scipy.stats import spearmanr
import os


def reshape_data(data):
    out_data = []
    for i in range(len(data)):
        out_data.append(data[i])
    return np.array(out_data)


def get_gbm_model(mode='partents',
                  data_path='../../data/data.pickle',
                  model_path='../../model/gbm_model_parents.pt'):
    
    print('gbm model mode: ', mode)
    data = pickle.load(open(data_path, 'rb'))



    train_parent1 = data['parent1_embedding'].values
    train_parent1 = reshape_data(train_parent1)
    train_parent2 = data['parent2_embedding'].values
    train_parent2 = reshape_data(train_parent2)
    train_thought = data['thougt_embedding'].values
    train_thought = reshape_data(train_thought)
    train_query = data['query_embedding'].values
    train_query = reshape_data(train_query)

    train_label = data['score']


    # turn label into 0 and 1
    train_label = np.array(train_label)
    train_label[train_label < 0.5] = 0
    train_label[train_label >= 0.5] = 1
    train_label.astype(int)

    # build a lightgbm model
    # split train and test
    if mode == 'thought':  
        input_thought = np.concatenate((train_thought, train_query), axis=1)
        X_train, X_test, y_train, y_test = train_test_split(input_thought, train_label, test_size=0.8, random_state=42)
    elif mode == 'parents':
        input_parents = np.concatenate((train_parent1, train_parent2, train_query), axis=1)
        X_train, X_test, y_train, y_test = train_test_split(input_parents, train_label, test_size=0.8, random_state=42)

    # create dataset for lightgbm
    lgb_train = lgb.Dataset(X_train, y_train)
    lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)

    # specify configurations
    params = {
        'boosting_type': 'gbdt',
        'objective': 'binary',
        'metric': {'auc'},
        'num_leaves': 31,
        'learning_rate': 0.005,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 2,
        'verbose': 0
    }

    # try to load model
    try:
        gbm = lgb.Booster(model_file=model_path)
        print('Load model...')
    except:
        print('No model found, train a new one...')
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        # train the model
        print('Start training...')
        gbm = lgb.train(params,
                        lgb_train,
                        num_boost_round=100,
                        valid_sets=lgb_eval)

    # predict
    y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)

    # eval
    y_pred = np.array(y_pred)

    # try different threshold
    y_pred_orig = gbm.predict(X_test, num_iteration=gbm.best_iteration)
    print('r^2:', r2_score(y_test, y_pred_orig))
    print('spearmanr:', spearmanr(y_test, y_pred_orig))
    for i in range(1,20):
        threshold = np.quantile(y_pred_orig, 0.05 * i)
        print('----threshold: ---- quantile: ', 0.05 * i, 'value: ', threshold)
        y_pred = np.array(y_pred_orig)
        y_pred[y_pred < threshold] = 0
        y_pred[y_pred >= threshold] = 1
        y_pred.astype(int)
        print('acc: ', accuracy_score(y_test, y_pred))
        print('auc: ', roc_auc_score(y_test, y_pred))
        print('precision: ', precision_score(y_test, y_pred))
        print('recall: ', recall_score(y_test, y_pred))
        print('f1: ', f1_score(y_test, y_pred))
        print("pred 0", 1-np.sum(y_pred)/len(y_pred))
        print("true 0", 1-np.sum(y_test)/len(y_test))
        
        
        print('----------------------------------')
        


    # save model
    gbm.save_model(model_path)

    return gbm


if __name__ == '__main__':
    gbm_parents = get_gbm_model(mode='parents', data_path="results/collect_data/med_final_v2/data.pickle", model_path='model/gbm_model_parents_med_v2.pt')
    gbm_thought = get_gbm_model(mode='thought', data_path="results/collect_data/med_final_v2/data.pickle", model_path='model/gbm_model_thoughts_med_v2.pt')
    