#########################################################################
##   This file is part of the auto_LiRPA library, a core part of the   ##
##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##
##   by the α,β-CROWN Team                                             ##
##                                                                     ##
##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##
##   Primary contacts: Huan Zhang <huan@huan-zhang.com> (UIUC)         ##
##                     Zhouxing Shi <zshi@cs.ucla.edu> (UCLA)          ##
##                     Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##
##                                                                     ##
##    See CONTRIBUTORS for all author contacts and affiliations.       ##
##                                                                     ##
##     This program is licensed under the BSD 3-Clause License,        ##
##        contained in the LICENCE file in this directory.             ##
##                                                                     ##
#########################################################################
import time
import torch

from math import floor, ceil
from .utils import eyeC
import warnings

@torch.jit.script
def _sort_out_constraints(A, b, x_0, epsilon):
    r"""
    Filter out some batches with constraints not intersecting with input region

    Args:
        A (Tensor): A matrix of constraints with shape of (batchsize, n_constraints, x_dim)
        b (Tensor): Bias term of constraints with shape of (batchsize, n_constraints)
        x_0 (Tensor): Centroid of the input space with shape of (batchsize, x_dim, 1)
        epsilon (Tensor): Offset from the centroid to the input space boundary with shape of (batchsize, x_dim, 1)
    Return:
        no_intersection (Tensor): A boolean tensor with shape (batchsize, ), indicating if certain batch is infeasible
            because a constraint does not intersect with input space
        fully_covered (Tensor): A boolean tensor with shape (batchsize, ), indicating if all the constraints in a certain 
            batch fully covers the corresponding input region. In this case, we can simply the batch as if it has no constraints
    """
    batch_size = x_0.shape[0]
    x_dim = x_0[0].numel()

    x_0 = x_0.view((batch_size, x_dim, 1))
    epsilon = epsilon.view((batch_size, x_dim, 1))
    A = A.view((batch_size, -1, x_dim))
    b = b.view((batch_size, -1))

    # minimal and maximal value of a 
    # minimal_val: (bs, n_constraints)
    minimal_val = A.bmm(x_0).squeeze(-1) + b - A.abs().bmm(epsilon).squeeze(-1)
    maximal_val = A.bmm(x_0).squeeze(-1) + b + A.abs().bmm(epsilon).squeeze(-1)

    # for any constrains: A * x + b <= 0, if its min(A * x + b) > 0, it has no intersection with x0 +- epsilon
    # for any constrains: A * x + b <= 0, if its max(A * x + b) <= 0, it fully covers x0 +- epsilon 

    # no_intersection, fully_covered: (bs, )
    no_intersection = (minimal_val > 0).any(1)
    fully_covered = (maximal_val <= 0).all(1) | (A == 0).all((1, 2))
    return no_intersection, fully_covered

@torch.jit.script
def _dist_rearrange(constraints_A, constraints_b, x_prime):
    r"""
    Reorder the constraints according to their distance to x_prime

    Args:
        constraints_A (Tensor): A matrix of constraints with shape of (batchsize, n_constraints, x_dim)
        constraints_b (Tensor): Bias term of constraints with shape of (batchsize, n_constraints)
        x_prime (Tensor): A tensor with shape of (batchsize, x_dim, 1). Based on the heuristic,
        this can be the input space centroid x_0, or the original optimal point x_prime
    Return:
        rearranged_A (Tensor): Rearranged matrix of constraints with shape of (batchsize, n_constraints, x_dim)
        rearranged_b (Tensor): Bias term of constraints with shape of (batchsize, n_constraints)
    """
    # dist: (bs, n_constraints)
    dist = (constraints_A.bmm(x_prime).squeeze(-1) + constraints_b) / constraints_A.norm(p=2, dim=-1)
    order = torch.sort(dist, descending=True, dim=1)[1]
    order_expand = order.unsqueeze(-1).expand(-1, -1, constraints_A.size(-1))
    rearranged_A = constraints_A.gather(index=order_expand, dim=1)
    rearranged_b = constraints_b.gather(index=order, dim=1)
    return rearranged_A, rearranged_b

