import numpy as np
from scipy.optimize import linprog
import itertools
from itertools import chain, combinations


def select_top_K_approximation_samples(lower_probs, upper_probs, K):
    # lower_probs & upper_probs: shape of (N, C)
    
    # Take lower probability as reference
    probs = lower_probs

    # Index array for selected classes for all samples
    index_sel = np.argsort(-probs, axis=-1)[:, :K]

    # Extract upper and lower probabilities of selected classes
    lower_probs_sel = lower_probs[np.arange(index_sel.shape[0])[:, None], index_sel]
    upper_probs_sel = upper_probs[np.arange(index_sel.shape[0])[:, None], index_sel]

    # Coefficient vectors for selected classes
    c_sel = np.zeros(probs.shape)
    c_sel[np.arange(index_sel.shape[0])[:, None], index_sel] = 1.0
    c_unsel = np.ones(probs.shape) - c_sel

    # upper and lower probabilities of unselected classes as a whole
    lower_probs_unsel = np.max((np.sum(lower_probs*c_unsel, axis=-1), 1 - np.sum(upper_probs*c_sel, axis=-1)), axis=0)
    upper_probs_unsel = np.min((np.sum(upper_probs*c_unsel, axis=-1), 1 - np.sum(lower_probs*c_sel, axis=-1)), axis=0)


    # Reconstract lower and upper probability arrays
    lower_prob_mod = np.append(lower_probs_sel, np.expand_dims(lower_probs_unsel, axis=-1), axis=-1)
    upper_prob_mod = np.append(upper_probs_sel, np.expand_dims(upper_probs_unsel, axis=-1), axis=-1)

    return lower_prob_mod, upper_prob_mod

def get_power_set(C):
    power_set = []
    for r in range(0, C+1):
        subsets = itertools.combinations(range(C), r)
        power_set.extend([list(subset) for subset in subsets if len(subset) > 1])
    return power_set


def get_all_subsets(s):
    return [list(subset) for subset in chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1))]


def compute_gh_measure(lower_probabilities, upper_probabilities):
    lower_probs = lower_probabilities
    upper_probs = upper_probabilities
    # K = 3
    # lower_probs, upper_probs = select_top_K_approximation_samples(lower_probs, upper_probs, K)
    num_smaple, num_class = lower_probs.shape

    # Compute the power set including cardinality >=2
    power_set = get_power_set(num_class)
    GH = 0.0
    for setB in power_set:
        m_setB = 0.0
        # subsets excluding empty set
        subsetsB = get_all_subsets(setB)
        for setA in subsetsB:
            c_setA = np.zeros((num_smaple, num_class))
            c_setA[:, setA] = 1.0
            v_setA = np.max((np.sum(lower_probs*c_setA, axis=-1), 1 - np.sum(upper_probs*(1-c_setA), axis=-1)), axis=0)
            m_setB = m_setB + (np.power(-1.0, int(len(setB)-len(setA)))*v_setA)
        
        gh_setB = m_setB*np.log2(len(setB))
        GH = gh_setB + GH       
    return GH

