import numpy as np
from typing import Tuple, Dict, Callable

def _pairwise_sq_dists(X: np.ndarray) -> np.ndarray:
    """Pairwise squared Euclidean distances between rows of X (n×n)."""
    G = X @ X.T
    sq_norms = np.sum(X * X, axis=1, keepdims=True)
    D2 = sq_norms + sq_norms.T - 2.0 * G
    np.maximum(D2, 0.0, out=D2)
    return D2

def _rbf_kernel(X: np.ndarray, gamma: float) -> np.ndarray:
    """RBF kernel K_ij = exp(-gamma * ||x_i - x_j||^2)."""
    D2 = _pairwise_sq_dists(X)
    return np.exp(-gamma * D2)

def _kernel_width(X: np.ndarray) -> float:
    """Median heuristic for Gaussian kernel width (with guard against degenerate cases)."""
    X_diff = np.expand_dims(X, axis=1) - X
    D = np.linalg.norm(X_diff, axis=2)
    s = np.median(D.flatten())
    if not np.isfinite(s) or s <= 0.0:
        s = 1.0
    return s

class SteinMixin:
    """Stein gradient/Hessian estimators and helpers."""

    def hessian(self, X: np.ndarray, eta_G: float, eta_H: float) -> np.ndarray:
        """Stein estimator of the Hessian of log p(x). Returns (n, d, d)."""
        _, d = X.shape
        s = _kernel_width(X)
        K = self._evaluate_kernel(X, s)
        nablaK = self._evaluate_nablaK(K, X, s)
        G = self.score(X, eta_G, K, nablaK)
        H = np.stack([self._hessian_col(X, G, col, eta_H, K, s) for col in range(d)], axis=1)
        return H

    def score(self, X: np.ndarray, eta_G: float, K: np.ndarray = None, nablaK: np.ndarray = None) -> np.ndarray:
        """Stein gradient (score) estimator. Returns (n, d)."""
        n, _ = X.shape
        if K is None or nablaK is None:
            s = _kernel_width(X)
            K = self._evaluate_kernel(X, s)
            nablaK = self._evaluate_nablaK(K, X, s)
        A = K + eta_G * np.eye(n)
        G = np.linalg.solve(A, nablaK)
        return G

    def _hessian_col(self, X: np.ndarray, G: np.ndarray, c: int, eta: float, K: np.ndarray, s: float) -> np.ndarray:
        """Stein estimator of the c-th Hessian column. Returns (n, d)."""
        X_diff = self._X_diff(X)
        n, _, _ = X_diff.shape
        Gv = np.einsum('i,ij->ij', G[:, c], G)
        nabla2vK = np.einsum('ik,ikj,ik->ij', X_diff[:, :, c], X_diff, K) / (s**4)
        nabla2vK[:, c] -= np.einsum('ik->i', K) / (s**2)
        A = K + eta * np.eye(n)
        H_col = -Gv + np.linalg.solve(A, nabla2vK)
        return H_col

    def hessian_diagonal(self, X: np.ndarray, eta_G: float, eta_H: float) -> np.ndarray:
        """Stein estimator of the diagonal of the Hessian matrix of log p(x)."""
        n, _ = X.shape
        s = _kernel_width(X)
        K = self._evaluate_kernel(X, s)
        nablaK = self._evaluate_nablaK(K, X, s)
        G = self.score(X, eta_G, K, nablaK)
        X_diff = self._X_diff(X)
        nabla2K = np.einsum('kij,ik->kj', -1 / (s**2) + X_diff**2 / (s**4), K)
        A = K + eta_H * np.eye(n)
        return -(G**2) + np.linalg.solve(A, nabla2K)

    def _evaluate_kernel(self, X: np.ndarray, s: float) -> np.ndarray:
        gamma = 1.0 / (2.0 * (s**2))
        return _rbf_kernel(X, gamma=gamma) / s

    def _evaluate_nablaK(self, K: np.ndarray, X: np.ndarray, s: float) -> np.ndarray:
        nablaK = -np.einsum('kij,ik->kj', self._X_diff(X), K) / (s**2)
        return nablaK

    def _X_diff(self, X: np.ndarray) -> np.ndarray:
        return np.expand_dims(X, axis=1) - X
    

def stein_score_hess_all(X: np.ndarray, eta_G: float = 1e-3, eta_H: float = 1e-3) -> Tuple[np.ndarray, np.ndarray]:
    """Apply Stein estimators env-by-env to produce Ŝ and Ĥ for Algorithm 1."""
    k1, n, d = X.shape
    S_hat = np.zeros((k1, n, d), dtype=float)
    H_hat = np.zeros((k1, n, d, d), dtype=float)
    stein = SteinMixin()
    for e in range(k1):
        S_hat[e] = stein.score(X[e], eta_G=eta_G)
        H_hat[e] = stein.hessian(X[e], eta_G=eta_G, eta_H=eta_H)
    return S_hat, H_hat


