import torch
import torch.nn.functional as F
from program import Program, compute_statistics_and_nll
from sr_gen import SymbolicExpressionGenerator
from function import *
from normalize_functions import *
from sympy import symbols, lambdify, parse_expr
from sympy.parsing.sympy_parser import standard_transformations, implicit_multiplication_application, function_exponentiation
from deap import gp, algorithms, tools, creator, base
from gp import prefix_to_internal_expression, expr_to_sequence, sequence_to_operator_sequence, SymbolicRegression
from simple_expr import parse_expression
def count_reward(nll_scores):
    return  nll_scores

class SymbolicExpressionTrainer:
    def __init__(self, generator, gp_module, learning_rate=0.01, entropy_weight=0.1, count_reward_weight=0.1, use_risk_seeking=False, ppo=False, clip_epsilon = 0.01, quantile=0.75, log_file = './only_ppo_count'):
        """
        :param generator: SymbolicExpressionGenerator 实例
        :param learning_rate: 学习率
        :param entropy_weight: 最大熵权重
        :param use_risk_seeking: 是否使用 Risk-Seeking 策略梯度方法
        :param quantile: 分位数阈值，默认为 0.75
        """
        self.generator = generator
        self.gp_sr = gp_module
        self.entropy_weight = entropy_weight
        self.use_risk_seeking = use_risk_seeking  # 是否启用 Risk-Seeking 方法
        self.quantile = quantile  # 分位数阈值
        self.best_one = None
        self.best_one_mse = float(np.inf)
        self.best_gp_mse = float(np.inf)
        self.best_gp_expr =None
        self.count_reward_weight = count_reward_weight
        self.ppo = ppo
        self.clip_epsilon = clip_epsilon
        # 优化器

        self.optimizer = torch.optim.Adam(self.generator.model.parameters(), lr=learning_rate)
        self.log_file = log_file
        self.old_log_probs = None

        self.success_flag = False
        self.global_flag = False
        self.global_expr = None



    def reward_function(self, expressions):
        """
        奖励函数（占位函数）
        :param expressions: 生成的表达式数组 (batch_size, num_nodes)
        :return: 每个表达式的奖励值 (batch_size,)
        """
        for item in expressions:
            if len(item) == 0:
                print(expressions)
        programs = [Program(item) for i,item in enumerate(expressions)]
        rewards = [
            0.0 if np.isnan(reward) or np.isinf(reward) else reward
            for reward in [item.reward for item in programs]
        ]

        encoding_vectors = np.array([item.expression_to_vector() for item in programs])
        nll_scores = compute_statistics_and_nll(encoding_vectors)
        nll_scores[nll_scores < 0] = 0

        # 只对非零部分进行归一化
        non_zero_mask = nll_scores > 0  # 找出非零部分的掩码
        non_zero_values = nll_scores[non_zero_mask]  # 提取非零部分

        if non_zero_values.size > 0:  # 如果存在非零值
            # 使用 Min-Max 归一化对非零部分进行缩放
            non_zero_min, non_zero_max = np.min(non_zero_values), np.max(non_zero_values)
            if non_zero_max > non_zero_min:
                normalized_non_zero = (non_zero_values - non_zero_min) / (non_zero_max - non_zero_min)
            else:
                # 如果所有非零值相同，直接设置为 1
                normalized_non_zero = np.ones_like(non_zero_values)

            # 将归一化后的非零部分重新插入原数组
            nll_scores[non_zero_mask] = normalized_non_zero

        count_rewards = count_reward(nll_scores=nll_scores)

        rewards = torch.tensor(rewards)
        count_rewards = torch.tensor(count_rewards)
        return programs, rewards, count_rewards

    def gp_reward_function(self, expressions):

        programs = [Program(item) for i, item in enumerate(expressions)]

        original_rewards = [
            0.0 if np.isnan(reward) or np.isinf(reward) else reward
            for reward in [item.reward for item in programs]
        ]
        '''
        # 检查每个奖励值，并捕获错误
        original_rewards = []
        for item in programs:
            reward = item.reward
            try:
                # 尝试调用 np.isnan 和 np.isinf
                if np.isnan(reward) or np.isinf(reward):
                    original_rewards.append(0)
                else:
                    original_rewards.append(reward)
            except TypeError as e:
                # 捕获 TypeError 并打印导致问题的数据及其类型
                print(f"Error processing reward: {reward} (type: {type(reward)})")
                print(f"Exception: {e}")
                original_rewards.append(0)  # 替换为默认值 0
        '''

        encoding_vectors = np.array([item.expression_to_vector() for item in programs])
        nll_scores = compute_statistics_and_nll(encoding_vectors)
        nll_scores[nll_scores < 0] = 0

        # 只对非零部分进行归一化
        non_zero_mask = nll_scores > 0  # 找出非零部分的掩码
        non_zero_values = nll_scores[non_zero_mask]  # 提取非零部分

        if non_zero_values.size > 0:  # 如果存在非零值
            # 使用 Min-Max 归一化对非零部分进行缩放
            non_zero_min, non_zero_max = np.min(non_zero_values), np.max(non_zero_values)
            if non_zero_max > non_zero_min:
                normalized_non_zero = (non_zero_values - non_zero_min) / (non_zero_max - non_zero_min)
            else:
                # 如果所有非零值相同，直接设置为 1
                normalized_non_zero = np.ones_like(non_zero_values)

            # 将归一化后的非零部分重新插入原数组
            nll_scores[non_zero_mask] = normalized_non_zero

        count_rewards = count_reward(nll_scores=nll_scores)
        count_rewards = torch.tensor(count_rewards)

        exprs = [item.expression for item in programs]
        initial_population = []
        for prefix_expr in exprs:
            expression_str = prefix_to_internal_expression(prefix_expr, self.gp_sr.operator_functions)
            initial_expr = gp.PrimitiveTree.from_string(expression_str, self.gp_sr.pset)
            initial_population.append(creator.Individual(initial_expr))

        best_individual, best_fitness = self.gp_sr.ga_run(initial_pop=initial_population)
        if best_fitness <1e-10:
            self.global_flag = True
            self.global_expr = expr_to_sequence(str(best_individual))

        if np.isnan(best_fitness) or np.isinf(best_fitness):
            reward = 0.0
        else:
            reward = 1/(1+best_fitness)

        # 将最优个体转换为字符串形式、
        best_individual = str(best_individual)
        # 将字符串表达式解析为前序遍历数组
        sequence = expr_to_sequence(best_individual)

        # 将前序遍历数组转换为用户友好的输出
        best_individual_sequence= sequence_to_operator_sequence(sequence, Program.original_operators_dict)

        best_individual_normalized,_ = get_flattened_normalized_expression(best_individual_sequence, Program.operator_priority)

        lcss = []
        lc_substrings = []
        rewards_weight = []
        for item in exprs:
            expr, mapping = get_flattened_normalized_expression(item, Program.operator_priority)
            original_lcs_bool_array, original_substring_bool_array = mark_common_elements_in_original(
                item, expr, mapping, best_individual_normalized)
            lcss.append(original_lcs_bool_array)
            lc_substrings.append(original_substring_bool_array)

            reward_weight = int(sum(original_lcs_bool_array)) / len(item)
            rewards_weight.append(reward_weight)
        rewards = reward * np.array(rewards_weight)
        rewards = torch.tensor(rewards)

        '''避免gp启动失败'''
        if np.isnan(best_fitness) or np.isinf(best_fitness):
            rewards = torch.tensor(original_rewards)

        return programs, rewards, count_rewards

    def weight_log_probs(self, log_probs, lcss):
        return 0


    def compute_loss(self, rewards, log_probs, entropies, count_rewards, idx):
        """
        计算标准策略梯度损失函数
        :param rewards: 奖励值 (batch_size,)
        :param log_probs: 对数概率 (batch_size, num_nodes)
        :param entropies: 动作熵 (batch_size, num_nodes)
        :return: 总损失
        """
        if idx >= 0:
            count_rewards_flag = 1
        else:
            count_rewards_flag = 0

            # 分别对普通奖励和计数奖励进行归一化
        standardized_rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)

        # 结合归一化后的奖励
        combined_rewards = standardized_rewards + count_rewards_flag * self.count_reward_weight * count_rewards

        if self.ppo:
            advantages = combined_rewards - combined_rewards.mean()

            if self.old_log_probs is None:
                self.old_log_probs = log_probs.detach()

            ratios = torch.exp(log_probs.sum(dim=-1) - self.old_log_probs.sum(dim=-1))
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()
            self.old_log_probs = log_probs.detach()

        else:
            # 对每个样本求和
            sum_log_probs = log_probs.sum(dim=-1)  # (batch_size,)
            # 策略梯度损失
            policy_loss = -(combined_rewards * sum_log_probs).mean()
        # 最大熵正则化项
        sum_entropies = entropies.sum(dim=-1)  # (batch_size,)
        entropy_loss = -self.entropy_weight * sum_entropies.mean()
        # 总损失
        total_loss = policy_loss + entropy_loss
        return total_loss

    def compute_risk_seeking_policy_loss(self, rewards, log_probs, entropies, count_rewards, idx):
        """
        Risk-Seeking 策略梯度损失函数
        :param rewards: 奖励值 (batch_size,)
        :param log_probs: 对数概率 (batch_size, num_nodes)
        :param entropies: 动作熵 (batch_size, num_nodes)
        :return: 总损失
        """
        if idx >= 0:
            count_rewards_flag = 1
        else:
            count_rewards_flag = 0
        # 计算分位数值
        try:
            threshold = torch.quantile(rewards, self.quantile)
        except:
            return torch.tensor(0.0, requires_grad=True)

        # 找到奖励值大于分位数的样本
        high_reward_mask = rewards > threshold  # (batch_size,)

        # 如果没有样本满足条件，则返回 0 损失
        if not high_reward_mask.any():
            return torch.tensor(0.0, requires_grad=True)

        # 仅对高奖励样本计算损失
        high_rewards = rewards[high_reward_mask]  # (num_high_rewards,)
        high_log_probs = log_probs[high_reward_mask].sum(dim=-1)  # (num_high_rewards,)
        #high_entropies = entropies[high_reward_mask].sum(dim=-1)  # (num_high_rewards,)
        high_count_rewards = count_rewards[high_reward_mask]

        # 分别对普通奖励和计数奖励进行归一化
        standardized_high_rewards = (high_rewards - high_rewards.mean()) / (high_rewards.std() + 1e-8)
        standardized_high_count_rewards = (high_count_rewards - high_count_rewards.mean()) / (
                    high_count_rewards.std() + 1e-8)

        # 结合归一化后的奖励
        complete_rewards = standardized_high_rewards + count_rewards_flag * self.count_reward_weight * standardized_high_count_rewards
        # 计算策略梯度损失
        policy_loss = -((complete_rewards - threshold) * high_log_probs).mean()
        # 最大熵正则化项
        entropy_loss = -self.entropy_weight * entropies.sum(dim=-1).mean()

        # 总损失
        total_loss = policy_loss + entropy_loss
        return total_loss

    def train_step(self, cleaned_expressions, entropies, log_probs, idx):
        """
        执行一次训练步骤
        :param cleaned_expressions: 清理后的表达式数组 (batch_size, variable_length)
        :param entropies: 动作熵 (batch_size, num_nodes)
        :param log_probs: 对数概率 (batch_size, num_nodes)
        :return: 损失值和最佳表达式
        """
        # 计算奖励值
        if self.gp_sr is None:
            programs, rewards, count_rewards = self.reward_function(cleaned_expressions)
        else:
            programs, rewards, count_rewards = self.gp_reward_function(cleaned_expressions)
        programs.sort(key=lambda p: p.reward, reverse=True)
        best_item = programs[0]

        if self.use_risk_seeking:
            loss = self.compute_risk_seeking_policy_loss(rewards, log_probs, entropies, count_rewards, idx)
        else:
            loss = self.compute_loss(rewards, log_probs, entropies, count_rewards, idx)

        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item(), best_item

    def train(self, num_epochs, name):
        """
        训练智能体
        :param num_epochs: 训练轮数
        """
        last_log_content = ""  # 用于保存最后一次刷新的内容
        for epoch in range(num_epochs):
            # 生成表达式
            expressions, expression_tensors = self.generator.generate()
            cleaned_expressions = self.generator.remove_invalid_actions_parallel(expressions)
            log_probs, entropies = self.generator.evaluate_sequence(expression_tensors)

            # 执行一次训练步骤
            loss, best_item = self.train_step(
                cleaned_expressions,
                entropies,
                log_probs,
                epoch
            )

            # 更新最佳表达式
            if self.best_one_mse > best_item.mse:
                self.best_one = best_item
                self.best_one_mse = best_item.mse
                print("\n")
                print(f"New expression is found: {parse_expression(self.best_one.expression.copy())} with mse:{self.best_one_mse}")

            if self.best_one.mse < 1e-8:
                print("\n")
                print(f"The True one was found by GCN:{parse_expression(self.best_one.expression.copy())}")
                self.success_flag = True
                break

            if self.global_flag:
                print("\n")
                self.success_flag = True

                print(f"The True one was found by GP: {parse_expression(self.global_expr.copy())}")
                break

            # 动态显示训练信息
            print(
                f"\rEpoch {epoch + 1}/{num_epochs}, Loss: {loss:.4f}",
                end=""
            )

            # 准备日志内容（只保存最后一次刷新的内容）
            last_log_content = (
                f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss:.4f}, "
            )

        Program.clear_cache()

        # 将最终的动态显示内容写入日志文件
        with open(self.log_file, "a") as f:
            f.write(last_log_content)
            f.write(f"The best expression in this running is {parse_expression(self.best_one.expression.copy())}, with mse: {self.best_one.mse:.4e}\n")
            if self.global_flag:
                f.write(f"The gp found the best expression in this running is {parse_expression(self.global_expr.copy())}")


        '''    
        if self.best_gp_mse < 1e-8:
            with open(self.log_file, "a") as f:
                f.write(last_log_content)
                f.write(f"The gp found the best expression in this running is {self.best_gp_expr}")
        '''

        print(f"\nFinished training for task: {name}. Log saved to {self.log_file}")
