import torch
import numpy as np
import ot
from scipy.spatial.distance import jensenshannon

def compute_js_distance_2d_histogram(xs, xt, bins=64, ranges=None):
    """
    Computes the Jensen-Shannon distance between the 2D histograms of two sets of samples.

    Args:
        xs (np.ndarray): Source samples in (theta, phi) space, shape (n_samples_s, 2).
        xt (np.ndarray): Target samples in (theta, phi) space, shape (n_samples_t, 2).
        bins (int or array_like): The number of bins to use for the 2D histogram.
        ranges (array_like, optional): The lower and upper range of the bins. 
                                       Format: [[xmin, xmax], [ymin, ymax]]. 
                                       If None, the range is computed from the data.

    Returns:
        float: The Jensen-Shannon distance between the two distributions.
    """
    # Create 2D histograms for both sample sets
    p_hist, _, _ = np.histogram2d(xs[:, 0], xs[:, 1], bins=bins, range=ranges)
    q_hist, _, _ = np.histogram2d(xt[:, 0], xt[:, 1], bins=bins, range=ranges)

    # Flatten the histograms to create 1D arrays
    p_flat = p_hist.flatten()
    q_flat = q_hist.flatten()

    # Normalize the histograms to form probability distributions
    p_dist = p_flat / p_flat.sum()
    q_dist = q_flat / q_flat.sum()
    
    # Add a small epsilon to avoid division by zero or log(0) issues
    epsilon = 1e-10
    p_dist += epsilon
    q_dist += epsilon
    
    p_dist /= p_dist.sum()
    q_dist /= q_dist.sum()

    # Compute the Jensen-Shannon distance
    js_distance = jensenshannon(p_dist, q_dist)

    return js_distance

def compute_w2_distance_theta_phi(xs, xt, reg=0.01):
    """
    Computes the Wasserstein-2 distance between two sets of samples 
    in the (theta, phi) space using the Sinkhorn algorithm.

    Args:
        xs (np.ndarray): Source samples, shape (n_samples_s, 2).
        xt (np.ndarray): Target samples, shape (n_samples_t, 2).
        reg (float): Regularization parameter for the Sinkhorn algorithm.

    Returns:
        float: The Wasserstein-2 distance.
    """
    n_samples_s = xs.shape[0]
    n_samples_t = xt.shape[0]

    # Uniform weights for the samples
    a = np.ones(n_samples_s) / n_samples_s
    b = np.ones(n_samples_t) / n_samples_t

    # Cost matrix (squared Euclidean distance)
    M = ot.dist(xs, xt, metric='sqeuclidean')

    # Compute squared Wasserstein distance using Sinkhorn algorithm
    w2_dist_sq, log = ot.sinkhorn2(a, b, M, reg=reg, log=True)

    return np.sqrt(w2_dist_sq)


def frobenius_norm_of_jacobian(model, x, t):
    """
    Computes the Frobenius norm of the Jacobian of the model.
    """
    model.eval()
    func = lambda x: model(x, t)
    
    # Ensure x requires gradients
    x.requires_grad_(True)
    
    # Compute Jacobian
    jac = torch.autograd.functional.jacobian(func, x)
    
    # The output of jacobian is of shape (batch_size, output_dim, batch_size, input_dim)
    # We need to get it to (batch_size, output_dim, input_dim)
    jac = jac.diagonal(dim1=0, dim2=2).permute(2, 0, 1)

    # Compute Frobenius norm for each item in the batch and then average
    fro_norm = torch.linalg.norm(jac, ord='fro', dim=(1, 2))
    
    return fro_norm.mean().item()