def gaussian_oracle_score_hess(X: np.ndarray, env_params: Dict[int, Dict[str, np.ndarray]]) -> Tuple[np.ndarray, np.ndarray]:
    """Exact score/Hessian for identity-f Gaussian data (used for correctness)."""
    k1, n, d = X.shape
    S = np.zeros((k1, n, d), dtype=float)
    H = np.zeros((k1, n, d, d), dtype=float)
    for e in range(k1):
        mu_e = env_params[e]["mu"]
        Sigma_inv = np.linalg.inv(env_params[e]["Sigma"])
        S[e] = -(X[e] - mu_e) @ Sigma_inv.T
        H[e] = -Sigma_inv[None, :, :]
    return S, H


def find_sigma_per_group(
    X: np.ndarray,
    score_fn: Callable[[np.ndarray], Tuple[np.ndarray, np.ndarray]],
    first_group_index: int = 1,
    *, verbose: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute estimated Hessian differences matrix (Sigma_hat) using score-guided pairing and Hessian differences (Algorithm 1).
    
    Args:
        X (np.ndarray): Array of shape (k+1, n, d), environments × samples × features.
        score_fn (Callable): Function that takes X and returns (S_hat, H_hat),
            where S_hat is (k+1, n, d) and H_hat is (k+1, n, d, d).
        first_group_index (int): Index marking the last environment of the first group of 
            environments. Counter starts at 0
        verbose (bool): If True, print pairing details.
    
    Returns:
        Tuple[np.ndarray, np.ndarray]:
            - Sigma_hat: Estimated covariance matrix, shape (d, d).
            - mean_pairs: Array of shape (k, 2), where each row contains (i0, j_e) indices for each environment.
    """
    k1, n, d = X.shape
    assert k1 >= 3, "Need at least three environments (k+1 ≥ 3)."
    # 1) per-env scores and Hessians
    S_hat, H_hat = score_fn(X)
    # 2) pairing across envs by nearest neighbor + minimal score-difference
    mean_pairs = np.zeros((k1 - 1, 2), dtype=int)
    for e in range(1, k1):
        dists = np.linalg.norm(X[0, :, None, :] - X[e, None, :, :], axis=-1)
        pairs = np.argmin(dists, axis=1)
        diffs = np.linalg.norm(S_hat[0, np.arange(n), :] - S_hat[e, pairs, :], axis=1) # (n,): for each S_hat[0,i] entry, compute |S_hat[0,i] - S_hat[e,pair of i]|
        i0 = int(np.argmin(diffs)) # index of the mean of the sources in the 0 env
        j_e = int(pairs[i0]) # index of the paired source in the e env
        mean_pairs[e - 1] = (i0, j_e) # the -1 is just because indices start at 0
        if verbose:
            print(f"env {e}: picked (i0={i0}, j_e={j_e}); ||Δscore||={diffs[i0]:.3e}")
    # 3) accumulate Hessian diffs
    Sigma_hats = np.zeros((2, d, d), dtype=float) # Sigma_hat for the two groups of extra envs
    # group 1: env 0 - envs 1,...,e_1
    Sigma_hat = np.zeros((d, d), dtype=float)
    for e in range(1, first_group_index+1):
        i0, j_e = mean_pairs[e - 1]
        Sigma_hat += (H_hat[0, i0] - H_hat[e, j_e])
    Sigma_hats[0] = Sigma_hat
    # group 2: env 0 - envs e_1+1,...,k1
    Sigma_hat = np.zeros((d, d), dtype=float)
    for e in range(first_group_index+1, k1):
        i0, j_e = mean_pairs[e - 1]
        Sigma_hat += (H_hat[0, i0] - H_hat[e, j_e])
    Sigma_hats[1] = Sigma_hat
    return Sigma_hats, mean_pairs



def sigma_closed_form_per_group(env_params: Dict[int, Dict[str, np.ndarray]], first_group_index:int =1) -> np.ndarray:
    """
    Compute closed-form solution for Sigma for Gaussian data with identity mixing function.
    
    Args:
        env_params (Dict[int, Dict[str, np.ndarray]]): Dictionary mapping environment index to dict with keys 
            'mu' (mean, shape (d,)) and 'Sigma' (covariance, shape (d, d)).
        first_group_index (int): Index marking the last environment of the first group of 
            environments. Counter starts at 0
    
    Returns:
        np.ndarray: Closed-form Sigma, shape (2, d, d).
    """
    k1 = len(env_params)
    d = env_params[0]["mu"].shape[0]
    assert k1 >= 3, "Need at least three environments (k+1 ≥ 3)."
    Sigma0_inv = np.linalg.inv(env_params[0]["Sigma"])
    Sigma_hats = np.zeros((2, d, d), dtype=float)
    # group 1
    Sigma = np.zeros_like(Sigma0_inv)
    for e in range(1, first_group_index+1):
        Sigma += (-Sigma0_inv + np.linalg.inv(env_params[e]["Sigma"]))
    Sigma_hats[0] = Sigma
    # group 2
    Sigma = np.zeros_like(Sigma0_inv)
    for e in range(first_group_index+1, k1):
        Sigma += (-Sigma0_inv + np.linalg.inv(env_params[e]["Sigma"]))
    Sigma_hats[1] = Sigma
    return Sigma_hats