"""This is the entropy example presented in the paper."""



import numpy as np


np.set_printoptions(precision=4)

num_classes = 3
num_features_per_class = 2
num_features = num_features_per_class * num_classes


# Make X
X = np.empty((2**num_features, num_features))
for j in range(num_features):
  X[:, j]=np.tile(np.array([0]*(2**j) + [1]*(2**j)), 2**(num_features-j-1))

# Make p_y from X
X_temp = np.reshape(X, (2**num_features, num_classes, int(num_features/num_classes))) + 1e-10
y = np.sum(X_temp, axis=-1)/np.reshape(np.sum(X_temp, axis=(1, 2)), (2**num_features, 1))



# Enforce feature 1 to be zero.
ids = np.where((X[:, 0] == 0.0))[0]

X_reduced = X[ids]
y_reduced = y[ids]

curr_probs = np.mean(y_reduced, axis=0)
print("Current Distribution: ", curr_probs)



def get_info_full(f):
  ans = 0
  curr_p = curr_probs
  curr_entropy = np.sum(-curr_p*np.log(curr_p))

  for x_val in np.unique(X_reduced[:, f]):
    ids = np.where(X_reduced[:, f]==x_val)[0]
    p_x = len(ids)/len(X_reduced)
    y_p = np.mean(y_reduced[ids, :], axis=0)
    next_ent = np.sum(-y_p*np.log(y_p))
    ans += p_x*(curr_entropy-next_ent)

    return ans
  
# Print CMI
for i in range(6):
  print(f"Feature: {i+1}, CMI: {get_info_full(i):.4f}")
  
  
# See what happens if we acquire feature 2.
print("\nIf we acqurie feature 2, the two possible distributions are:")
ids = np.where(X_reduced[:, num_features_per_class-1]==0.0)[0]
py = np.mean(y_reduced[ids], axis=0)
print(py)

ids = np.where(X_reduced[:, num_features_per_class-1]==1.0)[0]
py = np.mean(y_reduced[ids], axis=0)
print(py)



# See what happens if we acquire feature 3.
print("\nIf we acqurie feature 3, 4, 5 or 6 the two possible distributions are:")
ids = np.where(X_reduced[:, num_features_per_class]==0.0)[0]
py = np.mean(y_reduced[ids], axis=0)
print(py)

ids = np.where(X_reduced[:, num_features_per_class]==1.0)[0]
py = np.mean(y_reduced[ids], axis=0)
print(py)