from functools import partial
from typing import Callable

import torch 
import chex
import jax
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.math.matrix_square_root import sqrtm
from ott.tools import sinkhorn_divergence
from jax import device_get
import numpy as np

import ot as pot  # Python Optimal Transport package
import scipy
from sklearn.metrics.pairwise import pairwise_distances

@jax.jit
def symmetrize(X: jnp.ndarray) -> jnp.ndarray:
    return (X + X.T) / 2


@jax.jit
def compute_BW_UVP_with_gt_stats(
    model_samples: jnp.ndarray, true_samples_mu: jnp.ndarray, true_samples_covariance: jnp.ndarray
) -> jnp.ndarray:
    model_samples_covariance = jnp.cov(model_samples.T)
    model_samples_mu = model_samples.mean(axis=0)
    model_samples_covariance_sqrt = symmetrize(sqrtm(model_samples_covariance)[0])

    mu_term = 0.5 * jnp.sum((true_samples_mu - model_samples_mu) ** 2)
    covariance_term = (
        0.5 * jnp.trace(model_samples_covariance)
        + 0.5 * jnp.trace(true_samples_covariance)
        - jnp.trace(
            symmetrize(
                sqrtm(model_samples_covariance_sqrt @ true_samples_covariance @ model_samples_covariance_sqrt)[0]
            )
        )
    )

    BW = mu_term + covariance_term
    BW_UVP = 100 * (BW / (0.5 * jnp.trace(true_samples_covariance)))
    return BW_UVP


@jax.jit
def compute_BW_UVP_by_gt_samples(model_samples: jnp.ndarray, true_samples: jnp.ndarray) -> jnp.ndarray:
    true_samples_covariance = jnp.cov(true_samples.T)
    true_samples_mu = true_samples.mean(axis=0)

    return compute_BW_UVP_with_gt_stats(model_samples, true_samples_mu, true_samples_covariance)


