import pytest
import numpy as np
from aiau.oracles.correlated_noise_oracle import CorrelatedNoiseOracle
from aiau.data.data_manager import DataManager
from aiau.data.synthetic_datasets import generate_2d_synthetic_data


class DummyDataManager(DataManager):
    def __init__(self):
        self.full_X = np.array([[0.0, 0.0],
                                [1.0, 1.0],
                                [2.0, 2.0]])
        self.full_y = np.array([0.2, 0.6, 0.9])
        self.targets_dict = {i: v for i, v in enumerate(self.full_y)}


def test_correlated_noise_oracle_returns_expected_shape():
    dm = DummyDataManager()
    oracle = CorrelatedNoiseOracle(data_manager=dm)

    idx = [0, 1]
    result = oracle.query_target_value(dm, idx)

    assert isinstance(result, np.ndarray)
    assert result.shape == (len(idx),)


def test_correlated_noise_oracle_output_variability():
    dm = DummyDataManager()
    oracle = CorrelatedNoiseOracle(data_manager=dm, noise_level=1.0, eps_noise_level=1.0)

    idx = [1]
    values = [oracle.query_target_value(dm, idx)[0] for _ in range(10)]
    std_dev = np.std(values)

    assert std_dev > 0.01  # confirms noise is applied

def create_small_data_manager():
    """Creates a small DataManager instance for testing."""
    X, y = generate_2d_synthetic_data(x1_size=3, x2_size=3)
    indices = np.arange(X.shape[0])
    initially_labelled_indices = np.random.choice(indices, size=3, replace=False)

    return DataManager(indices=indices,
                       observations=X,
                       targets=y,
                       initially_labelled_indices=initially_labelled_indices)

def test_correlated_noise_oracle_exhibits_correlation():
    dm = create_small_data_manager()
    oracle = CorrelatedNoiseOracle(data_manager=dm, noise_level=0.5, eps_noise_level=1.0)
    dm.initialise(oracle)

    idx = list(range(len(dm.full_y)))
    samples = np.array([oracle.query_target_value(dm, idx) for _ in range(100)])

    empirical_corr = np.corrcoef(samples, rowvar=False)
    expected_corr = oracle.L @ oracle.L.T
    expected_corr /= np.sqrt(np.outer(np.diag(expected_corr), np.diag(expected_corr)))

    diff = np.abs(empirical_corr - expected_corr)
    assert np.mean(diff) < 0.15