"""This script is to calculate the expected 
solutions for inference behaviour given the analytical 
outputs given some neural networks.
"""
from itertools import combinations, permutations
import numpy as np

def calculate_combination_probability(prob_vector, combination):
    """this function takes in a probability vector and a combination of indices drawn from the
    of outputs and returns the corresponding probability of the drawn combination"""
    total_prob = 0
    # Iterate over all permutations of the combination
    for perm in permutations(combination):
        prob = 1
        remaining_prob_sum = sum(prob_vector)
        # Calculate probability for this permutation
        for i in perm:
            prob *= prob_vector[i] / remaining_prob_sum
            remaining_prob_sum -= prob_vector[i]
        total_prob += prob
    return total_prob


def calculate_choicewise_TNR(combs, target):
    responses = np.zeros((len(combs), len(target)))
    for i, comb in enumerate(combs):
        for j in comb:
            responses[i, j] = 1
    responses = np.stack([responses]*target.shape[1], axis=2)
    true_positive_rate = (1 - responses) * (1- target)
    return np.transpose(true_positive_rate, axes=(2, 1, 0))


def true_positive_rate_partitions(expected_tpr_matrix, targets, split_partitions=(2,4,8)):
    metric_list = np.full((len(split_partitions),), np.nan)
    idx = 0
    for i, partition in enumerate(split_partitions):
        metric_list[i] = np.sum(expected_tpr_matrix[:, idx:idx+partition]) / np.sum(1- targets[:,idx:idx+partition])
        idx += partition
    return metric_list