import torch

"""
H(Y)
H(Yhat)
H(Y|X)
H(Yhat|X)
H(Y,Yhat)

__ means conditioned on.
"""
H_Y, H_Z, H_Y__X, H_Z__X, H_Y_Z = torch.unbind(torch.eye(5))

H_Y__Z = H_Y_Z - H_Z
H_Z__Y = H_Y_Z - H_Y

MI_Y_Z = H_Z + H_Y - H_Y_Z
MI_X_Y_Z = MI_Y_Z

MI_X_Y__Z = H_Y__Z - H_Y__X
MI_X_Z__Y = H_Z__Y - H_Z__X

MI_X_Y = H_Y - H_Y__X
MI_X_Z = H_Z - H_Z__X

# We cannot compute H(X) because we don't have a data distribution
# I guess I could estimate it as ln batch_size or ln dataset_size...
# Thus everything containing H_x will be relative only.
relative_H_x__Y_Z = -(H_Y_Z - H_Y__X - H_Z__X)

entropy_distance = H_Y__Z + H_Z__Y

# Remove this name
conditional_entropy = H_Y__Z

decoder_uncertainty = H_Y__Z
reverse_decoder_uncertainty = H_Z__Y

preserved_information = MI_X_Z
relevant_information = MI_X_Y

preserved_relevant_information = MI_Y_Z

residual_information = MI_X_Y__Z
redundant_information = MI_X_Z__Y

label_uncertainty = H_Y__X
encoding_uncertainty = H_Z__X

label_entropy = H_Y
encoding_entropy = H_Z


def swap_Y_Z(quantity):
    return (
        (quantity @ H_Y) * H_Z
        + (quantity @ H_Z) * H_Y
        + (quantity @ H_Y__X) * H_Z__X
        + (quantity @ H_Z__X) * H_Y__X
        + (quantity @ H_Y_Z) * H_Y_Z
    )
