from runners.gbdt_runner import GBDTRunner
from xgboost import XGBClassifier
import numpy as np

from types import SimpleNamespace
from typing import List, Dict, Any
from numpy.typing import NDArray
import pandas as pd
import logging

class XGBRunner(GBDTRunner):

    def __init__(self, 
                config: SimpleNamespace, 
                data: pd.DataFrame, 
                labels: pd.Series, 
                numeric_cols: List[str],
                category_cols: List[str],
                logger: logging.Logger
        ) -> None:
        del category_cols
        
        super().__init__(config, data, labels, numeric_cols, logger)
        self.name = 'XGB'
    
    def get_model(self, 
                hparams: Dict[str, Any]
        ) -> XGBClassifier:
        hparams['n_jobs'] = self.config.model.n_jobs
        # hparams["tree_method"] = "hist"
        return XGBClassifier(**hparams)

    def fit_model(self, 
                    model: XGBClassifier, 
                    train_idx: NDArray[np.int_], 
                    test_idx: NDArray[np.int_], 
                    hparams: Dict[str, Any] = None, 
                    pseudo_data: pd.DataFrame = None, 
                    pseudo_label: NDArray[np.int_] = None, 
                    fold_idx: int = None
        ) -> XGBClassifier:
        del hparams, fold_idx
        
        if test_idx is not None:
            (X_train, y_train), (X_test, y_test) = self.get_train_test_from_idx(train_idx, test_idx)
        else:
            X_train, y_train = self.data.loc[train_idx], self.label.loc[train_idx]
        
        if pseudo_data is not None and pseudo_label is not None:
            X_train = X_train.append(pseudo_data)
            y_train = np.concatenate((y_train.values, pseudo_label))
        # y_train = y_train.append(pseudo_label)
        if self.config.runner_option.use_CV:
            return model.fit(X_train,  y_train, eval_set=[(X_test, y_test)], early_stopping_rounds = self.config.model.early_stopping_patience,
                            verbose = self.config.model.verbosity)
        else:
            return model.fit(X_train,  y_train, eval_set=[(self.X_valid, self.y_valid)], early_stopping_rounds = self.config.model.early_stopping_patience,
                            verbose = self.config.model.verbosity)

    def save_model(self, 
                    model: XGBClassifier, 
                    saving_path: str = None, 
                    fold_idx: int = None
        ) -> str:
        if saving_path == None:
            saving_path = f"model/{self.start_time}-{self.config.data.target}-{self.config.runner_option.model}"
            # saving_path += '.json'
            # model.save_model(saving_path)
        if fold_idx != None:
            saving_path += '-fold%d.json'
            model.save_model(saving_path % fold_idx)
        else:
            saving_path += '.json'
            model.save_model(saving_path)
        
        return saving_path

    def load_model(self, 
                    model_path: str, 
                    fold_idx: int = None
        ) -> XGBClassifier:
        model = XGBClassifier()
        if fold_idx is not None:
            model.load_model(model_path % fold_idx)
        else:
            model.load_model(model_path)
        model.n_jobs = self.config.model.n_jobs
        return model