import torch
import numpy as np
from collections import deque


class DynamicsBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.rewards = []
        self.next_states = []

        self.is_terminals = []

        self.rewards_mean = 0
        self.rewards_std = 1

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.rewards[:]
        del self.next_states[:]

        del self.is_terminals[:]

    def add(self, state, action, next_state, reward, is_terminal):
        self.actions.append(action)
        self.states.append(state)
        self.rewards.append(reward)
        self.next_states.append(next_state)

        self.is_terminals.append(is_terminal)

        # Update rewards rolling mean and std
        self.rewards_mean = self.rewards_mean + (reward - self.rewards_mean) / len(self.rewards)
        self.rewards_std = self.rewards_std + (reward - self.rewards_mean) * (reward - self.rewards_mean)


class RollingUncertaintyNormalizer:
    """
    A class that maintains running statistics for normalizing uncertainty measures.
    This helps standardize different uncertainty measures (variance, entropy, IG)
    to comparable scales for use as intrinsic rewards.
    """

    def __init__(
            self,
            max_size=10_000,
            epsilon=1e-8,
            clip_range=(-5, 5),
            device="cpu"
    ):
        """
        Initialize the normalizer with running statistics tracking.

        Args:
            max_size: Maximum size of the rolling window for statistics
            epsilon: Small constant for numerical stability
            clip_range: Tuple of (min, max) values to clip normalized uncertainties
            device: The device to store and compute tensors on
        """
        self.max_size = max_size
        self.epsilon = epsilon
        self.clip_range = clip_range
        self.device = device

        # Initialize rolling window for uncertainty values
        self.uncertainty_window = deque(maxlen=max_size)

        # Initialize running statistics
        self.mean = 0.0
        self.std = 1.0
        self.count = 0

        # Track minimum and maximum seen values
        self.min_val = float('inf')
        self.max_val = float('-inf')

    def update(self, uncertainty_values):
        """
        Update running statistics with new uncertainty values.

        Args:
            uncertainty_values: Tensor or array of uncertainty measures
        """
        if isinstance(uncertainty_values, torch.Tensor):
            uncertainty_values = uncertainty_values.cpu().numpy()

        # Flatten input if needed
        if uncertainty_values.ndim > 1:
            uncertainty_values = uncertainty_values.flatten()

        # Update min/max tracking
        self.min_val = min(self.min_val, uncertainty_values.min())
        self.max_val = max(self.max_val, uncertainty_values.max())

        # Add new values to rolling window
        for value in uncertainty_values:
            self.uncertainty_window.append(value)

        # Update count of total samples seen
        self.count += len(uncertainty_values)

        # Recompute statistics on current window
        if len(self.uncertainty_window) > 0:
            values_array = np.array(self.uncertainty_window)
            self.mean = np.mean(values_array)
            self.std = np.std(values_array) + self.epsilon

    def normalize(self, uncertainty_values, method='standardize'):
        """
        Normalize uncertainty values using specified method.

        Args:
            uncertainty_values: Tensor or array of uncertainty measures
            method: Normalization method ('standardize' or 'minmax')

        Returns:
            Normalized uncertainty values as tensor on specified device
        """
        if isinstance(uncertainty_values, np.ndarray):
            uncertainty_values = torch.FloatTensor(uncertainty_values)

        uncertainty_values = uncertainty_values.to(self.device)

        if method == 'standardize':
            # Z-score normalization
            normalized = (uncertainty_values - self.mean) / self.std
        elif method == 'minmax':
            # Min-max scaling to [0,1]
            denominator = (self.max_val - self.min_val) + self.epsilon
            normalized = (uncertainty_values - self.min_val) / denominator
        else:
            raise ValueError(f"Unknown normalization method: {method}")

        # Clip normalized values
        normalized = torch.clamp(normalized, self.clip_range[0], self.clip_range[1])

        return normalized

    def get_stats(self):
        """Return current statistics as a dictionary."""
        return {
            'mean': self.mean,
            'std': self.std,
            'min': self.min_val,
            'max': self.max_val,
            'count': self.count,
            'window_size': len(self.uncertainty_window)
        }