# @torch.compile()
# @torch.jit.script
def _solve(a, c, b, epsilon):
    r"""
    Solve the following optimization problem:

    Primal:         min_x   c^T x
                    s.t.    a^T x + b <= 0,
                            x0-epsilon <= x <= x0+epsilon

    Dual:           min_x max_beta  c^T x + beta * (a^T x + b)
                    s.t.            x0 - epsilon <= x <= x0 + epsilon
                                    beta >= 0

    Strong duality:
                    max_{beta >= 0} min_{x \in X} c^T x + beta * (a^T x + b)

    Dual norm:
                    max_{beta >= 0} - |c + beta * a|^T epsilon + beta * (a^T x0 + b) + c^T x_0

    Now the sole optimize problem is piece-wise linear, we just have to check each 
    turning point and the end points of beta (0 and +inf)

    Args:
        a (Tensor): A vector of constraint with shape of (batchsize, x_dim)
        c (Tensor): A vector of objective with shape of (batchsize, x_dim)
        b (Tensor): Bias term of constraint with shape of (batchsize, )
        epsilon (Tensor): Offset from the centroid to the input space boundary with shape of (batchsize, x_dim)
    Return:
        optimal_beta (Tensor): The optimal beta value with shape of (batchsize, )
    """

    batch_size = a.size(0)
    device = a.device
    dtype = a.dtype

    # q: (bs * h_dim, x_dim), it is the turning points of the piece-wise linear function
    q = - c/a

    # idx: (bs * h_dim, x_dim), indicating the ascending order of these turning points
    q_sort, idx = q.sort(dim=-1)

    # calculating the grading w.r.t. beta within each interval
    a_sort = torch.gather(a*epsilon, dim=-1, index=idx) 

    a_neg_cumsum = -a_sort.abs().cumsum(dim=-1) # shape: (bs, x_dim)
    a_neg_cumsum = torch.cat((torch.zeros((batch_size, 1), device=device), a_neg_cumsum), dim=-1) # shape: (bs, 1+x_dim)
    a_pos_cumsum = a_neg_cumsum - a_neg_cumsum[:, -1:] # shape: (bs, 1+x_dim)
    grad_beta = a_pos_cumsum + a_neg_cumsum + b.unsqueeze(-1) # shape: (bs, 1+x_dim)
    # breakpoint()

    # Due to the non-increasing trait of grad_beta, if there is a turning point at which the sign of gradient changes,
    # then the gradient must change from positive to negative, and this turning point is the optimal beta
    sign_change = torch.searchsorted(-grad_beta, torch.zeros((batch_size, 1), device=device), right=False)

    # It might be the case that grad_beta is always positive when beta > 0. 
    # This means the maximization object is ever-increasing, hence it is unbounded.

    # Following comes a case of sign_change where all the turning points q are positive:
    # (g stands for grad_beta, q stands for turing points)
    #    g[0] = 2       g[1] = 1       g[2] = -1       g[3] = -3   
    # 0 --------- q[0] --------- q[1] ----------- q[2] ----------- ... --------> +inf
    #                             ^
    #                      sign_change=2
    #
    # q should represent the interval endpoints, 
    # hence, need to pad the left and right end with 0 and inf separately.
    # # The following two line shows how this index mapping works
    q_new = torch.cat((torch.zeros((batch_size, 1), device=device), q_sort, torch.full((batch_size, 1), fill_value=torch.inf, dtype=dtype, device=device)), dim=-1)
    optimal_beta = torch.gather(q_new, dim=-1, index=sign_change).clamp(min=0).squeeze(-1)

    return optimal_beta

# @torch.jit.script
def _solve_with_h(a, c, b, epsilon):
    r"""
    Solve the following optimization problem:

    Primal:         min_x   c^T x
                    s.t.    a^T x + b <= 0,
                            x0-epsilon <= x <= x0+epsilon

    Dual:           min_x max_beta  c^T x + beta * (a^T x + b)
                    s.t.            x0 - epsilon <= x <= x0 + epsilon
                                    beta >= 0

    Strong duality:
                    max_{beta >= 0} min_{x \in X} c^T x + beta * (a^T x + b)

    Dual norm:
                    max_{beta >= 0} - |c + beta * a|^T epsilon + beta * (a^T x0 + b) + c^T x_0

    Now the sole optimize problem is piece-wise linear, we just have to check each 
    turning point and the end points of beta (0 and +inf)

    Args:
        a (Tensor): A vector of constraint with shape of (batchsize, x_dim)
        c (Tensor): A vector of objective with shape of (batchsize, h_dim, x_dim)
        b (Tensor): Bias term of constraint with shape of (batchsize, )
        epsilon (Tensor): Offset from the centroid to the input space boundary with shape of (batchsize, x_dim, 1)
    Return:
        optimal_beta (Tensor): The optimal beta value with shape of (batchsize, h_dim)
    """

    batch_size = a.size(0)
    h_dim = c.size(1)
    device = a.device
    dtype = a.dtype


    a_reshape = a.unsqueeze(1) # Shape (bs, 1, x_dim)

    epsilon_reshape = epsilon.view((batch_size, 1, -1)) # Shape (bs, 1, x_dim)
    b_reshape = b.view((-1, 1, 1)) # Shape (bs, 1, 1)

    # q: (bs, h_dim, x_dim), it is the turning points of the piece-wise linear function
    q = - c/a_reshape

    # idx: (bs, h_dim, x_dim), indicating the ascending order of these turning points
    q_sort, idx = q.sort(dim=-1)

    # calculating the grading w.r.t. beta within each interval
    a_mul_e = (a_reshape * epsilon_reshape).expand(-1, h_dim, -1)
    a_sort = torch.gather(a_mul_e, dim=-1, index=idx) 

    a_neg_cumsum = -a_sort.abs().cumsum(dim=-1) # shape: (bs, h_dim, x_dim)
    a_neg_cumsum = torch.cat((torch.zeros((batch_size, h_dim, 1), device=device), a_neg_cumsum), dim=-1) # shape: (bs, h_dim, 1+x_dim)
    a_pos_cumsum = a_neg_cumsum - a_neg_cumsum[:, :, -1:] # shape: (bs, h_dim, 1+x_dim)
    grad_beta = a_pos_cumsum + a_neg_cumsum + b_reshape # shape: (bs, h_dim, 1+x_dim)
    # breakpoint()

    # Due to the non-increasing trait of grad_beta, if there is a turning point at which the sign of gradient changes,
    # then the gradient must change from positive to negative, and this turning point is the optimal beta
    sign_change = torch.searchsorted(-grad_beta, torch.zeros((batch_size, h_dim, 1), device=device), right=False)

    # It might be the case that grad_beta is always positive when beta > 0. 
    # This means the maximization object is ever-increasing, hence it is unbounded.

    # Following comes a case of sign_change where all the turning points q are positive:
    # (g stands for grad_beta, q stands for turing points)
    #    g[0] = 2       g[1] = 1       g[2] = -1       g[3] = -3   
    # 0 --------- q[0] --------- q[1] ----------- q[2] ----------- ... --------> +inf
    #                             ^
    #                      sign_change=2
    #
    # q should represent the interval endpoints, 
    # hence, need to pad the left and right end with 0 and inf separately.
    # # The following two line shows how this index mapping works
    q_new = torch.cat((torch.zeros((batch_size, h_dim, 1), device=device), q_sort, torch.full((batch_size, h_dim, 1), fill_value=torch.inf, dtype=dtype, device=device)), dim=-1)
    optimal_beta = torch.gather(q_new, dim=-1, index=sign_change).clamp(min=0).squeeze(-1)

    return optimal_beta