def generate_test_data(expression_str, variable_range_str, num_samples):
    """
    根据输入的表达式和变量范围生成测试数据
    :param name: 数据集名称（如 "Nguyen-1"）
    :param expression_str: 表达式字符串（如 "sin(x) + cos(y) + log(z)"）
    :param variable_range_str: 变量范围字符串（如 "{'x': [-1, 1], 'y': [0, 2], 'z': [1, 10]}"）
    :param num_samples: 采样点的数量
    :return: 输入数据 (num_samples, num_variables) 和标签数据 (num_samples,)
    """
    # 解析变量范围
    variable_ranges = eval(variable_range_str)
    variables = list(variable_ranges.keys())
    ranges = [variable_ranges[var] for var in variables]

    # 定义符号变量
    expr_symbols = symbols(variables)

    # 将表达式字符串转换为符号表达式
    transformations = (standard_transformations + (implicit_multiplication_application, function_exponentiation))
    expr = parse_expr(
        expression_str,
        local_dict={var: expr_symbols[i] for i, var in enumerate(variables)},
        transformations=transformations
    )

    # 将符号表达式转换为可计算函数
    func = lambdify(expr_symbols, expr, modules="numpy")

    # 生成输入数据
    input_data = np.array([np.linspace(r[0], r[1], num_samples) for r in ranges]).T

    # 计算标签数据
    label_data = func(*[input_data[:, i] for i in range(len(variables))])

    return input_data, label_data
