import numpy as np
from function import *  # 引入检查和补充函数
import copy
from scipy.stats import multivariate_normal
from normalize_functions import get_flattened_normalized_expression
def positional_encoding(pos, idx, d):
    """
    计算单个符号的位置编码值。
    :param pos: 当前符号在表达式中的位置索引（从 0 开始）。
    :param idx: 符号在符号库中的位置索引。
    :param d: 符号库的大小。
    :return: 位置编码值。
    """

    """64 是最佳的"""
    if idx % 2 == 0:
        return np.sin(pos / (64 ** (idx / d)))  # 偶数索引用 sin
    else:
        return np.cos(pos / (64 ** (idx / d)))  # 奇数索引用 cos


class Program:
    operator_arities = []  # 类变量存储操作符元数数组
    operator_dict = {}  # 存储操作符字典
    operator_indices = []  # 存储操作符索引
    input_data = None  # 输入数据
    label_data = None  # 标签数据
    instance_cache = {}  # 缓存已创建的 Program 实例

    @classmethod
    def initialize(cls, operators_dict, all_operator_dict, operator_indices, operator_arities, input_data, label_data, const_value):
        """
        Initializes the Program class with operator functions, arities, and dataset.

        Parameters:
        ----------
        operator_dict : dict
            Dictionary containing operator functions and their arity.
        operator_indices : list
            List of operator indices (pre-order traversal).
        operator_arities : list
            List of arities for operators and variables.
        input_data : np.array
            The input dataset (e.g., a 100x3 array).
        label_data : np.array
            The target dataset (e.g., a 100x1 array).
        """
        cls.original_operators_dict = operators_dict
        cls.operator_dict = all_operator_dict
        cls.operator_indices = operator_indices
        cls.operator_arities = operator_arities
        cls.input_data = input_data
        cls.label_data = label_data

        cls.const_value = const_value

        cls.operator_priority = list(operators_dict.keys())

    def __new__(cls, index_sequence):
        """
        This method is responsible for managing the creation of Program instances.
        It checks if an instance with the same index_sequence already exists in the cache.
        """
        # 关键：将 index_sequence 转为 tuple 类型，确保它是可哈希的
        index_sequence_key = tuple(index_sequence)

        # Check if the instance already exists in the cache
        if index_sequence_key in cls.instance_cache:
            return cls.instance_cache[index_sequence_key]  # Return the cached instance

        instance = super().__new__(cls)
        instance.instance_cache = cls.instance_cache  # Ensure instance has access to the class cache
        cls.instance_cache[index_sequence_key] = instance  # Store the instance in the cache
        return instance

    def __init__(self, index_sequence):
        """
        Initializes the Program instance with the given index sequence.
        It builds the expression and prepares the computation.

        Parameters:
        ----------
        index_sequence : list
            A list of indices representing the pre-order traversal of an expression.
        """
        self.index_sequence = copy.deepcopy(index_sequence)
        self.expression = self.build_expression()
        self.normalized_expression,_ = get_flattened_normalized_expression(self.expression, self.operator_priority)
        self.mse = None
        self.reward = 0

        # Check and complete the expression if needed
        self.evaluate()

    @classmethod
    def clear_cache(cls):
        """
        Clears the cache of Program instances.
        """
        cls.instance_cache.clear()
        print("Cache cleared.")

    def build_expression(self):
        """
        Build an expression from the given index sequence based on the operator dictionary.

        Returns:
        -------
        expression : list
            A list representing the expression in symbolic form.
        """
        expression = []
        for idx in self.index_sequence:
            expression.append(list(self.operator_dict.keys())[idx])  # Operator symbol
        return expression

    def update_data(self, input_data, label_data):
        self.input_data = input_data.copy()
        self.label_data = label_data.copy()
    def evaluate(self):
        """
        Evaluate the expression using the input data and calculate the MSE with the label data.

        Returns:
        -------
        mse : float
            The mean squared error between the predicted and actual labels.
        """
        # Create a dictionary for variable values
        # Assume input_data is of shape (100, n), where n is the number of variables
        variable_values = {f'x_{i + 1}': self.input_data[:, i] for i in range(self.input_data.shape[1])}
        if self.const_value:
            variable_values['const'] = np.array([self.const_value for i in range(self.input_data.shape[0])])

        # Evaluate the expression for all rows at once
        # Each row corresponds to a set of variable values
        predictions = np.array(self.eval_expression(self.expression, variable_values))

        if np.any(np.isinf(predictions)) or np.any(np.isnan(predictions)):
            self.mse = np.inf
            self.reward = 0
        else:
            self.mse = np.mean(np.square(predictions - self.label_data))
            self.reward = 1/(1+self.mse)

    def eval_expression(self, expression, variable_values):
        """
        Evaluate the symbolic expression with a given set of variable values.

        Parameters:
        ----------
        expression : list
            A list of operators and variables in symbolic form.
        variable_values : dict
            A dictionary mapping variable names (e.g., x_1, x_2) to their values.

        Returns:
        -------
        result : float
            The result of the expression.
        """
        stack = []
        if len(expression) == 0:
            print(expression)
        try:
            for token in reversed(expression):
                if not token.startswith("x") and not token.startswith("con") :  # It's an operator
                    arity = self.operator_arities[list(self.operator_dict.keys()).index(token)]
                    operands = [stack.pop() for _ in range(arity)]
                    result = self.operator_dict[token]['function'](*operands)
                    stack.append(result)
                else:  # It's a variable or const
                    stack.append(variable_values[token])
        except Exception as e:
            print(expression)
            print(f"错误类型: {type(e).__name__}")  # 错误类型（如 ZeroDivisionError）
            print(f"错误信息: {e}")  # 错误消息（如 "division by zero"）


        return stack[-1]  # Final result

    def expression_to_vector(self):
        """
        将单个前序表达式编码为向量表示。
        :param expression: 前序形式的表达式（字符串，按空格分隔）。
        :param symbol_to_index: 符号到索引的映射。
        :param d: 符号库的大小。
        :return: 编码后的向量。
        """

        d = len(self.operator_arities)
        vector = np.zeros(d)
        tokens = self.normalized_expression.copy()
        keys_list = list(self.operator_dict.keys())
        for pos, token in enumerate(tokens):
            if token in self.operator_dict:
                idx = keys_list.index(token)
                vector[idx] += positional_encoding(pos, idx, d)

        # L2 归一化：将向量标准化为单位长度
        norm = np.linalg.norm(vector)
        return vector / norm if norm > 0 else vector



