import torch

# --- Constraint Definitions for Toy Example: 'sphere_mog' ---

def sphere_h(x):
    """
    Equality constraint for a unit sphere: h(x) = ||x||^2 - 1 = 0.
    input_dim : the dimension of the input vector x (no batch)
    """
    if x.ndim != 1: 
        raise ValueError("Input x must be a 1-dimensional tensor.")
    return torch.sum(x**2, dim = -1, keepdim=True) - 1.0

def sphere_g(x):
    """
    Inequality constraint for a half-plane: g(x) = -x_1 - 0.5 <= 0.
    This corresponds to x_1 >= -0.5.
    input_dim : the dimension of the input vector x (no batch)
    """
    if x.ndim != 1:
        raise ValueError("Input x must be a 1-dimensional tensor.")
    return -x[0:1] - 0.0

# --- Placeholder Constraints for Future Experiments ---

def robot_arm_h(x):
    """Placeholder for robot_arm equality constraints."""
    pass

def robot_arm_g(x):
    """Placeholder for robot_arm inequality constraints."""
    pass

def mol_gen_h(x):
    """Placeholder for mol_gen equality constraints."""
    pass

def mol_gen_g(x):
    """Placeholder for mol_gen inequality constraints."""
    pass


# --- Main Function to Get Constraints by Name ---

def get_constraint_functions(dataset_name):
    """
    Returns the appropriate h(x) and g(x) functions for a given dataset name.
    """
    if dataset_name == "sphere_mog":
        return sphere_h, sphere_g
    elif dataset_name == "robot_arm":
        return robot_arm_h, robot_arm_g
    elif dataset_name == "mol_gen":
        return mol_gen_h, mol_gen_g
    else:
        raise ValueError(f"No constraint functions defined for dataset: {dataset_name}")