import numpy as np
import pandas as pd

from src.human_metrics.calculate_accuracy import generate_response_df, check_array_dims

def compute_sibling_metric(group):
    """
    Function to compute the sibling metric
    
    Args:
        group (pandas DataFrame): the data frame of a single group
    Returns:
        float: the sibling metric
    """

    responses = []
    targets = []

    for idx, row in group.iterrows():
        # Compute sibling metric
        rel_targets = row['ground_truth'] 
        rel_response = row['response_value']

        desired_entries = [np.array([1, 0]), np.array([0, 1])]

        # Create a mask for desired entries
        mask_list = []
        for i in range(0, rel_targets.shape[0], 2):  # iterating over the array
            slice_a = rel_targets[i:i+2]
            mask_list.append(np.any([np.all(slice_a == entry) for entry in desired_entries], axis=0))

        mask = np.array(mask_list)
        mask = np.repeat(mask, 2)

        # Filter and reshape the matrices using the mask
        filtered_targets = rel_targets[mask]
        filtered_response  = rel_response[mask]
        # also remove the entries where the response is zero zero
        if np.sum(filtered_response) != 0:
            targets.append(filtered_targets)
            responses.append(filtered_response)

    return safe_mean(np.asarray(responses), np.asarray(targets)) if responses else np.nan

def calculate_n_sibling_clicks(group):
    """ 
    Function to compute the number of sibling clicks, P(1 click allocated to the siblings on each level)
    
    Args:
        group (pandas DataFrame): the data frame of a single group
    Returns:
        float: the proportion of sibling clicks
    """
    responses = []

    for idx, row in group.iterrows():
        # Compute sibling metric
        rel_targets = row['ground_truth'] 
        rel_response = row['response_value']

        desired_entries = [np.array([1, 0]), np.array([0, 1])]

        # Create a mask for desired entries
        mask_list = []
        for i in range(0, rel_targets.shape[0], 2):  # iterating over the array
            slice_a = rel_targets[i:i+2]
            mask_list.append(np.any([np.all(slice_a == entry) for entry in desired_entries], axis=0))

        mask = np.array(mask_list)
        mask = np.repeat(mask, 2)

        # Filter and reshape the matrices using the mask
        filtered_targets = rel_targets[mask]
        filtered_response  = rel_response[mask]
        
        # Filter and reshape the matrices using the mask
        metric = np.mean(filtered_response == filtered_targets)

        # if metric == 1:
        #     metric = 0
        # elif np.sum(filtered_response) == 2 :
        #     metric = 0
        # elif np.sum(filtered_response) == 0 and metric == 0.5:
        #     metric = 1
        # elif np.sum(filtered_response) == 1 and metric == 0.5:
        #     metric = 0

        if metric == 0:
            metric = 1
        elif np.sum(filtered_response) == 2 :
            metric = 0
        elif np.sum(filtered_response) == 0 and metric == 0.5:
            metric = 0
        responses.append(metric)
    return np.mean(responses)


def compute_chris_sibling_metric(group):
    """ 
    Function to compute the sibling metric in Chris's style
    
    Args:
        group (pandas DataFrame): the data frame of a single group
    Returns:
        float: the sibling metric in Chris's style
    """

    responses = []

    for idx, row in group.iterrows():
        # Compute sibling metric
        rel_targets = row['ground_truth'] 
        rel_response = row['response_value']

        desired_entries = [np.array([1, 0]), np.array([0, 1])]

        # Create a mask for desired entries
        mask_list = []
        for i in range(0, rel_targets.shape[0], 2):  # iterating over the array
            slice_a = rel_targets[i:i+2]
            mask_list.append(np.any([np.all(slice_a == entry) for entry in desired_entries], axis=0))

        mask = np.array(mask_list)
        mask = np.repeat(mask, 2)

        # Filter and reshape the matrices using the mask
        filtered_targets = rel_targets[mask]
        filtered_response  = rel_response[mask]
        
        # Filter and reshape the matrices using the mask
        metric = np.mean(filtered_response == filtered_targets)
        # replace entries ala Chris
        if metric == 0:
            metric = -1
        elif metric == 0.5:
            metric = 0
        responses.append(metric)
    return np.mean(responses)

