import itertools
import math
import numpy as np
import random
import pandas as pd

from numpy.linalg import lstsq
from scipy.optimize import nnls, lsq_linear
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Tuple, Union, Any, Literal
from utils.agg_utils import plot_all


class EnsembleAggregator(ABC):
    """
    Abstract base class for ensemble model aggregation strategies.
    
    Attributes
    ----------
    coef_ : np.ndarray
        Array of coefficients for each model
    normalizer : str or None
        Method to normalize coefficients ('softmax', 'sum', None)
    is_fit : bool
        Whether the aggregator has been fitted
    """
    
    def __init__(
        self, 
        num_models: int = None,
        normalizer: Optional[Literal['softmax', 'sum']] = None,
        visualize: bool = False,
        allow_fit_again: bool = False,
    ):
        """
        Initialize the aggregator.
        
        Parameters
        ----------
        num_models : int, optional
            Number of models in the ensemble
        normalizer : str or None, optional
            Method to normalize coefficients:
            - 'softmax': Use softmax normalization (exponential then normalize)
            - 'sum': Simple normalization (divide by sum)
            - None: No normalization applied
            Default is 'sum'
        """
        self.num_models = num_models
        self._coef = None
        self.normalizer = normalizer
        self.is_fit = False
        self.visualize = visualize
        self.allow_fit_again = allow_fit_again
        if num_models is not None:
            self._coef = np.ones(num_models) / num_models
    
    @property
    def coef_(self) -> np.ndarray:
        """Get the model coefficients."""
        return self._coef
    
    @coef_.setter
    def coef_(self, new_coef: Union[List[float], np.ndarray]) -> None:
        """
        Set model coefficients with optional normalization.
        
        Parameters
        ----------
        new_coef : Union[List[float], np.ndarray]
            New coefficient values
        """
        coef_array = np.array(new_coef, dtype=float)

        if self.normalizer is not None:
            assert self.normalizer in ['softmax', 'sum'], "Normalizer must be either 'softmax' or 'sum'"
        
        if self.normalizer == 'softmax':
            # Apply softmax normalization (numerically stable version)
            coef_shifted = coef_array - np.max(coef_array)
            exp_coef = np.exp(coef_shifted)
            self._coef = exp_coef / np.sum(exp_coef)
        
        elif self.normalizer == 'sum':
            # Simple normalization to sum to 1
            if np.sum(coef_array) == 0:
                # Handle all-zeros case
                self._coef = np.ones_like(coef_array) / len(coef_array)
            else:
                self._coef = coef_array / np.sum(coef_array)
        
        else:
            # No normalization
            self._coef = coef_array
    
    def fit(self, G: np.ndarray, y: np.ndarray):
        """
        Fit the aggregation strategy to the provided data.
        
        This base implementation handles NaN values by removing rows where
        either y or any value in G contains NaN.
        
        Parameters
        ----------
        G : np.ndarray
            2D array of shape (n_samples, n_models) containing predictions from each model
        y : np.ndarray
            1D array of shape (n_samples,) containing the target values
        
        Returns
        -------
        self : EnsembleAggregator
            The fitted aggregator instance
        """
        # Check if already fitted
        if self.is_fit:
            if not self.allow_fit_again:
                print(f"{self.__class__.__name__} is already fitted. Skipping fitting.")
                return self
            else:
                print(f"already fitted, fit again, not skipping fitting")
            
        assert G.shape[0] == y.shape[0], "G and y must have the same number of samples"
        assert G.shape[1] == self.num_models, "G must have the same number of models as the number of models in the aggregator"
        
        # Check if there are any NaN values in y or G
        y_has_nan = np.isnan(y).any()
        G_has_nan = np.isnan(G).any()
        
        if y_has_nan or G_has_nan:
            # Identify rows with NaN values in either y or G
            y_mask = ~np.isnan(y)
            G_mask = ~np.isnan(G).any(axis=1)
            valid_mask = y_mask & G_mask
            
            # Count how many rows were removed
            removed_count = len(y) - np.sum(valid_mask)
            
            if removed_count > 0:
                print(f"Warning: Removed {removed_count} rows containing NaN values ({removed_count/len(y)*100:.1f}% of data)")
                
                # If all data would be removed, raise an error
                if np.sum(valid_mask) == 0:
                    raise ValueError("All data contains NaN values. Cannot fit the model.")
                
                # Filter out rows with NaN values
                G = G[valid_mask]
                y = y[valid_mask]
        
        # Initialize num_models if not already set
        if self.num_models is None:
            self.num_models = G.shape[1]
            if self._coef is None:
                self._coef = np.ones(self.num_models) / self.num_models
        
        print(f"Info: Data for aggregator fitting: G.shape: {G.shape}, y.shape: {y.shape}")
        # Call the specific implementation
        result = self._fit(G, y)
        
        # Mark as fitted
        self.is_fit = True
        
        # plot
        if self.visualize:   
            plot_all(G=G, y=y, coef=self.coef_, g_pred= self.predict(G), debug_dir="./results")
        
        # Return self to enable method chaining
        return result
    
    @abstractmethod
    def _fit(self, G: np.ndarray, y: np.ndarray):
        """
        Fit the aggregation strategy to the provided data.
        """
        pass
    
    def predict(self, G: np.ndarray) -> np.ndarray:
        """
        Generate predictions by aggregating model outputs with learned coefficients.
        
        Parameters
        ----------
        G : np.ndarray
            2D array of shape (n_samples, n_models) containing predictions from each model
            
        Returns
        -------
        np.ndarray
            Weighted average predictions
        """
        if not self.is_fit:
            raise ValueError(f"{self.__class__.__name__} has not been fitted yet. Call fit() before predict().")
        
        # Handle NaN values in G by replacing them with 0.0
        if np.isnan(G).any():
            # Create a copy to avoid modifying the input
            G_copy = G.copy()
            # Replace all NaN values with 0.0
            G_copy[np.isnan(G_copy)] = 0.0
            return np.asarray(G_copy, float) @ self.coef_
        
        return np.asarray(G, float) @ self.coef_