def compute_statistics_and_nll(vectors):
    """
    计算编码向量的均值、协方差矩阵，并计算每个向量的负对数似然（NLL）。
    :param vectors: 二维数组，形状为 (n_samples, n_features)，表示 n_samples 个样本，每个样本有 n_features 个特征。
    :return: 一维数组，形状为 (n_samples,)，表示每个样本的 NLL 分数。
    """
    # 检查输入是否为空
    if vectors.size == 0:
        raise ValueError("Input vectors array is empty.")

    # 删除全零维度
    non_zero_dims = np.any(vectors != 0, axis=0)
    if not np.any(non_zero_dims):
        raise ValueError("All dimensions are zero. Check input data for valid encoding.")
    vectors_reduced = vectors[:, non_zero_dims]

    # 计算均值和协方差矩阵
    mean_vector = np.mean(vectors_reduced, axis=0)
    cov_matrix = np.cov(vectors_reduced, rowvar=False)

    # 定义多元正态分布
    try:
        mvn = multivariate_normal(mean=mean_vector, cov=cov_matrix, allow_singular=True)
        # 计算每个向量的负对数似然（NLL）
        nll_scores = -mvn.logpdf(vectors_reduced)
    except:
        nll_scores = np.ones_like(vectors[0]) * -1

    # 筛查无效值（如 NaN 或 inf）
    if np.any(np.isnan(nll_scores)) or np.any(np.isinf(nll_scores)):
        raise ValueError("Invalid NLL values detected (NaN or inf). Check the covariance matrix and input data.")

    return nll_scores


if __name__ == "__main__":
    # Example usage of Program class
    operator_functions, operator_arities,_,_ = register_operator_functions()
    input_data = np.random.rand(100, 5)  # Random input data (100 samples, 3 features)
    label_data = np.random.rand(100, 1)  # Random label data (100 samples)
    variable_dict = generate_variable_info(input_data)  # Generate variable info
    combined_dict, operator_indices, operator_arities = merge_operator_and_variable_dict(operator_functions,
                                                                                         variable_dict)
    # Initialize the Program class
    Program.initialize(operator_functions, combined_dict, operator_indices, operator_arities, input_data, label_data)

    # 新增部分：批量生成 Program 实例
    num_programs = 5  # 要生成的 Program 实例数量
    max_sequence_length = 5  # 最大序列长度
    programs = []  # 存储 Program 实例的数组

    expr = [[1, 8, 8],[0, 8, 8],[2, 8, 8],[3, 8, 8], [5, 8]]

    for _ in range(num_programs):
        # 随机生成 index_sequence

        # 创建 Program 实例并评估
        program_instance = Program(expr[_])
        #program_instance.evaluate()  # 计算 MSE 和 reward

        # 将实例添加到数组中
        programs.append(program_instance)

    # 按照 reward 值从大到小排序
    programs.sort(key=lambda p: p.reward, reverse=True)

    # 输出排序后的结果
    encoding_vectors = []
    print("Sorted Programs by Reward:")
    for i, program in enumerate(programs):
        print(f"Program {i + 1}: Reward = {program.reward}, MSE = {program.mse}")
        print(program.expression_to_vector())
        encoding_vectors.append(program.expression_to_vector())
    encoding_vectors = np.array(encoding_vectors)

    nll_scores = compute_statistics_and_nll(encoding_vectors)
    print("NLL Scores over all programs:")
    print(nll_scores)

    # 检查是否满足停止条件
    for program in programs:
        if program.mse < 1e-8:
            print(f"Stopping all programs because a program with MSE < 1e-8 was found.")
            break

    # 清除缓存
    Program.clear_cache()
