import json
import os
import warnings

import numpy as np

from sr_gen import SymbolicExpressionGenerator as ActionGenerator
from function import *  # 引入检查和补充函数
from program import Program
from utils_data import generate_data_by_name
from policy_gradient import SymbolicExpressionTrainer as PolicyGradientTrainer
from gp import GP_config, SymbolicRegression
import csv
from datetime import datetime

def extract_names_from_csv(filepath):
    """从CSV文件中提取所有name列的值"""
    names = []
    with open(filepath, mode='r', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        for row in reader:
            names.append(row['name'])
    return names

# 主程序
def main(config_file, log_dir):
    # 读取配置文件
    with open(config_file, 'r') as f:
        config = json.load(f)

    # 解析配置文件中的三块内容
    hyperparameters = config.get("hyperparameters", {})
    paths_tasks = config.get("paths_tasks", {})
    gp_hyperparameters = config.get("gp_hyperparameters", {})

    # 模型运行的超参数
    learning_rate = hyperparameters.get("learning_rate", 0.001)
    num_epochs = hyperparameters.get("num_epochs", 200)
    batch_size = hyperparameters.get("batch_size", 500)
    max_depth = hyperparameters.get("max_depth", 10)
    use_risk_seeking = hyperparameters.get("use_risk_seeking", False)
    beta = hyperparameters.get("beta", 0.9)
    entropy_coefficient = hyperparameters.get("entropy_coefficient", 0.01)
    count_reward_weight = hyperparameters.get("count_reward_weight", 1.0)
    hidden_size = hyperparameters.get("hidden_size", 64)  # 新增：隐藏层大小
    epsilon = hyperparameters.get("epsilon", 0.01)
    ppo = hyperparameters.get("ppo", True)
    clip_epsilon = hyperparameters.get("clip_epsilon", 0.2)
    self_loops = hyperparameters.get("self_loops", True)
    tree_type = hyperparameters.get("tree_type", "partial")
    gp_control = hyperparameters.get("gp", False)
    running_num = hyperparameters.get("running_num", 10)

    # 路径和任务制定
    dataset_path = paths_tasks.get("dataset_path", "./data/dataset.csv")
    sample_method = paths_tasks.get("sample_method", "random")
    task_names = paths_tasks.get("task_names", ["example_1"])
    if len(task_names)==1 and task_names[0] == "Feynman":
        task_names = extract_names_from_csv("./data/feynman_data.csv")[-50:]


    if "Feynman" == task_names[0]:
        const_value = np.pi
    else:
        const_value = False

    # 遍历任务数组，逐一处理每个任务
    for name in task_names:
        print(f"Processing task: {name}")
        local_success_count = 0
        path = "model_statis.txt"
        statis_path = os.path.join(log_dir, path)
        # 创建日志文件路径
        os.makedirs(log_dir, exist_ok=True)  # 确保日志目录存在
        log_file = os.path.join(log_dir, f"{name}_training_log.txt")

        for times in range(running_num):
            # 打开日志文件，准备写入
            with open(log_file, "a+") as f:
                f.write(f"Training Log for Task: {name} with {times} time\n\n")

            # 加载数据
            expression, variable_ranges, B, inputs_array, outputs_array = generate_data_by_name(
                dataset_path, name, sample_method
            )
            input_data = inputs_array
            label_data = outputs_array
            # 生成变量信息
            variable_dict = generate_variable_info(input_data, const_value=const_value)

            # 注册运算符函数
            operator_functions, operator_arities, trig, exp_log = register_operator_functions()

            # 合并运算符和变量字典
            combined_dict, operator_indices, operator_arities = merge_operator_and_variable_dict(
                operator_functions, variable_dict
            )

            # 初始化 Program 类
            Program.initialize(operator_functions, combined_dict, operator_indices, operator_arities, input_data, label_data, const_value=const_value)

            # 初始化模型和优化器
            model = ActionGenerator(
                hidden_size=hidden_size,
                n=max_depth,
                operator_arities=operator_arities,
                batch_size=batch_size,
                trig=trig,
                exp_log=exp_log,
                epsilon=epsilon,
                include_self_loops=self_loops,
                tree_type=tree_type
            )

            #基因编程初始化
            if gp_control:
                gp_config = GP_config(gp_hyperparameters)
                gp_sr = SymbolicRegression(gp_config, operator_functions, const_value)
                gp_sr.set_data(X=input_data, y=label_data)
            else :
                gp_sr = None

            # 初始化训练器
            trainer = PolicyGradientTrainer(
                generator=model,
                gp_module = gp_sr,
                learning_rate=learning_rate,
                entropy_weight=entropy_coefficient,
                count_reward_weight =count_reward_weight,
                use_risk_seeking=use_risk_seeking,
                ppo=ppo,
                clip_epsilon=clip_epsilon,
                quantile=beta,
                log_file=log_file
            )

            # 开始训练
            print("\nStarting training...")
            trainer.train(num_epochs=num_epochs, name=name)
            Program.clear_cache()


            if trainer.success_flag:
                local_success_count += 1

        with open(statis_path, "a+") as f:
            f.write(f"This model's number of  successful recovered task {name} is : {local_success_count}, Recovery rate is {local_success_count/running_num}\n")



if __name__ == "__main__":

    # 禁用特定类型的警告
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    #config_files = ["./config/config_original.json", ]

    config_files = [
        "./config/config_ppo_0.2_gp5.json",
    ]
    for config_file in config_files:
        # 获取当前时间并格式化为纯数字字符串
        time_str = datetime.now().strftime("%Y%m%d%H%M%S")
        print(datetime.now())  # 输出示例: 20231015143015
        parts = config_file.split('/')
        output_name = parts[-1].replace('.json', '') + "_" + time_str
        log_dir = "logs/"
        output_path = log_dir+output_name+'/'

        # 运行主程序
        main(config_file, output_path)