class EqualWeightAggregator(EnsembleAggregator):
    """
    Equal weighting strategy for ensemble models.
    
    This strategy assigns equal weights to all models in the ensemble.
    """
    
    def __init__(
        self, 
        **kwargs
    ):
        super().__init__(**kwargs)
    
    def _fit(self, G: np.ndarray, y: np.ndarray):
        # Set equal weights
        self.coef_ = np.ones(self.num_models) / self.num_models
        return self


class PerformanceWeightAggregator(EnsembleAggregator):
    """
    Performance-based weighting strategy for ensemble models.
    
    This strategy assigns weights to models based on their performance
    (lower error means higher weight).
    """
    
    def __init__(
        self, 
        num_models: int = None, 
        error_metric: str = 'mae',
        normalizer: Optional[Literal['softmax', 'sum']] = None
    ):
        super().__init__(num_models, normalizer)
        self.error_metric = error_metric
    
    def _fit(self, G: np.ndarray, y: np.ndarray):
        """
        Fit the weighting strategy based on model performance.
        
        Parameters
        ----------
        G : np.ndarray
            2D array of shape (n_samples, n_models) containing predictions from each model
        y : np.ndarray
            1D array of shape (n_samples,) containing the target values
            
        Returns
        -------
        self : PerformanceWeightAggregator
            The fitted aggregator instance
        """
        # Calculate performance metrics (errors) for each model
        performance = np.zeros(self.num_models)
        
        for i in range(self.num_models):
            # Get residuals for current model
            residuals = y - G[:, i]
            
            # Calculate error based on selected metric
            if self.error_metric.lower() == 'mae':
                error = np.abs(residuals).mean()
            elif self.error_metric.lower() == 'mse':
                error = (residuals ** 2).mean()
            elif self.error_metric.lower() == 'rmse':
                error = np.sqrt((residuals ** 2).mean())
            else:
                # Default to MAE
                error = np.abs(residuals).mean()
            
            performance[i] = error
        
        # Lower error means higher weight (invert)
        # Avoid division by zero by adding a small epsilon
        epsilon = 1e-10
        weights = 1.0 / (performance + epsilon)
        
        # Normalize weights to sum to 1
        self.coef_ = weights
        return self


