import torch
import numpy as np


def compute_rbf_kernel_from_cost(C: torch.Tensor, eps: float) -> torch.Tensor:
    #  K = exp(-C/eps) where C is 1/2 * ||y_i - y_j||^2 
    return torch.exp(-C / eps)

def rbf_kernel(X, Y, eps=None):
    XX = (X**2).sum(dim=1, keepdim=True)
    YY = (Y**2).sum(dim=1, keepdim=True)
    XY = X @ Y.t()
    dist = XX - 2*XY + YY.t()

    if eps is None:
        eps = X.size(1)  # heuristic
    return torch.exp(- dist / (2 * eps))

def mmd_rbf(X, Y, eps=None):
    Kxx = rbf_kernel(X, X, eps)
    Kyy = rbf_kernel(Y, Y, eps)
    Kxy = rbf_kernel(X, Y, eps)

    m = X.size(0)
    n = Y.size(0)

    mmd = Kxx.sum()/ (m*m) + Kyy.sum() / (n*n) - 2*Kxy.sum() / (m*n)
    return mmd

def squared_mmd_gaussians(
    m1: torch.Tensor, S1: torch.Tensor,
    m0: torch.Tensor, S0: torch.Tensor,
    eps: torch.Tensor | float,
    clamp_min: float = 0.0,
    jitter: float = 1e-9,
) -> torch.Tensor:
    """
    True (population) MMD between two multivariate Gaussians under the RBF kernel

        k(x,y) = exp( -||x-y||^2 / (2 * eps) )

    where eps > 0.

    Args:
        m1, m0: (k,) means
        S1, S0: (k,k) covariance matrices (symmetric PSD)
        eps: scalar > 0 (can be float or 0-dim tensor)
        clamp_min: clamp MMD^2 to this minimum before sqrt (numerical safety)
        jitter: diagonal jitter added for numerical stability in Cholesky/logdet/solve

    Returns:
        Scalar tensor: MMD (not squared).
    """
    if m1.ndim != 1 or m0.ndim != 1 or m1.shape != m0.shape:
        raise ValueError(f"m1 and m0 must both be shape (k,), got {m1.shape} and {m0.shape}")
    if S1.ndim != 2 or S0.ndim != 2 or S1.shape != S0.shape or S1.shape[0] != S1.shape[1]:
        raise ValueError(f"S1 and S0 must both be shape (k,k), got {S1.shape} and {S0.shape}")

    k = m1.shape[0]
    device = m1.device
    dtype = m1.dtype
    I = torch.eye(k, device=device, dtype=dtype)

    eps_t = eps if torch.is_tensor(eps) else torch.tensor(eps, device=device, dtype=dtype)
    if eps_t.ndim != 0:
        eps_t = eps_t.squeeze()
        if eps_t.ndim != 0:
            raise ValueError("eps must be a scalar (float or 0-dim tensor).")

    # Helper: compute |eps*I + A|^{-1/2} robustly via logdet:
    # term = exp(-0.5 * logdet(I + A/eps))
    def det_term(A: torch.Tensor) -> torch.Tensor:
        M = I + (A / eps_t)
        M = M + jitter * I
        sign, logabsdet = torch.linalg.slogdet(M)
        if not torch.all(sign > 0):
            raise ValueError("Matrix inside logdet is not PD; increase jitter or check covariances/eps.")
        return torch.exp(-0.5 * logabsdet)

    # E[k(X,X')] where X~N(m1,S1): depends on X-X' ~ N(0,2S1)
    t11 = det_term(2.0 * S1)

    # E[k(Y,Y')] where Y~N(m0,S0): depends on Y-Y' ~ N(0,2S0)
    t00 = det_term(2.0 * S0)

    # Cross term E[k(X,Y)] with X-Y ~ N(m1-m0, S1+S0)
    A = S1 + S0
    t10_det = det_term(A)

    # exp( -1/2 * (m1-m0)^T (eps I + A)^{-1} (m1-m0) )
    delta = (m1 - m0).reshape(k, 1)
    M = (eps_t * I) + A
    M = M + jitter * I
    L = torch.linalg.cholesky(M)
    sol = torch.cholesky_solve(delta, L)  # M^{-1} delta
    quad = (delta.transpose(0, 1) @ sol).squeeze()  # scalar
    t10_exp = torch.exp(-0.5 * quad)

    mmd2 = t11 + t00 - 2.0 * (t10_det * t10_exp)
    mmd2 = torch.clamp(mmd2, min=clamp_min)  # numerical safety
    mmd2 =  float(mmd2.item()) if hasattr(mmd2, "item") else float(mmd2)
    return mmd2 / 2

