# Model wrapping a bootstrap ensemble of models

import numpy as np
from tqdm import tqdm

from sklearn.base import BaseEstimator
from joblib import Parallel, delayed

import warnings
from sklearn.exceptions import ConvergenceWarning

class BootstrapEnsemble():
    """
    Model wrapping a bootstrapped ensemble of models
    
    Attributes:
        models (list): The ensemble of models.
        num_models (int): The number of models in the ensemble.

    Methods:
        __init__(self, model, num_models): Initializes a BootstrapEnsemble instance.
        fit(self, X, y): Fits the ensemble of models to the dataset.
        predict(self, X): Predicts the target values for the observations. Output
            is of shape (num_models, num_observations).
    Args:
        model (BaseEstimator): The model to be used in the ensemble.
        model_kwargs (dict): The keyword arguments to be passed to the model.
        num_models (int): The number of models in the ensemble.

    Returns:
        None
    """

    from joblib import Parallel, delayed

    def __init__(self, model, num_models, model_kwargs = {}) -> None:
        self.models = [model(**model_kwargs) for _ in range(num_models)]
        self.num_models = num_models

    def fit(self, X, y):
        """
        Fits the ensemble of models to the dataset.

        Args:
            X (numpy.ndarray): The observations.
            y (numpy.ndarray): The target values.

        Returns:
            None
        """
        def fit_model_random_sample(model, X, y):
            # Ignore ConvergenceWarning
            warnings.filterwarnings("ignore", category=ConvergenceWarning)
            warnings.filterwarnings("ignore", category=UserWarning)

            bootstrap_indices = np.random.choice(X.shape[0], X.shape[0] - int(X.shape[0]/self.num_models), replace=True)
            X_bootstrap, y_bootstrap = X[bootstrap_indices], y[bootstrap_indices]
            model.fit(X_bootstrap, y_bootstrap)
            return model  # Return the trained model
        
        def fit_model_w_indices(model, X, y, indices):
            # Ignore ConvergenceWarning
            warnings.filterwarnings("ignore", category=ConvergenceWarning)
            warnings.filterwarnings("ignore", category=UserWarning)

            X_subset, y_subset = X[indices], y[indices]
            model.fit(X_subset, y_subset)
            return model
        
        if len(X) < self.num_models:
            self.models = Parallel(n_jobs=-1)(
                delayed(fit_model_random_sample)(model, X, y) for model in self.models
            )
        else:
            # If we have enough observations, we can perform a kfold based bootstrap ensemble
            from sklearn.model_selection import KFold
            kf = KFold(n_splits=self.num_models, shuffle=True, random_state=42)
            training_indices = [train_idx for train_idx, _ in kf.split(X)]
            self.models = Parallel(n_jobs=-1)(
                delayed(fit_model_w_indices)(model, X, y, indices) for model, indices in zip(self.models, training_indices)
            )


    def predict(self, X):
        """
        Predicts the target values for the observations.

        Args:
            X (numpy.ndarray): The observations.

        Returns:
            numpy.ndarray: The predicted target values of shape (num_models, num_observations)
        """
        predictions = np.array([model.predict(X) for model in self.models])
        return predictions

    def predict_mean(self, X):
        """
        Predicts the target values for the observations.

        Args:
            X (numpy.ndarray): The observations.

        Returns:
            numpy.ndarray: The predicted target values.
        """
        predictions = np.array([model.predict(X) for model in self.models])
        return np.mean(predictions, axis=0)
    

# Test the BootstrapEnsemble class
if __name__ == "__main__":
    from sklearn.neural_network import MLPRegressor
    from aiau.data.synthetic_datasets import generate_2d_synthetic_data

    # Load the dataset
    X, y = generate_2d_synthetic_data(x1_size= 50, x2_size=50)

    # Create a BootstrapEnsemble object
    nn_kwargs = {"hidden_layer_sizes": (100, 100), "max_iter": 1000}
    ensemble = BootstrapEnsemble(model=MLPRegressor, num_models=10, model_kwargs=nn_kwargs)

    # Fit the ensemble of models to the dataset
    ensemble.fit(X, y)

    # Predict the target values for the observations
    predicted_targets = ensemble.predict(X)
    print(predicted_targets)
    print(predicted_targets.shape)


