"""
This module contains functions for estimating the expected mean squared error (EMSE)
of a deep ensemble's predictions

We will explore three different approaches to doing so:
- Cheating (ground truth)
- direct (model based prediction of error)
- Quadratic (model based prediction of error on the cross-product of errors along observations)

"""

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from aiau.strategy.learn_and_predict_nans import learn_and_predict_nans_direct, learn_and_predict_nans_quadratic

def estimate_pemse(data_manager, task_model, estimation_approach, oracle_type, oracle):
    
    if estimation_approach == "cheat":
        true_pemse = estimate_pemse_cheat(data_manager, task_model, oracle_type, oracle)
        return true_pemse
    elif estimation_approach == "direct":
        estimated_pemse = estimate_pemse_direct(data_manager, task_model)
        return estimated_pemse
    elif estimation_approach == "quadratic":
        estimated_pemse = estimate_pemse_quadratic(data_manager, task_model, oracle_type)
        return estimated_pemse
    else:
        raise ValueError(f"Unknown estimation approach: {estimation_approach}")


def estimate_pemse_cheat(data_manager, task_model, oracle_type, oracle):
    """
    Estimate the expected mean squared error (EMSE) using the ground truth.
    
    Args:
        data_manager: The data manager containing the data.
        task_model: The model used for predictions. The model we aim to estimate the EMSE for.
        oracle_type: The type of oracle used for querying.
    
    Returns:
        np.ndarray: The estimated EMSE.
    """
    r = 10

    if oracle_type == "type_1":
        # Predict the target values for all the observations
        ensemble_predictions = task_model.predict(data_manager.full_X)
        # Obtain the MSE for each observation
        true_pemse = np.mean(np.square(data_manager.full_y - ensemble_predictions), axis=0)
        return true_pemse
    
    elif oracle_type == "type_2" or oracle_type == "type_3":
        # Predict the target values for all the observations
        ensemble_predictions = task_model.predict(data_manager.full_X) # shape (ensemble_size, num_observations)

        # Obtain the PEMSE for each observation by calling the oracle r=10 times repeatedly
        oracle_predictions = []
        for _ in range(r):
            oracle_prediction = oracle.query_target_value(data_manager, idx=list(range(data_manager.full_X.shape[0])))
            oracle_prediction = np.array(oracle_prediction)
            oracle_predictions.append(oracle_prediction)
        oracle_predictions = np.array(oracle_predictions) # shape (r, num_observations)

        # Calculate the PEMSE for each observation
        true_diff = (ensemble_predictions[:, None, :] - oracle_predictions[None, :, :]) ** 2 # shape (ensemble_size, r, num_observations)
        diff_over_obs = true_diff.reshape(-1, true_diff.shape[2]) # shape (ensemble_size * r, num_observations)
        true_pemse = np.mean(diff_over_obs, axis=0) # shape (num_observations,)
        return true_pemse
    else:
        raise ValueError(f"Unknown oracle type: {oracle_type}")


def estimate_pemse_direct(data_manager, task_model):
    """
    Estimate the expected mean squared error (EMSE) by directly learning a model over the observed losses.

    Args:
        data_manager: The data manager containing the data.
        task_model: The model used for predictions. The model we aim to estimate the EMSE for.

    Returns:
        np.ndarray: The estimated EMSE.
    """
    # Predict the target values for all the observations
    ensemble_predictions = task_model.predict(data_manager.full_X)

    # Calculate the estimated MSE for each observation, in this instance we have access to the true MSE
    estimated_emse = []
    for i in range(ensemble_predictions.shape[1]):
        if i in data_manager.noisy_targets_dict:
            i_prediction = ensemble_predictions[:,i]
            i_mses = [np.square(i_prediction - target) for target in data_manager.noisy_targets_dict[i]]
            i_mse = np.mean(i_mses)
            estimated_emse.append(i_mse)
        else:
            estimated_emse.append(np.nan)
    estimated_emse = np.array(estimated_emse)

    # Replace NaN values with predictions
    estimated_emse = learn_and_predict_nans_direct(data_manager, ensemble_predictions, estimated_emse)

    return estimated_emse


