import torch
from torch import nn
from typing import Callable, Optional

def _landmark_rows_leverage(
        E: torch.Tensor,                # (n, e) on CUDA
        l: int,                         # how many landmarks
        seed: Optional[int] = None,     # reproducible sampling
        rank: Optional[int] = None,     # target spectral rank (defaults ↓)
        oversample: int = 5,            # extra basis vectors for safety
        n_iter: int = 2                 # power iterations in svd_lowrank
) -> torch.LongTensor:
    if seed is not None:
        torch.manual_seed(seed)

    n, e = E.shape
    if rank is None:
        rank = min(l // 2, 256, e)

    U, _, _ = torch.svd_lowrank(E, q=rank + oversample, niter=n_iter)

    lev = U.pow(2).sum(dim=1)
    lev /= lev.sum()

    idx = torch.multinomial(lev, l, replacement=False)
    return idx.detach().cpu().numpy()

def _pairwise_sq_dists(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Compute ‖a_i‑b_j‖² without forming (m,m,d) tensors (GPU‑friendly)."""
    # a: (m,d), b: (n,d) (contiguous not required)
    a_norm = (a ** 2).sum(-1, keepdim=True)          # (m,1)
    b_norm = (b ** 2).sum(-1, keepdim=True).T         # (1,n)
    # dist² = |a|² + |b|² − 2a·b
    return a_norm + b_norm - 2 * (a @ b.T)            # (m,n)


def _rbf_kernel(
    a: torch.Tensor,
    b: torch.Tensor,
    sigma: float,
    *,
    chunk: Optional[int] = None,
) -> torch.Tensor:
    """RBF kernel K exp(-‖x-y‖²/2σ²) with optional chunked computation."""
    if chunk is None or a.size(0) <= chunk:
        d2 = _pairwise_sq_dists(a, b)
        return torch.exp(-d2 / (2 * sigma ** 2))
    # Block‑wise over rows of *a* to keep peak memory small
    res_rows = []
    for start in range(0, a.size(0), chunk):
        end = min(start + chunk, a.size(0))
        d2_block = _pairwise_sq_dists(a[start:end], b)
        res_rows.append(torch.exp(-d2_block / (2 * sigma ** 2)))
    return torch.cat(res_rows, dim=0)


def _cosine_kernel(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    a_norm = torch.nn.functional.normalize(a, dim=-1)
    b_norm = torch.nn.functional.normalize(b, dim=-1)
    return a_norm @ b_norm.T


def _median_heuristic(x: torch.Tensor, max_samples: int = 1000) -> float:
    """Return median pairwise distance (√) for RBF bandwidth heuristic."""
    m = x.size(0)
    if m > max_samples:
        idx = torch.randperm(m, device=x.device)[:max_samples]
        x = x[idx]
    d2 = _pairwise_sq_dists(x, x)
    i, j = torch.triu_indices(d2.size(0), d2.size(1), offset=1, device=x.device)
    median = torch.median(d2[i, j]).sqrt().item()
    return max(median, 1e-6)  # avoid zero


def _landmark_p_to_full_krr(
    landmark_idx: torch.Tensor, # (l, )
    embeddings: torch.Tensor, # (n, e)
    p_L: torch.Tensor, # (l, )
    damp: float = 1e-2,
    sigma: float = None,
    chunk_size: int = 1024,
):
    L = embeddings[landmark_idx].float() # (l, e)
    if sigma is None:
        sigma = _median_heuristic(L)

    K = _rbf_kernel(L, L, sigma, chunk=chunk_size)  # (l, l)
    K.diagonal().add_(damp * K.diag().mean())
    v = torch.linalg.solve(K, p_L.float())
    
    K_E = _rbf_kernel(embeddings.float(), L, sigma, chunk=chunk_size)  # (N, l)

    p = K_E @ v
    return p.to(p_L.dtype)

def _estimate_projected_grads_krrf(
        embds: torch.Tensor,           # (N, d_emb)
        landmark_idx: torch.Tensor,    # (m,)  indices into dim‑0 of embds
        landmark_grads: torch.Tensor,  # (m, d_proj)
        damp: float = 1e-2,           # ridge coefficient λ
        sigma: float = None,          # RBF kernel bandwidth
        chunk_size: int = 1024,        # chunk size
) -> torch.Tensor:
    L = embds[landmark_idx].float() # (l, e)
    if sigma is None:
        sigma = _median_heuristic(L)

    K = _rbf_kernel(L, L, sigma, chunk=chunk_size)  # (l, l)
    K.diagonal().add_(damp * K.diag().mean())
    V = torch.linalg.solve(K, landmark_grads.float())
    
    K_E = _rbf_kernel(embds.float(), L, sigma, chunk=chunk_size)  # (N, l)

    grads = K_E @ V
    return grads