"""Noisy benchmark oracle implementation
"""
import numpy as np
from typing import Any
from aiau.data.data_manager import DataManager
from aiau.oracles.oracle import Oracle

class NoisyBenchmarkOracle(Oracle):
    """
    An oracle designed for evaluating strategies in R&D settings,
    it assumes that all the observations are sufficiently annotated.
    When queried it applies the prespecified noising function to the target value.

    Attributes:
        noise_function: A function that applies noise to the target value.

    Methods:
        __init__(self, noise_function): Initializes a NoisyBenchmarkOracle instance.
        query_target_value(self, data_manager, idx): Returns the noised target value for a given observation.
    """

    def __init__(self, noise_function, noise_level: float = 0.5, eps_noise_level: float = 1.0) -> None:
        super(NoisyBenchmarkOracle, self).__init__()
        self.noise_function = noise_function
        self.noise_level = noise_level
        self.eps_noise_level = eps_noise_level

    def query_target_value(self, data_manager: DataManager, idx: int) -> Any:
        """Default method is to simply return the target in the dataset

        Args:
            data_manager (DataManager): Reference to the data_manager which will load the observation if necessary.
            idx (int): Index of the observation for which we want to query an annotation.

        Returns:
            Any: The output of the oracle (the noised target value).
        """
        target_value = data_manager.full_y[idx]
        return self.noise_function(target_value, self.noise_level, self.eps_noise_level)


# Noise function and its effect on the oracle
def aleotoric_and_epistemic_noise_function(g: np.ndarray, noise_level: float = 0.5, eps_noise_level: float = 1.0) -> np.ndarray:
    """Heteroscedastic noise function that adds aleotoric and epistemic noise to the target value.
    """
    eps = np.random.normal(0, eps_noise_level, g.shape)  # Epistemic random noise (shape is the same as g)
    W = noise_level*np.sqrt(1 -  (g**2)) # Aleotoric noise (heteroscedastic) which cannot be explained by the model (shape is the same as g)
    f = g + (W * eps) # Add aleotoric noise to the target value
    return f 
