import jax
import optax
import numpy as np
from functools import partial
from scipy.special import softmax
from src.neural_nets.non_linear_nets.training_modules import cross_entropy_loss, bin_cross_entropy_loss
import jax.numpy as jnp
from jax import random

"""this file holds the metric modules for all non-linear flax models"""

@jax.jit
def compute_metrics(*, state, batch, key=None, mean_labels=None):
    """
    Compute various metrics based on the given state and batch.

    Args:
        state: The state object.
        batch: The batch data.
        key: The optional key value.

    Returns:
        The updated state object.
    """

    logits = state.apply_fn({'params': state.params}, batch[0])
    # call to metric computation functions here
    loss = 1/2 * optax.squared_error(
        predictions=logits, targets=batch[1]).sum()
    
    responses = compute_choices(logits, batch[1], key=key, n_choices=3)
    # get the discretised responses
    tpr_l1 = TPR(responses=responses, labels=batch[1], split_partition=(0,2))
    tpr_l2 = TPR(responses=responses, labels=batch[1], split_partition=(2,6))
    tpr_l3 = TPR(responses=responses, labels=batch[1], split_partition=(6,14))
    tnr_l1 = TNR(responses=responses, labels=batch[1], split_partition=(0,2))
    tnr_l2 = TNR(responses=responses, labels=batch[1], split_partition=(2,6))
    tnr_l3 = TNR(responses=responses, labels=batch[1], split_partition=(6,14))

    # continous metrics
    tpr_l1_cont = TPR(logits, batch[1], split_partition=(0,2))
    tpr_l2_cont = TPR(logits, batch[1], split_partition=(2,6))
    tpr_l3_cont = TPR(logits, batch[1], split_partition=(6,14))
    tnr_l1_cont = TNR(logits, batch[1], split_partition=(0,2))
    tnr_l2_cont = TNR(logits, batch[1], split_partition=(2,6))
    tnr_l3_cont = TNR(logits, batch[1], split_partition=(6,14))


    # compute distance from OCS
    mean_distance_OCS = distance_from_OCS(logits, mean_labels)

    metric_updates = state.metrics.single_from_model_output(loss=loss,
                                                            tpr_top_level=tpr_l1,
                                                            tpr_mid_level=tpr_l2,
                                                            tpr_bottom_level=tpr_l3,
                                                            tnr_top_level=tnr_l1,
                                                            tnr_mid_level=tnr_l2,
                                                            tnr_bottom_level=tnr_l3,
                                                            tpr_top_level_cont=tpr_l1_cont,
                                                            tpr_mid_level_cont=tpr_l2_cont,
                                                            tpr_bottom_level_cont=tpr_l3_cont,
                                                            tnr_top_level_cont=tnr_l1_cont,
                                                            tnr_mid_level_cont=tnr_l2_cont,
                                                            tnr_bottom_level_cont=tnr_l3_cont,
                                                            mean_distance_OCS=mean_distance_OCS)
    
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state, logits

@partial(jax.jit, static_argnames=['loss_fn'])
def compute_metrics_imbalance(*, state, batch, key=None, mean_labels=None, loss_fn=None):
    """
    Compute various metrics based on the given state and batch.

    Args:
        state: The state object.
        batch: The batch data.
        key: The optional key value.

    Returns:
        The updated state object.
    """

    logits = state.apply_fn({'params': state.params}, batch[0])
    # call to metric computation functions here
    loss = loss_fn(state.params, batch, state)
    # compute softmax
    exp_logits = jnp.exp(logits)
    probs = exp_logits / exp_logits.sum(axis=1, keepdims=True)

    # if loss_fn is cross_entropy_loss, then compute the responses via argmax
    # otherwise use the custom compute_choices function
    if loss_fn.__name__ == cross_entropy_loss.__name__:
        response_idx= jnp.argmax(probs, axis=1)
        responses = jnp.zeros(batch[1].shape)
        # set the choice indices to 1
        rows = jnp.arange(responses.shape[0])[:, None]  # Create a column vector of row indices
        responses = responses.at[rows, response_idx].set(1)  # Use advanced indexing to set ones
    else: 
        responses = compute_choices(logits, batch[1], key=key, n_choices=1)

    # for the case of binary imbalanced classes
    tpr_l1 = TPR(responses=responses, labels=batch[1], split_partition=(0,1))
    tpr_l2 = TPR(responses=responses, labels=batch[1], split_partition=(1,2))
    tpr_l3 = 0
    tnr_l1 = TNR(responses=responses, labels=batch[1], split_partition=(0,1))
    tnr_l2 = TNR(responses=responses, labels=batch[1], split_partition=(1,2))
    tnr_l3 = 0

    # continous metrics
    tpr_l1_cont = TPR(logits, batch[1], split_partition=(0,1))
    tpr_l2_cont = TPR(logits, batch[1], split_partition=(1,2))
    tpr_l3_cont = 0
    tnr_l1_cont = TNR(logits, batch[1], split_partition=(0,1))
    tnr_l2_cont = TNR(logits, batch[1], split_partition=(1,2))
    tnr_l3_cont = 0


    # compute distance from OCS
    if loss_fn.__name__ == cross_entropy_loss.__name__:
        mean_distance_OCS = distance_from_OCS(probs, mean_labels)
    else:
        mean_distance_OCS = distance_from_OCS(logits, mean_labels)

    metric_updates = state.metrics.single_from_model_output(loss=loss,
                                                            tpr_top_level=tpr_l1,
                                                            tpr_mid_level=tpr_l2,
                                                            tpr_bottom_level=tpr_l3,
                                                            tnr_top_level=tnr_l1,
                                                            tnr_mid_level=tnr_l2,
                                                            tnr_bottom_level=tnr_l3,
                                                            tpr_top_level_cont=tpr_l1_cont,
                                                            tpr_mid_level_cont=tpr_l2_cont,
                                                            tpr_bottom_level_cont=tpr_l3_cont,
                                                            tnr_top_level_cont=tnr_l1_cont,
                                                            tnr_mid_level_cont=tnr_l2_cont,
                                                            tnr_bottom_level_cont=tnr_l3_cont,
                                                            mean_distance_OCS=mean_distance_OCS)
    
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state

