from abc import ABC, abstractmethod
from runners.runner import Runner
from pytorch_tabular.config import (
    DataConfig,
    TrainerConfig,
    ExperimentRunManager
)
from pytorch_tabular import TabularModel
from pytorch_tabular.utils import get_balanced_sampler, get_class_weighted_cross_entropy
import optuna
import torch
import os
import shutil
import numpy as np
from pytorch_tabular import TabularDatamodule
from torch.utils.data import DataLoader
# from overridings import *

from sklearn.metrics import f1_score
from copy import deepcopy

import pandas as pd
import logging
from types import SimpleNamespace
from pytorch_tabular.config import OptimizerConfig
from pytorch_tabular.config import ModelConfig
from typing import List, Dict, Tuple, Union, Any
from numpy.typing import NDArray
from runners.model_types import PTs
from netcal.scaling import TemperatureScaling
from netcal.binning import HistogramBinning
from netcal.metrics import ECE

from misc.spline_calibrator import SplineCalibrator
from misc.gp_calibrator import GPCalibration
class PT_Runner(Runner, ABC):

    def __init__(self, 
                config: SimpleNamespace, 
                data: pd.DataFrame, 
                labels: pd.Series, 
                logger: logging.Logger, 
                numeric_cols: List[str], 
                category_cols: List[str]
        ) -> None:
        super().__init__(config, data, labels, numeric_cols, logger)

        self.category_cols = category_cols
        
        self.data_config = DataConfig(
            target=['target'],
            continuous_cols=numeric_cols,
            categorical_cols=category_cols,
            continuous_feature_transform=None,
            normalize_continuous_features=True,
        )

        self.saved_model_path = f'temporary_ckpt_data/{self.start_time}-{self.config.data.target}-{self.config.runner_option.model}'
        
        if self.config.data.metric == "accuracy_score":
            self.metric = 'accuracy'
        elif self.config.data.metric == "balanced_accuracy_score":
            self.metric = "accuracy"
        elif self.config.data.metric == 'f1_score':
            self.metric = 'f1_score'

        if self.config.optuna.direction == 'maximize':
            metric_mode = 'max'
        elif self.config.optuna.direction == 'minimize':
            metric_mode = 'min'
        
        
        self.trainer_config = TrainerConfig(
                # gpus=config.model.gpus,
                accelerator = 'gpu',
                # devices_list = config.model.gpus,
                devices = -1,
                auto_select_gpus = config.model.auto_select_gpus,
                fast_dev_run=config.model.fast_dev_run, 
                max_epochs=config.model.max_epochs, 
                batch_size=config.model.batch_size,
                early_stopping_patience = config.model.early_stopping_patience,
                gradient_clip_val = 1,
                checkpoints='valid_' + self.metric,
                checkpoints_path = self.saved_model_path,
                checkpoints_mode = metric_mode,
                early_stopping='valid_' + self.metric,
                early_stopping_mode = metric_mode,
                deterministic = True,
                seed = self.random_seed,
                # trainer_kwargs = {'logger' : False}
                # trainer_kwargs = dict(strategy=config.model.distributed_backend)
                )
        # print(config.model.gpus, self.trainer_config)
        

    @abstractmethod
    def get_model_config(self, 
                        hparams: Dict[str, Any]
        ) -> ModelConfig:
        pass

    @abstractmethod
    def get_optimizer_config(self, 
                            hparams: Dict[str, Any]
        ) -> OptimizerConfig:
        pass
        
    def __del__(self) -> None:

        if os.path.exists(self.saved_model_path):
            shutil.rmtree(self.saved_model_path)

    def get_model(self, 
                hparams: Dict[str, Any]
        ) -> PTs:
        model_config = self.get_model_config(hparams)
        optimizer_config = self.get_optimizer_config(hparams)
        model = TabularModel(
            data_config=self.data_config,
            model_config=model_config,
            optimizer_config = optimizer_config, 
            trainer_config=self.trainer_config,
        )
        return model
    
    def load_model(self, 
                    model_path: str, 
                    fold_idx: int = None
        ) -> PTs:
        
        if fold_idx is not None:
            model = TabularModel.load_from_checkpoint(model_path % fold_idx, strict=False)
        else:
            model = TabularModel.load_from_checkpoint(model_path, strict=False)
            
        return model

    def fit_model(self, 
                    model: PTs, 
                    train_idx: NDArray[np.int_], 
                    test_idx: NDArray[np.int_], 
                    hparams: Dict[str, Any] = None, 
                    pseudo_data: pd.DataFrame = None, 
                    pseudo_label: NDArray[np.int_] = None, 
                    fold_idx: int = 0
        ) -> PTs:
        del hparams

        if test_idx is not None:
            train_test_idx = np.concatenate((train_idx, test_idx))

            data = self.data.loc[train_test_idx,:]
            data['target'] = self.label[train_test_idx]
            
            train = data.loc[train_idx,:]
            valid = data.loc[test_idx,:]
        else:
            train = self.data.loc[train_idx, :]
            train["target"] = self.label[train_idx]
            
            valid = self.X_valid.copy()
            valid["target"] = self.y_valid
        
        
        
        
        if pseudo_data is not None and pseudo_label is not None:
            pseudo_data['target'] = pseudo_label
            train = train.append(pseudo_data)

        if self.config.model.use_balanced_sampler:
            sampler = get_balanced_sampler(train['target'].values.ravel())
        else:
            sampler = torch.utils.data.RandomSampler(train['target'].values.ravel())
        fit_hparams = {"train" : train, "validation" : valid, "train_sampler" : sampler, "min_epochs" : 10, "seed" : self.random_seed}
        if self.config.model.use_weighted_loss:
            fit_hparams["loss"] = get_class_weighted_cross_entropy(train["target"].values.ravel(), mu=self.config.model.mu)

        model.fit(**fit_hparams)
        
        return model

    def save_model(self, 
                    model: PTs, 
                    saving_path: str = None, 
                    fold_idx: int = None
        ) -> str:
        if saving_path == None:
            saving_path = f"model/{self.start_time}-{self.config.data.target}-{self.config.runner_option.model}"
        if fold_idx != None:
            saving_path += '-fold%d'
        model.save_model(saving_path % fold_idx)
        return saving_path

    def rename_prob_cols(self, 
                        data: pd.DataFrame
        ) -> pd.DataFrame:

        for y in self.label.unique():
            if ("%d.0_probability" % y) in data.columns:
                data.rename(columns={('%d.0_probability' % y) : ('%d_probability' % y)}, inplace=True)
        
        return data
    
    def clear_model_cache(self, 
                            path: str
        ) -> None:
        for fold_idx in range(self.config.KFold.n_splits):
            if os.path.exists(path % fold_idx):
                shutil.rmtree(path % fold_idx)
    
    def predict(self, 
                model: PTs, 
                test_idx: NDArray[np.int_] = None, 
                X_test: pd.DataFrame = None, 
                return_prob: bool = False, 
                fold_idx: int = 0,
                return_logits = False,
                is_test: bool = False
        ) -> Union[Tuple[NDArray[np.int_], NDArray[np.float_]], NDArray[np.int_]]:

        prob_cols = []
        for y in self.label.unique():
            prob_cols.append("%d_probability" % y) 
            
        if test_idx is not None:
            if is_test:
                X_test = self.X_test.copy()
                X_test['target'] = self.y_test
            else:
                data = self.data.loc[self.label.index,:]
                X_test = data.loc[test_idx, :]
                X_test['target'] = self.label.loc[test_idx]
            if self.config.runner_option.use_temperature:
                if not hasattr(self, 'temperature'):
                    self.temperature = [None for _ in range(self.config.KFold.n_splits)]
                if self.temperature[fold_idx] is None:
                    preds = model.predict(X_test)
                    preds_proba = preds[prob_cols].values
                    self.temperature[fold_idx] = TemperatureScaling()
                    self.temperature[fold_idx].fit(preds_proba, self.label.loc[test_idx].values)
                    
                    
            elif self.config.runner_option.use_histogram_binning:
                if not hasattr(self, 'n_bins'):
                    self.hb = [None for _ in range(self.config.KFold.n_splits)]
                if self.hb[fold_idx] is None:
                    preds = model.predict(X_test)
                    preds_proba = preds[prob_cols].values
                    min_b, min_ece = 7, 987654321
                    
                    n_bins = 10
                    ece = ECE(n_bins)
                    
                    
                    for b in range(7, 16):
                        hb = HistogramBinning(bins = b)
                        hb.fit(preds_proba, self.label.loc[test_idx].values)
                        calibrated = hb.transform(preds_proba)
                        
                        calibrated_score = ece.measure(calibrated, self.label.loc[test_idx].values)
                        if calibrated_score < min_ece:
                            min_ece = calibrated_score
                            min_b = b
                            self.hb[fold_idx] = hb
                    
                    
                        
                    
            elif self.config.runner_option.use_spline_calibrator:
                if self.splines[fold_idx] is None:
                    preds = model.predict(X_test)
                    preds_proba = preds[prob_cols].values
                    label = self.label.loc[test_idx].values
                    
                    
                    self.splines[fold_idx] = SplineCalibrator(preds_proba, label)
                    
            
            elif self.config.runner_option.use_gaussian_process:
                if self.gps[fold_idx] is None:
                    preds = model.predict(X_test)
                    preds_proba = preds[prob_cols].values
                    label = self.label.loc[test_idx].values
                    
                    
                    self.gps[fold_idx] = GPCalibration(n_classes=self.config.data.out_dim, inf_mean_approx = True)
                    self.gps[fold_idx].fit(preds_proba, label)
                    
                    
        
        X_test['target'] = [0 for i in range(len(X_test))]
        
        preds = model.predict(X_test, ret_logits = return_logits)
        preds = self.rename_prob_cols(preds)

        if return_prob:
            return preds['prediction'], preds[prob_cols].values
        elif return_logits:
            return preds['prediction'], preds[['logits_0', 'logits_1']].values
        else:
            return preds['prediction']
    

    
