import torch
import numpy as np
import logging
from torch.func import vmap, grad, jacrev, jvp

# Functional tools for automatic differentiation, batched with vmap
batched_grad = lambda f: vmap(grad(f))
batched_jacobian = lambda f: vmap(jacrev(f))

class Manifold_general:
    """
    A general manifold with equality h(x)=0 and inequality g(x)<=0 constraints.
    This class uses an internal cache and fully vectorized operations to maximize performance.
    """
    def __init__(self, dim, m, l, h, g, grad_h=None, grad_g=None, boundary_repulsion_rate=0.1):
        self.dim = dim
        self.m = m # Number of equality constraints
        self.l = l # Number of inequality constraints
        self.out_dim = dim # Add out_dim for compatibility
        
        # Store the original, non-batched functions for autograd
        self.h_single = h
        self.g_single = g

        # Create batched versions for internal use
        self.h = vmap(h) if m > 0 else None
        self.g = vmap(g) if l > 0 else None

        self.cut_off = 0.00
        self.epsilon = boundary_repulsion_rate

        self.grad_h = grad_h if grad_h is not None else (batched_jacobian(self.h_single) if m > 0 else lambda x: torch.empty(x.shape[0], 0, self.dim, device=x.device, dtype=x.dtype))
        self.grad_g = grad_g if grad_g is not None else (batched_jacobian(self.g_single) if l > 0 else lambda x: torch.empty(x.shape[0], 0, self.dim, device=x.device, dtype=x.dtype))
        
        self.hvp_h = self._create_batched_hvp_fn(self.h_single, self.m)
        self.hvp_g = self._create_batched_hvp_fn(self.g_single, self.l)
        
        # Internal cache
        self._cached_x = None
        self._cached_nabla_J = None
        self._cached_gram_matrix = None
        self._cached_active_constraints_mask = None
        self._cached_J_values = None
        self._cached_mean_curvature = None

    def _create_batched_hvp_fn(self, func, num_constraints):
        """
        A helper method to create a fully batched HVP function.
        This version avoids nested vmaps and data-dependent control flow.
        """
        if num_constraints == 0:
            return lambda point, tangent: torch.empty(point.shape[0], 0, point.shape[1], device=point.device, dtype=point.dtype)

        def hvp_per_sample(point, tangent):
            hvps = []
            for j in range(num_constraints):
                grad_fn_j = grad(lambda p: func(p)[j])
                _, hvp_j = jvp(grad_fn_j, (point,), (tangent,))
                hvps.append(hvp_j)
            return torch.stack(hvps)

        return vmap(hvp_per_sample)


    def _update_geometry_cache(self, x):
        """Internal method to compute and cache geometric quantities for a point x."""
        # --- Robust Caching Check ---
        # Checks for existence, shape, and content match before skipping computation.
        if (self._cached_x is not None and self._cached_x.shape == x.shape and torch.allclose(self._cached_x, x, atol=1e-7)):
            return

        bsz, dim = x.shape
        total_constraints = self.m + self.l

        nabla_J = torch.zeros(bsz, total_constraints, dim, device=x.device, dtype=x.dtype)
        J_values = torch.zeros(bsz, total_constraints, device=x.device, dtype=x.dtype)
        active_mask = torch.zeros(bsz, total_constraints, dtype=torch.bool, device=x.device)

        if self.m > 0:
            nabla_J[:, :self.m, :] = self.grad_h(x)
            J_values[:, :self.m] = self.h(x)
            active_mask[:, :self.m] = True

        if self.l > 0:
            g_vals = self.g(x)
            active_g_mask = g_vals >= -self.cut_off
            nabla_J[:, self.m:, :] = self.grad_g(x)
            J_values[:, self.m:] = g_vals # Store the TRUE g(x) value
            active_mask[:, self.m:] = active_g_mask
        
        nabla_J[~active_mask] = 0.0
        
        nabla_J_T = nabla_J.transpose(1, 2)
        gram_matrix = torch.bmm(nabla_J, nabla_J_T)
        gram_matrix.diagonal(dim1=-2, dim2=-1).add_(1e-6)

        self._cached_x = x
        self._cached_nabla_J = nabla_J
        self._cached_gram_matrix = gram_matrix
        self._cached_active_constraints_mask = active_mask
        self._cached_J_values = J_values
        self._cached_mean_curvature = None

    def compute_mean_curvature(self, x, n_hutchinson_samples= 5):
        """Computes the mean curvature vector H(x) using the pre-defined HVP functions."""
        self._update_geometry_cache(x)
        
        if self._cached_mean_curvature is not None:
            return self._cached_mean_curvature

        bsz, dim = x.shape
        nabla_J, gram_matrix, active_mask = self._cached_nabla_J, self._cached_gram_matrix, self._cached_active_constraints_mask
        
        if not active_mask.any():
            H = torch.zeros_like(x)
            self._cached_mean_curvature = H
            return H
            
        trace_estimates = torch.zeros(bsz, nabla_J.shape[1], device=x.device, dtype=x.dtype)

        for _ in range(n_hutchinson_samples):
            z = torch.randn(bsz, dim, device=x.device, dtype=x.dtype)
            projected_z = self.project_onto_tangent_space(z, x)

            all_hvps = torch.zeros_like(nabla_J)
            if self.m > 0:
                all_hvps[:, :self.m, :] = self.hvp_h(x, projected_z)
            if self.l > 0:
                all_hvps[:, self.m:, :] = self.hvp_g(x, projected_z)
            
            all_hvps[~active_mask] = 0.0
            trace_estimates += torch.sum(projected_z.unsqueeze(1) * all_hvps, dim=-1)

        v_traces = trace_estimates / n_hutchinson_samples
        v_traces[~active_mask] = 0.0

        w = torch.linalg.lstsq(gram_matrix, v_traces.unsqueeze(-1)).solution
        H = -torch.bmm(nabla_J.transpose(1, 2), w).squeeze(-1) # Corrected sign
        
        self._cached_mean_curvature = H
        return H
    
    def project_onto_tangent_space(self, y, base_point):
        """Projects a vector v onto the tangent space at x, using cached geometry."""
        self._update_geometry_cache(base_point)
        nabla_J, gram_matrix, active_mask = self._cached_nabla_J, self._cached_gram_matrix, self._cached_active_constraints_mask
        if not active_mask.any():
             return y

        nabla_J_v = torch.bmm(nabla_J, y.unsqueeze(-1))
        tang_vec = torch.linalg.lstsq(gram_matrix, nabla_J_v).solution
        return y - torch.bmm(nabla_J.transpose(1, 2), tang_vec).squeeze(-1)

    def constrain_fn(self, samples):
        """Computes the value of the equality constraint function h(x)."""
        if self.m > 0:
            return self.h(samples)
        else:
            return torch.empty(samples.shape[0], 0, device=samples.device, dtype=samples.dtype)

    def adding_correction_decaying(self, y, base_point, delta_t, alpha, sigma_sq, skip_mean_curvature=True):
        """Adds a correction term to y, using cached geometry."""
        self._update_geometry_cache(base_point) # Ensure cache is fresh
        if skip_mean_curvature:
            mean_curvature = torch.zeros_like(base_point)
        else:
            mean_curvature = self.compute_mean_curvature(base_point)
        nabla_J, gram_matrix, J_values, active_mask = self._cached_nabla_J, self._cached_gram_matrix, self._cached_J_values, self._cached_active_constraints_mask
        
        if not active_mask.any():
            return base_point + y
            
        J_values_decay = J_values.clone()
        if self.l > 0:
            J_values_decay[:, self.m:] += self.epsilon

        masked_J_decay = torch.where(active_mask, J_values_decay, torch.zeros_like(J_values_decay))
        z = torch.linalg.lstsq(gram_matrix, masked_J_decay.unsqueeze(-1)).solution
        decaying_term = -alpha * torch.bmm(nabla_J.transpose(1, 2), z).squeeze(-1)
        
        final_correction = decaying_term + mean_curvature
        
        scaling_factor = sigma_sq * torch.abs(delta_t)
        if scaling_factor.ndim == 1: scaling_factor = scaling_factor.unsqueeze(1)
            
        return base_point + y + final_correction * scaling_factor
    

    @torch.no_grad()
    def project_onto_manifold_with_base(self, y, base_point, threshold=1e-3, n_iters=30, **kwargs): # originally threshold 1e-5, n_iter = 30
        """Projects a point y + base_point onto the manifold using Newton's method."""
        keep_quiet = kwargs.get("keep_quiet", False)
        x_proj = y + base_point
        
        for i in range(n_iters):
            self._update_geometry_cache(x_proj)
            nabla_J, gram_matrix, J_values, active_mask = self._cached_nabla_J, self._cached_gram_matrix, self._cached_J_values, self._cached_active_constraints_mask

            if not active_mask.any():
                break

            J_values_decay = J_values.clone()
            if self.l > 0: 
                J_values_decay[:, self.m:] += self.epsilon # Bouncing projection

            # if self.l > 0:
            #     J_values_decay[:, self.m:] += 0.0 # No bouncing projection for inequality constraints

            masked_J_values = torch.where(active_mask, J_values, torch.zeros_like(J_values))
            masked_J_values_decay = torch.where(active_mask, J_values_decay, torch.zeros_like(J_values_decay))

            max_violation = torch.max(torch.abs(masked_J_values), dim=1).values
            if torch.all(max_violation < threshold):
                break

            lambda_sol = torch.linalg.lstsq(gram_matrix, masked_J_values_decay.unsqueeze(-1)).solution
            
            correction = -torch.bmm(nabla_J.transpose(1, 2), lambda_sol).squeeze(-1)
            
            x_proj = x_proj + correction

        self._update_geometry_cache(x_proj)
        J_values_final, active_mask_final = self._cached_J_values, self._cached_active_constraints_mask
        
        final_violations = torch.abs(J_values_final)
        final_violations[~active_mask_final] = 0.0
        
        non_converged_flag = torch.any(final_violations > threshold, dim=1) | ~torch.all(torch.isfinite(x_proj), dim=1)
        
        x_proj[non_converged_flag] = base_point[non_converged_flag]
        
        # --- MODIFIED LOGGING BLOCK ---
        if not keep_quiet:
            # Calculate final max violations for logging
            max_h_violation = 0.0
            if self.m > 0:
                h_violations = torch.abs(J_values_final[:, :self.m])
                if h_violations.numel() > 0:
                    max_h_violation = torch.max(h_violations).item()

            max_g_violation = 0.0
            if self.l > 0:
                g_violations = J_values_final[:, self.m:] # Use original values for g
                active_g_mask = active_mask_final[:, self.m:]
                # Only consider active inequality constraints for violation reporting
                if active_g_mask.any():
                    # For g, violation is g(x) > 0, so we look at the max positive value
                    max_g_violation = torch.max(torch.relu(g_violations[active_g_mask])).item()

            logging.info(f'Projection complete. Iterations: {i+1}, '
                         f'Non-converged: {non_converged_flag.sum().item()}, '
                         f'Max Eq Violation: {max_h_violation:.2e}, '
                         f'Max Ineq Violation: {max_g_violation:.2e}')
        # --- END OF MODIFICATION ---
        return x_proj.detach(), torch.logical_not(non_converged_flag).to(y)

    # @torch.no_grad()
    # def project_onto_manifold_with_base(self, y, base_point, threshold=1e-3, n_iters=10, **kwargs): # originally threshold 1e-5, n_iter = 30
    #     """Projects a point y + base_point onto the manifold using Newton's method."""
    #     keep_quiet = kwargs.get("keep_quiet", False)
    #     x_proj = y + base_point
        
    #     for i in range(n_iters):
    #         self._update_geometry_cache(x_proj)
    #         nabla_J, gram_matrix, J_values, active_mask = self._cached_nabla_J, self._cached_gram_matrix, self._cached_J_values, self._cached_active_constraints_mask

    #         if not active_mask.any():
    #             break

    #         J_values_decay = J_values.clone()
    #         if self.l > 0: 
    #             J_values_decay[:, self.m:] += self.epsilon # Bouncing projection

    #         # if self.l > 0:
    #         #     J_values_decay[:, self.m:] += 0.0 # No bouncing projection for inequality constraints

    #         masked_J_values = torch.where(active_mask, J_values, torch.zeros_like(J_values))
    #         masked_J_values_decay = torch.where(active_mask, J_values_decay, torch.zeros_like(J_values_decay))

    #         max_violation = torch.max(torch.abs(masked_J_values), dim=1).values
    #         if torch.all(max_violation < threshold):
    #             break

    #         lambda_sol = torch.linalg.lstsq(gram_matrix, masked_J_values_decay.unsqueeze(-1)).solution
            
    #         correction = -torch.bmm(nabla_J.transpose(1, 2), lambda_sol).squeeze(-1)
            
    #         x_proj = x_proj + correction

    # @torch.no_grad()
    # def project_onto_manifold_with_base(self, y, base_point, threshold=1e-5, n_iters=30, **kwargs): # originally threshold 1e-5, n_iter = 30
    #     """Projects a point y + base_point onto the manifold using Newton's method."""
    #     keep_quiet = kwargs.get("keep_quiet", False)
    #     x_proj = y + base_point

    #     I_prev = torch.zeros(x_proj.size(0), self.l, dtype=torch.bool, device=x_proj.device)
    #     for i in range(n_iters):
    #         self._update_geometry_cache(x_proj)
    #         nabla_J, gram, J_vals, active_mask = (
    #             self._cached_nabla_J, self._cached_gram_matrix,
    #             self._cached_J_values, self._cached_active_constraints_mask
    #         )

    #         # Split blocks
    #         h_vals = J_vals[:, :self.m]                                   # [B,m]
    #         g_vals = J_vals[:, self.m:]                                   # [B,l]

    #         # --- Active-set update: add violated or near-active inequalities
    #         near_active = (g_vals >= -self.cut_off)
    #         violated    = (g_vals > 0)
    #         I = near_active | violated | I_prev                           # [B,l]

    #         # Build equality Jacobian J_I = [∇h; ∇g_I]
    #         # We’ll keep your masked formulation but recover multipliers explicitly.
    #         mask = torch.zeros_like(J_vals, dtype=torch.bool)
    #         mask[:, :self.m] = True
    #         mask[:, self.m:] = I

    #         nabla_J_eff = nabla_J.clone()
    #         nabla_J_eff[~mask] = 0.0
    #         gram_eff = torch.bmm(nabla_J_eff, nabla_J_eff.transpose(1,2))
    #         gram_eff.diagonal(dim1=-2, dim2=-1).add_(1e-6)

    #         # Right-hand side is the equality residuals [h; g_I]
    #         rhs = torch.where(mask, J_vals, torch.zeros_like(J_vals)).unsqueeze(-1)

    #         # Solve for multipliers (λ, μ_I)
    #         lambda_mu = torch.linalg.lstsq(gram_eff, rhs).solution                    # [B, m+|I|, 1]

    #         # Recover correction: x^{+} = x - ∇J^T (λ, μ_I)
    #         correction = -torch.bmm(nabla_J_eff.transpose(1,2), lambda_mu).squeeze(-1)
    #         x_new = x_proj + correction

    #         # Recompute to get multipliers’ signs at the new point
    #         self._update_geometry_cache(x_new)
    #         # Rebuild and resolve once (cheap) to evaluate μ_I sign
    #         nabla_J_eff = self._cached_nabla_J.clone()
    #         nabla_J_eff[~mask] = 0.0
    #         gram_eff = torch.bmm(nabla_J_eff, nabla_J_eff.transpose(1,2))
    #         gram_eff.diagonal(dim1=-2, dim2=-1).add_(1e-6)
    #         rhs = torch.where(mask, self._cached_J_values, torch.zeros_like(J_vals)).unsqueeze(-1)
    #         lambda_mu = torch.linalg.lstsq(gram_eff, rhs).solution

    #         # Extract μ_I (inequality multipliers) and drop those with negative sign
    #         mu_block = lambda_mu[:, self.m:, 0]                           # [B, l(masked)]
    #         drop = (mu_block < -1e-8)
    #         I[drop] = False                                               # per-batch drop
    #         I_prev = I.clone()

    #         # Accept step
    #         x_proj = x_new

    #         # KKT stopping: eq feas, ineq feas, comp residual small
    #         eq_feas = (h_vals.abs().max(dim=1).values < 1e-6)
    #         ineq_feas = (g_vals.max(dim=1).values <= 1e-6)
    #         comp_res = torch.maximum(g_vals, torch.zeros_like(g_vals)) * torch.clamp(mu_block, min=0).max(dim=1).values
    #         done = eq_feas & ineq_feas & (comp_res < 1e-6)
    #         if torch.all(done):
    #             break

    #     self._update_geometry_cache(x_proj)
    #     J_values_final, active_mask_final = self._cached_J_values, self._cached_active_constraints_mask

    #     final_violations = torch.abs(J_values_final)
    #     final_violations[~active_mask_final] = 0.0

    #     non_converged_flag = torch.any(final_violations > threshold, dim=1) | ~torch.all(torch.isfinite(x_proj), dim=1)

    #     x_proj[non_converged_flag] = base_point[non_converged_flag]

    #     # --- MODIFIED LOGGING BLOCK ---
    #     if not keep_quiet:
    #         # Calculate final max violations for logging
    #         max_h_violation = 0.0
    #         if self.m > 0:
    #             h_violations = torch.abs(J_values_final[:, :self.m])
    #             if h_violations.numel() > 0:
    #                 max_h_violation = torch.max(h_violations).item()

    #         max_g_violation = 0.0
    #         if self.l > 0:
    #             g_violations = J_values_final[:, self.m:] # Use original values for g
    #             active_g_mask = active_mask_final[:, self.m:]
    #             # Only consider active inequality constraints for violation reporting
    #             if active_g_mask.any():
    #                 # For g, violation is g(x) > 0, so we look at the max positive value
    #                 max_g_violation = torch.max(torch.relu(g_violations[active_g_mask])).item()

    #         logging.info(f'Projection complete. Iterations: {i+1}, '
    #                         f'Non-converged: {non_converged_flag.sum().item()}, '
    #                         f'Max Eq Violation: {max_h_violation:.2e}, '
    #                         f'Max Ineq Violation: {max_g_violation:.2e}')
    #     # --- END OF MODIFICATION ---
    #     return x_proj.detach(), torch.logical_not(non_converged_flag).to(y)


    def sample(self, num_samples):
        raise NotImplementedError("Manifold-specific sampling is required.")

    def report_violations(self, x):
        """
        Computes and returns a string reporting the current constraint violations for a point x.
        This version includes the count of samples violating the inequality constraint.
        """
        # Ensure the cache is updated for the given point x
        self._update_geometry_cache(x)
        J_values = self._cached_J_values
        active_mask = self._cached_active_constraints_mask
        
        total_samples = x.shape[0]

        max_h_violation = 0.0
        if self.m > 0:
            h_violations = torch.abs(J_values[:, :self.m])
            if h_violations.numel() > 0:
                max_h_violation = torch.max(h_violations).item()

        max_g_violation = 0.0
        num_g_violated = 0
        if self.l > 0:
            g_violations = J_values[:, self.m:]
            active_g_mask = active_mask[:, self.m:]
            
            # --- MODIFICATION: Count violated samples ---
            # A sample violates the inequality constraint if its g(x) > 0.
            # Since g_violations contains the value of our min-of-max function,
            # we just need to count how many are positive.
            violated_samples_mask = g_violations > 0
            num_g_violated = torch.sum(violated_samples_mask).item()
            # --- END OF MODIFICATION ---

            if active_g_mask.any():
                # For g, violation is g(x) > 0, so we look at the max positive value
                max_g_violation = torch.max(torch.relu(g_violations[active_g_mask])).item()
        
        return (f'Eq_viol: {max_h_violation:.2e}, '
                f'Ineq_viol: {max_g_violation:.2e} '
                f'({num_g_violated}/{total_samples} violated out of {total_samples})')