import operator
import random
from typing import List, Callable, Tuple
from deap import gp, algorithms, tools, creator, base
import numpy as np
import json
import re

# 定义操作函数
def add(a, b):
    return a + b

def sub(a, b):
    return a - b

def mul(a, b):
    return a * b

def div(a, b):
    return a / (b + 1e-20)

def power(a, b):
    return a ** b

def sin(a):
    return np.sin(a)

def cos(a):
    return np.cos(a)

def exp(a):
    return np.exp(a)

def log(a):
    return np.log(np.abs(a))

def square(a):
    return np.square(a)

def third_power(a):
    return a ** 3

# 注册操作符函数及其元数
def register_operator_functions():
    """
    注册操作符函数及其元数（需要的操作数数量）。
    返回一个字典，将操作符符号映射到其函数和元数。
    """
    operator_functions = {
        '+': {'function': add, 'arity': 2},  # 2 操作数
        '-': {'function': sub, 'arity': 2},  # 2 操作数
        '*': {'function': mul, 'arity': 2},  # 2 操作数
        '/': {'function': div, 'arity': 2},  # 2 操作数
        'sin': {'function': sin, 'arity': 1},  # 1 操作数
        'cos': {'function': cos, 'arity': 1},  # 1 操作数
        'exp': {'function': exp, 'arity': 1},  # 1 操作数
        'log': {'function': log, 'arity': 1},  # 1 操作数
    }
    trig_function = []
    exp_log_function = []
    for idx, (op_symbol, op_info) in enumerate(operator_functions.items()):
        if op_symbol in ['sin', 'cos']:  # 检查是否是三角函数
            trig_function.append(idx)
        if op_symbol in ['exp', 'log']:
            exp_log_function.append(idx)
    arities = [op['arity'] for op in operator_functions.values()]
    return operator_functions, arities, trig_function, exp_log_function

# 配置类
class GP_config:
    def __init__(self, config_data):
        """
        从 JSON 文件中读取配置。
        :param config_file: JSON 配置文件路径
        """
        # 遗传算法的超参数
        self.gp = {
            'pops': config_data.get("pops", 100),  # 种群大小
            'cxpb': config_data.get("cxpb", 0.7),  # 交叉概率
            'mutpb': config_data.get("mutpb", 0.2),  # 变异概率
            'times': config_data.get("times", 50),  # 迭代次数
            'tournsize': config_data.get("tournsize", 3),  # 锦标赛选择的大小
            'max_height': config_data.get("max_height", 17),  # 最大树高度
            'hof_size': config_data.get("hof_size", 5),  # Hall of Fame 大小
            'threshold': config_data.get("threshold", 0.01)  # 停止阈值
        }

# 动态注册操作符
def register_operators(pset: gp.PrimitiveSet, operator_functions: dict) -> None:
    """
    根据操作符字典动态注册操作符到 PrimitiveSet 中。
    :param pset: DEAP 的 PrimitiveSet 对象
    :param operator_functions: 操作符字典
    """
    for op_symbol, op_info in operator_functions.items():
        func: Callable = op_info["function"]
        num_children: int = op_info["arity"]
        pset.addPrimitive(func, num_children)

# 将前序表达式转换为字符串形式的表达式
def prefix_to_internal_expression(prefix_expr: List[str], operator_functions: dict) -> str:
    """
    将前序表达式转换为字符串形式的表达式（用于内部使用，如 `add(exp(x1), x2)`）。
    """
    stack = []
    for token in reversed(prefix_expr):  # 从右向左遍历前序表达式
        if token in operator_functions:
            num_children = operator_functions[token]["arity"]
            children = [stack.pop() for _ in range(num_children)]
            func_name = operator_functions[token]["function"].__name__
            expr = f"{func_name}({', '.join(children)})"
            stack.append(expr)
        else:
            stack.append(token)
    return stack[0]