class AdaptiveAggregator(EnsembleAggregator):
    """
    Adaptive aggregation strategy that updates weights based on recent performance.
    
    This strategy maintains a history of model performance and updates the weights
    based on the recent history, with more recent performance having more influence.
    """
    
    def __init__(
        self, 
        num_models: int = None, 
        window_size: int = 5,
        decay_factor: float = 0.8,
        error_metric: str = 'mae',
        normalizer: Optional[Literal['softmax', 'sum']] = None
    ):
        """
        Initialize the adaptive aggregation weights.
        
        Parameters
        ----------
        num_models : int, optional
            Number of models in the ensemble
        window_size : int, optional
            Number of recent periods to consider for weighting, by default 5
        decay_factor : float, optional
            Factor to control how quickly past performance loses influence (0 to 1),
            by default 0.8
        error_metric : str, optional
            Name of the error metric to use, by default 'mae'
        """
        super().__init__(num_models, normalizer)
        self.window_size = window_size
        self.decay_factor = decay_factor
        self.error_metric = error_metric
        self.history = None
    
    def _fit(self, G: np.ndarray, y: np.ndarray):
        """
        Update the aggregation strategy based on model performance.
        
        Parameters
        ----------
        G : np.ndarray
            2D array of shape (n_samples, n_models) containing predictions from each model
        y : np.ndarray
            1D array of shape (n_samples,) containing the target values
            
        Returns
        -------
        self : AdaptiveAggregator
            The fitted aggregator instance
        """
        # Initialize history if not already done
        if self.history is None:
            self.history = [[] for _ in range(self.num_models)]
        
        # Calculate performance metrics (errors) for each model
        model_performance = np.zeros(self.num_models)
        
        for i in range(self.num_models):
            # Get residuals for current model
            residuals = y - G[:, i]
            
            # Calculate error based on selected metric
            if self.error_metric.lower() == 'mae':
                error = np.abs(residuals).mean()
            elif self.error_metric.lower() == 'mse':
                error = (residuals ** 2).mean()
            elif self.error_metric.lower() == 'rmse':
                error = np.sqrt((residuals ** 2).mean())
            else:
                # Default to MAE
                error = np.abs(residuals).mean()
            
            model_performance[i] = error
            
            # Update history for this model
            self.history[i].append(error)
            # Keep only the most recent window_size entries
            self.history[i] = self.history[i][-self.window_size:]
        
        # Calculate weights based on recent history with decay
        weights = np.zeros(self.num_models)
        
        for i in range(self.num_models):
            model_history = self.history[i]
            if not model_history:
                weights[i] = 1.0  # Default if no history
                continue
                
            # Calculate weighted average with decay
            decayed_sum = 0
            decay_weights = 0
            for j, value in enumerate(model_history):
                # More recent values have higher position
                position = len(model_history) - j - 1
                decay = self.decay_factor ** position
                decayed_sum += value * decay
                decay_weights += decay
            
            avg_error = decayed_sum / decay_weights if decay_weights > 0 else model_history[-1]
            
            # Invert because lower error is better
            weights[i] = 1.0 / (avg_error + 1e-10)
        
        # Normalize weights to sum to 1
        self.coef_ = weights
        return self


