import numpy as np


# Define the operation functions
def add(a, b):
    return a + b


def sub(a, b):
    return a - b


def mul(a, b):
    return a * b


def div(a, b):
    return a/(b+1e-20)


def power(a, b):
    return a ** b

def sin(a):
    return np.sin(a)

def cos(b):
    return np.cos(b)

def exp(a):
    return np.exp(a)

def log(a):
    return  np.log(np.abs(a))

def square(a):
    return np.square(a)

def third_power(a):
    return a ** 3

def sqrt(a):
    return np.sqrt(a)

def asin(a):
    return np.arcsin(a)

def acos(a):
    return np.arccos(a)

# Function to register operator functions and their arity (number of required operands)
def register_operator_functions():
    """
    Registers operator functions and their arity (number of operands).
    Returns a dictionary mapping operator symbols to their functions and arity.
    """
    operator_functions = {
        '+': {'function': add, 'arity': 2},  # 2 operands needed
        '-': {'function': sub, 'arity': 2},  # 2 operands needed
        '*': {'function': mul, 'arity': 2},  # 2 operands needed
        '/': {'function': div, 'arity': 2},  # 2 operands needed
        'sin': {'function': sin, 'arity': 1},
        'cos': {'function': cos, 'arity': 1},
        'exp': {'function': exp, 'arity': 1},
        'log': {'function': log, 'arity': 1},
    }

    '''
        'sqrt': {'function': sqrt, 'arity': 1},
        'asin': {'function': asin, 'arity': 1},
        'acos': {'function': acos, 'arity': 1},
    '''

    trig_function=[]
    exp_log_function=[]
    for idx, (op_symbol, op_info) in enumerate(operator_functions.items()):
        if op_symbol in ['sin', 'cos', 'asin', 'acos']:  # 检查是否是三角函数
            trig_function.append(idx)
        if op_symbol in ['exp', 'log']:
            exp_log_function.append(idx)
    arities = [op['arity'] for op in operator_functions.values()]
    return operator_functions, arities, trig_function, exp_log_function


# Function to generate variable names and arities based on the size of the input array
def generate_variable_info(input_array, const_value = False):
    """
    Generate variable names (x_1, x_2, ...) and their arities (which is -1 for variables).

    Parameters
    ----------
    input_array : np.ndarray
        The input array, which determines the number of variables.

    Returns
    -------
    variable_dict : dict
        A dictionary mapping variable names (x_1, x_2, ...) to their arities (-1).
    """
    num_variables = input_array.shape[1]  # Get number of columns (variables)

    # Generate variable names dynamically: x_1, x_2, ..., x_n
    variable_dict = {f'x_{i + 1}': {'arity': 0} for i in range(num_variables)}

    if const_value:
        variable_dict['const'] = {'arity': 0}

    return variable_dict


# Function to merge operator functions and variable functions into a single dictionary
def merge_operator_and_variable_dict(operator_functions,  variable_dict):
    # Merge operator functions and variable info into a single dictionary
    combined_dict = {**operator_functions, **variable_dict}
    arities = [op['arity'] for op in combined_dict.values()]
    operator_indices = [idx for idx, operator in enumerate(combined_dict.keys())]

    return combined_dict, operator_indices, arities  # Return both combined and operator-only dicts


def is_valid_expression(operator_arities, index_sequence):
    """
    Check if the given index sequence forms a valid expression based on the arities of operators and variables.

    Parameters
    ----------
    operator_arities : list of int
        A list where each index corresponds to either an operator or a variable, with its arity:
        operators have positive arities (e.g., 2 for binary operators), and variables have -1 arity.

    index_sequence : list of int
        The list of token indices (pre-order traversal) that represents the expression.

    Returns
    -------
    bool
        True if the expression is valid, False otherwise.
    """

    arities = np.array([operator_arities[t] for t in index_sequence])
    # Number of dangling nodes, returns the cumsum up to each point
    dangling = 1 + np.cumsum(arities - 1)  # Compute the number of unbalanced nodes
    # After processing, if there are no remaining dangling nodes, the expression is valid
    return np.all(dangling >= 0) and dangling[-1] == 0

def complete_tokens(operator_arities, index_sequence):
    """
    Ensure the expression formed by index_sequence is valid. If the expression is invalid,
    it will add variables (x_1, x_2, ...) until the expression is valid. If the expression
    is already valid, it will trim any unnecessary parts of the sequence.

    Parameters
    ----------
    operator_arities : list of int
        A list where each index corresponds to either an operator or a variable, with its arity.

    index_sequence : list of int
        The list of token indices (pre-order traversal) that represents the expression.

    Returns
    -------
    list of int
        A modified list of token indices representing a valid expression.
    """

    arities = np.array([operator_arities[t] for t in index_sequence])  # Get arities of the tokens
    dangling = 1 + np.cumsum(arities - 1)  # Number of dangling operands required

    # Check if the expression is valid (i.e., there are no dangling operands left)
    if -1 in (dangling - 1):  # If there is any dangling node
        # Find the last valid point in the sequence
        expr_length = 1 + np.argmax((dangling - 1) == -1)
        index_sequence = index_sequence[:expr_length]  # Trim the sequence up to that point
    else:
        # If the expression is not valid, keep adding variables until it becomes valid
        while dangling[-1] != 0:  # If there are dangling operands
            zero_arity_indices = np.where(operator_arities == 0)[0]
            if len(zero_arity_indices) > 0:
                # Get the first index where arity is 0 (accessing the tuple)
                zero_arity_index = zero_arity_indices[0]
            else:
                raise ValueError("No operator with arity 0 found in operator_arities.")

            # Add a variable (e.g., x_1)
            index_sequence.append(zero_arity_index)  # Assuming the first 3 are operators, x_1 -> index len(operator_arities) - 3
            arities = np.array([operator_arities[t] for t in index_sequence])  # Recalculate arities and dangling
            dangling = 1 + np.cumsum(arities - 1)

    return index_sequence

if __name__ == "__main__":
    # Example usage
    operator_functions, _, trig, exp_log = register_operator_functions()  # Get operator functions
    input_array = np.random.rand(100, 3)  # Generate a random (100, 3) input array
    variable_dict = generate_variable_info(input_array)  # Generate variable info
    print(trig)
    print(exp_log)

    combined_dict, operator_functions_only, arbity = merge_operator_and_variable_dict(operator_functions,variable_dict)

    # Output the results
    print("Combined Dictionary:")
    print(combined_dict)

    print("\nOperator Functions Only Dictionary:")
    print(operator_functions_only)

    print(arbity)
    index_sequence = [0, 5, 0, 5, 6]  # Example expression index sequence (pre-order traversal)

    # Check if the expression is valid
    is_valid = is_valid_expression(arbity, index_sequence)
    print(f"Is the expression valid? {is_valid}")

    index_sequence = [0, 5,0 , 5]  # 对应 + x_1 +
    index_sequence = complete_tokens(arbity, index_sequence)
    print(index_sequence)
