from abc import ABC, abstractmethod
import torch
import numpy as np
import random
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.metrics import f1_score
from datetime import datetime
import optuna

import scikitplot as skplt
import matplotlib.pyplot as plt

from sklearn.metrics import f1_score, recall_score, accuracy_score, confusion_matrix, accuracy_score
from sklearn.metrics import roc_auc_score, precision_score, recall_score
from sklearn import metrics

import os
from BorutaShap import BorutaShap

# from misc.likelihood import Likelihood
import pandas as pd
import logging
import sys
from pandarallel import pandarallel
import pickle
from misc import custom_metric

from sklearn.manifold import TSNE
import matplotlib.colors as mcolors

from types import SimpleNamespace
from typing import Tuple, List, Dict, Union, Any
from numpy.typing import NDArray

from torch.nn import Module
from runners.model_types import PTs, GBDTs, DreamquarkTabNet

import torchmetrics

import logging

import gc

import importlib

import prior_knowledge
import data_editor
from torchmetrics import Precision, Recall, F1Score, Accuracy

class Runner(ABC):
    """The object that run the given experiment

    Attributes:
        config: A configuration of the given experiment.
        data: A pandas dataframe of the given data's features.
        label: A pandas series of the given data's labels.
        random_seed: The random seed of the given experiment.
        numeric_cols: The list of numercial columns' name.
        logger: The logger of the given experiment.
        start_time: The start time of the given experiment.
        scorer_dict: A dictionary of parameters for scorer.
        scorer: The performance metric for the given experiment.
        model_name: A name of model of the given experiment.
        likelihoods: The likelihood modules for each fold for the given experiment.
    """
    def __init__(self, 
                config: SimpleNamespace, 
                data: pd.DataFrame, 
                label: pd.Series,
                numeric_cols: List[str], 
                logger: logging.Logger
        ) -> None:
        """Inits Runner

        Args:
            config: A configuration of the given experiment.
            data: Features of data.
            label: Labels for the given data.
            numeric_cols: A list of name for each numerical columns.
            logger: A logger which log all.
        """
        
        self.config = config
        self.random_seed = self.config.runner_option.random_seed
        self.set_random_seed(self.random_seed)
        
        self.init_data(data, label, numeric_cols)
        
        self.logger = logger

        self.start_time = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
        self.model_name = None
        
        self.init_eval_metric()
        # pandarallel.initialize(progress_bar=False, nb_workers = self.config.n_jobs)

        if self.config.self_training.alpha > 0 or self.config.runner_option.auto_alpha:
            if self.config.runner_option.use_CV:
                self.priors = [None for _ in range(self.config.KFold.n_splits)]
            else:
                self.priors = None
        
        if self.config.runner_option.data_editor is not None:
            if self.config.runner_option.use_CV:
                self.editors = [None for _ in range(self.config.KFold.n_splits)]
            else:
                self.editors = None
        
        
        self.init_calibrator()
                
        if not os.path.exists('temporary_ckpt_data'):
            os.mkdir('temporary_ckpt_data')
        if not os.path.exists(f'temporary_ckpt_data/{self.config.data.target}/{self.config.runner_option.model}/{self.config.runner_option.random_seed}/{self.start_time}'):
            os.makedirs(f'temporary_ckpt_data/{self.config.data.target}/{self.config.runner_option.model}/{self.config.runner_option.random_seed}/{self.start_time}', exist_ok=True)
        
        if self.config.runner_option.save_model and not os.path.exists(f"model/{self.config.data.target}/{self.config.runner_option.model}/{self.config.runner_option.random_seed}"):
            os.makedirs(f"model/{self.config.data.target}/{self.config.runner_option.model}/{self.config.runner_option.random_seed}", exist_ok=True)
    
    def init_data(self, data: pd.DataFrame, label: pd.Series, numeric_cols: List[str]) -> None:
        """Init data for a given experiment setting

        Args:
            data (pd.DataFrame,): Features of data.
            label (pd.Series): Labels for the given data.
            numeric_cols (List[str]): A list of name for each numerical columns.
        """
        if self.config.runner_option.use_CV:
            if not isinstance(data, tuple):
                train_idx, test_idx, _, _ = train_test_split(np.array([label.index]).reshape((-1, 1)), label.to_numpy(), test_size=self.config.data.test_size, random_state=self.random_seed, stratify=label.to_numpy())
                train_idx, test_idx = train_idx.ravel(), test_idx.ravel()
                
                self.X_test, self.y_test = data.loc[test_idx], label.loc[test_idx]
                
                self.data = data.loc[train_idx]
                self.label = label.loc[train_idx]
            else:
                self.data, self.label = data[0], label[0]
                self.X_test, self.y_test = data[1], label[1]
        else:
            if not isinstance(data, tuple):
                train_idx, test_idx, _, _ = train_test_split(np.array([label.index]).reshape((-1, 1)), label.to_numpy(), test_size=self.config.data.test_size + self.config.data.valid_size, random_state=self.random_seed, stratify=label.to_numpy())
                train_idx, test_idx = train_idx.ravel(), test_idx.ravel()
                
                valid_idx, test_idx, _, _ = train_test_split(np.array([label.loc[test_idx].index]).reshape((-1, 1)), label.loc[test_idx].to_numpy(), test_size=self.config.data.test_size / (self.config.data.test_size + self.config.data.valid_size), random_state=self.random_seed, stratify=label.loc[test_idx].to_numpy())
                valid_idx, test_idx = valid_idx.ravel(), test_idx.ravel()
                
                self.X_test, self.y_test = data.loc[test_idx], label.loc[test_idx]
                self.X_valid , self.y_valid = data.loc[valid_idx], label.loc[valid_idx]
                
                self.data = data.loc[train_idx]
                self.label = label.loc[train_idx]
            else:
                train_idx, valid_idx, _, _ = train_test_split(np.array([label[0].index]).reshape((-1, 1)), label[0].to_numpy(), test_size=self.config.data.valid_size, random_state=self.random_seed, stratify=label[0].to_numpy())
                train_idx, valid_idx = train_idx.ravel(), valid_idx.ravel()
                
                self.data, self.label = data[0].loc[train_idx], label[0].loc[train_idx]
                self.X_valid, self.y_valid = data[0].loc[valid_idx], label[0].loc[valid_idx]
                self.X_test, self.y_test = data[1], label[1]

        self.numeric_cols = numeric_cols
        
    def init_eval_metric(self) -> None:
        """Init evaluation metric"""
        
        self.scorer_dict = {}
        
        for param in self.config.data.metric_params:
            self.scorer_dict[param[0]] = param[1]

        if hasattr(metrics, self.config.data.metric):
            self.scorer = getattr(metrics, self.config.data.metric)
        elif hasattr(custom_metric, self.config.data.metric):
            self.scorer = getattr(custom_metric, self.config.data.metric)(label)
        elif hasattr(torchmetrics, self.config.data.metric):
            self.scorer = getattr(torchmetrics, self.config.data.metric)
        else:
            raise("Unknown Scorer")
        
    def init_calibrator(self) -> None:
        """Init the spline calibrator or the latent Gaussian process calibrator"""
        if self.config.runner_option.use_spline_calibrator:
            if self.config.runner_option.use_CV:
                self.splines = [None for _ in range(self.config.KFold.n_splits)]
            else:
                self.splices = None
        
        if self.config.runner_option.use_gaussian_process:
            if self.config.runner_option.use_CV:
                self.gps = [None for _ in range(self.config.KFold.n_splits)]
            else:
                self.gps = None
        
    @abstractmethod
    def get_model(self, 
                hparams: Dict[str, Any]
        ) -> Union[GBDTs, PTs, DreamquarkTabNet, Module]:
        """Return a model for given parameters
        
        Args:
            hparams:  Parameters for a model
        
        Returns:
            A model for given parameters
        """
        pass

    
    @abstractmethod
    def fit_model(self, 
                    model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                    train_idx: NDArray[np.int_], 
                    test_idx: NDArray[np.int_], 
                    hparams: Dict[str, Any] = None, 
                    pseudo_data: pd.DataFrame = None, 
                    pseudo_label: NDArray[np.int_] = None, 
                    fold_idx: int = 0
        ) -> Union[GBDTs, PTs, DreamquarkTabNet, Module]:
        """Train the model for the given data
        
        Args:
            model: A model to train
            train_idx: Indices of training data in self.data and self.label.
            test_idx: Indices of test data in self.data and self.label.
            hparams: Parameters of a model.
            pseudo_data: A dataframe of unlabeled data.
            pseudo_label: A array of pseudo labels
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.

        Returns:
            A trained model.
        """
        pass
    
    @abstractmethod
    def predict(self, 
                model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                test_idx: np.array = None, 
                X_test: pd.DataFrame = None, 
                return_prob: bool = False,
                return_logits: bool = False,
        ) -> Union[Tuple[NDArray[np.int_], NDArray[np.float_]], NDArray[np.int_]]:
        """Predict labels of the given data.

        When the test_idx is given and X_test is None, predict labels using the data for test_idx in self.data,
        and when the test_idx is None and X_test is given, predict labels using the X_test.

        Args:
            model: A trained model.
            test_idx: Indices of test data in self.data
            X_test: A dataframe to use for predicting its labels.
            return_probs: Returning probability of predicted labels or not.

        Returns:
            When return_probs is True, return the predicted labels and their probability,
            and when return_probs is False, return only the predicted labels.
        """
        pass

    @abstractmethod
    def save_model(self, 
                    model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                    saving_path: str = None, 
                    fold_idx: int = None
        ) -> str:
        """Save the given model.
        
        Args:
            model: The model that have to be saved.
            saving_path: The path where to save the given model.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.
        
        Returns:
            The saving path.
        """
        pass
    
    @abstractmethod
    def load_model(self, 
                    model_path: str, 
                    fold_idx: int = None
        ) -> Union[GBDTs, PTs, DreamquarkTabNet, Module]:
        """Load model which is in the given path.

        Args:
            model_path: The path of a model that have to loaded.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.
        
        Returns:
            A loaded model.
        """
        pass

    @abstractmethod
    def clear_model_cache(self, 
                            path: str
        ) -> None:
        """Clear the cached data in the given path.
        
        Args:
            path: The path where to clean.
        """
        pass
    
    def get_score(self, 
                    y_true: NDArray[np.int_], 
                    y_pred: NDArray[np.int_]
        ) -> float:
        """Generates a score of the predicted result.

        Args:
            y_true: Ground truth labels.
            y_pred: Predicted labels.
        
        Returns:
            A score of the given result using self.scorer.
        """
        self.scorer_dict['y_true'] = y_true
        self.scorer_dict['y_pred'] = y_pred
        return self.scorer(**self.scorer_dict)

    def run(self) -> None:
        """Running the experiment
        
        When the report option is on, only running reporting,
        and the parameter_search_only option is on, terminate the experiment after searching hyperparameters.
        """
        optuna.logging.set_verbosity(self.config.optuna.verbosity)

        if self.config.runner_option.report:
            self.report()
            return
            
        if self.config.runner_option.save_best_hparams:
            if not os.path.exists('hparams'):
                os.mkdir('hparams')

        if self.config.runner_option.save_model:
            if not os.path.exists('model'):
                os.mkdir('model')

        hparams = self.get_hparams()

        if self.config.runner_option.hparam_search_only:
            return
        
        std = 0

        if self.config.runner_option.use_CV:
            score, std = self.get_kfold_score(hparams)
        else:
            self.set_random_seed(self.random_seed)
            model = self.get_model(hparams)
            train_idx, _ = self.train_test_split_idx()
            if self.config.runner_option.feature_corruption > 0:
                random_idx = np.random.randint(0, len(train_idx))
                random_idx = train_idx[random_idx]
                random_sample = np.array(self.data[random_idx].values, dtype=np.float64)
                
                corruption_mask = np.zeros_like(self.data[train_idx].values, dtype=np.bool)
                for i in range(len(train_idx)):
                    corruption_idx =  torch.randperm(self.data.shape[1])[: self.config.runner_option.feature_corruption]
                    corruption_mask[i, corruption_idx] = True
                
                self.data[train_idx].values = np.array(torch.where(corruption_mask, random_sample, self.data[train_idx].values))
                
            if self.config.runner_option.limited_training_sample:
                train_idx = self.get_limited_samples_idx(self.data.index)
            model, _, score = self.fit_and_predict(model, train_idx, test_idx = None, is_test = True, hparams=hparams)
            
            if self.config.runner_option.save_model:
                gc.disable()
                self.save_model(model)
                gc.enable()
        self.logger.info("Best Score = %f , Std = %f" % (score, std))
        
    def objective(self, 
                    trial: optuna.trial.Trial, 
                    train_idx: NDArray[np.int_]= None, 
                    test_idx: NDArray[np.int_] = None, 
                    fold_idx: int = 0
        ) -> float:
        """Objective function for optuna

        Args:
            trial: A object which returns hyperparameters of a model of hyperparameter search trial.
            train_idx: Indices of training data in self.data and self.label.
            test_idx: Indices of test data in self.data and self.label.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.
        
        Returns:
            A score of given hyperparameters.
        """
        hparams = {}
        for k, v in self.config.optuna.hparams.items():
            hparams[k] = getattr(trial, v[0])(*v[1])

        model = self.get_model(hparams)
        if (type(train_idx) == type(None)) or (type(test_idx) == type(None)):
            # train_idx, test_idx = self.train_test_split_idx()
            if self.config.runner_option.limited_training_sample:
                train_idx = self.get_limited_samples_idx(self.data.index)

        model, preds, score = self.fit_and_predict(model, train_idx, test_idx, hparams = hparams, fold_idx = fold_idx)
        return score

    def get_train_test_from_idx(self, 
                                train_idx: NDArray[np.int_]= None, 
                                test_idx: NDArray[np.int_] = None
        ) -> Tuple[Tuple[np.array, np.array], Tuple[np.array, np.array]]:
        """Generates (features, labels) set of the given indices.

            Args:
                train_idx: Indices of training data in self.data and self.label.
                test_idx: Indices of test data in self.data and self.label.
            Returns:
                A pair of training set and test set.
        """
        X_train = y_train = X_test = y_test = None
        if type(train_idx) != type(None):
            X_train = self.data.loc[train_idx]
            y_train = self.label.loc[train_idx]
        
        if type(test_idx) != type(None):
            X_test = self.data.loc[test_idx]
            y_test = self.label.loc[test_idx]

        return (X_train, y_train), (X_test, y_test)

    def set_random_seed(self, 
                        random_seed: int
        ) -> None:
        """Set random seed

        Args:
            random_seed: A seed to set for the experiment.
        """
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        torch.cuda.manual_seed(random_seed)
        torch.cuda.manual_seed_all(random_seed)
        torch.backends.cudnn.deterministic = True

    def save_hparams(self, 
                    hparams: Dict[str, Any]
        ) -> None:
        """Saves the given hyperparameters.

        Args:
            hparams: The hyperparameters to save.
        """
        path = f'hparams/{self.start_time}-{self.config.data.target}-{self.config.runner_option.model}.pickle'
        with open(path, 'wb') as f:
            pickle.dump(hparams, f)

    def train_test_split_idx(self) -> Tuple[NDArray[np.int_], NDArray[np.int_]]:
        """Splits a indices of a dataset of the experiment to training and test indices.

        Returns:
            Training indices and test indices.
        """
        train_idx, test_idx, _, _ = train_test_split(np.array([self.label.index]).reshape((-1, 1)), self.label.to_numpy(), test_size=0.33, random_state=self.random_seed, stratify=self.label.to_numpy())
        return train_idx.ravel(), test_idx.ravel()

    def init_prior(self, 
                        train_idx: NDArray[np.int_], 
                        fold_idx: int
        ) -> None:
        """Initialize likelihoods of each label per data sample.

        Args:
            train_idx: Indices of training data in self.data and self.label.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.
        """
        
        # lib = importlib.import_module('prior_knowledge')
        # Prior = getattr(lib, self.config.runner_option.prior)
        Prior = getattr(prior_knowledge, self.config.runner_option.prior)
        if len(self.config.data.minimum_cols) > 0:
            prior = Prior(self.data.loc[:, self.config.data.minimum_cols], self.data.loc[self.label.index].loc[train_idx, self.config.data.minimum_cols], self.label.loc[train_idx], continuous_cols=self.numeric_cols, n_jobs=self.config.n_jobs)
        else:
            prior = Prior(self.data, self.data.loc[self.label.index].loc[train_idx, :], self.label.loc[train_idx], continuous_cols=self.numeric_cols, n_jobs=self.config.n_jobs)
        if self.config.runner_option.use_CV:
            self.priors[fold_idx] = prior
        else:
            self.priors = prior
            
    def init_data_editer(self, 
                        train_idx: NDArray[np.int_], 
                        fold_idx: int
        ) -> None:
        """Initialize likelihoods of each label per data sample.

        Args:
            train_idx: Indices of training data in self.data and self.label.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.
        """
        
        if hasattr(self.config.data, 'cached_data_editor_path') and \
            (
                (self.config.runner_option.use_CV and os.path.exists(self.config.data.cached_data_editor_path % (self.random_seed, fold_idx)))\
                or (not self.config.runner_option.use_CV and os.path.exists(self.config.data.cached_data_editor_path % self.random_seed))\
            ):
            if self.config.runner_option.use_CV:
                editor = pickle.load(open(self.config.data.cached_data_editor_path % (self.random_seed, fold_idx), 'rb'))
            else:
                editor = pickle.load(open(self.config.data.cached_data_editor_path % self.random_seed), 'rb')
        else:
            Editor = getattr(data_editor, self.config.runner_option.data_editor)
            if len(self.config.data.minimum_cols) > 0:
                editor = Editor(self.data.loc[:, self.config.data.minimum_cols], self.data.loc[self.label.index].loc[train_idx, self.config.data.minimum_cols], self.label.loc[train_idx], continuous_cols=self.numeric_cols, n_jobs=self.config.n_jobs)
            else:
                editor = Editor(self.data, self.data.loc[self.label.index].loc[train_idx, :], self.label.loc[train_idx], continuous_cols=self.numeric_cols, n_jobs=self.config.n_jobs)
        if self.config.runner_option.use_CV:
            self.editors[fold_idx] = editor
            if self.config.runner_option.save_data_editor:
                pickle.dump(editor, open(self.config.data.cached_data_editor_path % (self.random_seed, fold_idx), 'wb'))
        else:
            self.editors = editor
            if self.config.runner_option.save_data_editor:
                pickle.dump(editor, open(self.config.data.cached_data_editor_path % self.random_seed, 'wb'))
            
    def fit_and_predict(self, 
                        model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                        train_idx: NDArray[np.int_], 
                        test_idx: Union[NDArray[np.int_], None],
                        hparams: Dict[str, Any] = None, 
                        fold_idx: int = 0,
                        is_test: bool = False,
        ) -> Tuple[Union[GBDTs, PTs, DreamquarkTabNet, Module], NDArray[np.int_], float]:
        """Fits a model according to the self-training option and return a score of it.
            
        Args:
            model: A model to train
            train_idx: Indices of training data in self.data and self.label.
            test_idx: Indices of test data in self.data and self.label.
            hparams: Parameters of a model.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.
        Returns:
            A trained model, its predicted result of test data, and score of the result.
        """
        
        if self.config.runner_option.self_training == "naive":
            self.fit = self.naive_self_training
        elif "curriculum" in self.config.runner_option.self_training:
            self.fit = self.curriculum_self_training
        elif "fixed" in self.config.runner_option.self_training:
            self.fit = self.fixed_threshold_self_training
        else:
            self.fit = self.fit_model
        
        if self.config.runner_option.auto_alpha or (self.config.self_training.alpha > 0 and ((self.priors is None) or (fold_idx is not None and self.priors[fold_idx] is None))):
            self.init_prior(train_idx, fold_idx)
        if self.config.runner_option.data_editor is not None and ((self.editors is None) or (fold_idx is not None and self.editors[fold_idx] is None)):
            self.init_data_editer(train_idx, fold_idx)
        model = self.fit(model = model, train_idx = train_idx, test_idx = test_idx, hparams = hparams, fold_idx = fold_idx)
        
            
        if is_test:
            preds = self.predict(model, X_test = self.X_test, fold_idx = fold_idx, is_test = is_test)
            score = self.get_score(self.y_test, preds)
        elif self.config.runner_option.use_CV:
            preds = self.predict(model, test_idx, fold_idx = fold_idx)
            score = self.get_score(self.label.loc[test_idx], preds)
        else:
            preds = self.predict(model, X_test = self.X_valid, fold_idx = fold_idx, is_test = is_test)
            score = self.get_score(self.y_valid, preds)
        return model, preds, score
    
    def get_kfold_score(self, 
                        hparams: Dict[str, Any], 
        ) -> Tuple[float, float]:
        """Gets k-fold cross validation score and std of the given hyperparameters.

        Args:
            hparams: Parameters of a model.
        
        Returns:
            Average score of k-fold and standard deviation of it.
        """
        fold = StratifiedKFold(n_splits=self.config.KFold.n_splits, shuffle=True, random_state=self.random_seed)
        scores = []

        if self.config.runner_option.feature_corruption > 0:
            self.base_data = self.data.copy(deep=True)
            self.corruption_len = int(self.config.runner_option.feature_corruption * self.data.shape[1])
                
        for fold_idx, (train_idx, test_idx) in enumerate(fold.split(self.data.loc[self.label.index], self.label.to_numpy())):
            
            train_idx = self.data.loc[self.label.index].index[train_idx]
            
            if self.config.runner_option.feature_corruption > 0:
                # self.base_data = self.data.copy(deep=True)
                
                random_idx = np.random.randint(0, len(train_idx))
                random_idx = train_idx[random_idx]
                random_sample = torch.tensor(torch.from_numpy(self.data.loc[random_idx].values), dtype=torch.float)
                
                corruption_mask = torch.zeros_like(torch.from_numpy(self.data.loc[train_idx].values), dtype=torch.bool)
                for i in range(len(train_idx)):
                    corruption_idx =  torch.randperm(self.data.values.shape[1])[: self.corruption_len]
                    corruption_mask[i, corruption_idx] = True
                
                # dt = self.data.values
                # print(dt[train_idx])
                # dt[train_idx] = torch.where(corruption_mask, random_sample, torch.from_numpy(self.data.loc[train_idx].values)).numpy()
                # self.data = pd.DataFrame(dt, columns=self.data.columns, index=self.data.index)
                self.data.loc[train_idx] = torch.where(corruption_mask, random_sample, torch.from_numpy(self.data.loc[train_idx].values)).numpy()
                
            if self.config.runner_option.limited_training_sample:
                train_idx = self.get_limited_samples_idx(train_idx)
            test_idx = self.data.loc[self.label.index].index[test_idx]
            if self.config.runner_option.limited_validation_sample:
                test_idx = self.get_limited_samples_idx(test_idx, int(self.config.runner_option.limited_validation_sample))
            model = self.get_model(hparams)
            # if self.config.self_training.alpha > 0 and self.priors[fold_idx] is None:
            #     self.init_prior(train_idx, fold_idx)

            # model, _, score = self.fit_and_predict(model, train_idx, test_idx, hparams = hparams, fold_idx = fold_idx)
            
            model, _, score = self.fit_and_predict(model, train_idx, test_idx = test_idx, hparams = hparams, fold_idx = fold_idx, is_test = True)
            scores.append(score)
            
            if self.config.runner_option.save_model:
                gc.disable()
                self.save_model(model, fold_idx=fold_idx)
                gc.enable()
            
            if hasattr(self, 'pretrained_model_path'):
                self.clear_pretrained_model()
            
            if self.config.runner_option.feature_corruption > 0:
                self.data = self.base_data.copy(deep=True)

        score = np.mean(scores)
        std = np.std(scores)
        return score, std

    def get_limited_samples_idx(self, 
                                indice: NDArray[np.int_],
                                size: int = None
        ) -> NDArray[np.int_]:
        """Limiting the number of labeled samples during training.

        Args:
            indice: Training indices of self.data and self.label to be limited.
        Returns:
            A limited training indices.
        """
        
        if size is None:
            size = int(self.config.runner_option.limited_training_sample)
        
        _, indice, _, _ = train_test_split(np.array(indice).reshape((-1, 1)), 
                                                        indice, 
                                                        test_size=size, 
                                                        random_state=self.random_seed, 
                                                        stratify=self.label[indice].to_numpy())
        
        return indice.ravel()

    def get_hparams(self) -> Dict[str, Any]:
        """Search the best hyperparameters for a given model.

        Search the best hyperparameters for a given model using optuna. If the experiment using k-fold cross validation,
        optuna also use k-fold cross validation.
        Returns:
            The best hyperparameters.
        """
        print("Get parameters ...")
        
        def objective_cv(trial : optuna.trial.Trial) -> float:
            """Objective function of optuna with k-fold cross validation.

            Args:
                trial: A object which returns hyperparameters of a model of hyperparameter search trial.
            
            Returns:
                Average score of hyperparameters over k-fold.
            """
            fold = StratifiedKFold(n_splits=self.config.KFold.n_splits, shuffle=True, random_state=self.random_seed)
            scores = []
            
            if self.config.runner_option.feature_corruption > 0:
                self.base_data = self.data.copy(deep=True)
                self.corruption_len = int(self.config.runner_option.feature_corruption * self.data.shape[1])
            
            for fold_idx, (train_idx, test_idx) in enumerate(fold.split(self.data.loc[self.label.index].to_numpy(), self.label.to_numpy())):
                train_idx = self.data.loc[self.label.index].index[train_idx]
                
                if self.config.runner_option.feature_corruption > 0:
                    # self.base_data = self.data.copy(deep=True)
                    
                    random_idx = np.random.randint(0, len(train_idx))
                    random_idx = train_idx[random_idx]
                    random_sample = torch.tensor(torch.from_numpy(self.data.loc[random_idx].values), dtype=torch.float)
                    
                    corruption_mask = torch.zeros_like(torch.from_numpy(self.data.loc[train_idx].values), dtype=torch.bool)
                    for i in range(len(train_idx)):
                        corruption_idx =  torch.randperm(self.data.values.shape[1])[: self.corruption_len]
                        corruption_mask[i, corruption_idx] = True
                    
                    # dt = self.data.values
                    # print(dt[train_idx])
                    # dt[train_idx] = torch.where(corruption_mask, random_sample, torch.from_numpy(self.data.loc[train_idx].values)).numpy()
                    # self.data = pd.DataFrame(dt, columns=self.data.columns, index=self.data.index)
                    self.data.loc[train_idx] = torch.where(corruption_mask, random_sample, torch.from_numpy(self.data.loc[train_idx].values)).numpy()
                    
                if self.config.runner_option.limited_training_sample:
                    train_idx = self.get_limited_samples_idx(train_idx)
                test_idx = self.data.loc[self.label.index].index[test_idx]
                if self.config.runner_option.limited_validation_sample:
                    test_idx = self.get_limited_samples_idx(test_idx, int(self.config.runner_option.limited_validation_sample))
                if self.config.self_training.alpha > 0 and self.priors[fold_idx] is None:
                    self.init_prior(train_idx, fold_idx)
                score = self.objective(trial, train_idx, test_idx, fold_idx)
                scores.append(score)

                if hasattr(self, 'pretrained_model_path'):
                    self.clear_pretrained_model()
                    
                if self.config.runner_option.feature_corruption > 0:
                    self.data = self.base_data.copy(deep=True)
            
            return np.mean(scores)
        
        if self.config.runner_option.load_hparams:
            with open(self.config.model.hparams, 'rb') as f:
                hparams = pickle.load(f)
            # hparams = torch.load(self.config.model.hparams)
        else:
            optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
            opt = optuna.create_study(direction=self.config.optuna.direction,sampler=optuna.samplers.TPESampler(seed=self.random_seed))

            if self.config.runner_option.use_CV:
                opt.optimize(objective_cv, n_trials=self.config.optuna.n_trials)
            else:
                opt.optimize(self.objective, n_trials=self.config.optuna.n_trials)

            trial = opt.best_trial
            hparams = dict(trial.params.items())

            if self.config.runner_option.save_best_hparams:
                self.save_hparams(hparams)
        
        self.logger.info("Best Parameters")
        self.logger.info(hparams)

        return hparams
    
    def temporary_save_model(self, 
                            model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                            fold_idx: int, 
                            return_model: bool = True
        ) -> Union[Tuple[str, Union[GBDTs, PTs, DreamquarkTabNet, Module]], str]:
        """Saves the model for temporary using.

        Args:
            model: The model that have to be saved.
            saving_path: The path where to save the given model.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.

        Returns:
            The saving path.
        """
        saving_path = f'temporary_ckpt_data/{self.config.data.target}/{self.config.runner_option.model}/{self.config.runner_option.random_seed}/{self.start_time}'
        gc.disable()
        model_path = self.save_model(model, saving_path, fold_idx)
        gc.enable()
        
        if return_model:
            saved_model = self.load_model(model_path, fold_idx)
            return model_path, saved_model
        else:
            return model_path

    def get_remained(self, 
                    train_idx: NDArray[np.int_], 
                    test_idx: NDArray[np.int_]
        ) -> pd.DataFrame:
        """Generates a dataframe of unlabeled data.

        Args:
            train_idx: Indices of training data in self.data and self.label.
            test_idx: Indices of test data in self.data and self.label.
        Returns:
            A dataframe of unlabeled data.
        """
        self.train_size = len(train_idx)
        if test_idx is not None:
            train_test_idx = np.concatenate((train_idx, test_idx))
        else:
            train_test_idx = train_idx
        train_test_idx = [ i for i in range(len(self.data)) if self.data.index[i] in train_test_idx]
        remained = np.delete(self.data.index.array, train_test_idx)
        remained = self.data.loc[remained, :].copy()
        return remained

    def reset_pl_info(self, fold_idx = None):
        if fold_idx is not None:
            self.pl_precision[fold_idx] = np.nan
            self.pl_recall[fold_idx] = np.nan
            self.pl_f1[fold_idx] = np.nan
            self.pl_acc[fold_idx] = np.nan
        else:
            self.pl_precision = np.nan
            self.pl_recall = np.nan
            self.pl_f1 = np.nan
            self.pl_acc = np.nan
            
    def get_pseudo_label(self, 
                        model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                        threshold: float, 
                        remained: pd.DataFrame, 
                        fold_idx: int, 
                        use_percentiles : bool = True
        ) -> Tuple[NDArray[np.int_], NDArray[np.int_]]:
        """Return pseudo labels which are satisfied the given condition

        Args:
            model: A model object that you want to evaluate.
            threshold: A threshold for fixed threshold pseudo-labeling or curriculum pseudo-labeling.
                If it is for fixed threshold pseudo-labeling, it would be a score threshold, 
                and if it is for curriculum pseudo-labeling , it would be a percentile threshold.
            remained: A DataFrame of unlabeled data.
            fold_idx: The fold idx of the current experiment.
            use_percentiles: Whether to use fixed threshold pseudo-labeling or curriculum pseudo-labeling.
        
        Returns:
            The indices of selected pseudo-labels in DataFrame of unlabeled data and the pseudo-labels of selected unlabeled data.
        """

        # if self.config.runner_option.use_temperature:
        #     _, preds_proba = self.predict(model, X_test = remained, return_logits=True, fold_idx = fold_idx)
        #     # _, preds_proba = self.predict(model, X_test = remained, return_prob=True, fold_idx = fold_idx)
        # else:
        _, preds_proba = self.predict(model, X_test = remained, return_prob=True, fold_idx = fold_idx)

        preds_max = None
            
        if self.config.self_training.alpha > 0 or self.config.runner_option.auto_alpha:
            if self.config.runner_option.use_CV:
                prior = self.priors[fold_idx].get_prior(remained.index.array)
            else:
                prior = self.priors.get_prior(remained.index.array)
            if self.config.runner_option.auto_alpha:
                if self.config.runner_option.use_CV:
                    alpha = self.priors[fold_idx].alpha
                else:
                    alpha = self.priors.alpha
            else:
                alpha = self.config.self_training.alpha
            
            prior = np.nan_to_num(prior)
            # preds_proba = (np.multiply(prior, preds_proba) + preds_proba) / (alpha + 1)
            preds_proba = alpha * np.multiply(prior, preds_proba) + (1 - alpha) * preds_proba
        
        if self.config.runner_option.use_temperature:
            if self.config.data.task == "multiclass":
                preds_max = self.temperature[fold_idx].transform(preds_proba.astype(np.float32)).max(1)
            else:
                preds_max = self.temperature[fold_idx].transform(preds_proba.astype(np.float32))

            # preds_proba = torch.nn.functional.softmax(torch.FloatTensor(preds_logits) / 2.5).numpy()
            
            # preds_max = preds_proba.max(1)

        elif self.config.runner_option.use_histogram_binning:
            if self.config.data.task == "multiclass":
                preds_max = self.hb[fold_idx].transform(preds_proba).max(1)
            else:
                preds_max = self.hb[fold_idx].transform(preds_proba)
        
        elif self.config.runner_option.use_spline_calibrator:
            preds_max = self.splines[fold_idx](preds_proba)
            
        elif self.config.runner_option.use_gaussian_process:
            preds_max = self.gps[fold_idx].predict_proba(preds_proba).max(1)

        if preds_max is None:
            preds_max = preds_proba.max(1)
            
        if use_percentiles:
            pre_threshold = threshold
            threshold = np.percentile(preds_max, threshold * 100)
        
        if self.config.runner_option.data_editor is None:
            pl_idx = (preds_max >= threshold)
            if hasattr(self.config.model, 'batch_size') and (pl_idx.sum() + self.train_size) % self.config.model.batch_size == 1:
                pl_idx[np.argsort(-preds_max)[pl_idx.sum()]] = True
        else:
            if self.config.runner_option.use_CV:
                criteria = self.editors[fold_idx].get_criteria(remained.index.array)
            else:
                criteria = self.editors.get_criteria(remained.index.array)
            
            # if self.config.runner_option.data_editor == "Mahalanobis":
            dist_threshold = np.percentile(criteria, self.config.data_editor.dist_threshold * 100)
            pl_idx = (criteria > dist_threshold)
            pl_idx = np.apply_along_axis(lambda x : np.sum(x), axis=1, arr = pl_idx).reshape((-1, 1))
            pl_idx = np.apply_along_axis(lambda x: True if x > 0 else False, axis=1, arr= pl_idx)
            pl_idx = pl_idx & (preds_max >= threshold)
            
            if hasattr(self.config.model, 'batch_size') and (pl_idx.sum() + self.train_size) % self.config.model.batch_size == 1:
                min_conf = min(preds_max[pl_idx])
                del_indice = np.where(min_conf == preds_max)[0]
                for idx in del_indice:
                    if pl_idx[idx] == True:
                        pl_idx[idx] = False
                        break
            
        preds = preds_proba.argmax(1)[pl_idx]
        



        return pl_idx, preds
    
    def naive_self_training(self, 
                            model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                            train_idx: NDArray[np.int_], 
                            test_idx: NDArray[np.int_], 
                            hparams: Dict[str, Any], 
                            fold_idx: int = 0
        ) -> Union[GBDTs, PTs, DreamquarkTabNet, Module]:
        """Pseudo-Labeling according to the Lee et al. (https://www.kaggle.com/blobs/download/forum-message-attachment-files/746/pseudo_label_final.pdf).
            Use all of pseudo-labels of unlabeled data.
        Args:
            model: A model to train
            train_idx: Indices of training data in self.data and self.label.
            test_idx: Indices of test data in self.data and self.label.
            hparams: Parameters of a model.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.

        Returns:
            The best model during the self-training cycle.
        """
        model = self.fit_model(model, train_idx, test_idx, hparams=hparams, fold_idx = fold_idx)
        
        remained = self.get_remained(train_idx, test_idx)

        preds = self.predict(model, test_idx = test_idx, fold_idx = fold_idx)
        pre_score = self.get_score(self.label.loc[test_idx], preds)

        if pre_score >= 1e-10:
            # preds = self.predict(model, X_test = remained, fold_idx = fold_idx)
            _ , preds = self.get_pseudo_label(model, -1, remained, fold_idx, use_percentiles = False)
            score = pre_score
            pre_score -= 1e-10

            iter = 0
            while(score > pre_score):
                pre_score = score
                saving_path = self.temporary_save_model(model, fold_idx, False)
                model = self.get_model(hparams)

                model = self.fit_model(model, train_idx, test_idx, pseudo_data = remained, pseudo_label = preds, hparams = hparams, fold_idx = fold_idx)

                preds = self.predict(model, test_idx = test_idx, fold_idx = fold_idx)
                score = self.get_score(self.label.loc[test_idx], preds)

                # preds = self.predict(model, X_test = remained, fold_idx = fold_idx)
                _ , preds = self.get_pseudo_label(model, -1, remained, fold_idx, use_percentiles = False)
                iter += 1
        
        else:
            print("Skipped the self-training cycle because the score of the base model is almost 0")

        best_model = self.load_model(saving_path, fold_idx)

        if self.model_name == 'XGB':
            saving_path += '.json'
        elif self.model_name == 'LGBM':
            saving_path += '.pkl'

        self.clear_model_cache(saving_path)

        if hasattr(self, 'pretrained_model_path'):
            self.clear_pretrained_model()

        return best_model
    
    def get_alpha(self, model, test_idx, fold_idx, threshold = None, use_percentiles = False):
        
        preds, preds_proba = self.predict(model, test_idx = test_idx, fold_idx = fold_idx, return_prob=True)
        # best_alpha, min_ece = -1, 987654321
        best_alpha = 0
        # n_bins = 10
        # ece = ECE(n_bins)
        # uncalibrated_score = ece.measure(np.array(self.label.loc[test_idx].values, dtype=np.float64), preds_proba)
        pl_idx = (preds_proba.max(1) >= self.config.self_training.threshold)
        
        best_score = self.get_score(self.label.loc[test_idx].values[pl_idx], preds_proba.argmax(1)[pl_idx])
        
            
        for a in np.linspace(0, 2.5, 1000):
            if self.config.runner_option.use_CV:
                prior = self.priors[fold_idx].get_prior(test_idx)
            else:
                prior = self.priors.get_prior(test_idx)
            prior = prior * a
            _preds_proba = (np.multiply(prior, preds_proba) + preds_proba) / (a + 1)
            
            preds_max = _preds_proba.max(1)
            
            if use_percentiles:
                _threshold = np.percentile(preds_max, threshold * 100)
            else:
                _threshold = threshold
            
            pl_idx = (preds_max >= _threshold)
            
            score = self.get_score(self.label.loc[test_idx].values[pl_idx], _preds_proba.argmax(1)[pl_idx])
            
            if best_score < score:
                best_score = score
                best_alpha = a
        if self.config.runner_option.use_CV:
            self.priors[fold_idx].alpha = best_alpha
        else:
            self.priors.alpha = best_alpha
        print("Best Alpha :", best_alpha)
        return
    
    def fixed_threshold_self_training(self, 
                                        model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                                        train_idx: NDArray[np.int_], 
                                        test_idx: NDArray[np.int_], 
                                        hparams: Dict[str, Any], 
                                        fold_idx: int = 0
        ) -> Union[GBDTs, PTs, DreamquarkTabNet, Module]:
        """Pseudo-Labeling according to the Oliver et al. (https://proceedings.neurips.cc/paper/2018/hash/c1fea270c48e8079d8ddf7d06d26ab52-Abstract.html).
            Use pseudo-labels that have a higher confidence score over a certain threshold.
        Args:
            model: A model to train
            train_idx: Indices of training data in self.data and self.label.
            test_idx: Indices of test data in self.data and self.label.
            hparams: Parameters of a model.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.

        Returns:
            The best model during the self-training cycle.
        """
        threshold = self.config.self_training.threshold

        model = self.fit_model(model, train_idx, test_idx, hparams=hparams, fold_idx = fold_idx)
        saving_path = self.temporary_save_model(model, fold_idx, False)

        remained = self.get_remained(train_idx, test_idx)

        preds = self.predict(model, test_idx = test_idx, fold_idx = fold_idx)
        if test_idx is None:
            pre_score = self.get_score(self.y_valid, preds)    
        else:
            pre_score = self.get_score(self.label.loc[test_idx], preds)

        # if self.config.runner_option.auto_alpha:
        #     self.get_alpha(model, test_idx, fold_idx)
        if pre_score >= 1e-10:
            
            preds = self.predict(model, X_test = remained, fold_idx = fold_idx)

            score = pre_score
            pre_score -= 1e-10

            if 'percentiles' in self.config.runner_option.self_training:
                use_percentiles = True
            else:
                use_percentiles = False

            iter = 0
            while(score > pre_score):
                
                if self.config.runner_option.auto_alpha:
                    self.get_alpha(model, test_idx, fold_idx, threshold=threshold)
                    
                pre_score = score
                pl_idx, preds = self.get_pseudo_label(model, threshold, remained, fold_idx, use_percentiles = use_percentiles)
                
                saving_path = self.temporary_save_model(model, fold_idx, False)
                model = self.get_model(hparams)
                
                model = self.fit_model(model, train_idx, test_idx, pseudo_data = remained[pl_idx], pseudo_label = preds, hparams = hparams, fold_idx = fold_idx)

                preds = self.predict(model, test_idx = test_idx, fold_idx = fold_idx)
                if test_idx is None:
                    score = self.get_score(self.y_valid, preds)    
                else:
                    score = self.get_score(self.label.loc[test_idx], preds)

                iter += 1
                
                
        else:
            print("Skipped the self-training cycle because the score of the base model is almost 0")


        best_model = self.load_model(saving_path, fold_idx)

        if self.model_name == 'XGB':
            saving_path += '.json'
        elif self.model_name == 'LGBM':
            saving_path += '.pkl'
        self.clear_model_cache(saving_path)

        if hasattr(self, 'pretrained_model_path'):
            self.clear_pretrained_model()

        return best_model

    def curriculum_self_training(self, 
                                model: Union[GBDTs, PTs, DreamquarkTabNet, Module], 
                                train_idx: NDArray[np.int_], 
                                test_idx: NDArray[np.int_], 
                                hparams: Dict[str, Any], 
                                fold_idx: int = 0
        ) -> Union[GBDTs, PTs, DreamquarkTabNet, Module]:
        """Pseudo-Labeling according to the Cascante-Bonilla et al. (https://ojs.aaai.org/index.php/AAAI/article/view/16852).
            Use pseudo-labels that have a higher confidence score over a certain percentile.
        Args:
            model: A model to train
            train_idx: Indices of training data in self.data and self.label.
            test_idx: Indices of test data in self.data and self.label.
            hparams: Parameters of a model.
            fold_idx: A fold index that denotes which fold under the given k-fold cross validation.

        Returns:
            The best model during the self-training cycle.
        """
        delta = self.config.self_training.delta
        percentiles_holder = 1.00 - delta

        model = self.fit_model(model, train_idx, test_idx, hparams=hparams, fold_idx = fold_idx)
        saving_path = self.temporary_save_model(model, fold_idx, False)
        remained = self.get_remained(train_idx, test_idx)

        preds = self.predict(model, test_idx = test_idx, fold_idx = fold_idx)
        if test_idx is None:
            best_score = self.get_score(self.y_valid, preds)    
        else:
            best_score = self.get_score(self.label.loc[test_idx], preds)
        
        if best_score >= 1e-10:

            iter = 0
            while(percentiles_holder >= 0):
                
                if self.config.runner_option.auto_alpha:
                    self.get_alpha(model, test_idx, fold_idx, threshold = percentiles_holder, use_percentiles=True)
                    
                pl_idx, preds = self.get_pseudo_label(model, percentiles_holder, remained, fold_idx)
                
                model = self.get_model(hparams)

                model = self.fit_model(model, train_idx, test_idx, pseudo_data = remained[pl_idx], pseudo_label = preds, hparams = hparams, fold_idx = fold_idx)

                preds = self.predict(model, test_idx = test_idx, fold_idx = fold_idx)
                if test_idx is None:
                    score = self.get_score(self.y_valid, preds)    
                else:
                    score = self.get_score(self.label.loc[test_idx], preds)

                
                iter += 1

                if best_score < score:
                    saving_path = self.temporary_save_model(model, fold_idx, False)
                    best_score = score
                
                percentiles_holder -= delta
                
                

        else:
            print("Skipped the self-training cycle because the score of the base model is almost 0")

        best_model = self.load_model(saving_path, fold_idx)

        if self.model_name == 'XGB':
            saving_path += '.json'
        elif self.model_name == 'LGBM':
            saving_path += '.pkl'
        self.clear_model_cache(saving_path)

        if hasattr(self, 'pretrained_model_path'):
            self.clear_pretrained_model()
            
        return best_model


        
    def report(self) -> None:
        """Report a result of trained model.
        """
        if not os.path.exists('figures'):
            os.mkdir('figures')
        if self.config.runner_option.use_CV:
            report_path = self.config.model.path.split('/')[-1].split('-fold')[0]
        else:
            report_path = self.config.model.path.split('/')[-1]

        if not os.path.exists(f'report/{report_path}/figure'):
            os.mkdir(f'report/{report_path}/figure')
                

        if self.config.runner_option.use_borutashap:
            
            if self.config.runner_option.use_CV:
                model = self.load_model(self.config.model.path, self.config.KFold.n_splits - 1)
            else:
                model = self.load_model(self.config.model.path)

            Feature_Selector = BorutaShap(
                model = model,
                importance_measure='shap',
                classification=True,
            )
            if hasattr(self.config.data, 'valid_size'):
                self.set_random_seed(self.random_seed)
                train_idx, _ = self.train_test_split_idx()
                if self.config.runner_option.limited_training_sample:
                    train_idx = self.get_limited_samples_idx(self.data.index)
            else:
                # train_idx = self.label.index
                fold = StratifiedKFold(n_splits=self.config.KFold.n_splits, shuffle=True, random_state=self.random_seed)
                for fold_idx, (train_idx, test_idx) in enumerate(fold.split(self.data.loc[self.label.index].to_numpy(), self.label.to_numpy())):
                    train_idx = self.data.loc[self.label.index].index[train_idx]
                
            Feature_Selector.fit(X = self.data.loc[train_idx], y = self.label.loc[train_idx], n_trials=100, sample=False,
                                    train_or_test = 'test', normalize=False, verbose = True, random_state = self.random_seed,
                                    stratify=self.label.loc[train_idx])

            Feature_Selector.plot(which_features='all')
            plt.savefig(f"report/{plt_path.split('-fold')[0]}/figure/feature_importance_all.png", dpi=300)

            Feature_Selector.plot(which_features='accepted')
            plt.savefig(f"report/{plt_path.split('-fold')[0]}/figure/feature_importance_accepted.png", dpi=300)

            Feature_Selector.plot(which_features='tentative')
            plt.savefig(f"report/{plt_path.split('-fold')[0]}/figure/feature_importance_tentative.png", dpi=300)

            Feature_Selector.plot(which_features='rejected')
            plt.savefig(f"report/{plt_path.split('-fold')[0]}/figure/feature_importance_rejected.png", dpi=300)

            print("Accepted Features :", Feature_Selector.accepted)
            print("Tentative Features :", Feature_Selector.tentative)
            print("Rejected Features :", Feature_Selector.rejected)

            Feature_Selector.results_to_csv(filename=f"report/{plt_path.split('-fold')[0]}/feature_importance")
