import numpy as np
import pandas as pd

from src.human_metrics.calculate_accuracy import generate_response_df, check_array_dims


def compute_bias(y_true:np.ndarray, y_pred:np.ndarray, rel_indices:tuple=(0, 2)):
    """function to calculate the bias of the participants responses
    Args:
        y_true: numpy array of true values
        y_pred: numpy array of predicted values
        rel_idx: list of indices to include 
        in the accuracy calculation
    Returns:
        the bias given stimuli for the given binary distinction of two planets
    """
    
    # check that the dimensions of the input arrays are equal
    check_array_dims(y_true, y_pred)

    # check if y_true and y_pred are 4D arrays
    if y_true.ndim != 4 or y_pred.ndim != 4:
        raise NotImplementedError("y_true and y_pred should be 4D arrays")

    # convert the arrays to a dataframe
    response_df = generate_response_df(y_true, y_pred)

    # remove the test trials
    filtered_df = response_df[response_df['ground_truth'].apply(lambda x: x[14] == 0)].copy()
    # remove all trials where no response is given
    filtered_df = filtered_df[filtered_df['response_value'].apply(lambda x: np.sum(x) != 0)].copy()
    # replace the array with the relevant indices
    filtered_df['response_value'] = filtered_df['response_value'].apply(lambda x: x[rel_indices[0]:rel_indices[1]])

    # calculate the bias
    bias_series = filtered_df.groupby(['subject', 'block']).apply(
    lambda g: np.mean([np.sum(r) for r in g['response_value']]))
    bias_df = bias_series.unstack(level=-1)

    return bias_df.values

def compute_bias_three_levels(y_true, y_pred, partition_indices:tuple=(2, 4, 8)):
    """The function calculates the accuracy between two numpy 
       arrays for all three levels. It returns a 3d array of accuracies.
    Args:
        y_true (np.ndarray): of true values
        y_pred (np.ndarray): of predicted values
    Returns:
        np.ndarray: bias values for each level of the hierarchy
    """
    level_list = []
    start_idx = 0
    for idx in partition_indices:
        level_list.append(compute_bias(y_true, y_pred, rel_indices=(start_idx, start_idx + idx)))
        start_idx += idx
    return np.stack(level_list, axis=0)