import pytest
import numpy as np
from sklearn.linear_model import LinearRegression
from aiau.models.bootstrap_ensemble import BootstrapEnsemble


@pytest.fixture
def sample_data():
    X = np.random.rand(100, 3)
    y = 2 * X[:, 0] + 3 * X[:, 1] - 1 * X[:, 2] + 0.1 * np.random.randn(100)
    return X, y

@pytest.fixture
def small_data():
    X = np.random.rand(10, 2)
    y = 2 * X[:, 0] + 3 * X[:, 1] + np.random.randn(10) * 0.1
    return X, y

def test_initialisation(sample_data):
    X, y = sample_data
    ensemble = BootstrapEnsemble(model=LinearRegression, num_models=5)
    assert len(ensemble.models) == 5
    assert all(isinstance(m, LinearRegression) for m in ensemble.models)


def test_fit_and_predict_shape(sample_data):
    X, y = sample_data
    ensemble = BootstrapEnsemble(model=LinearRegression, num_models=4)
    ensemble.fit(X, y)
    predictions = ensemble.predict(X)

    assert predictions.shape == (4, X.shape[0])
    assert np.all(np.isfinite(predictions))


def test_predict_mean_output(sample_data):
    X, y = sample_data
    ensemble = BootstrapEnsemble(model=LinearRegression, num_models=3)
    ensemble.fit(X, y)
    pred = ensemble.predict_mean(X)

    assert pred.shape == (X.shape[0],)
    assert np.all(np.isfinite(pred))


def test_bootstrap_diversity(sample_data):
    """Tests whether the models receive different bootstrapped subsets by checking prediction variance."""
    X, y = sample_data
    ensemble = BootstrapEnsemble(model=LinearRegression, num_models=5)
    ensemble.fit(X, y)
    predictions = ensemble.predict(X)

    std_across_models = np.std(predictions, axis=0)
    assert np.mean(std_across_models) > 0  # if perfectly identical, there's no diversity

def test_small_data_fitting(small_data):
    """Tests fitting on a small dataset to ensure the ensemble can handle fewer samples."""
    X, y = small_data
    ensemble = BootstrapEnsemble(model=LinearRegression, num_models=3)
    ensemble.fit(X, y)

    predictions = ensemble.predict(X)
    assert predictions.shape == (3, X.shape[0])
    assert np.all(np.isfinite(predictions))

    mean_pred = ensemble.predict_mean(X)
    assert mean_pred.shape == (X.shape[0],)
    assert np.all(np.isfinite(mean_pred))
