import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from lib import dataload as dl
from lib import utils
from lib.config import DATA_TASK_MAP, PREDEFINED_TEST, XGB_SAVE_DIR, DATA_DIR, NUM_CPU
from lib.shap import get_shap_fi

import numpy as np
import random
import xgboost as xgb
from hyperopt import fmin, hp, STATUS_OK, tpe, Trials
from hyperopt.early_stop import no_progress_loss
from hyperopt.pyll.stochastic import sample
from lib.hpot import EarlystopWithWarmup
from functools import partial
from sklearn import metrics
from time import time
import pandas as pd


def get_xgb_fi(booster,column_names=None):
    fi_dict = booster.trees_to_dataframe().groupby("Feature").mean("Gain").drop("Leaf")["Gain"].sort_index()
    if column_names is not None:
        fi_dict = fi_dict.rename({f"f{no}":value for no,value in enumerate(info["columns"])},axis=0).sort_values(ascending=False)
    else:
        fi_dict = fi_dict.sort_values(ascending=False)
    return fi_dict.reset_index().rename({"Gain":"Value"},axis=1)

def booster_eval_to_list(string, return_metric=False):
    metric = [str(elem.split(":")[0].split("-")[-1]) for elem in string.split("\t")[1:]]
    score = [float(elem.split(":")[-1]) for elem in string.split("\t")[1:]]
    if return_metric:
        return score,metric
    else:
        return score
    
class XGBObjWrapper:

    def __init__(self,task,train_data, metrics, early_stop, valid_data=None, test_data=None, nfold=None,seed=None, maximize=False):
        if task == "bc":
            self._loss_func = "binary:logistic"
            num_class = len(np.unique(train_data[-1]))
            assert num_class == 2            
            self._adtnl_param = {"eval_metric":["logloss","auc","error"]} # 임시변경
        elif task == "mc":
            self._loss_func = "multi:softmax"
            num_class = len(np.unique(train_data[-1]))
            assert num_class > 2
            self._adtnl_param = {"eval_metric":["mlogloss","merror"],"num_class":num_class}
        elif task == "r":
            self._loss_func = "reg:squarederror"
            self._adtnl_param = {"eval_metric":["mae","rmse"]}
        else:
            raise ValueError(f"Valid task is one of 'r', 'bc', 'mc' : {task}" )

        assert len(train_data) == 2
        assert len(valid_data) == 2
        
        self._train_data = train_data
        self._valid_data = valid_data
        self._test_data = test_data

        self._nfold = nfold
        self._early_stop = early_stop
        self._metrics = metrics
        self._seed = seed
        self._maximize = maximize
        
        self._best_loss = np.inf
        self._best_model = None

    def objective(self,params):
        start = time()
        random.seed(self._seed)
        np.random.seed(self._seed)

        _params = {key:value for key,value in params.items()}
        _params.update(self._adtnl_param)
        if "seed" not in _params:_params.update({"seed":self._seed})
        if "objective" not in _params:_params.update({"objective":self._loss_func})
        if "nthread" not in _params:_params.update({"nthread":NUM_CPU})
        
        train_data = xgb.DMatrix(*self._train_data)
        
        num_boost_round = int(_params.get("n_estimators",10000))
        if "n_estimators" in _params : del _params["n_estimators"]
        _params["reg_alpha"] = _params["reg"]
        _params["reg_lambda"] = _params["reg"]
        del _params["reg"]
        
        if self._valid_data is None:
            assert False,"CrossValidation is not implemented"
            result = xgb.cv(_params,train_data,
                            num_boost_round=num_boost_round,
                            nfold=self._nfold,metrics=self._metrics,seed=self._seed,
                            early_stopping_rounds=self._early_stop, show_stdv= False,
                            maximize = self._maximize)

            stopped_round = result.iloc[:,1].argmax() if self._maximize else result.iloc[:,1].argmin()
            loss = result.iloc[stopped_round,1]
        else:
            valid_data = xgb.DMatrix(*self._valid_data)
            test_data=  xgb.DMatrix(*self._test_data)
            
            es_callback = xgb.callback.EarlyStopping(rounds=self._early_stop, metric_name=self._adtnl_param["eval_metric"][-1],
                                                     maximize=self._maximize,data_name="valid",save_best=True)
            result = xgb.train(_params,train_data,num_boost_round=num_boost_round,evals=[(train_data,"train"),(test_data,"test"),(valid_data,"valid")],
                               callbacks=[es_callback], verbose_eval =0)
                               
            stopped_round=result.best_iteration
            loss=result.best_score

        test_result = None
        _params["n_estimators"] = stopped_round
        
        if (self._best_loss > loss):
            self._best_loss = loss
            self._best_model = result
            self._best_param = _params
        else:
            pass                
                
        return {"loss":loss, "params": _params,'stopped_round': stopped_round,'train_time': time()-start, 'status': STATUS_OK, "test":test_result}

