import numpy as np
import pandas as pd

"""code to calculate the accuracy"""""

def calculate_accuracy(y_true, y_pred, relevant_axes=None):
    """function to calculate the accuracy between two numpy arrays
    Args:
        y_true: numpy array of true values
        y_pred: numpy array of predicted values
        relevant_axes: tuple of axes to include in the accuracy calculation
    """
    
    check_array_dims(y_true, y_pred)

    # check if relevant_axes tuple contains values not in array dimensions
    if relevant_axes is not None:
        axes_check = [relevant_axes] if isinstance(relevant_axes, int) else relevant_axes
        if any(axis not in range(y_true.ndim) for axis in axes_check):
            raise NotImplementedError("relevant_axes are not in array dimensions")
        return np.mean(np.where(np.isnan(y_true), np.nan, y_true == y_pred), axis=relevant_axes)

    return np.mean(y_true == y_pred)


def calculate_accuracy_no_control(y_true:np.array, y_pred:np.array, rel_indices:tuple=(None, None)):
    """function to calculate the accuracy between two numpy arrays 
    while excluding control planets
    Args:
        y_true: numpy array of true values
        y_pred: numpy array of predicted values
        rel_idx: list of indices to include 
        in the accuracy calculation
    """

    # check that the dimensions of the input arrays are equal
    check_array_dims(y_true, y_pred)

    # check if y_true and y_pred are 4D arrays
    if y_true.ndim != 4 or y_pred.ndim != 4:
        raise NotImplementedError("y_true and y_pred should be 4D arrays")
    
    # convert the arrays to a dataframe
    response_df = generate_response_df(y_true, y_pred)

    # remove the test trials
    filtered_df = response_df[response_df['ground_truth'].apply(lambda x: x[14] == 0)].copy()
    # remove all trials where no response is given
    filtered_df = filtered_df[filtered_df['response_value'].apply(lambda x: np.sum(x) != 0)].copy()
    # replace the array with the relevant indices
    filtered_df['ground_truth'] = filtered_df['ground_truth'].apply(lambda x: x[rel_indices[0]:rel_indices[1]])
    filtered_df['response_value'] = filtered_df['response_value'].apply(lambda x: x[rel_indices[0]:rel_indices[1]])

    # calculate the accuracy
    accuracy_series = filtered_df.groupby(['subject', 'block']).apply(
    lambda g: np.mean([np.mean(r == t) for r, t in zip(g['response_value'], g['ground_truth'])]))
    accuracy_df = accuracy_series.unstack(level=-1)

    return accuracy_df.values

def calculate_accuracy_three_levels(y_true, y_pred, partition_indices:tuple=(2, 4, 8)):
    """The function calculates the accuracy between two numpy 
       arrays for all three levels. It returns a 3d array of accuracies.
    Args:
        y_true (np.ndarray): of true values
        y_pred (np.ndarray): of predicted values
    """

    level_list = []
    start_idx = 0
    for idx in partition_indices:
        level_list.append(calculate_accuracy_no_control(y_true, y_pred, rel_indices=(start_idx, start_idx + idx)))
        start_idx += idx
    return np.stack(level_list, axis=0)

def generate_response_df(y_true, y_pred, planet_type:np.ndarray=None):
    """_summary_ The function generates a dataframe with the responses of each trial

    Args:
        y_true (np.ndarray): 4d array of true values (n_subs, n_blocks, n_trials, n_responses)
        y_pred (np.ndarray): 4d array of predicted values (n_subs, n_blocks, n_trials, n_responses)
        planet_type (np.ndarray, optional): 2d array of planet types (n_subs, n_blocks, n_trials, 1). Defaults to None.
    Returns:
        df (pd.DataFrame): dataframe: with columns ['subject', 
                                                    'block', 
                                                    'trial', 
                                                    'response_value', 
                                                    'ground_truth']
    """
    n_subs, n_blocks, n_trials, n_responses = y_true.shape

    # Generate indices for each subject, block, and trial
    sub_idx, block_idx, trial_idx = np.meshgrid(
        np.arange(n_subs),
        np.arange(n_blocks),
        np.arange(n_trials),
        indexing='ij'
    )

    # Reshape the indices and data
    sub_idx = sub_idx.reshape(-1)
    block_idx = block_idx.reshape(-1)
    trial_idx = trial_idx.reshape(-1)
    y_pred = y_pred.reshape(n_subs * n_blocks * n_trials, n_responses)
    y_true = y_true.reshape(n_subs * n_blocks * n_trials, n_responses)

    # if the planet type is given
    if planet_type is not None:
        planet_type = planet_type.reshape(n_subs * n_blocks * n_trials, 1) if planet_type is not None else None
        return pd.DataFrame(
            {
                'subject': sub_idx,
                'block': block_idx,
                'trial': trial_idx,
                'response_value': list(y_pred),
                'ground_truth': list(y_true),
                'planet_type': list(planet_type)
            }
        )
    # if not given
    return pd.DataFrame(
        {
            'subject': sub_idx,
            'block': block_idx,
            'trial': trial_idx,
            'response_value': list(y_pred),
            'ground_truth': list(y_true),
        }
    )


def check_array_dims(y_true, y_pred):
    """function to check if the dimensions of the input arrays are equal
    Args:
        y_true: numpy array of true values
        y_pred: numpy array of predicted values"""
    if not isinstance(y_true, np.ndarray) or not isinstance(y_pred, np.ndarray):
        raise NotImplementedError("inputs should be numpy arrays")

    if y_true.shape != y_pred.shape:
        raise NotImplementedError(
            "shapes of true and predicted values are not equal")
    return True
