from typing import List, Literal, Optional
import numpy as np
from sklearn.multioutput import MultiOutputRegressor
from xgboost.sklearn import XGBRanker, XGBRegressor
from tsbench.config import Config
from tsbench.experiments.metrics import Performance
from tsbench.experiments.tracking import Tracker
from .base import DatasetFeaturesMixin, Surrogate
from .registry import register_surrogate
from .transformers import ConfigTransformer, PerformanceTransformer


@register_surrogate("xgboost")
class XGBoostSurrogate(Surrogate, DatasetFeaturesMixin):
    """
    The XGBoost surrogate predicts a model's performance on a new dataset by using independent
    XGBoost regressors for each performance metric. For this, models and hyperparameters are
    converted to feature vectors.
    """

    def __init__(
        self,
        objective: Literal["regression", "ranking"] = "regression",
        use_simple_dataset_features: bool = True,
        use_seasonal_naive_performance: bool = False,
        use_catch22_features: bool = False,
        predict: Optional[List[str]] = None,
        tracker: Optional[Tracker] = None,
    ):
        """
        Args:
            objective: The optimization objective for the XGBoost estimators.
            use_simple_dataset_features: Whether to use dataset features to predict using a
                weighted average.
            use_seasonal_naive_performance: Whether to use the Seasonal Naïve nCRPS as dataset
                featuers. Requires the cacher to be set.
            use_catch22_features: Whether to use catch22 features for datasets statistics. Ignored
                if `use_dataset_features` is not set.
            predict: The metrics to predict. All if not provided.
            tracker: An optional tracker that can be used to impute latency and number of model
                parameters into model performances.
        """
        super().__init__(tracker)

        self.use_ranking = objective == "ranking"
        self.config_transformer = ConfigTransformer(
            add_model_features=True,
            add_dataset_statistics=use_simple_dataset_features,
            add_seasonal_naive_performance=use_seasonal_naive_performance,
            add_catch22_features=use_catch22_features,
            tracker=tracker,
        )
        self.performance_transformer = PerformanceTransformer(
            metrics=predict,
        )

        if self.use_ranking:
            base_estimator = XGBRanker(objective="rank:pairwise")
        else:
            base_estimator = XGBRegressor()
        self.estimator = MultiOutputRegressor(base_estimator)

    def fit(self, X: List[Config], y: List[Performance]) -> None:
        X_numpy = self.config_transformer.fit_transform(X)
        y_numpy = self.performance_transformer.fit_transform(y)

        if self.use_ranking:
            # We need to sort by dataset
            dataset_map = {d: i for i, d in enumerate({x.dataset for x in X})}
            dataset_indices = [dataset_map[x.dataset] for x in X]
            sorting = np.argsort(dataset_indices)

            # Then, sort X and y and assign the group IDs
            X = [X[i] for i in sorting]
            y = [y[i] for i in sorting]
            self.estimator.fit(X_numpy, y_numpy, qid=[dataset_map[x.dataset] for x in X])
        else:
            self.estimator.fit(X_numpy, y_numpy)

    def _predict(self, X: List[Config]) -> List[Performance]:
        X_numpy = self.config_transformer.transform(X)
        y = self.estimator.predict(X_numpy)
        return self.performance_transformer.inverse_transform(y)