class RollingNormalizer:
    def __init__(self, state_dim, action_dim):
        """
        Maintain moving mean and standard deviation of state, action and state_delta
        for the formulas see: https://www.johndcook.com/blog/standard_deviation/
        """

        self.state_dim = state_dim
        self.action_dim = action_dim

        self.state_mean = None
        self.state_sk = None
        self.state_stdev = None

        self.action_mean = None
        self.action_sk = None
        self.action_stdev = None

        self.state_delta_mean = None
        self.state_delta_sk = None
        self.state_delta_stdev = None

        self.reward_mean = None
        self.reward_sk = None
        self.reward_stdev = None

        self.count = 0

    @staticmethod
    def update_mean(mu_old, addendum, n):
        mu_new = mu_old + (addendum - mu_old) / n
        return mu_new

    @staticmethod
    def update_sk(sk_old, mu_old, mu_new, addendum):
        sk_new = sk_old + (addendum - mu_old) * (addendum - mu_new)
        return sk_new

    def update(self, state, action, state_delta, reward):
        self.count += 1

        # Transform to torch tensors
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

        # If action is dtype int 64 or int need to handle both int and np.int64
        if isinstance(action, np.int64) or isinstance(action, int):
            action = torch.tensor(action, dtype=torch.long)
            action = torch.nn.functional.one_hot(action, num_classes=self.action_dim).unsqueeze(0)
            # action = torch.tensor(action, dtype=torch.float32).unsqueeze(0)
        else:
            action = torch.tensor(action, dtype=torch.float32).unsqueeze(0)

        state_delta = torch.tensor(state_delta, dtype=torch.float32).unsqueeze(0)
        reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(0).unsqueeze(1)

        if self.count == 1:
            # first element, initialize
            self.state_mean = state.clone()
            self.state_sk = torch.zeros_like(state)
            self.state_stdev = torch.zeros_like(state)
            self.action_mean = action.clone()
            self.action_sk = torch.zeros_like(action)
            self.action_stdev = torch.zeros_like(action)
            self.state_delta_mean = state_delta.clone()
            self.state_delta_sk = torch.zeros_like(state_delta)
            self.state_delta_stdev = torch.zeros_like(state_delta)

            self.reward_mean = reward.clone()
            self.reward_sk = torch.zeros_like(reward)
            self.reward_stdev = torch.ones_like(reward)
            return

        state_mean_old = self.state_mean.clone()
        action_mean_old = self.action_mean.clone()
        state_delta_mean_old = self.state_delta_mean.clone()

        self.state_mean = self.update_mean(self.state_mean, state, self.count)
        self.action_mean = self.update_mean(self.action_mean, action, self.count)
        self.state_delta_mean = self.update_mean(self.state_delta_mean, state_delta, self.count)
        self.reward_mean = self.update_mean(self.reward_mean, reward, self.count)

        self.state_sk = self.update_sk(self.state_sk, state_mean_old, self.state_mean, state)
        self.action_sk = self.update_sk(self.action_sk, action_mean_old, self.action_mean, action)
        self.state_delta_sk = self.update_sk(self.state_delta_sk,state_delta_mean_old,self.state_delta_mean,state_delta)
        self.reward_sk = self.update_sk(self.reward_sk, self.reward_mean, reward, reward)

        self.state_stdev = torch.sqrt(self.state_sk / self.count)
        self.action_stdev = torch.sqrt(self.action_sk / self.count)
        self.state_delta_stdev = torch.sqrt(self.state_delta_sk / self.count)
        self.reward_stdev = torch.sqrt(self.reward_sk / self.count)

    @staticmethod
    def setup_vars(x, mean, stdev):
        assert x.size(-1) == mean.size(-1), f'sizes: {x.size()}, {mean.size()}'

        mean, stdev = mean.clone().detach(), stdev.clone().detach()
        mean, stdev = mean.to(x.device), stdev.to(x.device)

        while len(x.size()) < len(mean.size()):
            mean, stdev = mean.unsqueeze(0), stdev.unsqueeze(0)

        return mean, stdev

    def _normalize(self, x, mean, stdev):
        mean, stdev = self.setup_vars(x, mean, stdev)
        n = x - mean
        n = n / stdev
        return n

    def normalize_states(self, states):
        return self._normalize(states, self.state_mean, self.state_stdev)

    def normalize_actions(self, actions):
        return self._normalize(actions, self.action_mean, self.action_stdev)

    def normalize_state_deltas(self, state_deltas):
        return self._normalize(state_deltas, self.state_delta_mean, self.state_delta_stdev)

    def normalize_rewards(self, rewards):
        return self._normalize(rewards, self.reward_mean, self.reward_stdev)

    def denormalize_state_delta_means(self, state_deltas_means):
        mean, stdev = self.setup_vars(state_deltas_means, self.state_delta_mean, self.state_delta_stdev)
        return state_deltas_means * stdev + mean

    def denormalize_state_delta_vars(self, state_delta_vars):
        mean, stdev = self.setup_vars(state_delta_vars, self.state_delta_mean, self.state_delta_stdev)
        return state_delta_vars * (stdev ** 2)

    def denormalize_rewards_means(self, rewards):
        mean, stdev = self.setup_vars(rewards, self.reward_mean, self.reward_stdev)
        return rewards * stdev + mean

    def denormalize_rewards_vars(self, rewards):
        mean, stdev = self.setup_vars(rewards, self.reward_mean, self.reward_stdev)
        return rewards * (stdev ** 2)

    def renormalize_state_delta_means(self, state_deltas_means):
        mean, stdev = self.setup_vars(state_deltas_means, self.state_delta_mean, self.state_delta_stdev)
        return (state_deltas_means - mean) / stdev

    def renormalize_rewards_mean(self, rewards):
        mean, stdev = self.setup_vars(rewards, self.reward_mean, self.reward_stdev)
        return (rewards - mean) / stdev

    def renormalize_state_delta_vars(self, state_delta_vars):
        mean, stdev = self.setup_vars(state_delta_vars, self.state_delta_mean, self.state_delta_stdev)
        return state_delta_vars / (stdev ** 2)

    def get_state(self):
        state = {'state_mean': self.state_mean.clone(),
                 'state_stdev': self.state_stdev.clone(),
                 'action_mean': self.action_mean.clone(),
                 'action_stdev': self.action_stdev.clone(),
                 'state_delta_mean': self.state_delta_mean.clone(),
                 'state_delta_stdev': self.state_delta_stdev.clone(),
                 'count': self.count}
        return state

    def set_state(self, state):
        self.state_mean = state['state_mean'].clone()
        self.state_stdev = state['state_stdev'].clone()
        self.action_mean = state['action_mean'].clone()
        self.action_stdev = state['action_stdev'].clone()
        self.state_delta_mean = state['state_delta_mean'].clone()
        self.state_delta_stdev = state['state_delta_stdev'].clone()
        self.count = state['count']

    def __getstate__(self):
        return self.get_state()

    def __setstate__(self, state):
        self.set_state(state)


