import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from function import *

def check_for_nan(tensor, name="Tensor"):
    """
    检查张量是否包含 NaN 或 Inf，并抛出异常或替换异常值
    :param tensor: 输入张量 (torch.Tensor)
    :param name: 张量名称，用于调试信息
    :return: 经过检查后的张量
    """
    if torch.isnan(tensor).any() or torch.isinf(tensor).any():
        print(f"Error: {name} contains NaN or Inf values!")
        print(f"Tensor shape: {tensor.shape}")
        print(f"Tensor content:\n{tensor}")

        # 可选：替换异常值为默认值
        # tensor = torch.where(torch.isnan(tensor), torch.zeros_like(tensor), tensor)  # 替换 NaN 为 0
        # tensor = torch.where(torch.isinf(tensor), torch.zeros_like(tensor), tensor)  # 替换 Inf 为 0

        # 抛出异常以终止程序
        raise ValueError(f"{name} contains invalid values (NaN or Inf).")

    return tensor
class SymbolicExpressionGenerator:
    def __init__(self, hidden_size, n, operator_arities, batch_size, trig, exp_log, epsilon, include_self_loops=True, tree_type="full"):
        """
        :param n: 最大深度
        :param operator_arities: 运算符元数数组，例如 [2, 2, 1, 0]
        :param batch_size: 批处理大小
        :param include_self_loops: 是否在邻接矩阵中包含自连接
        :param tree_type: 树的类型 ("full" 表示满二叉树限制，"partial" 表示优先完备树限制)
        """
        self.n = n
        self.operator_arities = operator_arities
        self.m = len(operator_arities)  # 运算符数量
        self.hidden = hidden_size
        self.batch_size = batch_size
        self.depth = n
        self.num_nodes = 2 ** self.depth - 1
        self.include_self_loops = include_self_loops
        self.tree_type = tree_type
        self.empty_token = self.m
        self.trig = trig
        self.exp_log = exp_log
        self.epsilon = epsilon

        # 生成邻接矩阵
        self.adj_matrix = self._generate_binary_tree_adjacency_matrix()
        self.preorder_indices = self._get_preorder_indices()  # 获取前序遍历的节点顺序
        self.adj_matrix = self._reorder_adj_matrix(self.preorder_indices)
        self.adj_norm = self._normalize_adj_matrix()

        # 分析邻接矩阵，生成辅助信息（基于前序遍历重新编号）
        self.ancestor_list = self._generate_ancestor_list(self.preorder_indices)  # (num_nodes, max_ancestors)
        self.left_right_flags = self._generate_left_right_flags(self.preorder_indices)  # (num_nodes,)
        self.is_leaf = self._generate_is_leaf(self.preorder_indices)  # (num_nodes,)

        # 调整后的父节点数组
        self.adjusted_parent_indices = self._generate_adjusted_parent_indices(self.preorder_indices)

        # 预先计算类型限制
        self.type_constraints = self._precompute_type_constraints()

        self.variable_actions = [
            idx for idx, arity in enumerate(self.operator_arities) if arity == 0
        ]

        # 模型参数
        self.nlayers = self.depth - 2
        self.model = self._build_model()

    def _generate_binary_tree_adjacency_matrix(self):
        adj_matrix = np.zeros((self.num_nodes, self.num_nodes))
        for i in range(self.num_nodes):
            left_child = 2 * i + 1
            right_child = 2 * i + 2
            if left_child < self.num_nodes:
                adj_matrix[i, left_child] = 1
                adj_matrix[left_child, i] = 1
            if right_child < self.num_nodes:
                adj_matrix[i, right_child] = 1
                adj_matrix[right_child, i] = 1
        return adj_matrix

    def _get_preorder_indices(self):
        preorder_indices = []
        stack = [0]
        while stack:
            node = stack.pop()
            preorder_indices.append(node)
            right_child = 2 * node + 2
            left_child = 2 * node + 1
            if right_child < self.num_nodes:
                stack.append(right_child)
            if left_child < self.num_nodes:
                stack.append(left_child)
        return preorder_indices

    def _reorder_adj_matrix(self, preorder_indices):
        new_adj_matrix = self.adj_matrix[preorder_indices][:, preorder_indices]
        return new_adj_matrix

    def _normalize_adj_matrix(self):
        adj_matrix = torch.FloatTensor(self.adj_matrix)
        if self.include_self_loops:
            adj_matrix = adj_matrix + torch.eye(self.num_nodes)  # 添加自连接
        degree_matrix = torch.diag(torch.pow(adj_matrix.sum(dim=1), -0.5))
        adj_norm = torch.mm(torch.mm(degree_matrix, adj_matrix), degree_matrix)
        adj_norm = adj_norm.unsqueeze(0).expand(self.batch_size, -1, -1)
        return adj_norm

    def _generate_ancestor_list(self, preorder_indices):
        ancestor_list = []
        for node in preorder_indices:
            ancestors = []
            current_node = node
            while current_node != 0:
                parent = (current_node - 1) // 2
                ancestors.append(parent)
                current_node = parent
            ancestor_list.append([preorder_indices.index(a) for a in ancestors[::-1]])
        max_ancestors = max(len(ancestors) for ancestors in ancestor_list)
        padded_ancestor_list = [
            ancestors + [-1] * (max_ancestors - len(ancestors)) for ancestors in ancestor_list
        ]
        return np.array(padded_ancestor_list)

    def _generate_left_right_flags(self, preorder_indices):
        left_right_flags = np.full(self.num_nodes, -1)
        for i, node in enumerate(preorder_indices):
            if node == 0:
                continue
            parent = (node - 1) // 2
            parent_index = preorder_indices.index(parent)
            if node % 2 == 1:
                left_right_flags[i] = 0
            else:
                left_right_flags[i] = 1
        return left_right_flags

    def _generate_is_leaf(self, preorder_indices):
        is_leaf = np.ones(self.num_nodes, dtype=bool)
        for i, node in enumerate(preorder_indices):
            left_child = 2 * node + 1
            right_child = 2 * node + 2
            if left_child < self.num_nodes or right_child < self.num_nodes:
                is_leaf[i] = False
        return is_leaf

    def _generate_adjusted_parent_indices(self, preorder_indices):
        """
        生成调整后的父节点索引数组
        :return: 调整后的父节点索引数组 (num_nodes,)
        """
        adjusted_parent_indices = []
        for i, node in enumerate(preorder_indices):
            if node == 0:  # 根节点没有父节点
                adjusted_parent_indices.append(-1)
            else:
                parent = (node - 1) // 2
                parent_index = preorder_indices.index(parent)
                adjusted_parent_indices.append(parent_index)
        return np.array(adjusted_parent_indices)

    def _precompute_type_constraints(self):
        """
        预先计算类型限制
        :return: 类型限制 (num_nodes, num_classes)
        """
        type_constraints = torch.ones((self.num_nodes, self.m), dtype=torch.bool)
        if self.tree_type == "full":  # 满二叉树限制
            for idx, arity in enumerate(self.operator_arities):
                if arity == 0:  # 变量符号
                    type_constraints[:, idx] &= torch.tensor(self.is_leaf, dtype=torch.bool)
                elif arity > 0:  # 运算符
                    type_constraints[:, idx] &= ~torch.tensor(self.is_leaf, dtype=torch.bool)
        elif self.tree_type == "partial":  # 优先完备树限制
            type_constraints[0, :] = torch.tensor(
                [arity > 0 for arity in self.operator_arities], dtype=torch.bool
            )  # 根节点只能采样运算符
            type_constraints[1:, :] = True  # 其他节点可以采样任意符号
            type_constraints[self.is_leaf, :] = torch.tensor(
                [arity == 0 for arity in self.operator_arities], dtype=torch.bool
            )  # 叶子节点只能采样变量符号
        return type_constraints

    def _build_model(self):
        return GCN(nfeat=self.m + 1, nhid=self.hidden, nlayers=self.nlayers, nclass=self.m)

    def apply_rule_constraints(self, valid_actions, features, i):
        """
        应用规则限制
        :param valid_actions: 初始的有效动作掩码 (batch_size, num_classes)
        :param features: 当前的节点特征 (batch_size, num_nodes, feature_dim)
        :param i: 当前采样步索引
        :return: 更新后的有效动作掩码 (batch_size, num_classes)
        """
        batch_size, num_classes = valid_actions.shape

        # 获取当前节点的父节点索引
        parent_indices = self.adjusted_parent_indices[i]  # 当前节点的父节点索引
        if parent_indices == -1:  # 如果是根节点，没有父节点
            return valid_actions

        # 获取父节点的动作
        parent_actions = features[torch.arange(batch_size), parent_indices].argmax(dim=-1)  # (batch_size,)

        # 如果父节点是空节点（self.m），则当前节点必须为空节点
        parent_is_empty = parent_actions == self.empty_token
        # 创建一个布尔掩码，用于选择符合条件的样本
        empty_mask = parent_is_empty.unsqueeze(-1).expand(-1, num_classes)  # (batch_size, num_classes)

        # 将符合条件的样本的所有动作设置为 False
        valid_actions = valid_actions.clone()  # 避免原地修改警告
        valid_actions[empty_mask] = False

        # 如果父节点是元数为 1 的运算符，并且当前节点是右子节点，则必须为空节点
        is_right_child = self.left_right_flags[i] == 1  # 当前节点是否是右子节点
        parent_is_unary_op = torch.zeros_like(parent_actions, dtype=torch.bool)
        for action in range(self.m):
            if self.operator_arities[action] == 1:
                parent_is_unary_op |= (parent_actions == action)

        if is_right_child:
            # 将满足条件的批次的所有动作设置为 False
            mask = parent_is_unary_op.unsqueeze(-1).expand(-1, num_classes)  # (batch_size, num_classes)

            # 将符合条件的样本的所有动作设置为 False
            valid_actions = valid_actions.clone()  # 避免原地修改警告
            valid_actions[mask] = False

        return valid_actions

    def _avoid_trig_nesting(self, valid_actions, features, i):
        """
        避免三角函数互相嵌套
        :param valid_actions: 初始的有效动作掩码 (batch_size, num_classes)
        :param features: 当前的节点特征 (batch_size, num_nodes, feature_dim)
        :param i: 当前采样步索引
        :return: 更新后的有效动作掩码 (batch_size, num_classes)
        """
        batch_size, num_classes = valid_actions.shape

        # 获取当前节点的所有祖先节点
        ancestor_indices = self.ancestor_list[i]  # (max_ancestors,)
        ancestor_indices = ancestor_indices[ancestor_indices != -1]  # 剔除填充的 -1

        # 获取祖先节点的动作
        ancestor_actions = features[:, ancestor_indices].argmax(dim=-1)  # (batch_size, max_ancestors)

        # 确保 ancestor_actions 是整数张量
        ancestor_actions = ancestor_actions.long()

        # 检查祖先节点是否有三角函数
        has_trig_ancestor = torch.any(torch.isin(ancestor_actions, torch.tensor(self.trig)), dim=1)  # (batch_size,)

        # 创建一个布尔掩码，用于选择符合条件的样本
        mask = has_trig_ancestor.unsqueeze(-1).expand(-1, num_classes)  # (batch_size, num_classes)

        # 创建一个布尔掩码，禁止当前动作
        trig_mask = torch.isin(torch.arange(num_classes).unsqueeze(0).expand(batch_size, -1),
                               torch.tensor(self.trig))  # (batch_size, num_classes)

        # 更新有效动作掩码
        valid_actions = valid_actions.clone()  # 避免原地修改警告
        valid_actions[mask] &= ~trig_mask[mask]

        return valid_actions

    def _avoid_sampling_if_parent_is_variable(self, valid_actions, features, i):
        """
        如果当前节点的父节点是变量节点，则禁止采样任何动作
        :param valid_actions: 初始的有效动作掩码 (batch_size, num_classes)
        :param features: 当前的节点特征 (batch_size, num_nodes, feature_dim)
        :param i: 当前采样步索引
        :return: 更新后的有效动作掩码 (batch_size, num_classes)
        """
        batch_size, num_classes = valid_actions.shape

        # 获取当前节点的直接父节点索引
        parent_index = self.adjusted_parent_indices[i]  # 当前节点的父节点索引

        # 获取父节点的动作
        parent_actions = features[torch.arange(batch_size), parent_index].argmax(dim=-1)  # (batch_size,)

        # 确保 parent_actions 是整数张量
        parent_actions = parent_actions.long()

        # 检查父节点是否为变量节点
        is_variable_parent = torch.isin(parent_actions, torch.tensor(self.variable_actions))  # (batch_size,)

        # 创建一个布尔掩码，用于选择符合条件的样本
        mask = is_variable_parent.unsqueeze(-1).expand(-1, num_classes)  # (batch_size, num_classes)

        # 更新有效动作掩码
        valid_actions = valid_actions.clone()  # 避免原地修改警告
        valid_actions[mask] = False  # 禁止所有动作

        return valid_actions

    def _avoid_exp_log_constraints(self, valid_actions, features, i):
        """
        避免 exp 和 log 的自身嵌套以及直接相邻
        :param valid_actions: 初始的有效动作掩码 (batch_size, num_classes)
        :param features: 当前的节点特征 (batch_size, num_nodes, feature_dim)
        :param i: 当前采样步索引
        :return: 更新后的有效动作掩码 (batch_size, num_classes)
        """
        batch_size, num_classes = valid_actions.shape

        # 获取当前节点的所有祖先节点
        ancestor_indices = self.ancestor_list[i]  # (max_ancestors,)
        ancestor_indices = ancestor_indices[ancestor_indices != -1]  # 剔除填充的 -1

        # 获取祖先节点的动作
        ancestor_actions = features[:, ancestor_indices].argmax(dim=-1)  # (batch_size, max_ancestors)

        # 确保 ancestor_actions 是整数张量
        ancestor_actions = ancestor_actions.long()

        # 禁止自身嵌套
        for action in self.exp_log:
            # 找到祖先节点中等于当前动作的样本
            mask_self_nesting = torch.any(ancestor_actions == action, dim=1)  # (batch_size,)
            mask = mask_self_nesting.unsqueeze(-1).expand(-1, num_classes)  # (batch_size, num_classes)

            # 创建一个布尔掩码，禁止当前动作
            action_mask = torch.arange(num_classes).unsqueeze(0).expand(batch_size,
                                                                        -1) == action  # (batch_size, num_classes)

            # 更新有效动作掩码
            valid_actions = valid_actions.clone()  # 避免原地修改警告
            valid_actions[mask] &= ~action_mask[mask]

        # 获取当前节点的直接父节点索引
        parent_index = self.adjusted_parent_indices[i]  # 当前节点的父节点索引
        # 获取父节点的动作
        parent_actions = features[torch.arange(batch_size), parent_index].argmax(dim=-1)  # (batch_size,)

        # 禁止直接相邻的情况（如 exp(log(x)) 或 log(exp(x))）
        for exp_action, log_action in [(self.exp_log[0], self.exp_log[1]), (self.exp_log[1], self.exp_log[0])]:
            # 找到父节点等于 exp_action 的样本
            mask_adjacent = (parent_actions == exp_action)  # (batch_size,)
            mask = mask_adjacent.unsqueeze(-1).expand(-1, num_classes)  # (batch_size, num_classes)

            # 创建一个布尔掩码，禁止当前节点采样 log_action
            action_mask = torch.arange(num_classes).unsqueeze(0).expand(batch_size,
                                                                            -1) == log_action  # (batch_size, num_classes)

            # 更新有效动作掩码
            valid_actions = valid_actions.clone()  # 避免原地修改警告
            valid_actions[mask] &= ~action_mask[mask]

        return valid_actions

    def modify_sampling_distribution(self, adj, features, hidden_state, i):
        """
        修改采样分布以满足限制条件
        :param adj: 当前的邻接矩阵 (batch_size, num_nodes, num_nodes)
        :param features: 当前的节点特征 (batch_size, num_nodes, feature_dim)
        :param hidden_state: 当前的隐藏状态 (batch_size, num_classes)
        :param i: 当前采样步索引
        :return: 修改后的采样分布 (batch_size, num_classes)
        """
        batch_size, num_classes = hidden_state.shape

        # 检查采样分布是否包含 NaN 或 Inf
        hidden_state = check_for_nan(hidden_state, name="Sampling Probabilities")

        # 获取当前节点的类型限制
        valid_actions = self.type_constraints[i].unsqueeze(0).expand(batch_size, -1)  # (batch_size, num_classes)

        if i != 0:
            # 应用三角函数嵌套限制
            valid_actions = self._avoid_trig_nesting(valid_actions, features, i)

            # 应用 exp 和 log 的限制
            valid_actions = self._avoid_exp_log_constraints(valid_actions, features, i)

        # 应用规则限制
        valid_actions = self._avoid_sampling_if_parent_is_variable(valid_actions, features, i)

        valid_actions = self.apply_rule_constraints(valid_actions, features, i)

        if self.is_leaf[i]:
            expected_valid_actions = torch.tensor([arity == 0 for arity in self.operator_arities], dtype=torch.bool)
            all_false_actions = torch.tensor([False for _ in self.operator_arities], dtype=torch.bool)
            if not (valid_actions[0] == expected_valid_actions).all() and not (
                    valid_actions[0] == all_false_actions).all():
                raise ValueError("Leaf node constraints were violated!")

        # 只对合法动作计算 softmax
        #mask = valid_actions.float()  # (batch_size, num_classes)
        masked_hidden_state = hidden_state.masked_fill(~valid_actions, float('-inf'))  # 非法动作设置为 -inf
        modified_probs = F.softmax(masked_hidden_state, dim=-1)  # 对合法动作计算 softmax

        # 确保非法动作的概率为 0
        #modified_probs = sampling_probs * mask

        return modified_probs, valid_actions

    def sample_random_legal_actions(self, valid_actions):
        """
        从合法动作中随机采样
        :param valid_actions: 合法动作掩码 (batch_size, m)
        :return: 随机采样的合法动作索引 (batch_size,)
        """
        # 创建一个均匀分布的概率分布，仅保留合法动作
        uniform_probs = valid_actions.float()
        row_sums = uniform_probs.sum(dim=-1, keepdim=True)  # (batch_size, 1)

        # 检查是否存在没有合法动作的样本
        has_valid_action = row_sums.squeeze(-1) > 0  # (batch_size,)

        if not has_valid_action.all():
            invalid_samples = torch.where(~has_valid_action)[0].tolist()  # 找到没有合法动作的样本索引
            print("Error: Some samples have no valid actions!")
            print("Invalid samples:", invalid_samples)
            print("Valid actions for invalid samples:")
            for sample_idx in invalid_samples:
                print(f"Sample {sample_idx}:", valid_actions[sample_idx])
            raise ValueError("Some samples have no valid actions!")

        # 归一化概率分布
        uniform_probs = uniform_probs / row_sums.clamp(min=1e-10)

        # 从合法动作中随机采样
        return torch.multinomial(uniform_probs, 1).squeeze(-1)

    def check_actions(self, actions, valid_actions, action_type):
        """
        检查动作是否符合合法性约束
        :param actions: 动作索引 (batch_size,)
        :param valid_actions: 合法动作掩码 (batch_size, m)
        :param action_type: 动作类型（用于调试信息）
        """
        illegal_samples = []
        for sample_idx in range(actions.size(0)):
            if not valid_actions[sample_idx, actions[sample_idx]]:
                illegal_samples.append(sample_idx)

        if illegal_samples:
            print(f"Error: Illegal {action_type} detected!")
            print("Illegal samples:", illegal_samples)
            print(f"{action_type} for illegal samples:")
            for sample_idx in illegal_samples:
                print(
                    f"Sample {sample_idx}: Action = {actions[sample_idx]}, Valid actions = {valid_actions[sample_idx]}")
            raise ValueError(f"Illegal {action_type} detected!")

    def check_final_actions(self, actions, valid_actions, i):
        """
        检查最终的动作是否符合合法性约束
        :param actions: 最终采样的动作索引 (batch_size,)
        :param valid_actions: 合法动作掩码 (batch_size, m)
        :param i: 当前采样步索引
        """
        illegal_samples = []
        for sample_idx in range(actions.size(0)):
            if not valid_actions[sample_idx, actions[sample_idx]]:
                illegal_samples.append(sample_idx)

        if illegal_samples:
            print(f"Error: Illegal final actions detected at step {i}:")
            print("Illegal samples:", illegal_samples)
            print("Final actions for illegal samples:")
            for sample_idx in illegal_samples:
                print(
                    f"Sample {sample_idx}: Action = {actions[sample_idx]}, Valid actions = {valid_actions[sample_idx]}")
            raise ValueError("Illegal final actions detected!")


    def generate(self):
        all_node_features = torch.full((self.batch_size, self.num_nodes), self.m, dtype=torch.long)
        all_expressions = []
        probs_stack = []
        modified_probs_stack =[]
        for i in range(self.num_nodes):
            one_hot_features = self._to_one_hot(all_node_features)
            output = self.model(one_hot_features, self.adj_norm)

            modified_probs, valid_actions = self.modify_sampling_distribution(
                adj=self.adj_norm,
                features=one_hot_features,
                hidden_state=output[:,i],
                i=i
            )

            # 检查是否有 NaN 并处理
            nan_mask = torch.isnan(modified_probs).any(dim=-1)  # (batch_size,)
            if nan_mask.any():
                # 如果某个批次的概率分布包含 NaN，则直接采样空动作 self.m
                actions = torch.full((self.batch_size,), self.empty_token, dtype=torch.long)
                actions[~nan_mask] = torch.multinomial(modified_probs[~nan_mask], 1).squeeze(-1)
                modified_probs[nan_mask] = torch.zeros_like(modified_probs[0])  # 替换为全零张量

            else:
                # 正常采样
                #actions = torch.multinomial(modified_probs, 1).squeeze(-1)
                # 使用 ε-greedy 策略进行采样（并行化）
                # 检查采样分布是否包含 NaN 或 Inf
                random_actions = self.sample_random_legal_actions(valid_actions)
                sampled_actions = torch.multinomial(modified_probs, 1).squeeze(-1)

                # 检查 random_actions 和 sampled_actions 的合法性
                #self.check_actions(random_actions, valid_actions, "Random Actions")
                #self.check_actions(sampled_actions, valid_actions, "Sampled Actions")

                # 根据 ε-greedy 决策选择动作
                explore_mask = torch.rand(self.batch_size) < self.epsilon  # 探索掩码
                actions = torch.where(explore_mask, random_actions, sampled_actions)

                #self.check_final_actions(actions, valid_actions, i)



            all_node_features[:, i] = actions
            all_expressions.append(actions.cpu().numpy())

        all_expressions = np.array(all_expressions).T.tolist()

        return all_expressions, all_node_features
    def evaluate_sequence(self, all_expressions):
        """
        基于逐步采样的方式重新计算对数概率和熵值。
        参数:
        - all_expressions: 动作序列 (batch_size, num_nodes)。
        返回:
        - log_probs: 对数概率 (batch_size, num_nodes)，空节点处为 0。
        - entropies: 熵值 (batch_size, num_nodes)，空节点处为 0。
        """
        # 初始化所有节点特征为 self.m（空节点）
        all_node_features = torch.full((self.batch_size, self.num_nodes), self.m, dtype=torch.long)

        # 存储对数概率和熵值
        log_probs_list = []
        entropies_list = []

        for i in range(self.num_nodes):
            # 获取当前时间步的动作
            actions = all_expressions[:, i]  # (batch_size,)

            # 将当前节点特征转换为 one-hot 编码
            one_hot_features = self._to_one_hot(all_node_features)  # (batch_size, num_nodes, num_classes)

            # 使用模型预测当前时间步的动作分布
            output = self.model(one_hot_features, self.adj_norm)
            probs = F.softmax(output[:, i], dim=1)  # 当前时间步的概率分布 (batch_size, num_classes)

            # 计算对数概率
            safe_probs = probs + 1e-9  # 避免 log(0)
            local_actions = actions.clone()
            invalid_mask = local_actions == self.empty_token  # 找到无效动作的掩码

            # 初始化对数概率和熵值
            log_probs = torch.zeros_like(local_actions, dtype=torch.float32)  # (batch_size,)
            entropy = torch.zeros_like(local_actions, dtype=torch.float32)  # (batch_size,)

            # 仅对非空符号的动作计算对数概率和熵值
            valid_indices = ~invalid_mask  # 非空符号的索引
            if valid_indices.any():
                valid_actions = local_actions[valid_indices]  # 非空符号的动作
                valid_probs = safe_probs[valid_indices]  # 非空符号的概率分布

                # 计算对数概率
                log_probs[valid_indices] = torch.log(valid_probs[torch.arange(len(valid_actions)), valid_actions])

                # 计算熵值
                entropy[valid_indices] = -(valid_probs * torch.log(valid_probs)).sum(dim=-1)

            # 更新节点特征
            all_node_features[:, i] = actions

            # 存储对数概率和熵值
            log_probs_list.append(log_probs)
            entropies_list.append(entropy)

        # 将结果堆叠为张量
        log_probs = torch.stack(log_probs_list, dim=1)  # (batch_size, num_nodes)
        entropies = torch.stack(entropies_list, dim=1)  # (batch_size, num_nodes)

        return log_probs, entropies

    def _to_one_hot(self, indices):
        """
        将节点特征转换为 one-hot 编码
        :param indices: 节点特征索引 (batch_size, num_nodes)
        :return: one-hot 编码 (batch_size, num_nodes, num_classes)
        """
        one_hot = F.one_hot(indices, num_classes=self.m + 1).float()
        return one_hot

    def remove_invalid_actions_parallel(self, expressions):
        """
        使用 PyTorch 并行化剔除无效值（self.m）
        :param expressions: 表达式数组 (batch_size, num_nodes)
        :return: 清理后的表达式数组 (batch_size, variable_length)
        """
        # 将表达式转换为 PyTorch 张量
        expressions_tensor = torch.tensor(expressions, dtype=torch.long)  # (batch_size, num_nodes)

        # 创建布尔掩码，标记有效值
        valid_mask = expressions_tensor != self.m  # (batch_size, num_nodes)

        # 使用布尔掩码提取有效值
        cleaned_expressions = [expr[mask].tolist() for expr, mask in zip(expressions_tensor, valid_mask)]

        return cleaned_expressions

    def generate_valid_mask(self, expressions):
        """
        生成布尔掩码，标记有效值和无效值
        :param expressions: 表达式数组 (batch_size, num_nodes)
        :return: 布尔掩码 (batch_size, num_nodes)，其中 self.m 的值为 False，其他值为 True
        """
        # 将表达式转换为 PyTorch 张量
        expressions_tensor = torch.tensor(expressions, dtype=torch.long)  # (batch_size, num_nodes)

        # 创建布尔掩码，标记有效值
        valid_mask = expressions_tensor != self.m  # (batch_size, num_nodes)

        return valid_mask

