from function import register_operator_functions
def preorder_to_nested_with_indices(preorder):
    """
    将前序表达式转换为嵌套数组形式的中序表达式，并记录每个元素的原始索引。

    :param preorder: 前序遍历的表达式列表（如 ["add", "x1", "x2"]）
    :return: 嵌套数组形式的中序表达式和索引数组
    """
    stack = []
    index_stack = []
    # 反转前序序列，按后序方式处理
    for i, token in enumerate(reversed(preorder)):
        if token.startswith('x_') or token.startswith('con'):
            # 操作数直接压入栈（不包装成列表）
            stack.append(token)
            index_stack.append(len(preorder) - i - 1)  # 记录原始索引
        elif token in {"sin", "cos", "exp", "log", "sqrt", "asin", "acos"}:
            # 单目操作符，弹出一个操作数并组合
            if not stack:
                raise ValueError(f"Insufficient operands for unary operator: {token}")
            operand = stack.pop()
            operand_index = index_stack.pop()
            stack.append([token, operand])
            index_stack.append([len(preorder) - i - 1, operand_index])  # 记录原始索引
        elif token in {"+", "*", "-", "/"}:
            # 双目操作符，弹出两个操作数并组合
            if len(stack) < 2:
                raise ValueError(f"Insufficient operands for binary operator: {token}")

            # 注意：由于前序表达式被反转，左操作数先弹出，右操作数后弹出
            left = stack.pop()  # 先弹出的是左操作数
            right = stack.pop()  # 后弹出的是右操作数
            left_index = index_stack.pop()
            right_index = index_stack.pop()

            # 确保操作数顺序与原始前序表达式一致
            stack.append([token, left, right])
            index_stack.append([len(preorder) - i - 1, left_index, right_index])  # 记录原始索引
        else:
            # 其他情况（保留原样）
            stack.append(token)
            index_stack.append(len(preorder) - i - 1)  # 记录原始索引

    # 最终栈中应该只剩下一个元素，即嵌套数组
    if len(stack) != 1 or len(index_stack) != 1:
        raise ValueError("Invalid expression: leftover elements in stack")
    return stack[0], index_stack[0]


def normalize_expression_with_mapping(expr, indices, operator_priority):
    """
    根据优先级对嵌套数组形式的表达式进行标准化排序，并生成映射数组。

    :param expr: 嵌套数组形式的表达式（如 ['add', 'x1', 'x2']）
    :param indices: 对应的索引数组（如 [0, 2, 1]）
    :param operator_priority: 操作符优先级列表（如 ["add", "mul", "sin", "cos"]）
    :return: 标准化后的表达式和映射数组
    """
    # 定义优先级字典
    priority_map = {op: i for i, op in enumerate(operator_priority)}
    priority_map.update({f"x_{i}": len(operator_priority) + int(i) for i in range(20)})  # 变量优先级最低
    priority_map.update({'const': len(operator_priority)+20})

    def sort_subexpression(subexpr, subindices):
        """
        对子表达式进行排序，并同步调整映射数组。
        """
        if isinstance(subexpr, list):
            # 如果是嵌套数组，递归处理
            normalized, mapping = normalize_expression_with_mapping(subexpr, subindices, operator_priority)
            return normalized, mapping
        return subexpr, subindices

    if isinstance(expr, list):
        operator = expr[0]
        operands = expr[1:]
        operand_indices = indices[1:]

        # 提取映射数组
        sorted_pairs = [
            sort_subexpression(op, idx)
            for op, idx in zip(operands, operand_indices)
        ]
        sorted_operands, sorted_mappings = zip(*sorted_pairs)

        # 仅对 add 和 mul 的操作数进行排序
        if operator in {"+", "*"}:
            # 自定义排序规则
            sorted_pairs = sorted(
                zip(sorted_operands, sorted_mappings),
                key=lambda x: (lambda subexpr: [
                    priority_map[token] if token in priority_map else float('inf')
                    for token in flatten_expression(subexpr)
                ])(x[0])
            )
            sorted_operands, sorted_mappings = zip(*sorted_pairs)

        # 返回排序后的表达式和映射数组
        return [operator] + list(sorted_operands), [indices[0]] + list(sorted_mappings)
    else:
        # 如果是变量或单个操作符，直接返回
        return expr, indices

