import numpy as np
import pandas as pd
import ast  # 用于安全地解析字符串为字典
import numexpr as ne  # 用于快速计算表达式
from sympy import symbols, lambdify, parse_expr
from sympy.parsing.sympy_parser import standard_transformations, implicit_multiplication_application, function_exponentiation


def generate_test_data(expression_str: str, variable_range_str: str, num_samples: int, sampling_method: str = "random"):
    """
    根据输入的表达式和变量范围生成测试数据。
    :param expression_str: 表达式字符串（如 "sin(x) + cos(y) + log(z)"）
    :param variable_range_str: 变量范围字符串（如 "{'x': [-1, 1], 'y': [0, 2], 'z': [1, 10]}"）
    :param num_samples: 采样点的数量
    :param sampling_method: 采样方法，可选值为 "random"（随机采样）或 "uniform"（均匀采样）
    :return: 输入数据 (num_samples, num_variables) 和标签数据 (num_samples,)
    """
    # 解析变量范围字符串
    try:
        variable_ranges = ast.literal_eval(variable_range_str)
        if not isinstance(variable_ranges, dict):
            raise ValueError("Invalid variable_range_str format.")
    except (ValueError, SyntaxError) as e:
        print(f"Error parsing variable_range_str: {e}")
        return None, None

    # 提取变量名和范围
    variables = list(variable_ranges.keys())
    ranges = [variable_ranges[var] for var in variables]

    # 定义符号变量
    expr_symbols = symbols(variables)

    # 将表达式字符串转换为符号表达式
    transformations = (standard_transformations + (implicit_multiplication_application, function_exponentiation))
    expr = parse_expr(
        expression_str,
        local_dict={var: expr_symbols[i] for i, var in enumerate(variables)},
        transformations=transformations
    )

    # 将符号表达式转换为可计算函数
    func = lambdify(expr_symbols, expr, modules="numpy")

    # 生成输入数据
    if sampling_method == "uniform":
        # 均匀采样
        inputs = {
            var: np.linspace(low, high, num_samples)
            for var, (low, high) in zip(variables, ranges)
        }
    elif sampling_method == "random":
        # 随机采样
        inputs = {
            var: np.random.uniform(low, high, size=num_samples)
            for var, (low, high) in zip(variables, ranges)
        }
    else:
        raise ValueError(f"Unsupported sampling_method: {sampling_method}. Use 'random' or 'uniform'.")

    # 将输入变量合并为二维数组
    inputs_array = np.column_stack([inputs[var] for var in variables])

    # 计算标签数据
    label_data = func(*[inputs[var] for var in variables])

    return inputs_array, label_data

def generate_data_by_name(file_path: str, name: str, method: str = "random"):
    """
    根据给定的名字抽取 CSV 文件中对应的行内容，并生成数据。
    参数:
    - file_path (str): 配置文件路径。
    - name (str): 要抽取的表达式名字。
    返回:
    - expression (str): 抽取的表达式。
    - variable_ranges (dict): 抽取的变量取值范围。
    - B (int): 抽取的数据量。
    - inputs_array (np.ndarray): 输入变量数组。
    - outputs_array (np.ndarray): 输出数组。
    """
    try:
        # 读取 CSV 文件
        df = pd.read_csv(file_path, skip_blank_lines=True)
        # 检查列数是否正确
        if df.shape[1] != 4:
            raise ValueError("CSV 文件格式错误：每行应包含 4 列（name, expression, variable_ranges, B）。")
        # 按名字筛选行
        selected_row = df[df["name"] == name]
        if selected_row.empty:
            raise ValueError(f"No row found with name '{name}'.")
        # 提取行内容
        expression = selected_row.iloc[0]["expression"]
        variable_ranges = ast.literal_eval(selected_row.iloc[0]["variable_ranges"])  # 安全解析字符串为字典
        B = int(selected_row.iloc[0]["B"])
        # 使用 generate_test_data 生成数据
        inputs_array, outputs_array = generate_test_data(expression, str(variable_ranges), B, method)
        return expression, variable_ranges, B, inputs_array, outputs_array
    except Exception as e:
        print(f"Error processing CSV file: {e}")
        return None

# 示例用法
if __name__ == "__main__":
    # 配置文件路径
    config_file = "./data/dataset.csv"

    # 给定名字
    name = "Feynman_test"  # 替换为你想要的名字

    try:
        # 根据名字生成数据
        result = generate_data_by_name(config_file, name)
        if result is not None:
            expression, variable_ranges, B, inputs_array, outputs_array = result

            # 打印结果
            print("\nSelected Row by Name:")
            print("Name:", name)
            print("Expression:", expression)
            print("Variable Ranges:", variable_ranges)
            print("Data Size (B):", B)
            print("Inputs:")
            print(inputs_array)
            print("Outputs:")
            print(outputs_array)
    except Exception as e:
        print(e)