"""工具函数"""
import os
import re
import json
import numpy as np
import pandas as pd
import logger
import traceback
from schema import ApiConfig, ModuleConfig

config_file = "config.json"
font_file = "msyh.ttf"

active_api_config = None
api_config = None
module_config = None


def human_review(stage: str, context: dict = None) -> str:
    """
    关键节点人机交互确认。返回用户建议或'Y'表示同意。
    """
    import utils
    if getattr(utils.module_config, "enable_human_review", False):
        print(f"\n【人机交互·{stage}】请审查计划或后续任务内容：")
        if context:
            print(json.dumps(context, indent=4, ensure_ascii=False))
        resp = input("输入建议或直接输入Y(y)确认继续：").strip()
        return resp
    return "Y"

def load_api_config(config_name: str = "GLM") -> ApiConfig:
    """加载 API 配置"""
    global api_config, active_api_config
    with open(config_file, "r", encoding="utf-8") as file:
        data = json.load(file)
    active_api_config = data.get("active_api_config", "GLM")
    if config_name != active_api_config:
        config_name = active_api_config
    api_configs = [
        ApiConfig.from_dict(config) for config in data.get("api_configs", [])
    ]
    api_config = next(
        (config for config in api_configs if config.config_name == config_name),
        None,
    )
    return api_config, active_api_config


def load_module_config() -> ModuleConfig:
    """加载模块配置"""
    global module_config
    with open(config_file, "r", encoding="utf-8") as file:
        data = json.load(file)
    module_config = ModuleConfig.from_dict(data["module_config"])
    return module_config


def remove_code_recursive(data ,field_name: str = "code") -> None:
        """递归移除数据结构中指定字段（原地修改）。

        支持 dict/list 嵌套结构，遇到键名为 `code` 时删除对应项，继续向下递归处理其它字段。
        """
        try:
            if isinstance(data, dict):
                # 删除当前字典层级的 code 字段（若存在）
                if field_name in data:
                    try:
                        del data[field_name]
                    except Exception:
                        pass
                # 继续递归其它值
                for v in list(data.values()):
                    remove_code_recursive(v, field_name=field_name)
            elif isinstance(data, list):
                for item in data:
                    remove_code_recursive(item, field_name=field_name)
        except Exception:
            # 不要抛出异常，保持函数健壮性
            return

def strtify(obj):
    """
    将对象转换为字符串
    """
    if isinstance(obj, dict):
        return json.dumps(obj, ensure_ascii=False)
    return str(obj)


def parse_res(response):
    """
    解析结果
    """
    try:
        res = response.choices[0].message.content
        if "</think>" in res:
            res = res.split("</think>", 1)[1]
        if "```json" in res and "```" in res:
            res = res.split("```json", 1)[1]
            res = res.split("```", 1)[0]
        res = res.strip().replace("\n", "")
        return res
    except Exception as e:
        logger.trace(f"【解析结果出错】: {e}", traceback.format_exc())
        # 如果解析失败，尝试返回原始内容
        try:
            return response.choices[0].message.content if hasattr(response, 'choices') else ""
        except:
            return ""


def remove_comments(code_str: str) -> str:
    """Remove comments from code."""
    pattern = r"(\".*?\"|\'.*?\')|(\#.*?$)"

    def replace_func(match):
        if match.group(2) is not None:
            return ""
        else:
            return match.group(1)

    clean_code = re.sub(pattern, replace_func, code_str, flags=re.MULTILINE)
    clean_code = os.linesep.join(
        [s.rstrip() for s in clean_code.splitlines() if s.strip()]
    )
    return clean_code


def remove_imports_and_empty_lines(code_str: str) -> str:
    """Remove import lines, from lines and empty lines from code."""
    lines = code_str.splitlines()
    cleaned_lines = []
    for line in lines:
        stripped_line = line.strip()
        if not stripped_line:
            continue
        if stripped_line.startswith("import ") or stripped_line.startswith("from "):
            continue
        cleaned_lines.append(line)
    return os.linesep.join(cleaned_lines)



def try_run(func, *args, max_retries=3, **kwargs):
    attempts = 0
    while attempts < max_retries:
        res = func(*args, **kwargs)
        if not res:
            attempts += 1
            logger.error(
                f"第 {attempts} 次执行 {func.__name__} 出错，\n{traceback.format_exc()}"
            )
        else:
            return res
    logger.error(f"执行 {func.__name__} 失败，已达到最大重试次数 {max_retries} 次。")


def parse_code(response):
    """
    解析代码
    """
    try:
        res = response.choices[0].message.content
        if "```" in res:
            res = res.split("```python", 1)[1]
            res = res.split("```", 1)[0]
        res = res.strip()
        return res
    except Exception:
        # 如果解析失败，尝试返回原始内容
        try:
            return response.choices[0].message.content if hasattr(response, 'choices') else ""
        except:
            return ""




def get_table_meta(table_meta_filepath, table_name, columns):
    """
    根据数据表名和列名，获取数据表中指定列的元信息。

    :param table_name (str): 数据表名
    :param columns (list): 需要查询的列名列表

    :return dict: 包含列名和对应元信息的字典，或错误信息
    """

    with open(table_meta_filepath, "r", encoding="utf-8") as file:
        raw_table_data = json.load(file)

    table_meta = None

    for table in raw_table_data:
        if table["table_name"] == table_name:
            table_meta = table

    if table_meta is None:
        return {
            "error": f"数据表 {table_name} 的元信息不存在",
        }

    column_desc = {}
    for column in columns:
        for tmp in table_meta["columns"]:
            if tmp["name"] == column:
                column_desc[column] = tmp["desc"]

    return column_desc