def flatten_expression(expr):
    """
    将标准化后的嵌套数组表达式转换为扁平化的表达式数组。

    :param expr: 标准化后的嵌套数组表达式（如 ['add', 'x1', 'x2']）
    :return: 扁平化的表达式数组（如 ['add', 'x1', 'x2']）
    """
    result = []

    def flatten(subexpr):
        if isinstance(subexpr, list):
            # 如果是嵌套数组，递归展开
            for item in subexpr:
                flatten(item)
        else:
            # 如果是单个值，直接加入结果列表
            result.append(subexpr)

    # 调用递归展开函数
    flatten(expr)
    return result


def get_flattened_normalized_expression(preorder, operator_priority):
    """
    封装函数：接受一个前序表达式和操作符优先级列表，返回扁平化后的标准化表达式和映射数组。

    :param preorder: 前序遍历的表达式列表（如 ["add", "x1", "x2"]）
    :param operator_priority: 操作符优先级列表（如 ["add", "mul", "sin", "cos"]）
    :return: 扁平化后的标准化表达式和映射数组
    """
    # 1. 将前序表达式转换为嵌套数组形式，并记录索引
    nested_expr, nested_indices = preorder_to_nested_with_indices(preorder)

    # 2. 根据优先级对嵌套数组进行标准化排序，并生成映射数组
    normalized_expr, mapping_expr = normalize_expression_with_mapping(nested_expr, nested_indices, operator_priority)

    # 3. 将标准化后的嵌套数组和映射数组展平
    flattened_normalized_expr = flatten_expression(normalized_expr)
    flattened_mapping = flatten_expression(mapping_expr)

    # 返回结果
    return flattened_normalized_expr, flattened_mapping


def longest_common_subsequence_with_bool_arrays(list1, list2):
    """
    计算两个字符数组的最长公共子序列 (LCS)，并返回对应的布尔数组。

    参数:
        list1: 第一个字符数组 (list of str)
        list2: 第二个字符数组 (list of str)

    返回:
        一个字典，包含以下内容：
        - "lcs": 最长公共子序列 (list of str)
        - "bool_array_list1": 标记 list1 中哪些位置属于 LCS 的布尔数组 (list of bool)
        - "bool_array_list2": 标记 list2 中哪些位置属于 LCS 的布尔数组 (list of bool)
    """
    # 获取两个数组的长度
    m, n = len(list1), len(list2)

    # 创建一个二维 DP 表，初始化为 0
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    # 填充 DP 表
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if list1[i - 1] == list2[j - 1]:  # 如果字符相等
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:  # 如果字符不相等
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

    # 回溯找出 LCS，并记录其在 list1 和 list2 中的位置
    lcs = []
    positions_list1 = []
    positions_list2 = []
    i, j = m, n
    while i > 0 and j > 0:
        if list1[i - 1] == list2[j - 1]:  # 字符相等时加入 LCS
            lcs.append(list1[i - 1])
            positions_list1.append(i - 1)  # 记录在 list1 中的位置
            positions_list2.append(j - 1)  # 记录在 list2 中的位置
            i -= 1
            j -= 1
        elif dp[i - 1][j] > dp[i][j - 1]:  # 从上方来
            i -= 1
        else:  # 从左方来
            j -= 1

    # 因为回溯是从后往前找的，所以需要反转结果
    lcs.reverse()
    positions_list1.reverse()
    positions_list2.reverse()

    # 构造布尔数组
    bool_array_list1 = [False] * m
    bool_array_list2 = [False] * n
    for pos in positions_list1:
        bool_array_list1[pos] = True
    for pos in positions_list2:
        bool_array_list2[pos] = True

    # 返回结果
    return {
        "lcs": lcs,
        "bool_array_list1": bool_array_list1,
        "bool_array_list2": bool_array_list2
    }