def compute_uncertainty(states, mu_, var_, model, method="Variance", diff_states=None):
    '''
    Function to compute the uncertainty of the ensemble predictions. It takes as inputs the means and covariances of the
    ensemble, and the type of uncertainty to compute ("Variance", "Entropy" or "IG")
    :param means:
    :param covariances:
    :param type:
    :return:
    '''

    state_delta_means = mu_ - states.to(mu_.device).unsqueeze(0)
    state_delta_means = model.normalizer.renormalize_state_delta_means(state_delta_means)
    var_ = model.normalizer.renormalize_state_delta_vars(var_)

    mu, var = state_delta_means, var_
    var = rescale_var(var, model.min_logvar.clone().detach(), model.max_logvar.clone().detach())

    # Compute and return the uncertainty measures based on the type specified
    if method == "Variance":
        return mean_vars_ens(mu, var)
    elif method == "Entropy":
        # Make the vars diagonal by expanding the last dimension
        return compute_entropy(mu, var)
    elif method == "IG":
        # Make the vars diagonal by expanding the last dimension
        return compute_ig(mu, var)
    elif method == "Error" and diff_states is not None:

        # Get next_states from states and diff_states
        next_states = states + diff_states

        # As prediction here take only the mean of the ensemble
        mean_ns_pred = mu + states.to(mu.device).unsqueeze(0)
        mean_ns_pred = model.normalizer.denormalize_state_delta_means(mean_ns_pred)

        # Compute the error
        error = torch.norm(next_states - mean_ns_pred, dim=-1)

        return error


