from pathlib import Path
import os

import catboost as cat
import lightgbm as lgb
import numpy as np
import xgboost as xgb
from sklearn.exceptions import NotFittedError
import inspect

from TabZilla.models.basemodel import BaseModel

"""
    Define all Gradient Boosting Decision Tree Models:
    XGBoost, CatBoost, LightGBM
"""

"""
    XGBoost (https://xgboost.readthedocs.io/en/stable/)
"""


class XGBoost(BaseModel):

    # TabZilla: add default number of boosting rounds
    # default_epochs = 500

    def __init__(self, params, args):
        # Filter out non-XGBoost parameters
        xgb_params = {k: v for k, v in params.items() if k in [
            'max_depth', 'alpha', 'lambda', 'eta', 'learning_rate',
            'verbosity', 'device', 'tree_method', 'objective', 'eval_metric'
        ]}
        super().__init__(xgb_params, args)

        if args.use_gpu:
            xgb_params.update({
                "device": "cuda",
                "tree_method": "hist"
            })

        if args.objective == "regression":
            self.params["objective"] = "reg:squarederror"
            self.params["eval_metric"] = "rmse"
        elif args.objective == "classification":
            self.params["objective"] = "multi:softprob"
            self.params["num_class"] = args.num_classes
            self.params["eval_metric"] = "mlogloss"
            self.n_classes_ = args.num_classes  # Add this line
        elif args.objective == "binary":
            self.params["objective"] = "binary:logistic"
            self.params["eval_metric"] = "auc"
            self.n_classes_ = 2  # Add this line for binary classification

    def fit(self, X, y, X_val=None, y_val=None):
        train = xgb.DMatrix(X, label=y)
        val = xgb.DMatrix(X_val, label=y_val)
        eval_list = [(val, "eval")]
        self.model = xgb.train(
            self.params,
            train,
            num_boost_round=self.args.epochs,
            evals=eval_list,
            early_stopping_rounds=self.args.early_stopping_rounds,
            verbose_eval=self.args.logging_period,
        )

        return [], []

    def predict(self, X):
        X = xgb.DMatrix(X)
        return super().predict(X)

    def predict_proba(self, X):
        if not self.model:
            raise NotFittedError("This XGBoost instance is not fitted yet.")
        
        # Convert NumPy array to DMatrix before prediction
        if isinstance(X, np.ndarray):
            X_dmatrix = xgb.DMatrix(X)
        elif isinstance(X, xgb.DMatrix):
            X_dmatrix = X  # Already a DMatrix
        else:
            # Handle other potential types or raise an error
            raise TypeError(f"Unsupported data type for XGBoost predict_proba: {type(X)}. Expected numpy.ndarray or xgb.DMatrix.")

        raw_probabilities = self.model.predict(X_dmatrix)

        # Assuming raw_probabilities are P(Y=1) for binary classification
        # and have shape (n_samples,)
        if raw_probabilities.ndim == 1:
            probabilities = np.vstack([1 - raw_probabilities, raw_probabilities]).T
        elif raw_probabilities.ndim == 2 and raw_probabilities.shape[1] == 1:  # if it returns [[p1], [p2], ...]
            raw_probabilities = raw_probabilities.flatten()
            probabilities = np.vstack([1 - raw_probabilities, raw_probabilities]).T
        elif raw_probabilities.ndim == 2 and raw_probabilities.shape[1] == self.n_classes_:  # For multi-class
            probabilities = raw_probabilities
        else:
            raise ValueError(f"Unexpected shape for XGBoost probabilities: {raw_probabilities.shape}")
            
        return probabilities

    @classmethod
    def define_trial_parameters(cls, trial, args):
        params = {
            "max_depth": trial.suggest_int("max_depth", 2, 12, log=True),
            "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True),
            "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True),
            "eta": trial.suggest_float("eta", 0.01, 0.3, log=True),
        }
        return params

    # TabZilla: add function for seeded random params and default params
    @classmethod
    def get_random_parameters(cls, seed):
        rs = np.random.RandomState(seed)
        params = {
            "max_depth": int(np.round(np.power(2, rs.uniform(1, np.log2(12))))),
            "alpha": np.power(10, rs.uniform(-8, 0)),
            "lambda": np.power(10, rs.uniform(-8, 0)),
            "eta": 3.0 * np.power(10, rs.uniform(-2, -1)),
        }
        return params

    @classmethod
    def default_parameters(cls):
        params = {
            "max_depth": 5,
            "alpha": 1e-4,
            "lambda": 1e-4,
            "eta": 0.08,
        }
        return params