# 将字符串表达式解析为前序遍历序列
def expr_to_sequence(expr: str) -> List[str]:
    """
    将字符串表达式解析为前序遍历序列。
    例如："add(add(x1, x2), x3)" -> ["add", "add", "x1", "x2", "x3"]
    """
    tokens = re.findall(r"[\w]+|[\+\-\*/(),]", expr)
    stack = []
    result = []

    for token in tokens:
        if token == '(':
            stack.append(result)
            result = []
        elif token == ')':
            func = stack.pop()
            func.append(result)
            result = func
        elif token == ',':
            continue
        else:
            result.append(token)

    # 展平嵌套列表
    def flatten(lst):
        for item in lst:
            if isinstance(item, list):
                yield from flatten(item)
            else:
                yield item

    return list(flatten(result))

# 将前序遍历数组转换为用户友好的输出
def sequence_to_operator_sequence(sequence: List[str], operator_functions: dict) -> List[str]:
    """
    将前序遍历序列中的函数名替换为操作符符号。
    :param sequence: 前序遍历序列，例如 ["add", "mul", "x1", "x2", "x3"]
    :param operator_functions: 操作符字典，包含操作符及其子节点数的信息
    :return: 替换后的扁平数组，例如 ["+", "*", "x1", "x2", "x3"]
    """
    # 创建一个从函数名到操作符符号的映射
    func_to_operator = {v["function"].__name__: k for k, v in operator_functions.items()}

    # 替换函数名为操作符符号
    result = []
    for token in sequence:
        if token in func_to_operator:
            result.append(func_to_operator[token])
        else:
            result.append(token)
    return result

# 符号回归的遗传算法类
class SymbolicRegression:
    def __init__(self, config_s: GP_config, operator_functions: dict, const_value: float):
        self.operator_functions = operator_functions
        self.config_s = config_s
        self.const_value = const_value
        self.data = []  # 初始化为空列表，允许后续传入数据
        input_dim = 1
        self.toolbox = base.Toolbox()
        self.pset = gp.PrimitiveSet("MAIN", input_dim)
        self._initialize_primitive_set(input_dim)
        creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
        creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)
        self.toolbox.register("expr", gp.genHalfAndHalf, pset=self.pset, min_=1, max_=2)
        self.toolbox.register("individual", tools.initIterate, creator.Individual, self.toolbox.expr)
        self.toolbox.register("population", tools.initRepeat, list, self.toolbox.individual)
        self.toolbox.register("compile", gp.compile, pset=self.pset)
        self.toolbox.register("evaluate", self.evaluate)
        self.toolbox.register("select", tools.selTournament, tournsize=config_s.gp['tournsize'])
        self.toolbox.register("mate", gp.cxOnePoint)
        self.toolbox.register("mutate", gp.mutNodeReplacement, pset=self.pset)

    def _initialize_primitive_set(self, input_dim: int):
        """
        初始化 PrimitiveSet，并注册运算符和变量。
        """
        for i in range(input_dim):
            self.pset.renameArguments(**{f"ARG{i}": f"x_{i+1}"})
        if self.const_value:
            self.pset.addTerminal(self.const_value, name="const")  # 圆周率π
        register_operators(self.pset, self.operator_functions)

    def set_data(self, X: np.ndarray, y: np.ndarray):
        """
        设置或更新数据集，并根据输入数据的维度更新 PrimitiveSet。
        """
        if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
            raise ValueError("输入数据 X 和输出数据 y 必须是 NumPy 数组")
        if X.shape[0] != y.shape[0]:
            raise ValueError("输入数据 X 和输出数据 y 的样本数必须一致")
        self.data = [([X[i, j] for j in range(X.shape[1])], y[i]) for i in range(X.shape[0])]
        new_input_dim = X.shape[1]
        if new_input_dim != len(self.pset.arguments):
            self.pset = gp.PrimitiveSet("MAIN", new_input_dim)
            self._initialize_primitive_set(new_input_dim)
            self.toolbox.register("compile", gp.compile, pset=self.pset)

    def evaluate(self, individual):
        """
        评估函数，根据个体生成的表达式评估其适应度。
        """
        try:
            error = self.calculate_fitness(individual)
            return error,
        except ZeroDivisionError:
            return float('inf'),
        except ValueError:
            return float('inf'),
        except Exception as e:
            print(f"Unknown error during evaluation: {e}")
            return float('inf'),

    def calculate_fitness(self, individual):
        """
        计算拟合误差，这里使用均方误差（MSE）作为评估标准。
        """
        try:
            func = self.toolbox.compile(expr=individual)
            if not callable(func):
                print(f"Warning: The compiled function is not callable, it is: {type(func)}")
                return float('inf')
            error = 0.0
            for x, y in self.data:
                pred = func(*x) if isinstance(x, (list, tuple)) else func(x)
                error += (pred - y) ** 2
            return error / len(self.data)
        except ZeroDivisionError:
            return float('inf')
        except ValueError:
            return float('inf')
        except Exception as e:
            print(f"Unknown error during prediction: {e} with individual={individual}")
            return float('inf')

    def ga_run(self, initial_pop: List = None) -> Tuple[List, float]:
        """
        遗传算法运行函数，执行遗传算法并返回最优解和平均损失。
        """
        stats = tools.Statistics(lambda ind: ind.fitness.values)
        stats.register("avg", np.mean)
        stats.register("min", np.min)
        stats.register("max", np.max)
        hof = tools.HallOfFame(self.config_s.gp['hof_size'])
        required_size = self.config_s.gp['pops']
        if initial_pop is None:
            pop = self.toolbox.population(n=required_size)
        else:
            if len(initial_pop) < required_size:
                additional_pop = self.toolbox.population(n=required_size - len(initial_pop))
                initial_pop.extend(additional_pop)
            elif len(initial_pop) > required_size:
                initial_pop = initial_pop[:required_size]
            pop = initial_pop
        pop, logbook = algorithms.eaSimple(
            pop,
            self.toolbox,
            cxpb=self.config_s.gp['cxpb'],
            mutpb=self.config_s.gp['mutpb'],
            ngen=self.config_s.gp['times'],
            stats=stats,
            halloffame=hof,
            verbose=False
        )
        global_best_individual = hof[0]
        global_best_fitness = global_best_individual.fitness.values[0]
        return global_best_individual, global_best_fitness