#### Compute MMD with Gaussian kernel ####
@jax.jit
def mmd(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """Memory-efficient MMD implementation in JAX.

    This implements the minimum-variance/biased version of the estimator described
    in Eq.(5) of
    https://jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf.
    As described in Lemma 6's proof in that paper, the unbiased estimate and the
    minimum-variance estimate for MMD are almost identical.

    Note that the first invocation of this function will be considerably slow due
    to JAX JIT compilation.

    Args:
      x: The first set of embeddings of shape (n, embedding_dim).
      y: The second set of embeddings of shape (n, embedding_dim).

    Returns:
      The MMD distance between x and y embedding sets.
    """
    # The bandwidth parameter for the Gaussian RBF kernel. See the paper for more
    # details.
    _SIGMA = 10
    # The following is used to make the metric more human readable. See the paper
    # for more details.
    _SCALE = 1

    # Compute squared norms of rows
    x_sqnorms = jnp.sum(x**2, axis=1, keepdims=True)  # Shape: (n, 1)
    y_sqnorms = jnp.sum(y**2, axis=1, keepdims=True)  # Shape: (n, 1)

    # Compute gamma for the Gaussian RBF kernel
    gamma = 1 / (2 * _SIGMA**2)

    # Compute kernel matrices
    k_xx = jnp.exp(-gamma * (-2 * jnp.dot(x, x.T) + x_sqnorms + x_sqnorms.T))
    k_xy = jnp.exp(-gamma * (-2 * jnp.dot(x, y.T) + x_sqnorms + y_sqnorms.T))
    k_yy = jnp.exp(-gamma * (-2 * jnp.dot(y, y.T) + y_sqnorms + y_sqnorms.T))

    # Compute MMD^2
    MMD = _SCALE * (jnp.mean(k_xx) + jnp.mean(k_yy) - 2 * jnp.mean(k_xy))
    return MMD


@jax.jit
def compute_sinkhorn_divergence(xs: jnp.ndarray, ys: jnp.ndarray, epsilon: float = 1.0) -> float:
    """
    Computes the Sinkhorn divergence between two sets of points `xs` and `ys`.

    This function uses the JAX-OTT (Optimal Transport Tools) library to compute the Sinkhorn divergence,
    which is a unbiased version of the Wasserstein distance.

    Parameters
    ----------
    xs : jnp.ndarray
        An array of shape (n_samples_x, n_features) representing the first set of points.
    ys : jnp.ndarray
        An array of shape (n_samples_y, n_features) representing the second set of points.
    epsilon : float, optional
        Regularization parameter for the Sinkhorn algorithm, by default 1.

    Returns
    -------
    float
        The Sinkhorn divergence between the two sets of points.

    References
    ----------
    - JAX-OTT library documentation: https://ott-jax.readthedocs.io/en/stable/glossary.html#term-Sinkhorn-divergence

    """
    chex.assert_rank(xs, 2)
    chex.assert_rank(ys, 2)
    chex.assert_axis_dimension(xs, axis=1, expected=ys.shape[1])
    chex.assert_type([xs, ys], float)

    sinkhorn_div = sinkhorn_divergence.sinkhorn_divergence(pointcloud.PointCloud, xs, ys, epsilon=epsilon)

    return sinkhorn_div.divergence


@jax.jit
def l2_distance(x: jnp.ndarray, y: jnp.ndarray) -> float:
    return jnp.mean(jnp.sum((x - y) ** 2, axis=1))


@partial(jax.jit, static_argnames=["potential_func"])
def gradients_potential(potential_func: Callable[[jnp.ndarray], float], xs: jnp.ndarray) -> jnp.ndarray:
    return jax.vmap(jax.grad(potential_func))(xs)


@partial(jax.jit, static_argnames=["interaction_func"])
def gradients_interaction(interaction_func: Callable[[jnp.ndarray], float], xs: jnp.ndarray) -> jnp.ndarray:
    interaction_grad = lambda v: jax.grad(interaction_func)(v)
    interaction_grad_vmap = jax.vmap(interaction_grad)

    def get_interaction_component(pp):
        def W_fn(p):
            forw = interaction_grad_vmap(p - pp)
            back = -interaction_grad_vmap(pp - p)
            W_biased_sum = jnp.sum(forw + back, axis=0)
            assert W_biased_sum.shape == p.shape
            bs = pp.shape[0]
            return W_biased_sum / (bs - 1.0)

        return W_fn

    return jax.vmap(get_interaction_component(xs))(xs)


@partial(jax.jit, static_argnames=["grad_func", "gt_func", "func"])
def l2_uvp_backward(
    rho: jnp.ndarray,
    rho_prev: jnp.ndarray,
    grad_func: Callable[[Callable[[jnp.ndarray], float], jnp.ndarray], jnp.ndarray],
    gt_func: Callable[[jnp.ndarray], float],
    func: Callable[[jnp.ndarray], float],
) -> float:
    # rho, rho_prev: shape (num_samples, dim)
    grad_diff = grad_func(gt_func, rho) - grad_func(func, rho)  # (num_samples, dim)
    var = jnp.var(rho_prev, axis=0, ddof=1).sum()
    return 100 * jnp.mean(jnp.sum(grad_diff**2, axis=1)) / var


def earth_mover_distance(
    p,
    q,
    eigenvals=None,
    weights1=None,
    weights2=None,
    return_matrix=False,
    metric="sqeuclidean",
):
    """
    Returns the earth mover's distance between two point clouds
    Parameters
    ----------
    cloud1 : 2-D array
        First point cloud
    cloud2 : 2-D array
        Second point cloud
    Returns
    -------
    distance : float
        The distance between the two point clouds
    """
    p = p.toarray() if scipy.sparse.isspmatrix(p) else p
    q = q.toarray() if scipy.sparse.isspmatrix(q) else q
    if eigenvals is not None:
        p = p.dot(eigenvals)
        q = q.dot(eigenvals)
    if weights1 is None:
        p_weights = np.ones(len(p)) / len(p)
    else:
        weights1 = weights1.astype("float64")
        p_weights = weights1 / weights1.sum()

    if weights2 is None:
        q_weights = np.ones(len(q)) / len(q)
    else:
        weights2 = weights2.astype("float64")
        q_weights = weights2 / weights2.sum()

    pairwise_dist = np.ascontiguousarray(pairwise_distances(p, Y=q, metric=metric, n_jobs=-1))

    result = pot.emd2(p_weights, q_weights, pairwise_dist, numItermax=1e7, return_matrix=return_matrix)
    if return_matrix:
        square_emd, log_dict = result
        return np.sqrt(square_emd), log_dict
    else:
        return np.sqrt(result)
    
class MMD_loss(torch.nn.Module):
    '''
    fork from: https://github.com/ZongxianLee/MMD_Loss.Pytorch
    '''
    def __init__(self, kernel_mul = 2.0, kernel_num = 5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        return
    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)

        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2) 
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def forward(self, source, target):
        sample_size = 1000

        pred_idx = np.random.choice(target.shape[0], sample_size, replace=False)
        ref_idx = np.random.choice(source.shape[0], sample_size, replace=False)
        source = source[ref_idx]
        target = target[pred_idx]
        source = torch.Tensor(np.array(source))
        target = torch.Tensor(np.array(target))
        batch_size = int(source.size()[0])
        kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        loss = torch.mean(XX + YY - XY -YX)
        return loss.item()