class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        nn.init.xavier_uniform_(self.weight)
        # 使用 He 初始化
        #nn.init.kaiming_uniform_(self.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, input, adj):
        support = torch.matmul(input, self.weight.unsqueeze(0))
        output = torch.bmm(adj, support)
        return output

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nlayers, nclass):
        super().__init__()
        self.gc_layers = nn.ModuleList([
            GraphConvolution(nfeat if i == 0 else nhid, nhid)
            for i in range(nlayers)
        ])
        self.mlp = nn.Linear(nhid, nclass)

        # 对 MLP 层使用 He 初始化
        '''
        nn.init.kaiming_uniform_(self.mlp.weight, mode='fan_in', nonlinearity='relu')
        nn.init.zeros_(self.mlp.bias)  # 偏置初始化为 0
        '''

    def forward(self, x, adj):
        for gc in self.gc_layers:
            x = F.relu(gc(x, adj))
        x = self.mlp(x)
        return x


def is_valid_expression(expressions, operator_arities):
    """
    检查生成的表达式数组是否是合法的表达式
    :param expressions: 表达式数组 (batch_size, num_nodes)
    :param operator_arities: 运算符元数数组，例如 [2, 2, 1, 0]
    :return: 每个表达式是否合法的布尔数组 (batch_size,)
    """
    # 将 operator_arities 扩展一个无效值（4 对应的贡献为 0）
    extended_arities = operator_arities + [0]  # 最后一个值对应无效值 4

    valid_flags = []

    for expr in expressions:
        # 计算累积贡献值
        cumulative_sum = 0
        for action in expr:
            # 获取当前动作对应的元数贡献
            arity_contribution = extended_arities[action] - 1  # 二元操作符贡献 1，一元 0，变量 -1
            cumulative_sum += arity_contribution

        # 检查累积和是否为 -1
        valid_flags.append(cumulative_sum == -1)

    return torch.tensor(valid_flags, dtype=torch.bool)