def mean_vars_ens(means, var_s):
    '''
    Function to compute the mean of the variances of the ensemble predictions. It takes as inputs the means and logvars
    of the ensemble
    :param means:  (N_models, batch_size, D)
    :param var_s:  (N_models, batch_size, D)
    :return:
    '''

    # Compute the variance of the ensemble predictions via the law of total variance
    # Var(Y) = E[Var(Y|X)] + Var(E[Y|X])

    # Compute the mean of the variances of the ensemble predictions
    mean_vars = var_s.mean(dim=0)

    # Compute the variance of the ensemble predictions
    var_means = means.var(dim=0)

    # Compute the total variance
    total_var = mean_vars + var_means  # (batch_size, D)

    # Average total variance over the dimensions
    total_var = total_var.mean(dim=1)  # (batch_size,)

    return total_var


def multivariate_gaussian_entropy(covariance_matrices):
    """
    Compute the entropy of multivariate Gaussian distributions.

    Parameters:
    - covariance_matrices (Tensor): Covariance matrices with shape (..., D, D)

    Returns:
    - entropy (Tensor): Entropy values with shape (...)
    """
    D = covariance_matrices.size(-1)
    log_two_pi = torch.log(torch.tensor(2 * np.pi, device=covariance_matrices.device, dtype=covariance_matrices.dtype))
    sign, logdet = torch.slogdet(covariance_matrices)

    # Ensure the covariance matrices are positive definite
    if not torch.all(sign > 0):
        raise ValueError("Covariance matrix is not positive definite.")

    entropy = 0.5 * (logdet + D * (1 + log_two_pi))
    return entropy


def compute_information_gain(means, covariances):
    """
    Compute the Information Gain (IG) for a set of means and covariances from ensemble models.

    Parameters:
    - means (Tensor): Mean predictions with shape (N_models, batch_size, D)
    - covariances (Tensor): Covariance matrices with shape (N_models, batch_size, D, D)

    Returns:
    - IG (Tensor): Information Gain values with shape (batch_size,)
    """
    N_models, batch_size, D = means.shape

    # Compute the average mean over models
    mu_bar = means.mean(dim=0)  # Shape: (batch_size, D)

    # Compute deviations from the average mean
    delta_mu = means - mu_bar.unsqueeze(0)  # Shape: (N_models, batch_size, D)
    delta_mu = delta_mu.permute(1, 0, 2)    # Shape: (batch_size, N_models, D)

    # Compute covariance of the means
    delta_mu_T = delta_mu.transpose(1, 2)   # Shape: (batch_size, D, N_models)
    Cov_mu = (1 / N_models) * torch.bmm(delta_mu_T, delta_mu)  # Shape: (batch_size, D, D)

    # Compute mean covariance over models
    Sigma_mean = covariances.mean(dim=0)    # Shape: (batch_size, D, D)

    # Total covariance is sum of mean covariance and covariance of means
    Sigma_total = Sigma_mean + Cov_mu       # Shape: (batch_size, D, D)

    # Add a small value to the diagonal for numerical stability
    epsilon = 1e-6
    D_eye = torch.eye(D, device=Sigma_total.device).unsqueeze(0)  # Shape: (1, D, D)
    Sigma_total += epsilon * D_eye

    # Compute the entropy of the mixture distribution
    H_mixture = multivariate_gaussian_entropy(Sigma_total)  # Shape: (batch_size,)

    # Compute entropies of individual model predictions
    D_eye_expanded = D_eye.unsqueeze(0)  # Shape: (1, 1, D, D)
    covariances += epsilon * D_eye_expanded  # Shape: (N_models, batch_size, D, D)
    H_i = multivariate_gaussian_entropy(covariances.view(-1, D, D))  # Shape: (N_models * batch_size,)
    H_i = H_i.view(N_models, batch_size)  # Shape: (N_models, batch_size)
    mean_H = H_i.mean(dim=0)              # Shape: (batch_size,)

    # Compute Information Gain
    IG = H_mixture - mean_H  # Shape: (batch_size,)

    return IG


