import numpy as np
import jax.numpy as jnp
from scipy.special import softmax
from src.neural_nets.linear_nets.mlp_model import batched_forward


def accuracy(params, inputs, targets, temprature=1.0, 
             split_partitions=None, response=None, probabilistic=True, bias=False):
    """
    Computes the accuracy metric for a given model. Here we use discreet, sampled responses.

    Args:
        params: The parameters of the model.
        inputs: The input data.
        targets: The target labels.
        temprature (float, optional): The temprature for computing choices. Defaults to 1.0.
        split_partitions (list, optional): The partitions for splitting the data. Defaults to None.
        response (array, optional): The response array. Defaults to None.
        probabilistic (bool, optional): Whether to use probabilistic computation. Defaults to True.

    Returns:
        tuple or array: If split_partitions is provided, returns an array of accuracy values for each partition. 
        Otherwise, returns a tuple containing the model outputs, response array, and overall accuracy.

    """
    # if the model makes discreet choices we compute the accuaracy here
    if response is None:
        model_outputs, response = compute_choices(params, inputs, targets, temprature=temprature, probabilistic=probabilistic, bias=bias)
    
    if split_partitions is not None:
        result_matrix = response==targets
        metric_list = np.full((len(split_partitions),), np.nan)
        idx = 0
        for i, partition in enumerate(split_partitions):
            metric_list[i] = jnp.mean(result_matrix[:, idx:idx+partition])
            idx += partition
        return metric_list
    return model_outputs, response, jnp.mean(response == targets)

def compute_singular_values(params):
    """
    Compute the singular values of a matrix.

    Args:
        params (Tuple[np.ndarray, np.ndarray]): A tuple containing two numpy arrays. The first array represents the left matrix, and the second array represents the right matrix.

    Returns:
        np.ndarray: The singular values of the matrix product of the two input matrices.
    """
    if len(params) != 2:
        _, a, _ = np.linalg.svd(params[0], full_matrices=False)
    else:
        _, a, _ = np.linalg.svd(params[1].T @ params[0].T, full_matrices=False)
    return a


def true_positive_rate(params, inputs, targets, response=None, 
                       temprature=1.0, split_partitions=(2,4,8), probabilistic=True,
                       model_outputs=None, discreet_choices=True):
    """
    Calculates the true positive rate (TPR) for our multi-label classification task.

    Args:
        params (array-like): The parameters of the model.
        inputs (array-like): The input data.
        targets (array-like): The target labels.
        responses (array-like, optional): The responses of the model for faster computation. Defaults to None.
        temperature (float, optional): The temperature parameter for computing choices. Defaults to 1.0.
        split_levels (bool, optional): Whether to split the TPR into the three levels of the hierarchy. Defaults to False.
        model_outputs (array-like, optional): The model outputs for faster computation. Defaults to None.
        discreet_choices (bool, optional): Whether the model makes discreet choices. Defaults to True.
    Returns:
        float: The true positive rate. 
    """
    if discreet_choices:
        if response is None:
            _, response = compute_choices(params, inputs, targets, temprature=temprature, probabilistic=probabilistic)

        if split_partitions is not None:
            result_matrix = response*targets
            metric_list = np.full((len(split_partitions),), np.nan)
            idx = 0
            for i, partition in enumerate(split_partitions):
                metric_list[i] = jnp.sum(result_matrix[:, idx:idx+partition]) / jnp.sum(targets[:,idx:idx+partition])
                idx += partition
            return metric_list
        return jnp.sum(response * targets) / jnp.sum(targets)
    else:
        if model_outputs is None:
            model_outputs = batched_forward(params, inputs)

        if split_partitions is not None:
            result_matrix = np.multiply(model_outputs, targets)
            metric_list = np.full((len(split_partitions),), np.nan)
            idx = 0
            for i, partition in enumerate(split_partitions):
                metric_list[i] = jnp.sum(result_matrix[:, idx:idx+partition]) / jnp.sum(targets[:,idx:idx+partition])
                idx += partition
            return metric_list
        return jnp.sum(np.multiply(model_outputs, targets)) / jnp.sum(targets)


