import torch

import torch
from torch.func import vmap, jacrev

class ProjectWeightedImplicit(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xi_batch, sigma_batch, m_batch, h_func, max_iter=30):
        """
        Batched Newton‐projection with implicit‐diff backward pass.
        
        Args:
            xi_batch:   Tensor of shape (B, p), unconstrained points
            sigma_batch: Tensor of shape (B, p), positive weights
            m_batch:    Tensor of extra parameters per sample (e.g. (B, …))
            h_func:     callable h(u, m) → Tensor of shape (k,) or scalar
            max_iter:   number of Newton steps
        
        Returns:
            u_proj: Tensor of shape (B, p), projections s.t. h(u_proj, m)=0
        """
        B, p = xi_batch.shape

        def project_one(xi, sigma, m):
            # ensure positivity
            eps = 1e-6
            sigma = sigma.clamp(min=eps)
            sqrt_sigma = sigma.sqrt()
            inv_sqrt = 1.0 / sqrt_sigma

            # transform to tilde‐space
            xi_tilde = xi * inv_sqrt
            u_tilde  = xi_tilde.clone()

            def newton_step(u_tilde):
                # back to u‐space
                u = u_tilde * sqrt_sigma
                h = h_func(u, m)
                # make h a vector of shape (k,)
                if h.ndim == 0:
                    h = h.unsqueeze(0)

                # Jacobian ∂h/∂u at u: shape (k, p)
                J_u = jacrev(lambda u_: h_func(u_, m))(u)
                if J_u.ndim == 1:
                    J_u = J_u.unsqueeze(0)
                # chain‐rule for tilde: dh/dũ = J_u * diag(sqrt_sigma)
                J = J_u * sqrt_sigma.unsqueeze(0)

                # compute Newton update in tilde‐space
                δ = (xi_tilde - u_tilde).unsqueeze(-1)        # (p,1)
                JJt = J @ J.transpose(-2, -1)                  # (k,k)
                rhs = J @ δ + h.unsqueeze(-1)                  # (k,1)
                I   = torch.eye(JJt.shape[-1], device=JJt.device)
                λ   = torch.linalg.solve(JJt + 1e-6*I, rhs)    # (k,1)

                du = δ - J.transpose(-2, -1) @ λ               # (p,1)
                return torch.clamp(u_tilde + du.squeeze(-1), min=0.0)

            # run Newton iterations
            for _ in range(max_iter):
                u_tilde = newton_step(u_tilde)

            # back to original space and clamp
            return torch.clamp(u_tilde * sqrt_sigma, min=0.0)

        # vectorize over batch
        u_proj = vmap(project_one)(xi_batch, sigma_batch, m_batch)

        # save for backward
        ctx.save_for_backward(xi_batch, sigma_batch, m_batch, u_proj)
        ctx.h_func = h_func
        return u_proj

    @staticmethod
    def backward(ctx, grad_u):
        """
        Implicit differentiation via KKT system:
          A = [[Σ⁻¹,    Jᵀ],
               [  J,      0]]
        Stationarity & primal feasibility → solve Aᵀ w = [grad_u; 0],
        then ∂u/∂xi^T g = Σ⁻¹ w_u.
        """
        xi_batch, sigma_batch, m_batch, u_proj = ctx.saved_tensors
        h_func = ctx.h_func

        B, p = xi_batch.shape
        grad_xi_list = []

        for i in range(B):
            xi    = xi_batch[i]    # (p,)
            sigma = sigma_batch[i] # (p,)
            u     = u_proj[i]      # (p,)
            m     = m_batch[i]
            g     = grad_u[i]      # (p,)

            # compute J_u = ∂h/∂u at the solution
            h = h_func(u, m)
            if h.ndim == 0:
                h = h.unsqueeze(0)
            J_u = jacrev(lambda u_: h_func(u_, m))(u)
            if J_u.ndim == 1:
                J_u = J_u.unsqueeze(0)
            k = J_u.shape[0]

            # build KKT matrix A and solve Aᵀ w = [g; 0]
            σ_inv = 1.0 / sigma
            # top block: [Σ⁻¹, Jᵀ]
            top    = torch.cat([torch.diag(σ_inv), J_u.transpose(0,1)], dim=1)  # (p, p+k)
            # bottom: [J, 0]
            bottom = torch.cat([J_u, torch.zeros(k, k, device=u.device)], dim=1)  # (k, p+k)
            A      = torch.cat([top, bottom], dim=0)                             # (p+k, p+k)

            # right‐hand side
            rhs = torch.cat([g, torch.zeros(k, device=g.device)], dim=0)         # (p+k,)

            # solve Aᵀ w = rhs
            w = torch.linalg.solve(A.transpose(0,1), rhs)                        # (p+k,)

            # gradient wrt xi: Σ⁻¹ * w_u
            w_u       = w[:p]
            grad_xi   = σ_inv * w_u
            grad_xi_list.append(grad_xi)

        grad_xi_batch = torch.stack(grad_xi_list, dim=0)

        # return gradients for each forward argument
        return grad_xi_batch, None, None, None, None