@partial(jax.jit, static_argnames=['loss_fn'])
def compute_metrics_celeba(*, state, batch, key=None, mean_labels=None, loss_fn=None):
    logits = state.apply_fn({'params': state.params}, batch[0])
    # call to metric computation functions here
    loss = loss_fn(state.params, batch, state)
    # compute distance from OCS
    if loss_fn.__name__ == bin_cross_entropy_loss.__name__:
        logits = jax.nn.sigmoid(logits)
    mean_distance_OCS = distance_from_OCS(logits, mean_labels)
    metric_updates = state.metrics.single_from_model_output(loss=loss,
                                                            mean_distance_OCS=mean_distance_OCS)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state


def distance_from_OCS(logits: jnp.array, mean_labels: jnp.array):
    """compute the distance from the optimal constant solution"""

    # compute the manhattan distance between each row in logits and the optimal constant solution
    distance = jnp.sum(jnp.abs(logits - mean_labels), axis=1)
    mean_distance = jnp.mean(distance)
    return mean_distance


def TPR(responses: jnp.array,labels: jnp.array, split_partition=(0,2)):
    """
    Calculate the True Positive Rate (TPR) based on the responses and labels.

    Args:
        responses (jnp.array): The response values.
        labels (jnp.array): The label values.
        split_partition (tuple, optional): The split partition indices. Defaults to (0, 2).

    Returns:
        float: The calculated True Positive Rate.

    """

    result_matrix = responses*labels
    true_positives = jnp.sum(result_matrix[:, split_partition[0]:split_partition[1]])
    total_positives = jnp.sum(labels[:, split_partition[0]:split_partition[1]])
    return true_positives/total_positives

def TNR(responses: jnp.array,labels: jnp.array, split_partition=(0,2)):
    """
    Calculate the True Negative Rate (TNR) based on the responses and labels.

    Args:
        responses (jnp.array): The response values.
        labels (jnp.array): The label values.
        split_partition (tuple, optional): The split partition indices. Defaults to (0, 2).

    Returns:
        float: The calculated True Negative Rate.

    """

    result_matrix = (1-responses)* (1-labels)
    true_negatives = jnp.sum(result_matrix[:, split_partition[0]:split_partition[1]])
    total_negatives = jnp.sum(1 - labels[:, split_partition[0]:split_partition[1]])
    return true_negatives/total_negatives

def compute_choices(inputs, targets, temprature=0.2, epsilon=1e-10, key=None, n_choices=3):
    """
    Compute the choices based on the inputs and targets.

    Args:
        inputs (ndarray): The input values.
        targets (ndarray): The target values.
        temprature (float, optional): The temprature value. Defaults to 0.2.
        epsilon (float, optional): The small constant value. Defaults to 1e-10.
        key (ndarray, optional): The random key. Defaults to None.

    Returns:
        ndarray: The computed choices.
    """
    # Generate a key for each choice if not provided
    if key is None:
        key = random.PRNGKey(1000)  # You might want to pass this as an argument for reproducibility
    keys = random.split(key, targets.shape[0])

    logits = inputs / temprature
    # improve numerical stability
    logits_max = logits.max(axis=1, keepdims=True)
    logits = logits - logits_max
    # compute softmax
    exp_logits = jnp.exp(logits)
    probs = exp_logits / exp_logits.sum(axis=1, keepdims=True)

    probs = jnp.where(probs == 0, epsilon, probs)
    # Use a different key for each random choice
    choice_indices = jnp.array([jax.random.choice(keys[i], targets.shape[1], (n_choices,), replace=False, p=probs[i]) for i in range(targets.shape[0])])
    response = jnp.zeros(targets.shape)
    # set the choice indices to 1
    rows = jnp.arange(response.shape[0])[:, None]  # Create a column vector of row indices
    response = response.at[rows, choice_indices].set(1)  # Use advanced indexing to set ones
    return response