def compute_median_bandwidth(X, Y):
    Z = torch.cat([X, Y], dim=0)  # shape (N_total, D)
    with torch.no_grad():
        # Compute pairwise squared distances
        # Using (x - y)^2 = ||x||^2 + ||y||^2 - 2 x·y
        XX = (Z**2).sum(dim=1, keepdim=True)
        dist2 = XX - 2 * (Z @ Z.t()) + XX.t()
        # Take upper triangular part excluding diagonal
        triu_idx = torch.triu_indices(dist2.size(0), dist2.size(1), offset=1)
        upper_tri_dist = dist2[triu_idx[0], triu_idx[1]]
        median_dist = upper_tri_dist.median()
        gamma = 1.0 / (2 * median_dist)
    return gamma


def ipw_kernel_mean_embedding(Y, A, propensity_scores, gamma=None):
    """
    Computes IPW kernel mean embeddings for treatment (A=1) and control (A=0)
    
    Args:
        Y: (N, D) tensor of vectorized outcomes
        A: (N,) tensor of 0/1 treatment
        propensity_scores: (N,) tensor or numpy array
        gamma: bandwidth of RBF kernel (optional, median heuristic)
        
    Returns:
        mu_treated, mu_control: tensors of shape (N, N) (Gram matrices)
    """
    Y_treated = Y[A==1]
    Y_control = Y[A==0]
    
    # Extract corresponding propensity scores
    e = torch.tensor(propensity_scores, dtype=torch.float32)
    w_treated = ((A==1).float() / e)[A==1]
    w_control = ((A==0).float() / (1 - e))[A==0]
    
    # Normalize weights to sum to 1 (empirical KME)
    w_treated /= w_treated.sum()
    w_control /= w_control.sum()
    
    # Compute median bandwidth if not provided
    if gamma is None:
        gamma = compute_median_bandwidth(Y_treated, Y_control)
    
    return w_treated, w_control, gamma

def ipw_mmd_rbf(Y, A, propensity_scores, gamma=None):
    Y = Y.float()
    
    w_treated, w_control, gamma = ipw_kernel_mean_embedding(Y, A, propensity_scores, gamma)
    
    Y_treated = Y[A==1]
    Y_control = Y[A==0]
    
    # Compute Gram matrices
    K_tt = rbf_kernel(Y_treated, Y_treated, gamma)
    K_cc = rbf_kernel(Y_control, Y_control, gamma)
    K_tc = rbf_kernel(Y_treated, Y_control, gamma)
    
    # Weighted sums
    mmd2 = (w_treated.unsqueeze(1) * w_treated.unsqueeze(0) * K_tt).sum() \
           + (w_control.unsqueeze(1) * w_control.unsqueeze(0) * K_cc).sum() \
           - 2 * (w_treated.unsqueeze(1) * w_control.unsqueeze(0) * K_tc).sum()
    
    return mmd2

def compute_resolvent_times_K(K: torch.Tensor, P: torch.Tensor) -> torch.Tensor:
    """
    Computes (I - (K diag(P))^2)^{-1} K in PyTorch, i.e.
        (I - K diag(P) K diag(P))^{-1} K

    Parameters
    ----------
    K : (n, n) torch.Tensor
        Kernel Gram matrix (float32/float64)
    P : (n,) torch.Tensor
        Vector defining diag(P)

    Returns
    -------
    M : (n, n) torch.Tensor
        Result matrix
    """
    if K.dim() != 2 or K.size(0) != K.size(1):
        raise ValueError(f"K must be square (n,n). Got {tuple(K.shape)}")
    n = K.size(0)
    if P.dim() != 1 or P.numel() != n:
        raise ValueError(f"P must have shape (n,). Got {tuple(P.shape)}")

    # Ensure matching dtype/device
    P = P.to(dtype=K.dtype, device=K.device)

    # KD = K diag(P)  (scale columns of K by P)
    KD = K * P[None, :]      # (n, n)

    # (K diag(P))^2 = KD @ KD
    KDKD = KD @ KD                 # (n, n)

    I = torch.eye(n, dtype=K.dtype, device=K.device)

    # Solve (I - KDKD) X = K  => X = (I - KDKD)^{-1} K
    M = torch.linalg.solve(I - KDKD, K)
    return M