def sort_out_constr_batches(x_L, x_U, constraints):
    sorted_out_batches = {}
    constraints_enable = False
    if constraints is None or constraints[0] is None or constraints[0].numel() == 0:
        return None, sorted_out_batches
    constraints_enable =  True
    x_0 = (x_L + x_U) / 2
    epsilon = (x_U - x_L) / 2
    constraints_A, constraints_b = constraints
    no_intersection, fully_covered = _sort_out_constraints(constraints_A, constraints_b, x_0, epsilon)
    if fully_covered.all():
        print("All the added constraints fully cover the input space. No need to calculate beta.")
        constraints = None
    sorted_out_batches["infeasible_batch"] = no_intersection
    sorted_out_batches["fully_covered"] = fully_covered
    sorted_out_batches["topk_mask"] = None
    return constraints, sorted_out_batches

def constraints_solving(
    x_L, x_U, objective, constraints, sign=-1.0,
    sorted_out_batches={}, timer=None,
    constraints_enable=True, rearrange_constraints=False, use_x0=False, 
    internal_dtype=torch.float32,
    max_chunk_size=None,
    safety_factor=0.8,
    solver_memory_factor=2.0,
    objective_mask=None,
    aux_bounds=None
    ):
    r"""
    Combined constraint solving function with conditional logic based on objective shape.

    - If objective is eyeC or broadcastable (shape[0]=1), uses a vectorized,
        auto-chunked approach.
    - If objective has batch dim matching input (shape[0]=N_batch), uses the
        original approach (repeating inputs, no chunking).

    Solves LP: max / min A_t * x, s.t. A_c * x + b_c <= 0, x_L <= x <= x_U

    Args:
        x_L, x_U: Input bounds tensors.
        objective: Target coefficients (Tensor or eyeC).
            - Tensor shape: (H, X), (1, H, X), or (N_batch, H, X).
            - eyeC: Represents identity matrix.
        constraints: Tuple (A_c, b_c) or None.
        sign: -1.0 for lower bound, +1.0 for upper bound.
        sorted_out_batches: Dict with pre-filtered batch masks.
        timer: Optional Timer object.
        constraints_enable: flag for enabling constraints solving, this is set for heuristic hybrid solving, should be True by default
        rearrange_constraints: flag for enabling rearranging constraints
        use_x0: Heuristic flag for constraint rearrangement.
        internal_dtype: Data type for internal computations.
        max_chunk_size, safety_factor, solver_memory_factor: Params for chunking path.
        objective_mask (Tensor, optional): Boolean mask of shape (N_batch, H) indicating
            which objectives to compute. If None, all are computed.
        aux_bounds (Tensor, optional): When hybrid constraint solving is enbaled, constrains_solving function will be called twice.
            For its second run, we will load the result from the first run to save time computing naive results.
    Returns:
        bound (Tensor): Computed bounds (N_batch, H, 1), dtype=ori_dtype.
    """
    if timer: timer.start('init')
    if timer: timer.start("concretize")

    ori_dtype = x_L.dtype
    device = x_L.device
    N_batch = x_L.size(0)
    x_L = x_L.reshape((N_batch, -1, 1))
    x_U = x_U.reshape((N_batch, -1, 1))

    epsilon = (x_U - x_L) / 2.0
    x0 = (x_U + x_L) / 2.0
    is_eyeC = isinstance(objective, eyeC)

    # --- Naive Case (No Constraints) ---
    if constraints is None or constraints[0] is None or constraints[0].numel() == 0 or (not constraints_enable):
        if is_eyeC:
            solved_obj = x0 + sign * epsilon # Shape: (N_batch, X, 1)
        else:
            # Use objective_tensor (N_batch, H, X)
            base_term = torch.einsum('bhx,bxo->bho', objective, x0)
            eps_term = torch.einsum('bhx,bxo->bho', objective.abs(), epsilon)
            solved_obj = base_term + sign * eps_term # Shape: (N_batch, H, 1)
        if timer: timer.add("init")
        if timer: timer.add("concretize")
        return solved_obj, None


    if internal_dtype != ori_dtype:
        x_L = x_L.to(internal_dtype)
        x_U = x_U.to(internal_dtype)
        objective = objective.to(internal_dtype)
        x0 = x0.to(internal_dtype)
        epsilon = epsilon.to(internal_dtype)

    is_broadcastable = False
    is_batch_specific = False
    H = -1 # Hidden dim
    X = x_L.size(1) # x_dim
    if is_eyeC:
        is_broadcastable = True
        H = X
        # Internally represent eyeC as identity matrix for broadcastable path
        objective_tensor = torch.eye(X, device=device, dtype=internal_dtype).unsqueeze(0) # Shape (1, X, X)
    else:
        if objective.shape[0] != N_batch:
            is_broadcastable = True
        else:
            is_batch_specific = True
        H = objective.shape[1]
        objective_tensor = objective
        if objective.shape[2] != X: raise ValueError("Objective shape mismatch")

    # --- Constrained Case Setup ---
    if internal_dtype != ori_dtype:
        x_L_flat = x_L_flat.to(internal_dtype)
        x_U_flat = x_U_flat.to(internal_dtype)
        epsilon = epsilon.to(internal_dtype)
        x0 = x0.to(internal_dtype)
        # objective_tensor is already converted

    constraints_A, constraints_b = constraints
    constraints_A = constraints_A.reshape((N_batch, -1, X)).to(internal_dtype)
    constraints_b = constraints_b.reshape((N_batch, -1)).to(internal_dtype)
    n_constraints = constraints_A.size(1)

    # --- Initial Batch Filtering (Common Logic) ---
    infeasible_batches = sorted_out_batches.get("infeasible_batch", torch.zeros(N_batch, dtype=torch.bool, device=device))
    fully_covered = sorted_out_batches.get("fully_covered", torch.zeros(N_batch, dtype=torch.bool, device=device))
    if objective_mask is None:
        objective_mask = sorted_out_batches.get("topk_mask", None)

    initial_mask = infeasible_batches # Batches to skip entirely
    # Final bounds tensor initialized
    final_bounds = torch.zeros(N_batch, H, 1, device=device, dtype=internal_dtype)
    fill_value_inf = torch.tensor(torch.inf if sign == -1.0 else -torch.inf, dtype=internal_dtype, device=device)

    # --- Calculate Naive Bounds (used as default/fallback) ---
    # This needs to be calculated *before* the main branches for all batches and all H.
    naive_bounds = torch.zeros(N_batch, H, 1, device=device, dtype=internal_dtype)
    if aux_bounds is not None:
        naive_bounds_all = aux_bounds.flatten(1).unsqueeze(-1)
    elif is_eyeC:
        naive_bounds_all = x0 + sign * epsilon # Shape (N_batch, X, 1) -> (N_batch, H, 1)
    elif is_broadcastable:
        # obj_tensor is (1, H, X)
        base_term_naive = torch.einsum('shx,bxo->bho', objective_tensor, x0)
        eps_term_naive = torch.einsum('shx,bxo->bho', objective_tensor.abs(), epsilon)
        naive_bounds_all = base_term_naive + sign * eps_term_naive # Shape (N_batch, H, 1)
    elif is_batch_specific:
        # obj_tensor is (N, H, X)
        base_term_naive = torch.einsum('bhx,bxo->bho', objective_tensor, x0)
        eps_term_naive = torch.einsum('bhx,bxo->bho', objective_tensor.abs(), epsilon)
        naive_bounds_all = base_term_naive + sign * eps_term_naive # Shape (N_batch, H, 1)
    else:
        raise RuntimeError("Internal logic error in naive bound calculation")
    naive_bounds = naive_bounds_all # Assign calculated bounds
    # Fill final_bounds for initially skipped/covered batches using naive bounds or Inf
    if initial_mask.any():
        # final_bounds[initial_mask] = fill_value_inf # Keep Inf for infeasible
        final_bounds[initial_mask] = naive_bounds[initial_mask]
    fully_covered_active_mask = fully_covered & (~initial_mask)
    if fully_covered_active_mask.any():
        final_bounds[fully_covered_active_mask] = naive_bounds[fully_covered_active_mask] # Use naive bounds

    active_batches_mask = ~initial_mask & ~fully_covered # Batches requiring solver
    active_indices = torch.nonzero(active_batches_mask, as_tuple=True)[0]
    B_act = active_indices.numel() # Number of batches needing the solver
    if timer: timer.add('init') # Combined timing for setup

    # --- Early Exit if No Active Batches ---
    if B_act == 0:
        print(f"Constrained concretize: No active batches after filtering.")
        # Ensure non-active parts have naive bounds before returning
        # (already done above by initializing with naive/inf)
        if timer: timer.add("concretize")
        final_bounds = naive_bounds
        return final_bounds.to(ori_dtype), initial_mask

    # ====================================================================
    # === BRANCH 1: Broadcastable Objective -> Auto Chunked/Vectorized ===
    # ====================================================================
    if is_broadcastable:
        if timer: timer.start('chunking')
        # --- Get Reference Objective (H, X) ---
        objective_ref = objective_tensor # Shape (H, X)

        # --- Dynamic Chunk Size Calculation ---
        calculated_chunk_size = B_act
        try:
            free_mem, total_mem = torch.cuda.mem_get_info()
            usable_mem = free_mem * safety_factor
            dtype_size = torch.finfo(internal_dtype).bits // 8
            mem_constraints_per_item = (n_constraints * X + n_constraints) * dtype_size
            mem_x0eps_per_item = 2 * X * dtype_size
            mem_ori_c_per_item = H * X * dtype_size
            mem_dual_obj_per_item = H * dtype_size
            mem_solver_per_item_bh = H * (X + X + 1 + X + 1) * dtype_size * solver_memory_factor
            mem_masks_temps_per_item = H * 2 # approx
            mem_per_item_est = (mem_constraints_per_item + mem_x0eps_per_item +
                                mem_ori_c_per_item + mem_dual_obj_per_item +
                                mem_solver_per_item_bh + mem_masks_temps_per_item) * 5
            if mem_per_item_est > 0:
                estimated_max_chunk = max(1, floor(usable_mem / mem_per_item_est))
                calculated_chunk_size = min(B_act, estimated_max_chunk)
        except Exception as e:
            print(f"Warning: GPU memory calculation failed: {e}. Using B_act as chunk size.")
            calculated_chunk_size = B_act # Fallback
        if max_chunk_size is not None and max_chunk_size > 0:
            final_chunk_size = min(calculated_chunk_size, max_chunk_size)
        else:
            final_chunk_size = calculated_chunk_size
        final_chunk_size = max(1, final_chunk_size) # Ensure chunk size is at least 1
        num_chunks = ceil(B_act / final_chunk_size)
        if timer: timer.add('chunking')

        # --- Loop Over Chunks ---
        if timer: timer.start('chunking_loop')
        for i_chunk in range(num_chunks):
            chunk_start_idx_rel = i_chunk * final_chunk_size
            chunk_end_idx_rel = min(chunk_start_idx_rel + final_chunk_size, B_act)
            current_chunk_size = chunk_end_idx_rel - chunk_start_idx_rel
            if current_chunk_size == 0: continue
            chunk_indices_abs = active_indices[chunk_start_idx_rel:chunk_end_idx_rel]

            A_chunk = constraints_A[chunk_indices_abs]
            b_chunk = constraints_b[chunk_indices_abs]
            x0_chunk = x0[chunk_indices_abs]
            eps_chunk = epsilon[chunk_indices_abs]

            # --- Get Objective Mask for the Chunk ---
            # Shape: (current_chunk_size, H)
            if objective_mask is not None:
                # Select the mask rows corresponding to the active batches in this chunk
                current_objective_mask = objective_mask[chunk_indices_abs]
            else:
                # If no mask provided, assume all objectives are needed
                current_objective_mask = torch.ones(current_chunk_size, H, dtype=torch.bool, device=device)

            # --- Rearrangement Heuristic ---
            if use_x0 or is_eyeC:
                x_prime_chunk = x0_chunk
            else:
                x_prime_chunk = x0_chunk # Default heuristic for chunked mode
            A_chunk, b_chunk = _dist_rearrange(A_chunk, b_chunk, x_prime_chunk)

            # --- Initialize State for Vectorized Loop (Chunk) ---
            ori_c_chunk = objective_ref.expand(current_chunk_size, H, X).clone()
            base_objective_term_chunk = torch.einsum('bhx,bxo->bh', ori_c_chunk, x0_chunk)

            if sign == 1.0: # Adjust for minimization problem solved by _solve
                ori_c_chunk *= -1.0
                base_objective_term_chunk *= -1.0

            # Initialize dual part to 0
            dual_objective_part_chunk = torch.zeros(current_chunk_size, H, device=device, dtype=internal_dtype)
            # Unfinished mask starts as the objective mask for this chunk
            unfinished_mask_chunk = current_objective_mask.clone()

            # --- Vectorized Constraint Loop (Operating on Chunk) ---
            for k in range(n_constraints):
                # Only consider pairs where the objective is requested AND not yet finished (inf)
                active_bh_mask_chunk = unfinished_mask_chunk
                if not active_bh_mask_chunk.any(): break

                # Get indices relative to the chunk *and* the H dimension
                batch_idx_chunk, h_idx_chunk = torch.where(active_bh_mask_chunk)
                num_active_bh_chunk = batch_idx_chunk.numel()
                if num_active_bh_chunk == 0: continue

                # Gather data only for active (batch_idx_chunk, h_idx_chunk) pairs
                a_k_chunk_all = A_chunk[:, k, :] # (chunk_size, X)
                b_k_chunk_all = b_chunk[:, k]   # (chunk_size,)
                d_k_chunk_all = torch.einsum('bx,bxo->b', a_k_chunk_all, x0_chunk) + b_k_chunk_all # (chunk_size,)

                # Select based on batch_idx_chunk
                a_solve = a_k_chunk_all[batch_idx_chunk] # (num_active_bh_chunk, X)
                d_solve = d_k_chunk_all[batch_idx_chunk] # (num_active_bh_chunk,)
                e_solve = eps_chunk[batch_idx_chunk, :, 0] # (num_active_bh_chunk, X)

                # Select objective based on (batch_idx_chunk, h_idx_chunk)
                c_solve = ori_c_chunk[batch_idx_chunk, h_idx_chunk] # (num_active_bh_chunk, X)

                optimal_beta = _solve(a_solve, c_solve, d_solve, e_solve) # (num_active_bh_chunk,)

                with torch.no_grad(): # Update state using scatter/indexing
                    finite_beta_mask = torch.isfinite(optimal_beta)
                    infinite_beta_mask = ~finite_beta_mask

                    # Indices relative to the `optimal_beta` tensor
                    inf_indices_relative = torch.where(infinite_beta_mask)[0]
                    finite_indices_relative = torch.where(finite_beta_mask)[0]

                    # Get the corresponding (batch, h) indices in the chunk
                    inf_batch_idx_chunk = batch_idx_chunk[inf_indices_relative]
                    inf_h_idx_chunk = h_idx_chunk[inf_indices_relative]
                    upd_batch_idx_chunk = batch_idx_chunk[finite_indices_relative]
                    upd_h_idx_chunk = h_idx_chunk[finite_indices_relative]

                    # Mark infinite betas as finished
                    if inf_indices_relative.numel() > 0:
                        dual_objective_part_chunk[inf_batch_idx_chunk, inf_h_idx_chunk] = torch.inf
                        unfinished_mask_chunk[inf_batch_idx_chunk, inf_h_idx_chunk] = False

                    # Update state for finite betas
                    if finite_indices_relative.numel() > 0:
                        finite_opt_beta = optimal_beta[finite_indices_relative]
                        finite_a_solve = a_solve[finite_indices_relative]
                        finite_d_solve = d_solve[finite_indices_relative]

                        ori_c_chunk[upd_batch_idx_chunk, upd_h_idx_chunk] += finite_opt_beta.unsqueeze(-1) * finite_a_solve
                        dual_objective_part_chunk[upd_batch_idx_chunk, upd_h_idx_chunk] += finite_opt_beta * finite_d_solve
            # --- End of k loop ---

            # --- Final Objective Calculation for Unfinished Items in Chunk ---
            final_unfinished_mask_chunk = unfinished_mask_chunk # Still respects the original objective_mask
            if final_unfinished_mask_chunk.any():
                final_batch_idx_chunk, final_h_idx_chunk = torch.where(final_unfinished_mask_chunk)

                final_ori_c_chunk = ori_c_chunk[final_batch_idx_chunk, final_h_idx_chunk]
                final_eps_chunk = eps_chunk[final_batch_idx_chunk, :, 0]

                final_eps_term_chunk = -torch.einsum('nx,nx->n', final_ori_c_chunk.abs(), final_eps_chunk)

                dual_objective_part_chunk[final_batch_idx_chunk, final_h_idx_chunk] += final_eps_term_chunk

            # --- Combine terms and handle mask ---
            final_obj_chunk_minimized = base_objective_term_chunk + dual_objective_part_chunk
            if sign == 1.0: final_obj_chunk = -final_obj_chunk_minimized # Flip sign back if maximizing
            else: final_obj_chunk = final_obj_chunk_minimized

            # Handle NaN/Inf introduced by solver or calculations
            final_obj_chunk = torch.nan_to_num(final_obj_chunk, nan=fill_value_inf.item(), posinf=fill_value_inf.item(), neginf=-fill_value_inf.item())

            # --- Use Naive Bounds where mask was False ---
            naive_bounds_chunk = naive_bounds[chunk_indices_abs].squeeze(-1) # Shape (chunk_size, H)
            # Where the objective mask was false, use the naive bound, otherwise use the calculated bound
            final_obj_chunk_masked = torch.where(
                current_objective_mask,
                final_obj_chunk,
                naive_bounds_chunk
            )

            # --- Store results back into the main bounds tensor ---
            final_bounds[chunk_indices_abs] = final_obj_chunk_masked.unsqueeze(-1)
        # --- End of chunk loop ---
        print(f"Chunked Summary: batches: {N_batch}, chunks={num_chunks}, chunk_size={final_chunk_size}, active_batches rate={B_act/N_batch:.2f}")
        if timer: timer.add('chunking_loop')
        
        return final_bounds

    # ===================================================================================
    # === BRANCH 2: Batch-Specific Objective (Neuron Mask) -> Original/Repeated Input ===
    # ===================================================================================
    elif is_batch_specific:
        if timer: timer.start('chunking')
        # --- Get Reference Objective (H, X) ---
        objective_ref = objective_tensor # Shape (B, H, X)

        final_chunk_size = B_act
        num_chunks = 1
        if timer: timer.add('chunking')

        # --- Loop Over Chunks ---
        if timer: timer.start('chunking_loop')
        for i_chunk in range(num_chunks):
            chunk_start_idx_rel = i_chunk * final_chunk_size
            chunk_end_idx_rel = min(chunk_start_idx_rel + final_chunk_size, B_act)
            current_chunk_size = chunk_end_idx_rel - chunk_start_idx_rel
            if current_chunk_size == 0: continue
            chunk_indices_abs = active_indices[chunk_start_idx_rel:chunk_end_idx_rel]

            A_chunk = constraints_A[chunk_indices_abs]
            b_chunk = constraints_b[chunk_indices_abs]
            x0_chunk = x0[chunk_indices_abs]
            eps_chunk = epsilon[chunk_indices_abs]

            # --- Get Objective Mask for the Chunk ---
            # Shape: (current_chunk_size, H)
            if objective_mask is not None:
                # Select the mask rows corresponding to the active batches in this chunk
                current_objective_indices = objective_mask[chunk_indices_abs]
                current_objective_mask = torch.zeros((current_chunk_size, H), dtype=torch.bool, device=objective_mask.device)
                current_objective_mask.scatter_(dim=1, index=current_objective_indices, value=True)
            else:
                # If no mask provided, assume all objectives are needed
                current_objective_mask = torch.ones(current_chunk_size, H, dtype=torch.bool, device=device)

            # --- Rearrangement Heuristic ---
            if use_x0 or is_eyeC:
                x_prime_chunk = x0_chunk
            else:
                x_prime_chunk = x0_chunk # Default heuristic for chunked mode
            if rearrange_constraints:
                A_chunk, b_chunk = _dist_rearrange(A_chunk, b_chunk, x_prime_chunk)

            # --- Initialize State for Vectorized Loop (Chunk) ---
            ori_c_chunk = objective_ref[chunk_indices_abs].expand(current_chunk_size, H, X).clone()
            base_objective_term_chunk = torch.einsum('bhx,bxo->bh', ori_c_chunk, x0_chunk)

            if sign == 1.0: # Adjust for minimization problem solved by _solve
                ori_c_chunk *= -1.0
                base_objective_term_chunk *= -1.0

            # Initialize dual part to 0
            dual_objective_part_chunk = torch.zeros(current_chunk_size, H, device=device, dtype=internal_dtype)
            # Unfinished mask starts as the objective mask for this chunk
            unfinished_mask_chunk = current_objective_mask.clone()
            # --- Vectorized Constraint Loop (Operating on Chunk) ---
            for k in range(n_constraints):
                # Only consider pairs where the objective is requested AND not yet finished (inf)
                active_bh_mask_chunk = unfinished_mask_chunk
                if not active_bh_mask_chunk.any(): break

                # Get indices relative to the chunk *and* the H dimension
                batch_idx_chunk, h_idx_chunk = torch.where(active_bh_mask_chunk)
                num_active_bh_chunk = batch_idx_chunk.numel()
                if num_active_bh_chunk == 0: continue

                # Gather data only for active (batch_idx_chunk, h_idx_chunk) pairs
                a_k_chunk_all = A_chunk[:, k, :] # (chunk_size, X)
                b_k_chunk_all = b_chunk[:, k]   # (chunk_size,)
                d_k_chunk_all = torch.einsum('bx,bxo->b', a_k_chunk_all, x0_chunk) + b_k_chunk_all # (chunk_size,)

                # Select based on batch_idx_chunk
                a_solve = a_k_chunk_all[batch_idx_chunk] # (num_active_bh_chunk, X)
                d_solve = d_k_chunk_all[batch_idx_chunk] # (num_active_bh_chunk,)
                e_solve = eps_chunk[batch_idx_chunk, :, 0] # (num_active_bh_chunk, X)

                # Select objective based on (batch_idx_chunk, h_idx_chunk)
                c_solve = ori_c_chunk[batch_idx_chunk, h_idx_chunk] # (num_active_bh_chunk, X)

                if timer: timer.start("solve_constraints")
                optimal_beta = _solve(a_solve, c_solve, d_solve, e_solve) # (num_active_bh_chunk,)
                # print(f"Solver {num_active_bh_chunk} problems")
                if timer: timer.add("solve_constraints")

                with torch.no_grad(): # Update state using scatter/indexing
                    finite_beta_mask = torch.isfinite(optimal_beta)
                    infinite_beta_mask = ~finite_beta_mask

                    # Indices relative to the `optimal_beta` tensor
                    inf_indices_relative = torch.where(infinite_beta_mask)[0]
                    finite_indices_relative = torch.where(finite_beta_mask)[0]

                    # Get the corresponding (batch, h) indices in the chunk
                    inf_batch_idx_chunk = batch_idx_chunk[inf_indices_relative]
                    inf_h_idx_chunk = h_idx_chunk[inf_indices_relative]
                    upd_batch_idx_chunk = batch_idx_chunk[finite_indices_relative]
                    upd_h_idx_chunk = h_idx_chunk[finite_indices_relative]

                    # Mark infinite betas as finished
                    if inf_indices_relative.numel() > 0:
                        dual_objective_part_chunk[inf_batch_idx_chunk, inf_h_idx_chunk] = torch.inf
                        unfinished_mask_chunk[inf_batch_idx_chunk, inf_h_idx_chunk] = False

                    # Update state for finite betas
                    if finite_indices_relative.numel() > 0:
                        finite_opt_beta = optimal_beta[finite_indices_relative]
                        finite_a_solve = a_solve[finite_indices_relative]
                        finite_d_solve = d_solve[finite_indices_relative]

                        ori_c_chunk[upd_batch_idx_chunk, upd_h_idx_chunk] += finite_opt_beta.unsqueeze(-1) * finite_a_solve
                        dual_objective_part_chunk[upd_batch_idx_chunk, upd_h_idx_chunk] += finite_opt_beta * finite_d_solve
            # --- End of k loop ---

            # --- Final Objective Calculation for Unfinished Items in Chunk ---
            final_unfinished_mask_chunk = unfinished_mask_chunk # Still respects the original objective_mask
            if final_unfinished_mask_chunk.any():
                final_batch_idx_chunk, final_h_idx_chunk = torch.where(final_unfinished_mask_chunk)

                final_ori_c_chunk = ori_c_chunk[final_batch_idx_chunk, final_h_idx_chunk]
                final_eps_chunk = eps_chunk[final_batch_idx_chunk, :, 0]

                final_eps_term_chunk = -torch.einsum('nx,nx->n', final_ori_c_chunk.abs(), final_eps_chunk)

                dual_objective_part_chunk[final_batch_idx_chunk, final_h_idx_chunk] += final_eps_term_chunk

            # --- Combine terms and handle mask ---
            final_obj_chunk_minimized = base_objective_term_chunk + dual_objective_part_chunk
            if sign == 1.0: final_obj_chunk = -final_obj_chunk_minimized # Flip sign back if maximizing
            else: final_obj_chunk = final_obj_chunk_minimized

            # Handle NaN/Inf introduced by solver or calculations
            final_obj_chunk = torch.nan_to_num(final_obj_chunk, nan=fill_value_inf.item(), posinf=fill_value_inf.item(), neginf=-fill_value_inf.item())

            # --- Use Naive Bounds where mask was False ---
            naive_bounds_chunk = naive_bounds[chunk_indices_abs].squeeze(-1) # Shape (chunk_size, H)
            # Where the objective mask was false, use the naive bound, otherwise use the calculated bound
            final_obj_chunk_masked = torch.where(
                current_objective_mask,
                final_obj_chunk,
                naive_bounds_chunk
            )

            # --- Store results back into the main bounds tensor ---
            final_bounds[chunk_indices_abs] = final_obj_chunk_masked.unsqueeze(-1)
            infeasible_batches = final_bounds.flatten(1).isinf().any(dim=1)
            final_bounds[infeasible_batches] = naive_bounds[infeasible_batches]
            # --- End of chunk loop ---
            print(f"Chunked Summary: batches: {N_batch}, chunks={num_chunks}, chunk_size={final_chunk_size}, active_batches rate={B_act/N_batch:.2f}")
            if timer: timer.add('chunking_loop')
            if timer: timer.add("concretize")
            return final_bounds, infeasible_batches

    else: # Should not be reached
        raise RuntimeError("Internal error: Objective type not handled.")

    if timer: timer.add("concretize")
    # --- Return Final Bounds ---
    return final_bounds.to(ori_dtype), infeasible_batches

def warmup_solve(batchsize=2, x_dim=3, dtype=torch.float32, device='cuda'):
    """
    Warms up the torch.compile cache for the _solve function.

    Args:
        batchsize (int): The batch size for the dummy input.
        x_dim (int): The dimension of the x vector for the dummy input.
        dtype (torch.dtype): The data type for the dummy input.
        device (str): The device ('cpu' or 'cuda') for the dummy input.
    """
    start_time = time.time()
    print(f"Warming up _solve function on device '{device}' with dtype '{dtype}'...")

    # Create dummy input tensors with the specified shape, dtype, and device
    a_dummy = torch.randn(batchsize, x_dim, dtype=dtype, device=device)
    c_dummy = torch.randn(batchsize, x_dim, dtype=dtype, device=device)
    b_dummy = torch.randn(batchsize, dtype=dtype, device=device)
    epsilon_dummy = torch.rand(batchsize, x_dim, dtype=dtype, device=device) + 1 # Ensure positive

    _solve(a_dummy, c_dummy, b_dummy, epsilon_dummy)
    if device == 'cuda': torch.cuda.synchronize()
    print(f"Warmup completed in {time.time() - start_time} seconds.")
    return