def calculate_sibling_bias(group):
    """
    Function to compute the sibling bias

    Args:
        group (pandas DataFrame): the data frame of a single group
    Returns:
        float: the sibling bias
    """
    responses = []

    for idx, row in group.iterrows():
        # Compute sibling metric
        rel_targets = row['ground_truth'] 
        rel_response = row['response_value']

        desired_entries = [np.array([1, 0]), np.array([0, 1])]

        # Create a mask for desired entries
        mask_list = []
        for i in range(0, rel_targets.shape[0], 2):  # iterating over the array
            slice_a = rel_targets[i:i+2]
            mask_list.append(np.any([np.all(slice_a == entry) for entry in desired_entries], axis=0))

        mask = np.array(mask_list)
        mask = np.repeat(mask, 2)

        # Filter and reshape the matrices using the mask
        filtered_targets = rel_targets[mask]
        filtered_response  = rel_response[mask]
        
        # Filter and reshape the matrices using the mask
        metric = np.mean(filtered_response == filtered_targets)
        # replace entries
        if metric == 1:
            responses.append(0)
        # do not count the trials in which no button was clicked
        elif np.sum(filtered_response) == 0:
            responses.append(0)
        elif np.sum(filtered_response) == 2:
            responses.append(1)
    return np.mean(responses)


def calculate_siblings_no_control(y_true:np.array, 
                                  y_pred:np.array, 
                                  rel_indices:list=(None, None),
                                  chris_metric:bool=False,
                                  sibling_bias:bool=False, 
                                  sibling_clicks:bool=False):
    """function to calculate the sibling metric between two numpy arrays 
    while excluding control planets
    Args:
        y_true: numpy array of true values
        y_pred: numpy array of predicted values
        rel_idx: list of indices to include for this level of the hierarchy
        in the accuracy calculation
        chris_metric: boolean to indicate whether to use the metric as defined by Chris
        sibling_bias: boolean to indicate whether to calculate the sibling bias
        sibling_clicks: boolean to indicate whether to calculate the number of sibling clicks
    """

    # check that the dimensions of the input arrays are equal
    check_array_dims(y_true, y_pred)

    # check if y_true and y_pred are 4D arrays
    if y_true.ndim != 4 or y_pred.ndim != 4:
        raise NotImplementedError("y_true and y_pred should be 4D arrays")
    
    # convert the arrays to a dataframe
    response_df = generate_response_df(y_true, y_pred)

    # remove the test trials
    filtered_df = response_df[response_df['ground_truth'].apply(lambda x: x[14] == 0)].copy()
    # remove all trials where no response is given
    filtered_df = filtered_df[filtered_df['response_value'].apply(lambda x: np.sum(x) != 0)].copy()
    # get the relevant indices
    filtered_df['ground_truth'] = filtered_df['ground_truth'].apply(lambda x: x[rel_indices[0]:rel_indices[1]])
    filtered_df['response_value'] = filtered_df['response_value'].apply(lambda x: x[rel_indices[0]:rel_indices[1]])
    # # remove all trials where no response is given
    # filtered_df = filtered_df[filtered_df['response_value'].apply(lambda x: np.sum(x) != 0)].copy()
    # remove all trials where more than 1 response is given
    # filtered_df = filtered_df[filtered_df['response_value'].apply(lambda x: np.sum(x) == 1)].copy()
    
    # Calculate the
    if chris_metric:
        sibling_series = filtered_df.groupby(['subject', 'block']).apply(lambda group: compute_chris_sibling_metric(group))
    elif sibling_bias:
        sibling_series = filtered_df.groupby(['subject', 'block']).apply(lambda group: calculate_sibling_bias(group))
    elif sibling_clicks:
        sibling_series = filtered_df.groupby(['subject', 'block']).apply(lambda group: calculate_n_sibling_clicks(group))
    else:
        sibling_series = filtered_df.groupby(['subject', 'block']).apply(lambda group: compute_sibling_metric(group))
    sibling_df = sibling_series.unstack(level=-1)

    return  sibling_df.values


def calculate_siblings_three_levels(y_true, 
                                    y_pred, 
                                    partition_indices:tuple=(2, 4, 8),
                                    chris_metric:bool=False,
                                    sibling_bias:bool=False,
                                    sibling_clicks:bool=False):
    """The function calculates siblngs metric between two numpy 
       arrays for all three levels. It returns a 3d array.
    Args:
        y_true (np.ndarray): of true values
        y_pred (np.ndarray): of predicted values
        partition_indices (tuple): indices to partition the array
        chris_metric (bool): whether to use the metric as defined by Chris
        sibling_bias (bool): whether to calculate the sibling bias
    """
    level_list = []
    start_idx = 0
    for idx in partition_indices:
        level_list.append(calculate_siblings_no_control(y_true, y_pred, rel_indices=(start_idx, start_idx + idx), 
                                                        chris_metric=chris_metric, sibling_bias=sibling_bias, sibling_clicks=sibling_clicks))
        start_idx += idx
    return np.stack(level_list, axis=0)

def safe_mean(a, b):
    """
    Calculates the mean of two arrays, handling the case where both arrays are empty.

    Args:
        a: The first array.
        b: The second array.

    Returns:
        The mean of the two arrays, or 0 if both arrays are empty.
    """
    return 0 if a.size == 0 and b.size == 0 else np.mean(a == b)