"""Approx regularizers that compute Var[Z], Var[Z|X] and Var[Z|Y], as well as Var[Z_x], and we can use that
to estimate the entropy."""
import math
from typing import Callable, Optional

import numpy as np
import torch

from XXX.uib.utils.constants import dbl_null_threshold, flt_null_threshold


def flatten_x_k_z_(latent_x_k_):
    latent_x_k_z = latent_x_k_.flatten(2)
    return latent_x_k_z


def logsum(values) -> torch.Tensor:
    return torch.sum(torch.log(values[values > dbl_null_threshold]))


def reduce_(values, reduction: Optional[str]):
    if reduction == "sum":
        reduced = torch.sum(values)
    elif reduction == "mean":
        reduced = torch.mean(values)
    elif reduction == "logsum":
        reduced = logsum(values)
    else:
        reduced = values
    return reduced


def reduce(func, reduction):
    def wrapper(*args):
        return reduce_(func(*args), reduction)

    return wrapper


def force_stochastic(func, *, stochastic: bool):
    if stochastic:
        return func

    def wrapped(value_x_, *rest):
        value_x_k_ = value_x_[:, None, ...]
        result = func(value_x_k_, *rest)
        return result

    return wrapped


def squared_sum(*, stochastic=False, reduction: Optional[str] = "mean"):
    def squared_(latent_x_k_: torch.Tensor, labels_x):
        latent_xk_z = flatten_x_k_z_(latent_x_k_).flatten(0, 1)
        value_xk = torch.bmm(latent_xk_z[:, None, :], latent_xk_z[:, :, None])[:, 0]
        value_x_k = value_xk.reshape(latent_x_k_.shape[0:2])
        return value_x_k

    return force_stochastic(reduce(squared_, reduction=reduction), stochastic=stochastic)


def mean_by_x_(latent_x_k_: torch.Tensor):
    latent_x_k_z = latent_x_k_.flatten(2)
    latent_x_1_z = torch.mean(latent_x_k_z, dim=1, keepdim=True)
    return latent_x_1_z


def mean_by_x(func):
    def wrapped(latent_x_k_, *rest):
        return func(mean_by_x_(latent_x_k_), *rest)

    return wrapped


def squared_mean_by_x(*, stochastic=False, reduction="mean"):
    return force_stochastic(
        reduce(mean_by_x(squared_sum(stochastic=True, reduction=None)), reduction=reduction), stochastic=stochastic
    )


def covariance_trace_(values_x_k_, labels_x):
    values_x_k_z = flatten_x_k_z_(values_x_k_)
    variances_z = torch.var(values_x_k_z, dim=(0, 1), keepdim=False)
    return variances_z


def covariance_trace(*, stochastic=False, reduction="sum"):
    return reduce(force_stochastic(covariance_trace_, stochastic=stochastic), reduction)


def covariance_trace_given_X_x_z_(values_x_k_, labels_x):
    values_x_k_z = flatten_x_k_z_(values_x_k_)
    variances_x_z = torch.var(values_x_k_z, dim=1, keepdim=False)
    return variances_x_z


def covariance_trace_given_X_(values_x_k_, labels_x):
    variances_x_z = covariance_trace_given_X_x_z_(values_x_k_, labels_x)

    variances_z = torch.mean(variances_x_z, dim=0, keepdim=False)
    return variances_z


def covariance_trace_given_X(*, stochastic=False, reduction="sum"):
    return reduce(force_stochastic(covariance_trace_given_X_, stochastic=stochastic), reduction)


def covariance_trace_given_Y_y_z_(values_x_k_, labels_x):
    values_x_k_z = flatten_x_k_z_(values_x_k_)

    labels, counts = torch.unique(labels_x, return_counts=True)

    variances_y_z = torch.zeros(
        (len(counts), values_x_k_z.shape[2]), dtype=values_x_k_z.dtype, device=values_x_k_z.device
    )

    for i, (label, count) in enumerate(zip(labels, counts)):
        if count > 1:
            variances_y_z[i] = torch.var(values_x_k_z[labels_x == label], dim=(0, 1), keepdim=False)

    p_y = counts.double() / len(labels_x)

    return variances_y_z, p_y


def covariance_trace_given_Y_(values_x_k_, labels_x):
    variances_y_z, p_y = covariance_trace_given_Y_y_z_(values_x_k_, labels_x)
    result_z = torch.sum(variances_y_z * p_y[:, None], dim=0, keepdim=False)

    return result_z


def covariance_trace_given_Y(*, stochastic=False, reduction="sum"):
    return reduce(force_stochastic(covariance_trace_given_Y_, stochastic=stochastic), reduction)


def covariance_trace_mean_by_X(*, stochastic=False, reduction="sum"):
    return reduce(force_stochastic(mean_by_x(covariance_trace_), stochastic=stochastic), reduction)


def get_entropy(variances_z) -> torch.Tensor:
    k = len(variances_z)
    return 0.5 * (k * np.log(2 * np.pi) + k + logsum(variances_z))


def estimate_entropy(covariance_trace_estimator: Callable, *, stochastic=False):
    compute_variances_z = covariance_trace_estimator(stochastic=stochastic, reduction=None)

    def estimate_entropy_(*args):
        variances_z = compute_variances_z(*args)
        return get_entropy(variances_z)

    return estimate_entropy_


def get_batch_entropy(variances_i_z: torch.Tensor):
    k = variances_i_z.shape[1]
    log_i_z = torch.log(variances_i_z)
    log_i_z[variances_i_z <= flt_null_threshold] = 0.0
    entropy_i = 0.5 * (k * np.log(2 * np.pi) + k + torch.sum(log_i_z, dim=1))
    return entropy_i


def estimate_entropy_Z__Y(*, stochastic):
    def estimate_entropy_(values_x_k_, labels_x):
        variances_y_z, p_y = covariance_trace_given_Y_y_z_(values_x_k_, labels_x)
        entropies_y = get_batch_entropy(variances_y_z)

        entropy = torch.sum(entropies_y * p_y)
        return entropy

    return force_stochastic(estimate_entropy_, stochastic=stochastic)


def estimate_entropy_Z__X(*, stochastic):
    def estimate_entropy_(values_x_k_, labels_x):
        variances_x_z = covariance_trace_given_X_x_z_(values_x_k_, labels_x)
        entropies_y = get_batch_entropy(variances_x_z)

        entropy = torch.mean(entropies_y)
        return entropy

    return force_stochastic(estimate_entropy_, stochastic=stochastic)