def estimate_pemse_quadratic(data_manager, task_model, oracle_type):
    """
    Estimate the expected mean squared error (EMSE) using a quadratic approach.

    Args:
        data_manager: The data manager containing the data.
        task_model: The model used for predictions. The model we aim to estimate the EMSE for.
        oracle_type: The type of oracle used for querying.

    Returns:
        np.ndarray: The estimated EMSE.
    """
    # Predict the target values for all the observations
    ensemble_predictions = task_model.predict(data_manager.full_X)

    # Calculate the estimated MSE for each observation, in this instance we have access to the true MSE
    estimated_emse = []
    for i in range(ensemble_predictions.shape[1]):
        if i in data_manager.noisy_targets_dict:
            i_prediction = ensemble_predictions[:,i]
            i_mses = [np.square(i_prediction - target) for target in data_manager.noisy_targets_dict[i]]
            i_mse = np.mean(i_mses)
            estimated_emse.append(i_mse)
        else:
            estimated_emse.append(np.nan)
    estimated_emse = np.array(estimated_emse)

    # Calculate the an estimated_coemse matrix
    # This is a matrix of the multiplied difference between the predictions and the observed values across
    # pairs of observations
    if oracle_type == "type_1" or oracle_type == "type_2":
        estimated_emse_matrix = compute_emse_matrix_uncorrelated(data_manager, ensemble_predictions)
    elif oracle_type == "type_3":
        estimated_emse_matrix = compute_emse_matrix_correlated(data_manager, ensemble_predictions)

    estimated_emse = learn_and_predict_nans_quadratic(data_manager,
                                                               ensemble_predictions,
                                                               estimated_emse,
                                                               estimated_emse_matrix)
    
    return estimated_emse


def compute_emse_matrix_uncorrelated(data_manager, ensemble_predictions):
    estimated_emse_matrix = np.zeros((ensemble_predictions.shape[1], ensemble_predictions.shape[1]))
    for i in range(ensemble_predictions.shape[1]):
        for j in range(ensemble_predictions.shape[1]):
            if i in data_manager.noisy_targets_dict and j in data_manager.noisy_targets_dict:
                i_prediction = ensemble_predictions[:,i]
                j_prediction = ensemble_predictions[:,j]
                i_error = np.array([i_prediction - target for target in data_manager.noisy_targets_dict[i]]) # shape (ensemble_size, num_observations_i)
                j_error = np.array([j_prediction - target for target in data_manager.noisy_targets_dict[j]]) # shape (ensemble_size, num_observations_j)

                i_error = i_error.T
                j_error = j_error.T
                
                # We need the outer product of each pair of errors along the estimators, then the mean of all elements
                outer_product = np.einsum('ki, kj -> kij', i_error, j_error) # shape (ensemble_size, num_observations_i, num_observations_j)
                estimated_emse_matrix[i, j] = np.mean(outer_product)
            else:
                estimated_emse_matrix[i, j] = np.nan

    return estimated_emse_matrix


def compute_emse_matrix_correlated(data_manager, ensemble_predictions):
    estimated_emse_matrix = np.zeros((ensemble_predictions.shape[1], ensemble_predictions.shape[1]))
    for i in range(ensemble_predictions.shape[1]):
        for j in range(ensemble_predictions.shape[1]):
            counts = 0
            sum_over_obs = 0
            for h in sorted(list(data_manager.labelling_history.keys())):
                if i in data_manager.labelling_history[h] and j in data_manager.labelling_history[h]:
                    i_prediction = ensemble_predictions[:,i]
                    j_prediction = ensemble_predictions[:,j]
                    i_error = i_prediction - data_manager.correlated_draws_labels[h][i] # shape (ensemble_size,)
                    j_error = j_prediction - data_manager.correlated_draws_labels[h][j] # shape (ensemble_size,)

                    sum_over_obs += np.mean(i_error * j_error)
                    counts += 1
            if counts > 0:
                estimated_emse_matrix[i, j] = sum_over_obs / counts
            else:
                estimated_emse_matrix[i, j] = np.nan

    # Print how many nans and non-nans
    print(f"Estimated EMSE matrix: {np.sum(np.isnan(estimated_emse_matrix))} nans, {np.sum(~np.isnan(estimated_emse_matrix))} non-nans")
    return estimated_emse_matrix