# ------------------------------------------------------------------------------
# Usage example:

# Suppose:
#   xi_batch   = torch.randn(B, p, requires_grad=True)
#   sigma_batch= torch.rand(B, p)
#   m_batch    = some tensor of shape (B, …)
#   def full_residual(u, m):  return ...  # returns a vector (k,) constraint

# Then project + differentiate:
# u_proj = ProjectWeightedImplicit.apply(xi_batch, sigma_batch, m_batch, full_residual, 30)
# loss = some_loss(u_proj)
# loss.backward()
# ------------------------------------------------------------------------------


class ProjectOrthoImplicit(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xi_batch, sigma_batch, m_batch, h_func, max_iter=30):
        """
        Batched Newton‐projection with implicit‐diff backward pass.
        
        Args:
            xi_batch:   Tensor of shape (B, p), unconstrained points
            sigma_batch: Tensor of shape (B, p), positive weights
            m_batch:    Tensor of extra parameters per sample (e.g. (B, …))
            h_func:     callable h(u, m) → Tensor of shape (k,) or scalar
            max_iter:   number of Newton steps
        
        Returns:
            u_proj: Tensor of shape (B, p), projections s.t. h(u_proj, m)=0
        """
        B, p = xi_batch.shape

        def project_one(xi, sigma, m):
            # ensure positivity
            eps = 1e-6
            sigma = sigma.clamp(min=eps)
            ##Added this line
            sigma = torch.ones_like(sigma)
            sqrt_sigma = sigma.sqrt()
            inv_sqrt = 1.0 / sqrt_sigma

            # transform to tilde‐space
            xi_tilde = xi * inv_sqrt
            u_tilde  = xi_tilde.clone()

            def newton_step(u_tilde):
                # back to u‐space
                u = u_tilde * sqrt_sigma
                h = h_func(u, m)
                # make h a vector of shape (k,)
                if h.ndim == 0:
                    h = h.unsqueeze(0)

                # Jacobian ∂h/∂u at u: shape (k, p)
                J_u = jacrev(lambda u_: h_func(u_, m))(u)
                if J_u.ndim == 1:
                    J_u = J_u.unsqueeze(0)
                # chain‐rule for tilde: dh/dũ = J_u * diag(sqrt_sigma)
                J = J_u * sqrt_sigma.unsqueeze(0)

                # compute Newton update in tilde‐space
                δ = (xi_tilde - u_tilde).unsqueeze(-1)        # (p,1)
                JJt = J @ J.transpose(-2, -1)                  # (k,k)
                rhs = J @ δ + h.unsqueeze(-1)                  # (k,1)
                I   = torch.eye(JJt.shape[-1], device=JJt.device)
                λ   = torch.linalg.solve(JJt + 1e-6*I, rhs)    # (k,1)

                du = δ - J.transpose(-2, -1) @ λ               # (p,1)
                return torch.clamp(u_tilde + du.squeeze(-1), min=0.0)

            # run Newton iterations
            for _ in range(max_iter):
                u_tilde = newton_step(u_tilde)

            # back to original space and clamp
            return torch.clamp(u_tilde * sqrt_sigma, min=0.0)

        # vectorize over batch
        u_proj = vmap(project_one)(xi_batch, sigma_batch, m_batch)

        # save for backward
        ctx.save_for_backward(xi_batch, sigma_batch, m_batch, u_proj)
        ctx.h_func = h_func
        return u_proj

    @staticmethod
    def backward(ctx, grad_u):
        """
        Implicit differentiation via KKT system:
          A = [[Σ⁻¹,    Jᵀ],
               [  J,      0]]
        Stationarity & primal feasibility → solve Aᵀ w = [grad_u; 0],
        then ∂u/∂xi^T g = Σ⁻¹ w_u.
        """
        xi_batch, sigma_batch, m_batch, u_proj = ctx.saved_tensors
        h_func = ctx.h_func

        B, p = xi_batch.shape
        grad_xi_list = []

        for i in range(B):
            xi    = xi_batch[i]    # (p,)
            sigma = sigma_batch[i] # (p,)
            u     = u_proj[i]      # (p,)
            m     = m_batch[i]
            g     = grad_u[i]      # (p,)

            # compute J_u = ∂h/∂u at the solution
            h = h_func(u, m)
            if h.ndim == 0:
                h = h.unsqueeze(0)
            J_u = jacrev(lambda u_: h_func(u_, m))(u)
            if J_u.ndim == 1:
                J_u = J_u.unsqueeze(0)
            k = J_u.shape[0]

            # build KKT matrix A and solve Aᵀ w = [g; 0]
            ## CHANGED HERE
            σ_inv = (sigma)*(1.0 / sigma)
            # top block: [Σ⁻¹, Jᵀ]
            top    = torch.cat([torch.diag(σ_inv), J_u.transpose(0,1)], dim=1)  # (p, p+k)
            # bottom: [J, 0]
            bottom = torch.cat([J_u, torch.zeros(k, k, device=u.device)], dim=1)  # (k, p+k)
            A      = torch.cat([top, bottom], dim=0)                             # (p+k, p+k)

            # right‐hand side
            rhs = torch.cat([g, torch.zeros(k, device=g.device)], dim=0)         # (p+k,)

            # solve Aᵀ w = rhs
            w = torch.linalg.solve(A.transpose(0,1), rhs)                        # (p+k,)

            # gradient wrt xi: Σ⁻¹ * w_u
            w_u       = w[:p]
            grad_xi   = σ_inv * w_u
            grad_xi_list.append(grad_xi)

        grad_xi_batch = torch.stack(grad_xi_list, dim=0)

        # return gradients for each forward argument
        return grad_xi_batch, None, None, None, None


