"""
This module contains functions for estimating the bias
of a deep ensemble's predictions,

We will explore three different approaches to doing so:
- Cheating (ground truth)
- direct (model based prediction of bias)
- Quadratic (model based prediction of bias on the cross-product of errors along observations)

"""

import numpy as np
from aiau.strategy.learn_and_predict_nans import learn_and_predict_nans_direct, learn_and_predict_nans_quadratic

def estimate_bias_squared(data_manager, task_model, estimation_approach, oracle_type, oracle, batch_strategy="top-k"):
    """
    Estimate the bias squared using the specified estimation approach.

    Args:
        data_manager: The data manager containing the data.
        task_model: The model used for predictions. The model we aim to estimate the bias for.
        estimation_approach: The approach to use for estimating the bias squared.
        oracle_type: The type of oracle used for querying.
        oracle: The oracle instance used for querying.

    Returns:
        np.ndarray: The estimated bias squared.
    """
    if estimation_approach == "cheat":
        true_bias_squared = estimate_bias_squared_cheat(data_manager, task_model, oracle_type, oracle)
        return true_bias_squared
    elif estimation_approach == "direct":
        estimated_bias_squared = estimate_bias_squared_direct(data_manager, task_model)
        return estimated_bias_squared
    elif estimation_approach == "quadratic":
        estimated_bias_squared = estimate_bias_squared_quadratic(data_manager, task_model, batch_strategy)
        if batch_strategy == "top-k":
            return estimated_bias_squared
        else:
            diagonal, cobias = estimated_bias_squared
            return diagonal, cobias
    else:
        raise ValueError(f"Unknown estimation approach: {estimation_approach}")
    

def estimate_bias_squared_cheat(data_manager, task_model, oracle_type, oracle) -> np.ndarray:
    # Predict the target values for all the observations
    ensemble_predictions = task_model.predict(data_manager.full_X)
    ensemble_mean = np.mean(ensemble_predictions, axis=0)
    r = 10

    if oracle_type == "type_1":
        true_mu_Y = data_manager.full_y
        true_bias_squared = np.square(ensemble_mean - true_mu_Y) # This is the true bias squared
        return true_bias_squared
    elif oracle_type == "type_2" or oracle_type == "type_3":
        # Obtain r = 10 oracle queries
        oracle_predictions = []
        for _ in range(r):
            oracle_prediction = oracle.query_target_value(data_manager, idx=list(range(data_manager.full_X.shape[0])))
            oracle_prediction = np.array(oracle_prediction)
            oracle_predictions.append(oracle_prediction)
        oracle_predictions = np.array(oracle_predictions) # shape (r, num_observations)
        mean_oracle_predictions = np.mean(oracle_predictions, axis=0) # shape (num_observations,)
        # Calculate the bias squared
        bias_squared = np.square(ensemble_mean - mean_oracle_predictions)
        return bias_squared
    else:
        raise ValueError(f"Unknown oracle type: {oracle_type}")                


def estimate_bias_squared_direct(data_manager, task_model):
    """
    Estimate the bias squared using the direct approach.

    Args:
        data_manager: The data manager containing the data.
        task_model: The model used for predictions. The model we aim to estimate the bias for.

    Returns:
        np.ndarray: The estimated bias squared.
    """
    # Predict the target values for all the observations
    ensemble_predictions = task_model.predict(data_manager.full_X)

    # Calculate the estimated MSE for each observation, in this instance we have access to the true MSE
    estimated_bias_squared = []
    for i in range(ensemble_predictions.shape[1]):
        if i in data_manager.noisy_targets_dict:
            i_prediction = ensemble_predictions[:,i]
            i_prediction_mean = np.mean(i_prediction)
            i_observations = [target for target in data_manager.noisy_targets_dict[i]]
            i_observed_mean = np.mean(i_observations)
            i_estimated_bias_squared = np.square(i_prediction_mean - i_observed_mean)
            estimated_bias_squared.append(i_estimated_bias_squared)
        else:
            estimated_bias_squared.append(np.nan) # note, replaced by np.nan, which needs to be estimated later.
    estimated_bias_squared = np.array(estimated_bias_squared)

    estimated_bias_squared = learn_and_predict_nans_direct(data_manager,
                                                           ensemble_predictions,
                                                           estimated_bias_squared)
    return estimated_bias_squared


def estimate_bias_squared_quadratic(data_manager, task_model, batch_strategy="top-k"):
    """
    Estimate the bias squared using the quadratic approach.

    Args:
        data_manager: The data manager containing the data.
        task_model: The model used for predictions. The model we aim to estimate the bias for.

    Returns:
        np.ndarray: The estimated bias squared.
    """
    # Predict the target values for all the observations
    ensemble_predictions = task_model.predict(data_manager.full_X)

    # Calculate the estimated MSE for each observation, in this instance we have access to the true MSE
    estimated_bias = []
    for i in range(ensemble_predictions.shape[1]):
        if i in data_manager.noisy_targets_dict:
            i_prediction = ensemble_predictions[:,i]
            i_prediction_mean = np.mean(i_prediction)
            i_observations = [target for target in data_manager.noisy_targets_dict[i]]
            i_observed_mean = np.mean(i_observations)
            i_estimated_bias = i_prediction_mean - i_observed_mean
            estimated_bias.append(i_estimated_bias)
        else:
            estimated_bias.append(np.nan) # note, replaced by np.nan, which needs to be estimated later.
    estimated_bias = np.array(estimated_bias) # shape (num_observations,)

    estimated_bias_squared_matrix = estimated_bias[:, np.newaxis] * estimated_bias[np.newaxis, :] # shape (num_observations, num_observations)
    
    estimated_bias_diagonal = learn_and_predict_nans_quadratic(data_manager,
                                                               ensemble_predictions,
                                                               estimated_bias,
                                                               estimated_bias_squared_matrix,
                                                               batch_strategy=batch_strategy)
    

    if batch_strategy == "top-k":
        return estimated_bias_diagonal
    else:
        diagonal, cobias = estimated_bias_diagonal
        # Assert that the diagonal is all positive
        assert np.all(diagonal >= 0), "Estimated bias diagonal contains negative values, which is not expected."
        return diagonal, cobias