"""
This entropy estimator bounds Z by putting a Gaussian over the activations.
"""
import math
import torch
import numpy as np

from XXX.uib.utils import torch_empirical_covariance
from XXX.uib import information_quantities as iq

null_threshold = 1e-300


def diff_entropy(cov: torch.Tensor):
    # return 0.5 * (torch.logdet(2 * np.pi * np.e * cov))  # + len(cov) * (math.log(2 * np.pi) + 1))
    # q, r = torch.qr(cov, some=False)
    # r_diagonal = r.diagonal()
    # cutoff_diag = abs_diag #[abs_diag > 1e-16]
    # logdet = torch.log(r_diagonal.prod())
    # U, S, V = torch.svd(cov, some=True, compute_uv=True)
    # logdet = torch.log(s).sum()
    eigenvalues, eigenvectors = torch.symeig(cov, eigenvectors=True)
    logdet = torch.log(eigenvalues).sum()
    return 0.5 * (logdet + len(eigenvalues) * (math.log(2 * np.pi) + 1))
    # return 0.5 * (torch.slogdet(cov)[1] + len(cov) * (math.log(2 * np.pi) + 1))


def get_central_limit_covariance(sample_x_features):
    covariance, shrinkage = torch_empirical_covariance.torch_oas(sample_x_features)
    # TODO: Why did I put in len(sample_x_features) intially again? Not consistent with everything else?!?
    return covariance  # / len(sample_x_features)


def onehot_continuous_information(
    features_X_Z: torch.Tensor, labels_x: torch.Tensor, information_quantity: torch.Tensor
):
    information_quantity = information_quantity.clone()

    # y is categorical/one_hot
    # Z is continuous

    # We can compute the following measures:
    # H(Y), H(Yhat), H(Y|X), H(Yhat|X), H(Yhat|Y)

    # Orthogonal projection
    factor_h_Z__Y = iq.H_Y_Z @ information_quantity
    # We can only compute H_Z__Y easily. The rest of the code will figure out the difference.
    information_quantity -= factor_h_Z__Y * iq.H_Z__Y

    factor_h_Y = iq.H_Y @ information_quantity
    factor_h_Z = iq.H_Z @ information_quantity

    ys, counts = torch.unique(labels_x, return_counts=True)
    counts = counts.double()
    p_Y = counts / counts.sum()
    h_Y = -torch.sum(p_Y * torch.log(p_Y))

    h_Z__Y = 0.0
    if factor_h_Z__Y != 0.0:
        features_X_Z = features_X_Z.double()
        for i, y in enumerate(ys):
            features_Z = features_X_Z[labels_x == y]
            cov = get_central_limit_covariance(features_Z)
            h_Z__Y += diff_entropy(cov) * p_Y[i]

    h_Z = 0.0
    if factor_h_Z != 0.0:
        features_X_Z = features_X_Z.double()
        cov = get_central_limit_covariance(features_X_Z)
        h_Z = diff_entropy(cov)

    # We can ignore H(Y|X), H(Yhat|X) because we are not using a BNN atm.
    result = (
        factor_h_Y.to(device=labels_x.device) * h_Y
        + factor_h_Z.to(device=labels_x.device) * h_Z
        + factor_h_Z__Y.to(device=labels_x.device) * h_Z__Y
    )
    return result


def continuous_information(features_X_Y, features_X_Z, information_quantity):
    # y and Z are continuous

    # We can compute the following measures:
    # H(Y), H(Yhat), H(Y|X), H(Yhat|X), H(Yhat,Y)

    # We can ignore H(Y|X), H(Yhat|X) because we are not using a BNN atm.

    factor_h_Y_Z = iq.H_Y_Z @ information_quantity
    factor_h_Y = iq.H_Y @ information_quantity
    factor_h_Z = iq.H_Z @ information_quantity

    if factor_h_Y_Z != 0.0:
        features_X_Y = features_X_Y.double()
        features_X_Z = features_X_Z.double()
        features_X_Y_Z = torch.cat((features_X_Y, features_X_Z), dim=1)

        cov_Y_Z = get_central_limit_covariance(features_X_Y_Z)
        h_Y_Z = diff_entropy(cov_Y_Z)

        h_Y = 0.0
        if factor_h_Y != 0.0:
            num_features_y = features_X_Y.shape[1]
            h_Y = diff_entropy(cov_Y_Z[:num_features_y, :num_features_y])

        h_Z = 0.0
        if factor_h_Z != 0.0:
            num_features_Z = features_X_Z.shape[1]
            h_Z = diff_entropy(cov_Y_Z[:-num_features_Z, :-num_features_Z])

    else:
        h_Y_Z = 0.0

        h_Y = 0.0
        if factor_h_Y != 0.0:
            features_X_Y = features_X_Y.double()
            cov_y = get_central_limit_covariance(features_X_Y)
            h_Y = diff_entropy(cov_y)

        h_Z = 0.0
        if factor_h_Z != 0.0:
            features_X_Z = features_X_Z.double()
            cov_Z = get_central_limit_covariance(features_X_Z)
            h_Z = diff_entropy(cov_Z)

    result = (
        factor_h_Y.to(device=features_X_Y.device) * h_Y
        + factor_h_Z.to(device=features_X_Y.device) * h_Z
        + factor_h_Y_Z.to(device=features_X_Y.device) * h_Y_Z
    )
    return result