class ProjectOrthImplicit(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xi_batch, m_batch, h_func, max_iter: int = 10, reg: float = 1e-6):
        """
        Orthogonal projection via eliminated‐lambda Newton steps:

            u* = argmin ||u - xi||^2   s.t. h(u)=0

        Each step solves:
            δ = xi - u
            S = J_u J_u^T + reg·I
            λ =  S^{-1}(J_u δ + h)
            Δu = δ - J_u^T λ
            u <- u + Δu
        """
        B, n = xi_batch.shape

        def project_one(xi, m):
            u = xi.clone()
            for _ in range(max_iter):
                # 1) compute residual & constraints
                δ = (xi - u).unsqueeze(-1)          # (n,1)
                h = h_func(u, m)
                if h.ndim == 0: h = h.unsqueeze(0)   # (k,)

                # 2) constraint Jacobian J_u (k × n)
                J_u = jacrev(lambda uu: h_func(uu, m))(u)  # (k, n)
                if J_u.ndim == 1: J_u = J_u.unsqueeze(0)
                
                # 3) build & regularize small matrix S = J_u J_u^T + reg·I  (k×k)
                S = J_u @ J_u.transpose(0,1)
                S = S + reg * torch.eye(S.shape[0], device=S.device, dtype=S.dtype)
                
                # 4) solve for λ:  S λ = J_u δ + h
                rhs = (J_u @ δ) + h.unsqueeze(-1)    # (k,1)
                try:
                    lam = torch.linalg.solve(S, rhs)  # (k,1)
                except RuntimeError:
                    lam = torch.linalg.pinv(S) @ rhs   # fallback
                
                # 5) update u:  Δu = δ - J_u^T λ
                du = δ - J_u.transpose(0,1) @ lam    # (n,1)
                u  = u + du.squeeze(-1)

            return u

        u_proj = vmap(project_one)(xi_batch, m_batch)
        ctx.save_for_backward(xi_batch, u_proj, m_batch)
        ctx.h_func = h_func
        ctx.max_iter = max_iter
        ctx.reg = reg
        return u_proj

    @staticmethod
    def backward(ctx, grad_u):
        """
        Implicit‐diff VJP for the orthogonal projector:

          ∂u*/∂ξ = I - J_u^T (J_u J_u^T)⁻¹ J_u

        So
          grad_ξ = (I - J_u^T S⁻¹ J_u) · grad_u
                 = grad_u - J_u^T [ S⁻¹ (J_u grad_u) ].
        """
        xi_batch, u_proj, m_batch = ctx.saved_tensors
        h_func = ctx.h_func
        reg    = ctx.reg
        B, n   = xi_batch.shape

        grads = []
        for i in range(B):
            u = u_proj[i]
            m = m_batch[i]
            g = grad_u[i].unsqueeze(-1)         # (n,1)

            # 1) Jacobian at solution
            J_u = jacrev(lambda uu: h_func(uu, m))(u)
            if J_u.ndim == 1: J_u = J_u.unsqueeze(0)  # (k,n)
            
            # 2) small Schur matrix S = J_u J_u^T + reg·I
            S = J_u @ J_u.transpose(0,1)
            S = S + reg * torch.eye(S.shape[0], device=S.device, dtype=S.dtype)
            try:
                S_inv = torch.linalg.inv(S)
            except RuntimeError:
                S_inv = torch.linalg.pinv(S)

            # 3) α = S⁻¹ (J_u g)
            α = S_inv @ (J_u @ g)            # (k,1)

            # 4) VJP: grad_ξ = g - J_u^T α
            grad_xi = g - J_u.transpose(0,1) @ α  # (n,1)
            grads.append(grad_xi.squeeze(-1))

        return torch.stack(grads, dim=0), None, None, None, None


