import torch

import torch.distributions as dist


def mutual_information_kde(x, y, bandwidth=0.1, device=None):
    """
    Computes mutual information between two continuous tensors using KDE.
    """
    # Get original device
    if device is None:
        device = x.device

    # Move to CPU for operations not supported by MPS
    x_cpu = x.cpu()
    y_cpu = y.cpu()

    # Ensure tensors are 1D vectors
    x_cpu = x_cpu.flatten()
    y_cpu = y_cpu.flatten()

    # Make sure they have the same shape
    min_length = min(x_cpu.shape[0], y_cpu.shape[0])
    x_cpu = x_cpu[:min_length]
    y_cpu = y_cpu[:min_length]

    # Stack x and y together for joint density estimation
    xy = torch.stack([x_cpu, y_cpu], dim=1)  # Now xy has shape [min_length, 2]

    # Define Gaussian kernel density estimators
    epsilon = 1e-10

    # For univariate normal distributions
    p_x = (
        dist.Normal(x_cpu.unsqueeze(1), bandwidth + epsilon)
        .log_prob(x_cpu.unsqueeze(0))
        .exp()
        .mean(dim=0)
        + epsilon
    )
    p_y = (
        dist.Normal(y_cpu.unsqueeze(1), bandwidth + epsilon)
        .log_prob(y_cpu.unsqueeze(0))
        .exp()
        .mean(dim=0)
        + epsilon
    )

    # For multivariate normal distribution
    eye_matrix = torch.eye(2) * (bandwidth + epsilon)  # 2x2 covariance matrix
    p_xy = dist.MultivariateNormal(xy, eye_matrix).log_prob(xy).exp().mean() + epsilon

    # Compute mutual information - note we need to adjust this since dimensions have changed
    joint_entropy = -torch.log(p_xy).mean()
    x_entropy = -torch.log(p_x).mean()
    y_entropy = -torch.log(p_y).mean()

    # MI = H(X) + H(Y) - H(X,Y)
    mi = x_entropy + y_entropy - joint_entropy

    # Move result back to original device
    return mi.to(device)