class SPAAggregator(EnsembleAggregator):
    """
    Sparsity–Pattern Aggregation (SPA) estimator for ensemble aggregation.
    
    This estimator uses a Bayesian model to discover and exploit sparsity patterns
    in the ensemble, potentially setting some model coefficients to exactly zero.
    """
    
    def __init__(
        self, 
        sigma: float = None, 
        max_enum: int = 20,
        n_iter: int = 10_000, 
        burn_in: int = 2_000, 
        random_state: int = 0,
        positive: bool = False,
        constraint: Optional[Tuple[float, float]] = None,
        **kwargs
    ):
        """
        Initialize the SPA aggregator.
        
        Parameters
        ----------
        num_models : int, optional
            Number of models in the ensemble
        sigma : float or None, optional
            Known noise std-dev. If None, an unbiased estimate is obtained
            from the full model.
        max_enum : int, default 20
            If the number of experts M ≤ max_enum we enumerate every pattern;
            otherwise Metropolis–Hastings is used.
        n_iter : int, default 10_000
            Total iterations for the Metropolis sampler (ignored if enumerating).
        burn_in : int, default 2_000
            Samples discarded at the beginning of the chain.
        random_state : int, default 0
            Reproducibility seed.
        positive : bool, default False
            If True, the coefficients are constrained to be non-negative.
        constraint : bool, default False
            If True, the coefficients are constrained to be between 0 and 1.
        """
        super().__init__(**kwargs)
        self.sigma = sigma
        self.sigma_epsilon = 1e-8
        self.max_enum = max_enum
        self.n_iter = n_iter
        self.burn_in = burn_in
        self.rng = np.random.default_rng(random_state)
        self._method = None
        self.positive = positive
        self.constraint = constraint
    # ------------------------------------------------------------------ #
    # Public API                                                         #
    # ------------------------------------------------------------------ #
    def _fit(self, G: np.ndarray, y: np.ndarray) -> None:
        """
        Fit SPA from a matrix of expert predictions and the targets.
        
        Parameters
        ----------
        G : ndarray, shape (n_samples, M)
            Column j contains the j-th expert's point forecasts g_j(x_i).
        y : array-like, shape (n_samples,)
            Observed responses.
        
        Returns
        -------
        self : SPAAggregator
            The fitted aggregator instance
        """
        G = np.asarray(G, float)
        y = np.asarray(y, float)
        n, M = G.shape
        
        # full-model σ² if user did not supply one
        sigma = (self.sigma if self.sigma is not None else
                 np.sum((y - G.mean(axis=1))**2) / (n)) # why n-1?
        sigma = float(np.sqrt(sigma)) + self.sigma_epsilon
        
        # rank of design matrix for the Rigollet–Tsybakov prior
        R = np.linalg.matrix_rank(G)
        logH = math.log(2 * sum(
            math.comb(M, k) * (k / (2*math.e*M))**k
            for k in range(R+1)
        ))
        
        # -------------------------helpers------------------------------ #
        def lsq_coeff(mask):
            """Least-squares ω̂_p for a pattern mask (boolean length-M)."""
            idx = np.flatnonzero(mask)
            if idx.size == 0:
                return np.zeros(M)
            A = G[:, idx]
            
            beta = None
            if not self.positive:
                # Unconstrained
                beta, *_ = lstsq(A, y, rcond=None)
            else:    
                # Non-negative
                if self.constraint is None:
                    beta, _ = nnls(A, y)
                else:
                    res = lsq_linear(A, y, bounds=(self.constraint[0], self.constraint[1]))
                    beta = res.x

            w = np.zeros(M)
            w[idx] = beta
            return w
        
        def log_prior(mask):
            k = mask.sum()
            if k == 0:
                return -logH
            if k < R:
                return -logH + k*(math.log(k) - math.log(2*math.e*M))
            if k == M:
                return math.log(0.5)
            return -np.inf  # prior probability 0
        
        def log_weight(mask, rss):
            k = mask.sum()
            return -(rss)/(4*sigma**2) - k/2 + log_prior(mask) # TODO sigma can be zero here.
        
        # ---------------------------------------------------------------- #
        # 1.  Exact enumeration when feasible                              #
        # ---------------------------------------------------------------- #
        if M <= self.max_enum:
            logws, coeffs = [], []
            for bits in itertools.product([0, 1], repeat=M):
                mask = np.array(bits, dtype=bool)
                w = lsq_coeff(mask)
                residuals = y - G @ w
                logw = log_weight(mask, np.sum(residuals**2))
                logws.append(logw)  # TODO add a threshold for updating the weights? or minus a constant?
                coeffs.append(w)

                # reward = np.sqrt(np.sum(residuals**2)) - (-30)
                # logw = log_weight(mask, reward)
                # logws.append(logw)
                # coeffs.append(w)
                # import termcolor
                # print(termcolor.colored(f"residual mse: {np.sum(residuals**2):.4f}, reward: {reward:.4f}, logw: {logw:.4f}, bits: {bits}, w: {w}, R: {R}", "green"))
            
            # numerically stable normalisation
            logws = np.array(logws)
            w_norm = np.exp(logws - logws.max())
            w_norm /= w_norm.sum()
            self.coef_ = np.sum(np.stack(coeffs) * w_norm[:, None], axis=0) # eq2 in paper #TODO is this correct??
            self._method = "enumeration"
            return self
        
        # ---------------------------------------------------------------- #
        # 2.  Metropolis-Hastings for large M                              #
        # ---------------------------------------------------------------- #
        mask = np.zeros(M, dtype=bool)                  # start: empty model TODO maybe start from np.ones(M, dtype=bool)?, hard to pass the acceptance probability
        w = lsq_coeff(mask)
        rss = np.sum((y - G @ w)**2)
        logw_curr = log_weight(mask, rss) - 10000
        
        coef_sum = np.zeros(M)
        weight_sum = 0.0
        
        for it in range(self.n_iter):
            # propose: flip one random coordinate
            j = self.rng.integers(0, M) # use numpy.random.randint, for reproducibility
            mask_prop = mask.copy()
            mask_prop[j] = ~mask_prop[j]
            
            w_prop = lsq_coeff(mask_prop)
            rss_prop = np.sum((y - G @ w_prop)**2)
            logw_prop = log_weight(mask_prop, rss_prop)

            # acceptance probability
            if math.log(self.rng.random()) < (logw_prop - logw_curr):    # TODO why use log? math.log(self.rng.random()) < 0, since self.rng.random() is between 0 and 1
                mask, w, rss, logw_curr = mask_prop, w_prop, rss_prop, logw_prop

            # collect after burn-in, TODO do we add a threshold for updating the weights? Or do we check the acceptance probability here
            if it >= self.burn_in:
                coeff_weight = math.exp(logw_curr)  # unnormalised
                coef_sum += coeff_weight * w
                weight_sum += coeff_weight
        
        self.coef_ = coef_sum / weight_sum
        self._method = "metropolis"
        return self