# 示例调用
if __name__ == "__main__":
    # 测试数据生成参数
    name = "Nguyen-1"
    expression_str = "sin(x*x)*cos(x)-1"
    variable_range_str = "{'x': [-1, 1]}"
    num_samples = 20

    # 生成测试数据
    input_data, label_data = generate_test_data(expression_str, variable_range_str, num_samples)

    # 初始化生成器
    n = 5  # 最大深度
    batch_size = 500  # 批处理大小
    include_self_loops = True  # 是否包含自连接
    tree_type = "partial"  # 树的类型 ("full" 或 "partial")

    # 注册运算符函数
    operator_functions, operator_arities, trig, exp_log = register_operator_functions()

    # 生成变量信息
    variable_dict = generate_variable_info(input_data)

    # 合并运算符和变量字典
    combined_dict, operator_indices, operator_arities = merge_operator_and_variable_dict(operator_functions, variable_dict)

    # 初始化 Program 类
    Program.initialize(combined_dict, operator_indices, operator_arities, input_data, label_data)

    # 初始化生成器
    generator = SymbolicExpressionGenerator(32 ,n, operator_arities, batch_size, trig, exp_log, include_self_loops, tree_type)

    # 初始化训练器
    trainer = SymbolicExpressionTrainer(generator, learning_rate=0.0005, entropy_weight=0.005, use_risk_seeking=True, quantile=0.95)

    # 开始训练
    trainer.train(num_epochs=3000,name='example_1')