import numpy as np
from tqdm import tqdm
from typing import Any
from aiau.data.data_manager import DataManager
from aiau.oracles.oracle import Oracle

class CorrelatedNoiseOracle(Oracle):
    """
    An oracle designed for evaluating strategies in R&D settings,
    assuming that all the observations are sufficiently annotated.
    When queried, it applies the pre-specified noising function.

    In this case, the noise of nearby observations is correlated.

    Attributes:
        noise_function (function): A function that applies noise to the target value.

    Methods:
        __init__(self, data_manager, noise_level, eps_noise_level): Initializes a CorrelatedNoiseOracle instance.
        query_target_value(self, data_manager, idx): Returns the noised target value for a given observation.
    """

    def __init__(self, data_manager: DataManager, noise_level: float = 0.5, eps_noise_level: float = 1.0) -> None:
        super(CorrelatedNoiseOracle, self).__init__()

        # Construct the correlation matrix from the x values in the dataset
        # and store the lower triangular part of the matrix computed
        # using a cholesky decomposition.
        values = data_manager.full_X
        values = values.T
        Sigma = np.empty((values.shape[1], values.shape[1]))
        for ii in tqdm(range(Sigma.shape[0]), "Building correlation matrix for the oracle"):
            for jj in range(ii, Sigma.shape[0]):
                coord_ii = values[:, ii]
                coord_jj = values[:, jj]
                d_norm = np.linalg.norm(coord_ii - coord_jj)
                Sigma[ii, jj] = np.exp(-d_norm / (np.pi / 2))
                Sigma[jj, ii] = Sigma[ii, jj]
        self.L = np.linalg.cholesky(Sigma)

        self.noise_level = noise_level
        self.eps_noise_level = eps_noise_level

    def query_target_value(self, data_manager: DataManager, idx: list) -> Any:
        """
        Returns the noised target value for a given observation.

        Args:
            data_manager (DataManager): Reference to the data_manager which will load the observation if necessary.
            idx (list): Index of the observation for which we want to query an annotation.

        Returns:
            Any: The output of the oracle (the noised target value).
        """
        g = data_manager.full_y
        eps = np.random.normal(0, self.eps_noise_level, g.shape)  # Generate epistemic noise,
        W = self.noise_level * np.sqrt(1 - (g ** 2))  # Heteroscedastic noise via g values
        f = g + (W * self.L @ eps)
        # select values from f using indices
        return f[idx]