class SingleBestAggregator(EnsembleAggregator):
    """
    Single best model selection strategy.
    
    This strategy selects only the best-performing model based on the error metric
    and assigns it a weight of 1.0, while all other models get a weight of 0.0.
    """
    
    def __init__(
        self, 
        error_metric: str = 'mae',
        **kwargs
    ):
        """
        Initialize the single best model selector.
        
        Parameters
        ----------
        num_models : int, optional
            Number of models in the ensemble
        error_metric : str, optional
            Name of the error metric to use, by default 'mae'
        normalizer : Optional[Literal['softmax', 'sum']], optional
            Normalization method, by default None
        """
        # For single best, we typically don't want to normalize the weights
        # as we specifically want one model to have weight 1.0 and others 0.0
        super().__init__(**kwargs)
        self.error_metric = error_metric
        self.best_model_index = None
    
    def _fit(self, G: np.ndarray, y: np.ndarray):
        """
        Identify the best-performing model and set its weight to 1.0.
        
        Parameters
        ----------
        G : np.ndarray
            2D array of shape (n_samples, n_models) containing predictions from each model
        y : np.ndarray
            1D array of shape (n_samples,) containing the target values
            
        Returns
        -------
        self : SingleBestAggregator
            The fitted aggregator instance
        """
        # Calculate performance metrics (errors) for each model
        errors = np.zeros(self.num_models)
        
        for i in range(self.num_models):
            # Get residuals for current model
            residuals = y - G[:, i]
            
            # Calculate error based on selected metric
            if self.error_metric.lower() == 'mae':
                error = np.abs(residuals).mean()
            elif self.error_metric.lower() == 'mse':
                error = (residuals ** 2).mean()
            elif self.error_metric.lower() == 'rmse':
                error = np.sqrt((residuals ** 2).mean())
            else:
                # Default to MAE
                error = np.abs(residuals).mean()
            
            errors[i] = error
        
        # Find the index of the best model (lowest error)
        self.best_model_index = np.argmin(errors)
        
        # Create weights with zeros everywhere except for the best model
        weights = np.zeros(self.num_models)
        weights[self.best_model_index] = 1.0
        
        # Set the weights
        self.coef_ = weights
        
        # print(f"Selected model {self.best_model_index} as the best model with {self.error_metric}={errors[self.best_model_index]:.4f}")
        
        return self