"""
    CatBoost (https://catboost.ai/)
"""


class CatBoost(BaseModel):

    # TabZilla: add default number of boosting rounds
    # default_epochs = 500

    def __init__(self, params, args):
        super().__init__(params, args)

        # Get valid parameter names for CatBoostClassifier's constructor
        valid_constructor_params = set(inspect.signature(cat.CatBoostClassifier).parameters.keys())

        # Prepare a dictionary for parameters to be passed to CatBoostClassifier
        catboost_final_params = {}

        # Make a copy of input params to modify
        current_params = self.params.copy()

        # 1. Handle specific name mappings
        if 'epochs' in current_params:
            if 'iterations' in valid_constructor_params:
                catboost_final_params['iterations'] = current_params.pop('epochs')
            else:  # Should not happen if 'iterations' is standard
                current_params.pop('epochs')
        
        if 'patience' in current_params:
            if 'early_stopping_rounds' in valid_constructor_params:
                catboost_final_params['early_stopping_rounds'] = current_params.pop('patience')
            else:  # Should not happen if 'early_stopping_rounds' is standard
                current_params.pop('patience')

        # 2. Filter remaining params against valid CatBoost constructor params
        for param_name, param_value in current_params.items():
            if param_name in valid_constructor_params:
                catboost_final_params[param_name] = param_value

        # 3. Set/override specific arguments from self.args or defaults

        # train_dir
        if hasattr(self.args, 'dataset') and self.args.dataset:
            dataset_identifier = self.args.dataset
        elif hasattr(self.args, 'task') and self.args.task is not None:
            dataset_identifier = "task_" + str(self.args.task)
        else:
            dataset_identifier = "unknown_run"
        catboost_train_dir = os.path.join("output", "CatBoost", dataset_identifier, "catboost_info")
        os.makedirs(catboost_train_dir, exist_ok=True)
        catboost_final_params['train_dir'] = catboost_train_dir

        # allow_writing_files
        catboost_final_params['allow_writing_files'] = True
        
        # verbose
        if 'verbose' not in catboost_final_params:  # CatBoost default is 1, often 0 is preferred for less log spam
            catboost_final_params['verbose'] = 0 
        
        # random_seed
        if 'random_seed' not in catboost_final_params:
            catboost_final_params['random_seed'] = self.args.seed if hasattr(self.args, 'seed') else 42
        
        # cat_features (from self.args.cat_idx)
        if 'cat_features' in valid_constructor_params:
            if hasattr(self.args, 'cat_idx') and self.args.cat_idx is not None:
                catboost_final_params["cat_features"] = self.args.cat_idx
            elif 'cat_features' not in catboost_final_params:  # Only default to None if not already provided
                catboost_final_params["cat_features"] = None
        
        self.model = cat.CatBoostClassifier(**catboost_final_params)
        self.n_classes_ = None  # Will be set in fit

    def fit(self, X, y, X_val=None, y_val=None):
        # Store n_classes_
        self.n_classes_ = len(np.unique(y))

        eval_set = None
        if X_val is not None and y_val is not None:
            eval_set = [(X_val, y_val)]
        
        cat_features_for_fit = self.model.get_params().get('cat_features')

        self.model.fit(X, y, eval_set=eval_set, cat_features=cat_features_for_fit)
        return self

    def predict(self, X):
        if not hasattr(self.model, 'classes_'):  # or some other check to ensure model is fitted
            raise NotFittedError("This CatBoost instance is not fitted yet.")
        
        return self.model.predict(X)

    def predict_proba(self, X):
        if not hasattr(self.model, 'classes_'):  # or some other check to ensure model is fitted
            raise NotFittedError("This CatBoost instance is not fitted yet.")

        raw_probabilities = self.model.predict_proba(X)
        
        # Ensure probabilities are in the shape (n_samples, n_classes)
        if raw_probabilities.ndim == 1 and self.n_classes_ == 2:
            probabilities = np.vstack([1 - raw_probabilities, raw_probabilities]).T
        else:
            probabilities = raw_probabilities
            
        return probabilities

    @classmethod
    def define_trial_parameters(cls, trial, args):
        params = {
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
            "max_depth": trial.suggest_int("max_depth", 2, 12, log=True),
            "l2_leaf_reg": trial.suggest_float("l2_leaf_reg", 0.5, 30, log=True),
        }
        return params

    # TabZilla: add function for seeded random params and default params
    @classmethod
    def get_random_parameters(cls, seed):
        rs = np.random.RandomState(seed)
        params = {
            "learning_rate": 3.0 * np.power(10, rs.uniform(-2, -1)),
            "max_depth": int(np.round(np.power(2, rs.uniform(1, np.log2(12))))),
            "l2_leaf_reg": 0.5 * np.power(60, rs.uniform(0, 1)),
        }
        return params

    @classmethod
    def default_parameters(cls):
        params = {
            "learning_rate": 0.08,
            "max_depth": 5,
            "l2_leaf_reg": 5,
        }
        return params


