from runners.runner import *
from abc import ABC, abstractmethod
from copy import deepcopy
import pandas as pd
import logging
from types import SimpleNamespace
from typing import Tuple, Union
from typing import List, Dict, Any
from numpy.typing import NDArray
from runners.model_types import GBDTs
from netcal.scaling import TemperatureScaling
from netcal.binning import HistogramBinning
from netcal.metrics import ECE

from misc.spline_calibrator import SplineCalibrator
from misc.gp_calibrator import GPCalibration
class GBDTRunner(Runner, ABC):

    def __init__(self, 
                config: SimpleNamespace, 
                data: pd.DataFrame, 
                labels: pd.Series, 
                numeric_cols: List[str], 
                logger: logging.Logger
        ) -> None:
        super().__init__(config, data, labels, numeric_cols, logger)
        

    @abstractmethod
    def get_model(self, 
                hparams: Dict[str, Any]
        ):
        pass

    @abstractmethod
    def save_model(self, 
                    model: GBDTs, 
                    saving_path: str, 
                    fold_idx: int
        ) -> str:
        pass  

    def clear_model_cache(self, 
                        path: str
        ) -> None:
        for fold_idx in range(self.config.KFold.n_splits):
            if os.path.exists(path % fold_idx):
                os.remove(path % fold_idx)

    def predict(self, 
                model: GBDTs, 
                test_idx: np.array = None, 
                X_test: pd.DataFrame = None, 
                return_prob: bool = False, 
                fold_idx: int = 0,
                is_test: bool = False,
        ) -> Union[Tuple[NDArray[np.int_], NDArray[np.float_]], NDArray[np.int_]]:
        
        if is_test:
            X_test = self.X_test
        elif test_idx is not None:
            _, (X_test, _) = self.get_train_test_from_idx(test_idx = test_idx)
        elif X_test is None:
            X_test = self.X_valid
            
        if test_idx is not None:
                            
            if self.config.runner_option.use_temperature:
                if not hasattr(self, 'temperature'):
                    self.temperature = [None for _ in range(self.config.KFold.n_splits)]
                if self.temperature[fold_idx] is None:
                    preds = model.predict_proba(X_test)
                    preds_proba = preds
                    self.temperature[fold_idx] = TemperatureScaling()
                    self.temperature[fold_idx].fit(preds_proba, self.label.loc[test_idx].values)
                    
            elif self.config.runner_option.use_histogram_binning:
                if not hasattr(self, 'n_bins'):
                    self.hb = [None for _ in range(self.config.KFold.n_splits)]
                if self.hb[fold_idx] is None:
                    preds = model.predict_proba(X_test)
                    preds_proba = preds
                    min_b, min_ece = 7, 987654321
                    
                    n_bins = 10
                    ece = ECE(n_bins)
                    
                    for b in range(7, 16):
                        hb = HistogramBinning(bins = b)
                        hb.fit(preds_proba, self.label.loc[test_idx].values)
                        calibrated = hb.transform(preds_proba)
                        
                        calibrated_score = ece.measure(calibrated, self.label.loc[test_idx].values)
                        if calibrated_score < min_ece:
                            min_ece = calibrated_score
                            min_b = b
                            self.hb[fold_idx] = hb
                    
                    
                        
                    
            elif self.config.runner_option.use_spline_calibrator:
                if self.splines[fold_idx] is None:
                    preds_proba = model.predict_proba(X_test)
                    label = self.label.loc[test_idx].values
                    
                    
                    self.splines[fold_idx] = SplineCalibrator(preds_proba, label)
                    
                    
            elif self.config.runner_option.use_gaussian_process:
                if self.gps[fold_idx] is None:
                    preds_proba = model.predict_proba(X_test)
                    label = self.label.loc[test_idx].values
                    
                    
                    self.gps[fold_idx] = GPCalibration(n_classes=self.config.data.out_dim, inf_mean_approx = True)
                    self.gps[fold_idx].fit(preds_proba, label)
                    
                    
        if return_prob:
            return model.predict(X_test), model.predict_proba(X_test)
        else:
            return model.predict(X_test)