class LinearAggregator(EnsembleAggregator):
    """
    Linear combination aggregation strategy using elastic net regularization.
    
    This strategy fits a linear model (OLS with regularization) to find the optimal
    combination of model predictions. It can use various regularization approaches
    including ridge, lasso, or elastic net.
    """
    
    def __init__(
        self, 
        alpha: float = 0.001,  # Regularization strength
        l1_ratio: float = 0.5,  # Mix between L1 and L2 (1 = lasso, 0 = ridge)
        fit_intercept: bool = False,
        max_iter: int = 1000,
        tol: float = 1e-4,
        random_state: int = 42,
        positive: bool = True,
        **kwargs
    ):
        """
        Initialize the linear combination aggregator with elastic net regularization.
        
        Parameters
        ----------
        num_models : int, optional
            Number of models in the ensemble
        alpha : float, optional
            Regularization strength parameter, by default 0.001
        l1_ratio : float, optional
            The mixing parameter between L1 and L2 regularization:
            l1_ratio = 1 is lasso (L1), l1_ratio = 0 is ridge (L2)
            Default is 0.5 (elastic net)
        fit_intercept : bool, optional
            Whether to calculate the intercept, by default False
        max_iter : int, optional
            Maximum number of iterations for the solver, by default 1000
        tol : float, optional
            Tolerance for stopping criteria, by default 1e-4
        normalizer : Optional[Literal['softmax', 'sum']], optional
            Normalization method to apply to coefficients, by default None
        random_state : int, optional
            Random seed for reproducibility, by default 42
        positive : bool, optional
            Whether to enforce non-negative coefficients, by default True
        """
        super().__init__(**kwargs)
        self.alpha = alpha
        self.l1_ratio = l1_ratio
        self.fit_intercept = fit_intercept
        self.max_iter = max_iter
        self.tol = tol
        self.linear_model = None
        self.intercept_ = 0.0
        self.random_state = random_state
        self.positive = positive

    def _fit(self, G: np.ndarray, y: np.ndarray):
        """
        Fit a linear model using elastic net regularization.
        
        Parameters
        ----------
        G : np.ndarray
            2D array of shape (n_samples, n_models) containing predictions from each model
        y : np.ndarray
            1D array of shape (n_samples,) containing the target values
            
        Returns
        -------
        self : LinearCombinationAggregator
            The fitted aggregator instance
        """
        # Import sklearn here to avoid making it a global dependency
        try:
            from sklearn.linear_model import ElasticNet
        except ImportError:
            raise ImportError(
                "scikit-learn is required for LinearCombinationAggregator. "
                "Please install it with 'pip install scikit-learn'."
            )
        
        # Create and fit the elastic net model
        self.linear_model = ElasticNet(
            alpha=self.alpha,
            l1_ratio=self.l1_ratio,
            fit_intercept=self.fit_intercept,
            max_iter=self.max_iter,
            tol=self.tol,
            random_state=self.random_state,
            positive=self.positive, # for positive coefficients
        )
        
        # Fit the model
        self.linear_model.fit(G, y)
        
        # Set the coefficients
        self.coef_ = self.linear_model.coef_
        
        # Store the intercept
        if self.fit_intercept:
            self.intercept_ = self.linear_model.intercept_
        
        # Print coefficients information
        nonzero_coefs = np.sum(np.abs(self.coef_) > 1e-6)
        # print(f"Linear combination fitted with {nonzero_coefs}/{self.num_models} non-zero coefficients")
        # if self.fit_intercept:
            # print(f"Intercept: {self.intercept_:.4f}")
            
        return self
    
    def predict(self, G: np.ndarray) -> np.ndarray:
        """
        Generate predictions using the linear combination.
        
        Parameters
        ----------
        G : np.ndarray
            2D array of shape (n_samples, n_models) containing predictions from each model
            
        Returns
        -------
        np.ndarray
            Linear combination predictions
        """
        # Handle NaN values in G by replacing them with 0.0
        if np.isnan(G).any():
            G_copy = G.copy()
            G_copy[np.isnan(G_copy)] = 0.0
            G = G_copy
        
        # Perform linear combination with intercept if applicable
        if self.fit_intercept:
            return np.asarray(G, float) @ self.coef_ + self.intercept_
        else:
            return np.asarray(G, float) @ self.coef_


# Factory function to create aggregation strategies
def create_aggregator(
    strategy_type: str,
    num_models: int,
    **kwargs
) -> EnsembleAggregator:
    """
    Factory function to create ensemble aggregation strategies.
    
    Parameters
    ----------
    strategy_type : str
        Type of aggregation strategy. Options: 'equal', 'performance', 'adaptive'
    num_models : int
        Number of models in the ensemble
    **kwargs
        Additional arguments to pass to the specific aggregation strategy
        
    Returns
    -------
    EnsembleAggregator
        A concrete instance of EnsembleAggregator
        
    Raises
    ------
    ValueError
        If the strategy_type is not recognized
    """
    if strategy_type.lower() == 'equal':
        return EqualWeightAggregator(num_models, **kwargs)
    elif strategy_type.lower() == 'performance':
        return PerformanceWeightAggregator(num_models, **kwargs)
    elif strategy_type.lower() == 'adaptive':
        return AdaptiveAggregator(num_models, **kwargs)
    elif strategy_type.lower() == 'spa':
        return SPAAggregator(num_models, **kwargs)
    elif strategy_type.lower() == 'singlebest':
        return SingleBestAggregator(num_models, **kwargs)
    elif strategy_type.lower() == 'linear':
        return LinearAggregator(num_models, **kwargs)
    else:
        raise ValueError(f"Unknown aggregation strategy: {strategy_type}") 