import pytest
import numpy as np
from aiau.oracles.noisy_oracle import NoisyBenchmarkOracle, aleotoric_and_epistemic_noise_function
from aiau.data.data_manager import DataManager


class DummyDataManager(DataManager):
    def __init__(self):
        self.full_y = np.array([0.1, 0.5, 0.9])
        self.targets_dict = {i: v for i, v in enumerate(self.full_y)}


def test_noisy_benchmark_oracle_returns_expected_shape():
    dm = DummyDataManager()
    oracle = NoisyBenchmarkOracle(noise_function=aleotoric_and_epistemic_noise_function)

    idx = [0, 1, 2]
    result = oracle.query_target_value(dm, idx)

    assert isinstance(result, np.ndarray)
    assert result.shape == dm.full_y[idx].shape


def test_noisy_benchmark_oracle_noise_effect():
    dm = DummyDataManager()
    oracle = NoisyBenchmarkOracle(noise_function=aleotoric_and_epistemic_noise_function,
                                   noise_level=1.0, eps_noise_level=1.0)

    idx = [1]
    values = [oracle.query_target_value(dm, idx)[0] for _ in range(10)]
    std_dev = np.std(values)

    assert std_dev > 0.01  # stochastic noise introduced