def compute_entropy(means, diag_covariances):
    """
    Compute the entropy of the ensemble predictions with diagonal covariance matrices.

    Parameters:
    - means (Tensor): Mean predictions with shape (N_models, batch_size, D)
    - diag_covariances (Tensor): Diagonal elements of covariance matrices with shape (N_models, batch_size, D)

    Returns:
    - H_mixture (Tensor): Entropy of the ensemble mixture distribution, shape (batch_size,)
    """
    N_models, batch_size, D = means.shape

    # Compute the average mean over models
    mu_bar = means.mean(dim=0)  # Shape: (batch_size, D)

    # Compute deviations from the average mean
    delta_mu = means - mu_bar.unsqueeze(0)  # Shape: (N_models, batch_size, D)
    delta_mu = delta_mu.permute(1, 0, 2)    # Shape: (batch_size, N_models, D)

    # Compute diagonal covariance of the means (Covariance due to model uncertainty)
    # Instead of full matrix multiplication, we can compute variance along model dimension
    Cov_mu_diag = (1 / N_models) * (delta_mu ** 2).sum(dim=1)  # Shape: (batch_size, D)

    # Compute mean diagonal covariance over models (Covariance due to data uncertainty)
    Sigma_mean_diag = diag_covariances.mean(dim=0)  # Shape: (batch_size, D)

    # Total covariance diagonal is sum of mean covariance and covariance of means
    Sigma_total_diag = Sigma_mean_diag + Cov_mu_diag  # Shape: (batch_size, D)

    # Add a small value for numerical stability
    epsilon = 1e-6
    Sigma_total_diag += epsilon

    # Compute the entropy of the mixture distribution
    # For a diagonal covariance matrix, the entropy is:
    # H = 0.5 * log(det(2πe * Σ)) = 0.5 * sum(log(2πe * σ_i^2))
    H_mixture = 0.5 * (torch.log(2 * torch.pi * torch.e * Sigma_total_diag).sum(dim=-1))

    return H_mixture


def multivariate_gaussian_kl(mu1, sigma1, mu2, sigma2):
    """
    Compute KL divergence between two multivariate Gaussians.

    Input dimensions:
    - mu1: (batch_size, D)        # e.g., (10, 3)
    - sigma1: (batch_size, D, D)  # e.g., (10, 3, 3)
    - mu2: (batch_size, D)        # e.g., (10, 3)
    - sigma2: (batch_size, D, D)  # e.g., (10, 3, 3)

    Returns:
    - KL divergence: (batch_size,)  # e.g., (10,)
    """
    k = mu1.size(-1)  # Dimensionality of the distributions

    # Compute inverse of sigma2 for each batch element
    # Input: (batch_size, D, D) -> Output: (batch_size, D, D)
    sigma2_inv = torch.linalg.inv(sigma2)

    # Compute trace term: tr(Σ₂⁻¹Σ₁) for each batch element
    # matmul: (batch_size, D, D) × (batch_size, D, D) -> (batch_size, D, D)
    # diagonal: Get diagonal elements -> (batch_size, D)
    # sum(-1): Sum over D dimension -> (batch_size,)
    batch_trace = torch.matmul(sigma2_inv, sigma1)
    trace_term = torch.diagonal(batch_trace, dim1=-2, dim2=-1).sum(-1)

    # Compute quadratic term: (μ₂-μ₁)ᵀΣ₂⁻¹(μ₂-μ₁)
    # delta: (batch_size, D)
    delta = mu2 - mu1
    # Reshape delta for batch matrix multiplication:
    # (batch_size, 1, D) × (batch_size, D, D) × (batch_size, D, 1)
    quad_term = torch.bmm(
        torch.bmm(
            delta.unsqueeze(1),  # (batch_size, 1, D)
            sigma2_inv  # (batch_size, D, D)
        ),
        delta.unsqueeze(2)  # (batch_size, D, 1)
    ).squeeze()  # Result: (batch_size,)

    # Compute log determinant term: ln|Σ₂| - ln|Σ₁|
    # logdet returns (batch_size,)
    logdet_term = torch.logdet(sigma2) - torch.logdet(sigma1)

    # All terms are now (batch_size,), so addition works properly
    return 0.5 * (trace_term + quad_term + logdet_term - k)


def rescale_var(var, min_log_var, max_log_var, decay=0.1):
    min_var, max_var = torch.exp(min_log_var), torch.exp(max_log_var)
    return max_var - decay * (max_var - var)