"""
    LightGBM (https://lightgbm.readthedocs.io/en/latest/)
"""


class LightGBM(BaseModel):

    # TabZilla: add default number of boosting rounds
    # default_epochs = 500

    def __init__(self, params, args):
        super().__init__(params, args)

        self.params["verbosity"] = -1

        if args.objective == "regression":
            self.params["objective"] = "regression"
            self.params["metric"] = "mse"
        elif args.objective == "classification":
            self.params["objective"] = "multiclass"
            self.params["num_class"] = args.num_classes
            self.params["metric"] = "multiclass"
        elif args.objective == "binary":
            self.params["objective"] = "binary"
            self.params["metric"] = "auc"

    def fit(self, X, y, X_val=None, y_val=None):
        train = lgb.Dataset(X, label=y, categorical_feature=self.args.cat_idx)
        val = lgb.Dataset(X_val, label=y_val, categorical_feature=self.args.cat_idx)
        self.model = lgb.train(
            self.params,
            train,
            num_boost_round=self.args.epochs,
            valid_sets=[val],
            valid_names=["eval"],
            callbacks=[
                lgb.early_stopping(self.args.early_stopping_rounds),
                lgb.log_evaluation(self.args.logging_period),
            ],
            # categorical_feature=self.args.cat_idx,
        )

        return [], []

    def predict_proba(self, X):
        probabilities = self.model.predict(X)

        if self.args.objective == "binary":
            probabilities = probabilities.reshape(-1, 1)
            probabilities = np.concatenate((1 - probabilities, probabilities), 1)

        self.prediction_probabilities = probabilities
        return self.prediction_probabilities

    @classmethod
    def define_trial_parameters(cls, trial, args):
        params = {
            "num_leaves": trial.suggest_int("num_leaves", 2, 4096, log=True),
            "lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),
            "lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True),
        }
        return params

    # TabZilla: add function for seeded random params and default params
    @classmethod
    def get_random_parameters(cls, seed):
        rs = np.random.RandomState(seed)
        params = {
            "num_leaves": int(np.round(np.power(2, rs.uniform(1, 12)))),
            "lambda_l1": np.power(10, rs.uniform(-8, 1)),
            "lambda_l2": np.power(10, rs.uniform(-8, 1)),
            "learning_rate": 3.0 * np.power(10, rs.uniform(-2, 1)),
        }
        return params

    @classmethod
    def default_parameters(cls):
        params = {
            "num_leaves": 512,
            "lambda_l1": 1e-3,
            "lambda_l2": 1e-3,
            "learning_rate": 0.08,
        }
        return params