# ----------------------------------------------------------------------------
# Usage:

# xi_batch:  (B, n) input points (requires_grad=True)
# m_batch:   (B, …) extra parameters for h_func
# def h_func(u, m): return constraints of shape (k,)

# project:
# u_proj = ProjectOrthImplicit.apply(xi_batch, m_batch, h_func, max_iter=10)

# # backprop:
# loss = some_loss(u_proj)
# loss.backward()



def safe_diag_JSigmaJT(sigma: torch.Tensor,
                       J_u: torch.Tensor,
                       reg: float = 1e-1,
                       eps: float = 1e-6) -> torch.Tensor:
    """
    Robustly compute diag(J Σ J^T) for one sample.

    Args:
      sigma: (p,) — diagonal of Σ (prior variances)
      J_u:   (k, p) — ∂h/∂u at the solution
      reg:   extra jitter on S = J_u Σ J_u^T
      eps:   minimum allowed variance on sigma

    Returns:
      (p,) — diagonal entries of J Σ J^T
    """
    # 1) clamp sigma so no zeros
    sigma = sigma.clamp(min=eps)

    # 2) build and regularize small matrix S = J_u Σ J_u^T
    S = (J_u * sigma.unsqueeze(0)) @ J_u.transpose(0, 1)     # (k, k)
    S = S + (reg + eps) * torch.eye(S.shape[0], device=S.device, dtype=S.dtype)

    # 3) invert with fallback
    try:
        S_inv = torch.linalg.inv(S)
    except RuntimeError:
        S_inv = torch.linalg.pinv(S)

    # 4) form diag: σ - σ² * diag(J_u^T S⁻¹ J_u)
    T    = S_inv @ J_u               # (k, p)
    quad = (J_u * T).sum(dim=0)      # (p,)
    return sigma - sigma**2 * quad