# 主函数
def main():
    # 注册操作符
    operator_functions, _, _, _ = register_operator_functions()

    # 使用 NumPy 数组指定输入和输出数据
    X = np.array([
        [1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [4, 5, 6],
        [5, 6, 7]
    ])
    y = np.array([14, 29, 50, 77, 110])

    # 加载配置文件
    config_file = './config/config.json'
    with open(config_file, 'r') as f:
        config = json.load(f)
    config_gp = config.get("gp_hyperparameters", {})
    config = GP_config(config_gp)

    # 创建符号回归对象
    sr = SymbolicRegression(config, operator_functions)
    sr.set_data(X, y)

    # 定义批量初始种群（二维数组，每行是一个前序表达式）
    initial_population_prefix = [
        ["+", "x_2", "sin", "x_1"],
        ["-", "x_2", "x_1"],
        ["*", "x_2", "x_1"],
        ["/", "x_2", "x_1"],
        ["+", "x_2", "sin", "x_1"],
        ["*", "+", "x_2", "x_1", "x_1"],
        ["+", "x_2", "*", "x_1", "x_3"],
        ["*", "+", "x_3", "x_1", "+", "x_1", "x_2"],
    ]

    # 将前序表达式转换为初始种群
    initial_population = []
    for prefix_expr in initial_population_prefix:
        internal_expression_str = prefix_to_internal_expression(prefix_expr, operator_functions)
        initial_expr = gp.PrimitiveTree.from_string(internal_expression_str, sr.pset)
        initial_population.append(creator.Individual(initial_expr))

    # 运行遗传算法
    best_individual, best_fitness = sr.ga_run(initial_pop=initial_population)

    # 输出结果
    print("\n=== 最优结果 ===")
    print("最优的表达式:")
    print(best_individual)
    print("该表达式的拟合误差为:")
    print(sr.calculate_fitness(best_individual))
    print(f"Hall of Fame 中个体的平均损失为: {best_fitness}")

    # 将最优个体转换为字符串形式、
    best_individual = str(best_individual)
    # 将字符串表达式解析为前序遍历数组
    sequence = expr_to_sequence(best_individual)

    # 将前序遍历数组转换为用户友好的输出
    user_expression_str = sequence_to_operator_sequence(sequence, operator_functions)

    print("\n=== 最优表达式的用户友好形式 ===")
    print(user_expression_str)

if __name__ == "__main__":
    main()