def true_negative_rate(params, inputs, targets, response=None, 
                       temprature=1.0, split_partitions=(2,4,8), probabilistic=True,
                       model_outputs=None, discreet_choices=True):
    """
    Calculates the true negative rate (TNR) for a binary classification model.

    Args:
        params (array-like): The parameters of the model.
        inputs (array-like): The input data.
        targets (array-like): The target labels.
        responses (array-like, optional): The responses of the model for faster computation. Defaults to None.
        temperature (float, optional): The temperature parameter for computing choices. Defaults to 1.0.
        split_levels (bool, optional): Whether to split the TNR into the three levels of the hierarchy. Defaults to False.
        model_outputs (array-like, optional): The model outputs for faster computation. Defaults to None.
        discreet_choices (bool, optional): Whether the model makes discreet choices. Defaults to True.

    Returns:
        float: The true negative rate.
    """
    if discreet_choices:
        if response is None:
            _, response = compute_choices(params, inputs, targets, temprature=temprature, probabilistic=probabilistic)

        if split_partitions is not None:
            result_matrix = (1-response) * (1-targets)
            metric_list = np.full((len(split_partitions),), np.nan)
            idx = 0
            for i, partition in enumerate(split_partitions):
                metric_list[i] = jnp.sum(result_matrix[:, idx:idx+partition]) / jnp.sum(1-targets[:,idx:idx+partition])
                idx += partition
            return metric_list
        return jnp.sum((1-response) * (1-targets)) / jnp.sum(1-targets)
    else:
        if model_outputs is None:
            model_outputs = batched_forward(params, inputs)

        if split_partitions is not None:
            result_matrix = np.multiply(1-model_outputs, 1-targets)
            metric_list = np.full((len(split_partitions),), np.nan)
            idx = 0
            for i, partition in enumerate(split_partitions):
                metric_list[i] = jnp.sum(result_matrix[:, idx:idx+partition]) / jnp.sum(1-targets[:,idx:idx+partition])
                idx += partition
            return metric_list
        return jnp.sum(np.multiply(1-model_outputs, 1-targets)) / jnp.sum(1-targets)

def level_bias(params, inputs, targets, response=None,
                temprature=1.0, split_partitions=(2,4,8),
                probabilistic=True, adjust:bool=False, expected_bias:tuple=(.42, .857, 1.71)):
    """
    Calculates the level bias metric for a given set of parameters, inputs, and targets.
    The level bias metric measures the average response of a neural network model across 
    different partitions of the response space of the tree dataset

    Args:
        params (Any): The parameters of the neural network model.
        inputs (Any): The input data for the neural network model.
        targets (Any): The target data for the neural network model.
        response (Optional[Any]): The response of the neural network model. Defaults to None.
        temprature (float): The temprature parameter for computing choices. Defaults to 1.0.
        split_partitions (Tuple[int]): The partitions to split the response space into. Defaults to (1, 2, 4, 8).
        probabilistic (bool): Flag indicating whether to use probabilistic choices. Defaults to True.
        adjust (bool):weather to adjust for the fraction of "learnable bias"
    Returns:
        np.ndarray: The level bias metric for each partition.
    """
    if response is None:
        _, response = compute_choices(params, inputs, targets, temprature=temprature, probabilistic=probabilistic)

    metric_list = np.full((len(split_partitions),), np.nan)
    idx = 0
    for i, partition in enumerate(split_partitions):
        metric_list[i] = jnp.sum(response[:,idx:idx+partition]) / targets.shape[0]
        idx += partition

    # here we adjust by fraction of "learnable bias"
    if adjust:
        for i, expected_bias in enumerate(expected_bias):
            metric_list[i] = 1 - np.absolute(metric_list[i] - 1) / abs(expected_bias - 1)
    return metric_list