def batch_safe_diag_JSigmaJT(u_proj: torch.Tensor,
                             sigma_batch: torch.Tensor,
                             m_batch: torch.Tensor,
                             h_func,
                             reg: float = 1e-1,
                             eps: float = 1e-6) -> torch.Tensor:
    """
    Vectorized, safe diag(J Σ J^T) over a batch, with pre-/post-clamping
    and debug prints.

    Args:
      u_proj:      (B, p) — projected outputs u*
      sigma_batch: (B, p) — input variances (diagonals of Σ)
      m_batch:     (B, …) — extra per-sample params for h_func
      h_func:      callable h(u, m) → (k,) or scalar
      reg:         jitter for small-matrix regularization
      eps:         variance floor
    Returns:
      (B, p) — safe diagonal entries
    """
    # Pre-clamp any NaNs/Infs and zeros in the input variances
    sigma_batch = sigma_batch.nan_to_num(nan=eps, posinf=eps, neginf=eps)
    sigma_batch = sigma_batch.clamp(min=eps)

    def one(sigma, u, m):
        # Debug input range
        # print(f"[DEBUG] input σ: min={sigma.min().item():.3e}, max={sigma.max().item():.3e}")
        # Build Jacobian of constraints
        J_u = jacrev(lambda u_: h_func(u_, m))(u)
        if J_u.ndim == 1:
            J_u = J_u.unsqueeze(0)
        # Safe Schur-complement diag
        safe = safe_diag_JSigmaJT(sigma, J_u, reg=reg, eps=eps)
        # Debug output
        # print(f"[DEBUG] output diag NaN={safe.isnan().any().item()}, min={safe.min().item():.3e}, max={safe.max().item():.3e}")
        return safe

    # Vectorize over batch
    variances = vmap(one)(sigma_batch, u_proj, m_batch)

    # Post-clamp to eliminate any stray NaNs/Infs or zeros
    variances = variances.nan_to_num(nan=eps, posinf=eps, neginf=eps)
    variances = variances.clamp(min=eps)

    return variances


# Example integration into your pipeline:

def project_and_stats(mu_batch: torch.Tensor,
                      var_batch: torch.Tensor,
                      m_batch: torch.Tensor,
                      h_func,
                      max_iter: int = 30):
    # 1) Project to constraints (using your existing Function)
    new_mu = ProjectWeightedImplicit.apply(mu_batch, var_batch, m_batch, h_func, max_iter)
    # 2) Compute safe diagonal of J Σ J^T
    new_var = batch_safe_diag_JSigmaJT(new_mu, var_batch, m_batch, h_func,
                                       reg=1e-6, eps=1e-6)
    return new_mu, new_var


# And if you want a single “safe” entrypoint:
def safe_project_and_stats(mu, var, m, h_func, max_iter=30, eps_sigma=1e-6):
    new_mu, new_var = project_and_stats(mu, var, m, h_func, max_iter)
    # final safety clamp
    new_var = new_var.nan_to_num(nan=eps_sigma, posinf=eps_sigma, neginf=eps_sigma)
    new_var = new_var.clamp(min=eps_sigma)
    return new_mu, new_var