if __name__ == "__main__":
    # 示例调用
    n = 5  # 最大深度
    operator_arities = [2, 2, 1, 0]  # 运算符元数数组 [-, +, sin, x1]
    batch_size = 5  # 批处理大小
    include_self_loops = True  # 是否包含自连接
    tree_type = "full"  # 树的类型 ("full" 或 "partial")

    # Example usage
    operator_functions, _, trig, exp_log = register_operator_functions()  # Get operator functions
    input_array = np.random.rand(100, 3)  # Generate a random (100, 3) input array
    variable_dict = generate_variable_info(input_array)  # Generate variable info

    combined_dict, operator_functions_only, operator_arities = merge_operator_and_variable_dict(operator_functions, variable_dict)

    generator = SymbolicExpressionGenerator(32, n, operator_arities, batch_size, trig, exp_log, include_self_loops, tree_type)

    # 输出辅助信息
    print("Ancestor List:")
    print(generator.ancestor_list)
    print("\nLeft/Right Flags:")
    print(generator.left_right_flags)
    print("\nIs Leaf:")
    print(generator.is_leaf)
    print("\nAdjusted Parent Indices:")
    print(generator.adjusted_parent_indices)

    expressions, nodes = generator.generate()
    print(expressions)
    logs, entropies = generator.evaluate_sequence(nodes)
    print("only_ppo_count")
    print(logs)
    print("entropies")
    print(entropies)

    cleaned_expressions = generator.remove_invalid_actions_parallel(expressions)
    print("Cleaned Expressions:")
    for expr in cleaned_expressions:
        print(expr)

    fixed_expressions = [complete_tokens(operator_arities, cleaned_expressions[i]) for i in range(batch_size)]
    print("Fixed Expressions:")
    for expr in fixed_expressions:
        print(expr)

    # 检查清理后的表达式是否合法
    valid_flags = is_valid_expression(cleaned_expressions, operator_arities)
    print("Valid Flags for Cleaned Expressions:", valid_flags)

    # 调用生成布尔掩码的函数
    valid_mask = generator.generate_valid_mask(expressions)
    print("Valid Mask:")
    print(valid_mask)