def sibling_metric(params, inputs, targets, response=None,
                   temprature=1.0, split_partitions=(2,4,8),
                   probabilistic=True):
    """
    Computes the sibling metric for a given model.

    Args:
        params: The parameters of the model.
        inputs: The input data.
        targets: The target labels.
        response (array, optional): The response array. Defaults to None.
        temprature (float, optional): The temprature for computing choices. Defaults to 1.0.
        split_partitions (tuple, optional): The partitions for splitting the data. Defaults to (2, 4, 8).
        probabilistic (bool, optional): Whether to use probabilistic computation. Defaults to True.

    Returns:
        array: An array of sibling metric values for each partition.
    """

    
    if response is None:
        _, response = compute_choices(params, inputs, targets, temprature=temprature, probabilistic=probabilistic)

    result_matrix = response==targets
    metric_list = np.full((len(split_partitions),), np.nan)

    idx = 0
    for i, partition in enumerate(split_partitions):
        #  the data from this level
        rel_response = response[:,idx:idx+partition]
        rel_targets  = targets[:,idx:idx+partition]
        # only use the parts of the response and target that are relevant, i.e. we would have expected a response
        # Define the desired entries
        desired_entries = [np.array([1, 0]), np.array([0, 1])]

        # Create a mask for desired entries
        mask = np.column_stack(
            [np.any([np.all(rel_targets[:, i:i+2] == entry, axis=1) for entry in desired_entries], axis=0) 
            for i in range(0, rel_targets.shape[1], 2)]
        )
        expanded_mask = np.repeat(mask, 2, axis=1).flatten()

        # Filter and reshape the matrices using the mask
        filtered_targets = rel_targets.flatten()[expanded_mask].reshape(-1, 2)
        filtered_response  = rel_response.flatten()[expanded_mask].reshape(-1, 2)

        # remove rows in which we have two adjacent ones or no ones at all
        rows_with_too_many = np.where(np.sum(filtered_response, axis=1) > 1)[0]
        rows_with_no_ones = np.where((filtered_response == 0).all(axis=1))[0]

        # remove columns in the matrix where the targets have an even index and have a zero to the right
        filtered_targets  = np.delete(filtered_targets, np.concatenate((rows_with_too_many, rows_with_no_ones)), axis=0)
        filtered_response = np.delete(filtered_response, np.concatenate((rows_with_too_many, rows_with_no_ones)), axis=0)
        metric_list[i] = safe_mean(filtered_response, filtered_targets)
        idx += partition

    return metric_list

def chris_sibling_metric(params, inputs, targets, response=None,
                   temprature=1.0, split_partitions=(2,4,8),
                   probabilistic=True):
    
    if response is None:
        _, response = compute_choices(params, inputs, targets, temprature=temprature, probabilistic=probabilistic)

    metric_list = np.full((len(split_partitions),), np.nan)

    idx = 0
    for i, partition in enumerate(split_partitions):
        #  the data from this level
        rel_response = response[:,idx:idx+partition]
        rel_targets  = targets[:,idx:idx+partition]
        # only use the parts of the response and target that are relevant, i.e. we would have expected a response
        # Define the desired entries
        desired_entries = [np.array([1, 0]), np.array([0, 1])]

        # Create a mask for desired entries
        mask = np.column_stack(
            [np.any([np.all(rel_targets[:, i:i+2] == entry, axis=1) for entry in desired_entries], axis=0) 
            for i in range(0, rel_targets.shape[1], 2)]
        )
        expanded_mask = np.repeat(mask, 2, axis=1).flatten()

        # Filter and reshape the matrices using the mask
        filtered_targets = rel_targets.flatten()[expanded_mask].reshape(-1, 2)
        filtered_response  = rel_response.flatten()[expanded_mask].reshape(-1, 2)
        metric = np.mean(filtered_response == filtered_targets, axis=1)
        # replace entries ala Chris
        metric[metric == 0] = -1 
        metric[metric == 0.5] = 0
        metric_list[i] = np.mean(metric)
        idx += partition

    return metric_list