def safe_diag_PSigmaP(sigma: torch.Tensor,
                      J_u: torch.Tensor,
                      reg: float = 1e-6,
                      eps: float = 1e-6) -> torch.Tensor:
    """
    Robustly compute diag(P Σ P) for one sample, where
      P = I - J_u^T (J_u J_u^T + reg·I)^{-1} J_u
      Σ = diag(sigma)

    Args:
      sigma: (n,) — diagonal of Σ (prior variances)
      J_u:   (k, n) — ∂h/∂u at the projected solution
      reg:   jitter added to S = J_u J_u^T
      eps:   minimum allowed variance floor

    Returns:
      (n,) — diagonal entries of P Σ P
    """
    # 1) clamp sigma to avoid zeros
    sigma = sigma.clamp(min=eps)

    # 2) form and regularize the small matrix S = J_u J_u^T
    S = J_u @ J_u.transpose(0, 1)                               # (k, k)
    S = S + reg * torch.eye(S.shape[0], device=S.device, dtype=S.dtype)

    # 3) invert with fallback
    try:
        S_inv = torch.linalg.inv(S)
    except RuntimeError:
        S_inv = torch.linalg.pinv(S)

    # 4) build M = J_u^T S_inv J_u  => shape (n, n)
    A = S_inv @ J_u                                              # (k, n)
    M = J_u.transpose(0,1) @ A                                  # (n, n)

    # 5) compute P = I - M
    n = sigma.shape[0]
    P = torch.eye(n, device=M.device, dtype=M.dtype) - M        # (n, n)

    # 6) diag(P Σ P) = sum_j P[i,j]^2 * sigma[j]
    post = (P.pow(2) * sigma.unsqueeze(0)).sum(dim=1)            # (n,)

    return post.clamp(min=eps)


def batch_safe_diag_PSigmaP(u_proj: torch.Tensor,
                            sigma_batch: torch.Tensor,
                            m_batch: torch.Tensor,
                            h_func,
                            reg: float = 1e-6,
                            eps: float = 1e-6) -> torch.Tensor:
    """
    Vectorized safe diag(P Σ P) over the batch.

    Args:
      u_proj:      (B, n) — projected outputs u*
      sigma_batch: (B, n) — input variances σ
      m_batch:     (B, …) — extra per-sample params for h_func
      h_func:      callable h(u, m) → (k,) or scalar
      reg:         jitter for S
      eps:         variance floor
    Returns:
      (B, n) — diag(P Σ P) per sample
    """
    # pre‐clamp inputs
    sigma_batch = sigma_batch.nan_to_num(nan=eps, posinf=eps, neginf=eps)
    sigma_batch = sigma_batch.clamp(min=eps)

    def one(sigma, u, m):
        # build constraint Jacobian at u*
        J_u = jacrev(lambda u_: h_func(u_, m))(u)
        if J_u.ndim == 1:
            J_u = J_u.unsqueeze(0)
        # safe Schur‐complement diagonal
        return safe_diag_PSigmaP(sigma, J_u, reg=reg, eps=eps)

    return vmap(one)(sigma_batch, u_proj, m_batch)


def project_and_stats_orth(mu_batch: torch.Tensor,
                           var_batch: torch.Tensor,
                           m_batch: torch.Tensor,
                           h_func,
                           max_iter: int = 10,
                           reg: float = 1e-6,
                           eps: float = 1e-6):
    """
    1) Orthogonal‐projection: u* = argmin ||u - mu||² s.t. h(u)=0
       with implicit‐diff backward via ProjectOrthImplicit.

    2) Compute posterior variances diag(P Σ P) safely.

    Args:
      mu_batch:   (B, n) — prior means ξ
      var_batch:  (B, n) — prior variances σ (diag Σ)
      m_batch:    (B, …) — extra params for h_func(u,m)
      h_func:     callable h(u, m) → (k,) constraints
      max_iter:   Newton steps for projection
      reg, eps:   jitter & floor for safe diag
    Returns:
      new_mu: (B, n) — orthogonal projections u*
      new_var:(B, n) — diag(P Σ P)
    """
    # 1) project to constraint manifold
    # new_mu = ProjectOrthImplicit.apply(mu_batch, m_batch, h_func, max_iter)
    new_mu = ProjectOrthoImplicit.apply(mu_batch, var_batch, m_batch, h_func, max_iter)
    # 2) compute safe diag(P Σ P)
    new_var = batch_safe_diag_PSigmaP(new_mu, var_batch, m_batch, h_func,
                                      reg=reg, eps=eps)
    return new_mu, new_var
