import numpy as np
from scipy.stats import entropy


'''
Contains functions for calculating costs and entropies for discrete data
Implements Lovasz extension functions and subgradient calculations
'''

entropy_H_dict = {}
entropy_G_dict = {}
cost_dict = {}

''' Define cost (-1/N score)'''
def cost(data, J, i):
    global cost_dict
    if (i, str(J)) in cost_dict.keys():
        return cost_dict[i, str(J)]
    # Calculate the cost for a given parent set J and node i
    data = data.data
    data_i = data[:, [i]]
    if len(J) == 0: # []
        data_iJ = data_i
        unique_states_iJ, counts_iJ = np.unique(data_iJ, axis=0, return_counts=True)
        # print(counts_iJ)
        probabilities_iJ = counts_iJ / counts_iJ.sum() # a list of probabilities
        joint_entropy_iJ = entropy(probabilities_iJ) # base='e'
        cost_dict[(i, str(J))] = joint_entropy_iJ
        return joint_entropy_iJ
    
    data_J = data[:, J]
    data_iJ = np.concatenate((data_i, data_J), axis=1)

    unique_states_J, counts_J = np.unique(data_J, axis=0, return_counts=True)
    # print(counts_J)
    probabilities_J = counts_J / counts_J.sum()
    joint_entropy_J = entropy(probabilities_J)


    unique_states_iJ, counts_iJ = np.unique(data_iJ, axis=0, return_counts=True)
    # print(counts_iJ)
    probabilities_iJ = counts_iJ / counts_iJ.sum()
    joint_entropy_iJ = entropy(probabilities_iJ)
    
    cost_dict[(i, str(J))] = joint_entropy_iJ - joint_entropy_J
    return joint_entropy_iJ - joint_entropy_J

''' Define entropy functions'''
def entropy_H(d,J,i): # index in J is counted without node i (from 0 to n-2)
    data = d.data
    new_data = np.delete(data,i,1)
    if len(J)==0: # empty parent set
        return 0
    new_data_J = new_data[:,J]
    unique_states_J, counts_J = np.unique(new_data_J, axis=0, return_counts=True)
    probabilities_J = counts_J / counts_J.sum()
    joint_entropy_J = entropy(probabilities_J)
    return joint_entropy_J
    
            
def entropy_G(d,J,i): # index in J is counted without node i (from 0 to n-2)
    data = d.data
    data_i = data[:,[i]]
    new_data = np.delete(data,i,1)
    if len(J)==0: # empty parent set
        data_iJ = data_i
    else:
        new_data_J = new_data[:,J]
        data_iJ = np.concatenate((data_i,new_data_J),axis=1)
    unique_states_iJ, counts_iJ = np.unique(data_iJ, axis=0, return_counts=True)
    probabilities_iJ = counts_iJ / counts_iJ.sum()
    joint_entropy_iJ = entropy(probabilities_iJ)
    return joint_entropy_iJ

# evaluate the Lovarz extension function and calculate the subgradient

def g(d, x, i, C_set, lambda_c, regu_Lambda):
    ''' The Lovasz extension of the first submodular function g'''
    global entropy_G_dict
    n = d.n
    ndata = d.ndata
    
    sigma = np.argsort(-x) # index, with decreasing x-component
    s = np.zeros(n-1) # to calculate subgradient for G
    ''' subgradient component for sigma[k] k=0'''
    J = [sigma[0]]
    if (i, str(J)) not in entropy_G_dict.keys():
        entropy_G_dict[(i, str(J))] = entropy_G(d,J,i)
    temp = entropy_G_dict[(i, str(J))]
    for c_index in range(len(C_set)):
        C = C_set[c_index]
        if i in C:
            if (sigma[0]<i and sigma[0] in C) or (sigma[0]>=i and sigma[0]+1 in C): # since i^th dimension deleted
                temp += (lambda_c[c_index])/ndata
    J = []
    if (i, str(J)) not in entropy_G_dict.keys():
        entropy_G_dict[(i, str(J))] = entropy_G(d,J,i)
    s[sigma[0]] = temp - entropy_G_dict[(i, str(J))] 

    ''' Subgradient components for sigma[k] k>=1'''
    for k in range(1, n-1):
        J = sigma[range(k+1)]
        J = sorted(J)
        if (i, str(J)) not in entropy_G_dict.keys():
            entropy_G_dict[(i, str(J))] = entropy_G(d,J,i)
        temp1 = entropy_G_dict[(i, str(J))]
        for c_index in range(len(C_set)):
            C = C_set[c_index]
            if i in C:
                add = 0
                for k_index in J:
                    if (k_index<i and k_index in C) or (k_index>=i and k_index+1 in C): # since i^th dimension deleted
                        add = 1 # intersection is not empty
                temp1 += add * (lambda_c[c_index])/ndata
        s[sigma[k]] = temp1 - temp
        temp = temp1
    return np.dot(x, s)+entropy_G_dict[(i, str([]))], s # return function evaluation and subgradient

def h(d, x, i):
    ''' The Lovasz extension of the first submodular function h'''
    global entropy_H_dict
    n = d.n
    ndata = d.ndata
    regu_Lambda = 0.5*np.log(ndata) # for BIC score

    ''' subgradient component for sigma[k] k=0'''
    sigma = np.argsort(-x)
    y = np.zeros(n-1) # subgradient for H
    J = [sigma[0]]
    if sigma[0] < i:
        penalty = regu_Lambda/ndata*(d.arity[i]-1)*d.arity[sigma[0]]# penalty for sigma[0]
    else:
        penalty = regu_Lambda/ndata*(d.arity[i]-1)*d.arity[sigma[0]+1]
    if (i, str(J)) not in entropy_H_dict.keys():
        entropy_H_dict[(i, str(J))] = entropy_H(d,J,i)
    temp = entropy_H_dict[(i, str(J))] - penalty
    y[sigma[0]] = temp + regu_Lambda/ndata*(d.arity[i]-1)
    
    ''' subgradient component for sigma[k] k>=1'''
    for k in range(1, n-1):
        J = sigma[range(k+1)]
        J = sorted(J)
        if sigma[k]<i:
            penalty = d.arity[sigma[k]] * penalty # penalty after adding sigma[k]
        else:
            penalty = d.arity[sigma[k]+1] * penalty # penalty after adding sigma[k]
        if (i, str(J)) not in entropy_H_dict.keys():
            entropy_H_dict[(i, str(J))] = entropy_H(d,J,i)
        temp1 = entropy_H_dict[(i, str(J))] - penalty
        y[sigma[k]] = temp1 - temp
        temp = temp1
    return np.dot(x,y)-regu_Lambda/ndata*(d.arity[i]-1), y   # return function evaluation and subgradient