def sibling_bias(params, inputs, targets, response=None,
                   temprature=1.0, split_partitions=(2,4,8),
                   probabilistic=True):
    
    if response is None:
        _, response = compute_choices(params, inputs, targets, temprature=temprature, probabilistic=probabilistic)

    metric_list = np.full((len(split_partitions),), np.nan)

    idx = 0
    for i, partition in enumerate(split_partitions):
        #  the data from this level
        rel_response = response[:,idx:idx+partition]
        rel_targets  = targets[:,idx:idx+partition]
        # only use the parts of the response and target that are relevant, i.e. we would have expected a response
        # Define the desired entries
        desired_entries = [np.array([1, 0]), np.array([0, 1])]

        # Create a mask for desired entries
        mask = np.column_stack(
            [np.any([np.all(rel_targets[:, i:i+2] == entry, axis=1) for entry in desired_entries], axis=0) 
            for i in range(0, rel_targets.shape[1], 2)]
        )
        expanded_mask = np.repeat(mask, 2, axis=1).flatten()

        # Filter and reshape the matrices using the mask
        filtered_targets = rel_targets.flatten()[expanded_mask].reshape(-1, 2)
        filtered_response  = rel_response.flatten()[expanded_mask].reshape(-1, 2)

        # remove the rows with no ones
        rows_with_no_ones = np.where((filtered_response == 0).all(axis=1))[0]

        # remove columns in the matrix where the targets have an even index and have a zero to the right
        filtered_targets  = np.delete(filtered_targets, rows_with_no_ones, axis=0)
        filtered_response = np.delete(filtered_response, rows_with_no_ones, axis=0)
        metric = np.mean(filtered_response == filtered_targets, axis=1)
        # replace entries
        metric[metric == 1] = 0
        metric[metric == 0.5] = 1
        # check that metric is not empty
        metric_list[i] = 0 if metric.size == 0 else np.mean(metric)
        idx += partition

    return metric_list

def safe_mean(a, b):
    """
    Calculates the mean of two arrays, handling the case where both arrays are empty.

    Args:
        a: The first array.
        b: The second array.

    Returns:
        The mean of the two arrays, or 0 if both arrays are empty.
    """
    return 0 if a.size == 0 and b.size == 0 else np.mean(a == b)

def compute_choices(params, inputs, targets, temprature=1.0, probabilistic=True, epsilon=1e-10, bias=False):
    """
    Computes the choices made by a model based on the model forward pass values.
    We take a sampling approach where we compress the continuous outputs of the model in a 
    probability mass function using softmax and then sample from this distribution 
    to examine model choices. Note that this adds another hyper parameters, 
    namely the softmax temperature.


    Args:
        params (array-like): The parameters of the model.
        inputs (array-like): The input data.
        targets (array-like): The target labels.
        temperature (float, optional): The temperature parameter for computing choices. Defaults to 1.0.

    Returns:
        prediction(array-like): The raw model outputs.
        response(array-like): The model choices.
    """
    prediction = batched_forward(params, inputs, bias)
    if probabilistic:
        # add small value to avoid zero probabilities
        prob_distributions = softmax(prediction/temprature, axis=1)
        # check if enough non zero entries exist in prob_distributions 
        # and add small constant if not
        if np.any(np.sum(prob_distributions!=0, axis=1) < np.sum(targets[0])):
            print("WARNING: not enough non-zero entries in prob_distributions adding small constant")
            prob_distributions += epsilon
        # get the choice indices for each trial
        choice_indices = [np.random.choice(targets.shape[1], 
                            int(np.sum(targets[i])), 
                            replace=False, 
                            p=prob_distributions[i])
                            for i in range(targets.shape[0])]
    if not probabilistic:
        choice_indices = np.argmax(prediction, axis=1)
    response = np.zeros(targets.shape)
    # set the choice indices to 1
    for i in range(targets.shape[0]):
        response[i, choice_indices[i]] = 1
    return prediction, response