def longest_common_substring_with_mask_arrays(list1, list2):
    """
    计算两个字符数组的最长公共子字符串 (Longest Common Substring)，并返回匹配的子字符串及其在两个列表中的位置和布尔掩码数组。

    参数:
        list1: 第一个字符数组 (list of str)
        list2: 第二个字符数组 (list of str)

    返回:
        一个字典，包含以下内容：
        - "substring": 最长公共子字符串 (list of str)
        - "start_pos_list1": 子字符串在 list1 中的起始位置 (int)
        - "end_pos_list1": 子字符串在 list1 中的结束位置 (int)
        - "start_pos_list2": 子字符串在 list2 中的起始位置 (int)
        - "end_pos_list2": 子字符串在 list2 中的结束位置 (int)
        - "mask_array_list1": 标记 list1 中哪些位置属于最长公共子字符串的布尔数组 (list of bool)
        - "mask_array_list2": 标记 list2 中哪些位置属于最长公共子字符串的布尔数组 (list of bool)
    """
    m, n = len(list1), len(list2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    max_length = 0
    end_pos_list1 = 0

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if list1[i - 1] == list2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
                if dp[i][j] > max_length:
                    max_length = dp[i][j]
                    end_pos_list1 = i
            else:
                dp[i][j] = 0

    start_pos_list1 = end_pos_list1 - max_length
    substring = list1[start_pos_list1:end_pos_list1]

    start_pos_list2 = -1
    for j in range(n - max_length + 1):
        if list2[j:j + max_length] == substring:
            start_pos_list2 = j
            break
    end_pos_list2 = start_pos_list2 + max_length

    mask_array_list1 = [False] * m
    mask_array_list2 = [False] * n
    for i in range(start_pos_list1, end_pos_list1):
        mask_array_list1[i] = True
    for j in range(start_pos_list2, end_pos_list2):
        mask_array_list2[j] = True

    return {
        "substring": substring,
        "start_pos_list1": start_pos_list1,
        "end_pos_list1": end_pos_list1,
        "start_pos_list2": start_pos_list2,
        "end_pos_list2": end_pos_list2,
        "bool_array_list1": mask_array_list1,
        "bool_array_list2": mask_array_list2
    }


def mark_variables_and_operators(flattened_expression, operator_priority):
    """
    标记扁平化后的标准化表达式中哪些值是变量，哪些值是操作符。

    :param flattened_expression: 扁平化后的标准化表达式（如 ['add', 'x1', 'x2']）
    :param operator_priority: 操作符优先级列表（如 ["add", "mul", "sin", "cos"]）
    :return: 变量标记数组和操作符标记数组
    """
    # 初始化标记数组
    variable_mask = []
    operator_mask = []

    # 遍历扁平化表达式的每个元素
    for token in flattened_expression:
        if token.startswith('x') or token.startswith('con'):  # 如果是以 'x' 开头，则为变量
            variable_mask.append(True)
            operator_mask.append(False)
        elif token in operator_priority:  # 如果在操作符优先级列表中，则为操作符
            variable_mask.append(False)
            operator_mask.append(True)
        else:  # 其他情况（理论上不应该出现）
            raise ValueError(f"Unexpected token in flattened expression: {token}")

    return variable_mask, operator_mask



def check_reconstruction(original_expr, flattened_normalized_expr, flattened_mapping):
    """
    检查标准化后的数组是否可以根据映射数组还原回去。

    :param original_expr: 原始表达式数组
    :param flattened_normalized_expr: 扁平化的标准化表达式数组
    :param flattened_mapping: 扁平化的映射数组
    :return: 是否可以还原（布尔值）
    """
    reconstructed_expr = [None] * len(original_expr)

    # 使用映射数组重建原始表达式
    for value, index in zip(flattened_normalized_expr, flattened_mapping):
        reconstructed_expr[index] = value

    # 检查重建的表达式是否与原始表达式一致
    return reconstructed_expr == original_expr

def mark_common_elements_in_original(preorder, flattened_normalized_expr, flattened_mapping,
                                    normalized_label_expr):
    """
    标记原始表达式中属于最长公共子序列 (LCS) 或最长公共子字符串的元素。
    :param preorder: 原始表达式数组（如 ["+", "+", "x_2", "x_1", "+", "x_4", "x_3"]）
    :param flattened_normalized_expr: 扁平化的标准化表达式数组
    :param flattened_mapping: 标准化表达式的映射数组
    :param label_preorder: 目标表达式数组
    :param normalized_label_expr: 目标表达式的扁平化标准化数组
    :param label_mapping: 目标表达式的映射数组
    :return: 两个布尔数组，分别标记原始表达式和目标表达式中属于公共部分的元素
    """
    # 计算最长公共子序列 (LCS)
    lcs_result = longest_common_subsequence_with_bool_arrays(flattened_normalized_expr, normalized_label_expr)
    lcs_bool_array_list1 = lcs_result["bool_array_list1"]


    # 计算最长公共子字符串
    substring_result = longest_common_substring_with_mask_arrays(flattened_normalized_expr, normalized_label_expr)
    substring_bool_array_list1 = substring_result["bool_array_list1"]


    # 初始化原始表达式的布尔数组
    original_lcs_bool_array = [False] * len(preorder)
    original_substring_bool_array = [False] * len(preorder)

    # 反向映射 LCS 的布尔数组到原始表达式
    for normalized_index, is_common in enumerate(lcs_bool_array_list1):
        if is_common:
            original_index = flattened_mapping[normalized_index]
            original_lcs_bool_array[original_index] = True

    # 反向映射最长公共子字符串的布尔数组到原始表达式
    for normalized_index, is_common in enumerate(substring_bool_array_list1):
        if is_common:
            original_index = flattened_mapping[normalized_index]
            original_substring_bool_array[original_index] = True

    return original_lcs_bool_array, original_substring_bool_array

# 测试用例
def test_full_process():

    operator_functions, _, _, _ = register_operator_functions()
    operator_priority = list(operator_functions.keys())

    # 示例表达式和操作符优先级
    preorder = ["+", "+", "arccos","x_2", "x_1", "+", "x_4", "const"]
    print("原始的表达式",preorder)
    #operator_priority = ["add", "mul", "sub", "div", "sin", "cos"]

    # 调用封装函数
    flattened_normalized_expr, flattened_mapping = get_flattened_normalized_expression(preorder, operator_priority)

    # 输出结果
    print("标准化表达式:", flattened_normalized_expr)
    print("标准化映射数组:", flattened_mapping)

    variable_mask, operator_mask = mark_variables_and_operators(flattened_normalized_expr, operator_priority)
    print("标准化表达式的变量:", variable_mask)
    print("标准化表达式的操作符:", operator_mask)

    label_preorder = ["+", "*", "x_4", "x_3","*", "x_4", "x_2"]
    print("目标表达式:", label_preorder)

    normalized_label_expr, label_mapping = get_flattened_normalized_expression(label_preorder, operator_priority)
    print("标准化目标表达式:", normalized_label_expr)

    result = longest_common_subsequence_with_bool_arrays(flattened_normalized_expr, normalized_label_expr)

    result_2 = longest_common_substring_with_mask_arrays(flattened_normalized_expr, normalized_label_expr)

    print("LCS:", result["lcs"])
    print("标准化表达式:", result["bool_array_list1"])
    print("标准化目标表达式:", result["bool_array_list2"])

    print("Longest Common Substring", result_2["substring"])
    print("list1 的布尔数组:", result_2["bool_array_list1"])
    print("list2 的布尔数组:", result_2["bool_array_list2"])

    flag_expr = check_reconstruction(preorder, flattened_normalized_expr, flattened_mapping)
    print("原表达式检查:", flag_expr)

    flag_label = check_reconstruction(label_preorder, normalized_label_expr, label_mapping)
    print("新表达式检查:", flag_label)

    # 标记原始表达式中属于公共部分的元素
    original_lcs_bool_array, original_substring_bool_array = mark_common_elements_in_original(
        preorder, flattened_normalized_expr, flattened_mapping,
        normalized_label_expr
    )

    length_lcs = int(sum(original_lcs_bool_array))
    length_substring_bool_array = int(sum(original_substring_bool_array))

    print("原始表达式中属于 LCS 的元素:", original_lcs_bool_array)
    print("原始表达式中属于最长公共子字符串的元素:", original_substring_bool_array)

    print("lcs长度", length_lcs)
    print("lcss长度", length_substring_bool_array)




if __name__ == "__main__":
    # 运行测试
    test_full_process()