def run(run,seed,tag,save_path,
        hpot_repeat,
        xgb_es_round,
        hpot_es_round,
        n_startup=20,warmup=20, el_candid=24
       ):
    data_dir = f"{DATA_DIR}/{run}"
    predefined_test = run in PREDEFINED_TEST
    task = DATA_TASK_MAP[run]
    
    XGB_METRIC_MAP = {"bc":"error","r":"rmse","mc":"merror"} 
    result_dict = {"param":None,"metric":XGB_METRIC_MAP[task],"train":None,"valid":None,"test1":None,"test2":None}
    
    data_load_func = getattr(dl,f"load_{run}")


    if predefined_test:
        X_train,X_valid,X_test,y_train,y_valid,y_test,info = data_load_func(data_dir,seed=seed,as_frame=False,normalize=True)
    else:
        X_train,X_valid,X_test,y_train,y_valid,y_test,info = data_load_func(data_dir,test_size=0.2,seed=seed,as_frame=False,normalize=True)

    space = {
        'eta': hp.loguniform('eta', np.log(0.01), np.log(0.5)),
        'gamma': hp.qloguniform('gamma', np.log(1e-4),np.log(0.5), 1e-3),
        'min_child_weight':hp.qloguniform('min_child_weight', np.log(1e-4), np.log(5.0), 1e-3) if task!="r" else hp.quniform('min_child_weight', 0.5,30.5,1),
        'reg':hp.qloguniform('reg', np.log(1e-4), np.log(3.0), 1e-3),       
        'colsample_bytree': hp.uniform('colsample_by_tree', 0.3, 1.0),
        'colsample_bylevel': hp.uniform('colsample_bylevel', 0.3, 1.0),
        'subsample':hp.uniform('subsample',0.3,1.0),
        'max_depth':hp.uniformint("max_depth",2,14),
    }    
        
        
    xgb_obj_wrapper = XGBObjWrapper(task,(X_train,y_train), XGB_METRIC_MAP[task], xgb_es_round,
                            valid_data=(X_valid,y_valid), test_data=(X_test,y_test),
                            nfold=None,seed=seed, maximize=False)
    xgb_obj = xgb_obj_wrapper.objective
    
    es_func = EarlystopWithWarmup(hpot_es_round,warmup,higher_better=False)
    algo = partial(tpe.suggest, n_startup_jobs=n_startup, n_EI_candidates=el_candid)
    trials = Trials()
    best = fmin(fn = xgb_obj, space = space, algo = algo, max_evals = hpot_repeat, 
                trials = trials,early_stop_fn = es_func, rstate=np.random.default_rng(seed))
    
    
    
    best_with_es = xgb_obj_wrapper._best_param
    if task == "r":
        model_class = xgb.XGBRegressor
        score_func = lambda y_true,y_pred : np.sqrt(metrics.mean_squared_error(y_true,y_pred))
    else:
        model_class = xgb.XGBClassifier
        score_func = lambda y_true,y_pred : metrics.accuracy_score(y_true,y_pred)

    mdl1 = xgb_obj_wrapper._best_model
    train_scores, metric_names = booster_eval_to_list(mdl1.eval(xgb.DMatrix(X_train,y_train)), return_metric=True)
    valid_scores, metric_names = booster_eval_to_list(mdl1.eval(xgb.DMatrix(X_valid,y_valid)), return_metric=True)
    test1_scores, metric_names = booster_eval_to_list(mdl1.eval(xgb.DMatrix(X_test,y_test)), return_metric=True)
    
    if task!="r":
        # Since Accuracy
        train_scores[-1] = 1.0 - train_scores[-1]
        valid_scores[-1] = 1.0 - valid_scores[-1]
        test1_scores[-1] = 1.0 - test1_scores[-1]
        metric_names= ["loss","acc"] if task == "mc" else ["loss","auc","acc"]
    else:
        metric_names= ["loss","rmse"]
        

    # model.json
    xgb_obj_wrapper._best_model.save_model(os.path.join(save_path,f"model.json"))
    
    # result.json
    result_dict["param"] = best_with_es
    result_dict["metric"] = metric_names    
    result_dict["train"] = train_scores
    result_dict["valid"] = valid_scores    
    result_dict["test1"] = test1_scores
    result_dict["info"] = {"scale":str(info.get("scale",None)),"task":task,"num_feat":str(info.get("num_feat",None))}   
    
    
    utils.save_json(result_dict,os.path.join(save_path,f"result.json"))
    
    # hpot.csv
    result_list = [elem["result"] for elem in trials.trials]
    hyperparam_list = [elem["params"] for elem in result_list]
    hyperparam_df = pd.DataFrame(hyperparam_list )
    other_list = [{key:value for key,value in elem.items() if key!="params"} for elem in result_list]
    other_df = pd.DataFrame(other_list )
    final_df = pd.concat([other_df,hyperparam_df],keys=["info","params"],axis=1)    
    final_df.reset_index(drop=False,inplace=True)
    final_df.to_csv(os.path.join(save_path,f"hpot.csv"),index=False)
    
    # Save feature_importances, and optionally shap values
    return trials

if __name__ == "__main__":
    
    import argparse
    parser  = argparse.ArgumentParser()
    parser.add_argument("--run",type=str)    
    parser.add_argument("--seed",type=int,default=None)
    parser.add_argument("--tag",type=str)
    parser.add_argument("--hpot_repeat",type=int,default=100)
    parser.add_argument("--xgb_early_stop",type=int,default=50)
    parser.add_argument("--hpot_early_stop",type=int,default=20)
    parser.add_argument("--desc",type=str,default="")
    
    parser.add_argument("--n_startup",type=int,default=20)
    parser.add_argument("--warmup",type=int,default=0)
    parser.add_argument("--el_candid",type=int,default=24)
    
    args = parser.parse_args()

    if args.seed is not None: 
        np.random.seed(args.seed)
    RESULT_SAVE_DIR = f"{XGB_SAVE_DIR}/{args.run}/{args.tag}-{args.seed}"
    os.makedirs(RESULT_SAVE_DIR,exist_ok=True)
    print("XGBoost : {}-{}-{} : {}".format(args.run, args.seed, args.tag, args.desc))
    trials = run(args.run, args.seed, args.tag, RESULT_SAVE_DIR, args.hpot_repeat, args.xgb_early_stop, args.hpot_early_stop, 
                 args.n_startup, args.warmup, args.el_candid)