def compute_ig(state_delta_means, next_state_vars):

    # Data are in (ens_size, n_actors, d_state). Need to transpose to (n_actors, ens_size, d_state)
    state_delta_means = state_delta_means.transpose(0, 1)
    next_state_vars = next_state_vars.transpose(0, 1)

    mu, var = state_delta_means, next_state_vars                         # shape: both (n_actors, ensemble_size, d_state)
    n_act, es, d_s = mu.size()                                            # shape: (n_actors, ensemble_size, d_state)

    # entropy of the mean
    mu_diff = mu.unsqueeze(1) - mu.unsqueeze(2)                           # shape: (n_actors, ensemble_size, ensemble_size, d_state)
    var_sum = var.unsqueeze(1) + var.unsqueeze(2)                         # shape: (n_actors, ensemble_size, ensemble_size, d_state)

    err = (mu_diff * 1 / var_sum * mu_diff)                               # shape: (n_actors, ensemble_size, ensemble_size, d_state)
    err = torch.sum(err, dim=-1)                                          # shape: (n_actors, ensemble_size, ensemble_size)
    det = torch.sum(torch.log(var_sum), dim=-1)                           # shape: (n_actors, ensemble_size, ensemble_size)

    log_z = -0.5 * (err + det)                                            # shape: (n_actors, ensemble_size, ensemble_size)
    log_z = log_z.reshape(n_act, es * es)                                 # shape: (n_actors, ensemble_size * ensemble_size)
    mx, _ = log_z.max(dim=1, keepdim=True)                                # shape: (n_actors, 1)
    log_z = log_z - mx                                                    # shape: (n_actors, ensemble_size * ensemble_size)
    exp = torch.exp(log_z).mean(dim=1, keepdim=True)                      # shape: (n_actors, 1)
    entropy_mean = -mx - torch.log(exp)                                   # shape: (n_actors, 1)
    entropy_mean = entropy_mean[:, 0]                                     # shape: (n_actors)

    # mean of entropies
    total_entropy = torch.sum(torch.log(var), dim=-1)                     # shape: (n_actors, ensemble_size)
    mean_entropy = total_entropy.mean(dim=1) / 2 + d_s * np.log(2.) / 2    # shape: (n_actors)

    # jensen-shannon divergence
    jsd = entropy_mean - mean_entropy                                 # shape: (n_actors)

    return jsd

#
# def compute_entropy(means, var_s):
#     """
#     Compute entropy for diagonal multivariate Gaussian distributions.
#
#     Args:
#         means: tensor of shape (n_samples, batch_size, d_state) - ensemble of mean predictions
#         var_s: tensor of shape (n_samples, batch_size, d_state) - ensemble of variance predictions
#
#     Returns:
#         entropy: tensor of shape (batch_size,) - total entropy per sample
#     """
#     # Compute total variance (aleatoric + epistemic) for each dimension
#     # Shape: (batch_size, d_state)
#     total_var = var_s.mean(0) + means.var(0)
#
#     # Compute entropy per dimension
#     # Shape: (batch_size, d_state)
#     entropy_per_dim = 0.5 * torch.log(2 * np.pi * np.e * total_var)
#
#     # Sum entropy across dimensions to get total entropy
#     # Shape: (batch_size,)
#     total_entropy = entropy_per_dim.sum(dim=-1)
#
#     return total_entropy
#
#
# def compute_ig(means: torch.Tensor, vars: torch.Tensor) -> torch.Tensor:
#     """
#     Compute Information Gain (ensemble disagreement measure) for diagonal multivariate Gaussians.
#
#     Args:
#         means: tensor of shape (n_samples, batch_size, d_state) - ensemble of mean predictions
#         vars: tensor of shape (n_samples, batch_size, d_state) - ensemble of variance predictions
#
#     Returns:
#         ig: tensor of shape (batch_size,) - information gain per sample
#     """
#     # First compute the average entropy of individual ensemble members
#     # For each dimension, we compute 0.5 * log(2πe * σ²)
#     # Shape: (n_samples, batch_size, d_state)
#     individual_entropies = 0.5 * torch.log(2 * np.pi * np.e * vars)
#
#     # Average these entropies across ensemble members
#     # Shape: (batch_size, d_state)
#     avg_entropy = individual_entropies.mean(0)
#
#     # Sum across dimensions to get total average entropy
#     # Shape: (batch_size,)
#     total_avg_entropy = avg_entropy.sum(dim=-1)
#
#     # Compute the entropy of the mixture using our previous function
#     # Shape: (batch_size,)
#     mixture_entropy = compute_entropy(means, vars)
#
#     # The information gain is the difference between mixture entropy and average entropy
#     # Shape: (batch_size,)
#     return mixture_entropy - total_avg_entropy
