import numpy as np
import pandas as pd

from src.human_metrics.calculate_accuracy import generate_response_df, check_array_dims

def compute_group_metric(group, compute_tnr:bool=False):
    """
    Function to compute the TPR or TNR for a single group.
    
    Args:
        group (pandas DataFrame): the data frame of a single group
        compute_tnr (bool): if True, compute TNR, otherwise compute TPR
        
    Returns:
        float: TPR or TNR value
    """

    metrics = []

    for idx, row in group.iterrows():
        if compute_tnr:
            # Compute True Negatives
            true_negatives = np.dot((1 - row['ground_truth']), (1 - row['response_value']))
            total_negatives = np.sum(1 - row['ground_truth'])

            if total_negatives > 0:
                tnr = true_negatives / total_negatives
                metrics.append(tnr)    
        else:
            # Compute True Positives
            true_positives = np.dot(row['ground_truth'], row['response_value'])
            total_positives = np.sum(row['ground_truth'])

            if total_positives > 0:
                tpr = true_positives / total_positives
                metrics.append(tpr)

    return np.mean(metrics) if metrics else np.nan


def calculate_metric_no_control(y_true:np.array, y_pred:np.array, 
                                rel_indices:list=(None, None), compute_tnr:bool=False, 
                                trial_wise:bool=False):
    """function to calculate the true positive or true negative rate between the two numpy arrays 
    while excluding control planets
    Args:
        y_true: numpy array of true values
        y_pred: numpy array of predicted values
        rel_idx: list of indices to include 
        compute_tnr (bool): if True, compute TNR, otherwise compute TPR
        trial_wise (bool): if True, compute the metric for each trial, otherwise compute the metric for each block
        in the accuracy calculation
    """

    # check that the dimensions of the input arrays are equal
    check_array_dims(y_true, y_pred)

    # check if y_true and y_pred are 4D arrays
    if y_true.ndim != 4 or y_pred.ndim != 4:
        raise NotImplementedError("y_true and y_pred should be 4D arrays")

    # convert the arrays to a dataframe
    response_df = generate_response_df(y_true, y_pred)

    # remove the test trials
    filtered_df = response_df[response_df['ground_truth'].apply(lambda x: x[14] == 0)].copy()
    # remove all trials where no response is given
    # filtered_df = filtered_df[filtered_df['response_value'].apply(lambda x: np.sum(x) != 0)].copy()
    # get the relevant indices
    filtered_df['ground_truth'] = filtered_df['ground_truth'].apply(lambda x: x[rel_indices[0]:rel_indices[1]])
    filtered_df['response_value'] = filtered_df['response_value'].apply(lambda x: x[rel_indices[0]:rel_indices[1]])

    # add a trial index column for each subject indexing all trials done by that subject
    if trial_wise:
        filtered_df['trial_index'] = filtered_df.groupby(['subject']).cumcount()
        tpr_series = filtered_df.groupby(['subject', 'trial_index']).apply(lambda group: compute_group_metric(group, compute_tnr=compute_tnr))
    else: 
        tpr_series = filtered_df.groupby(['subject', 'block']).apply(lambda group: compute_group_metric(group, compute_tnr=compute_tnr))
    tpr_df = tpr_series.unstack(level=-1)

    return  tpr_df.values


def calculate_metric_three_levels(y_true, y_pred, compute_tnr:bool=False, 
                                    partition_indices:tuple=(2, 4, 8), trial_wise:bool=False):
    """The function calculates the True Positive rate or True Negative Rate between two numpy 
       arrays for all three levels. It returns a 3d array.
    Args:
        y_true (np.ndarray): of true values
        y_pred (np.ndarray): of predicted values
        compute_tnr (bool): if True, compute TNR, otherwise compute TPR
    """
    level_list = []
    start_idx = 0
    for idx in partition_indices:
        level_list.append( calculate_metric_no_control(y_true, y_pred, rel_indices=(start_idx, start_idx + idx), compute_tnr=compute_tnr, trial_wise=trial_wise))
        start_idx += idx
    return np.stack(level_list, axis=0)