
import json
import base64
import re
import os
import argparse
from typing import Dict, List, Tuple
from PIL import Image
import io
import shlex


async def get_corresponding_functions(conversation, all_functions):
    all_tool_names = set([msg['name'] for msg in conversation if msg['role'] == 'tool'])
    return [i for i in all_functions if i['name'] in all_tool_names]

def replace_str(s, str_mapping_dict):
    for source_str in str_mapping_dict.keys():
        s = s.replace(source_str, str_mapping_dict[source_str])
    return s

def extract_json(response_text: str) -> Tuple[Dict | str, bool]:
    pattern = r"```json(.*?)```"

    match = re.search(pattern, response_text, re.DOTALL)
    if match:
        json_string = match.group(1)
        try:
            data = json.loads(json_string)
            return data, True
        except Exception as e:
            return str(e), False

    return "Error: No JSON format ```json(...)``` found", False

async def extract_json_dict(response_text: str) -> Tuple[Dict, bool]:
    begin_idx = response_text.find('```json') + len('```json')
    end_idx = response_text.rfind('```')
    if begin_idx != -1 and end_idx != -1:
        json_string = response_text[begin_idx:end_idx].strip()
        try:
            data = json.loads(json_string)
            return data, True
        except Exception as e:
            return {"info": str(e)}, False
    return {"info": "Error: No JSON found"}, False


def save_json(file_path, data, indent=4):
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=indent)

def save_json_incre_list(file_path, data, indent=4):
    if os.path.isfile(file_path):
        existing_data = load_json(file_path)
        existing_data.extend(data)
        data = existing_data
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=indent)

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def load_txt(file_path, encoding='utf-8'):
    with open(file_path, 'r', encoding=encoding) as f:
        data = f.read()
    return data

def load_txt_raw(file_path, encoding='utf-8'):
    with open(file_path, 'r', encoding=encoding) as f:
        data = f.read()
    return data

def save_txt(file_path, data, encoding='utf-8'):
    with open(file_path, 'w', encoding=encoding) as f:
        f.write(data)

def encode_media(file_path: str) -> str:
    """Encodes a media file (image or video) into a Base64 string."""
    with open(file_path, "rb") as media_file:
        return base64.b64encode(media_file.read()).decode("utf-8")

def parse_args():
    """
    解析命令行参数。

    返回:
        argparse.Namespace: 包含所有命令行参数的对象。
    """
    parser = argparse.ArgumentParser(
        description="一个示例程序，用于演示如何解析命令行参数。",
        formatter_class=argparse.RawTextHelpFormatter
    )

    # --- 核心参数 ---
    parser.add_argument(
        '--model_name',
        type=str,
        default='claude-3-7-sonnet-20250219',
        help='指定要使用的模型名称。\n(默认值: claude-3-7-sonnet-20250219)'
    )

    parser.add_argument(
        '--user_query',
        type=str,
        default=None,
        help='用户输入的查询或问题。\n如果没有提供，程序可能会进入交互模式或读取标准输入。'
    )

    parser.add_argument(
        '--gen_part_prompt',
        type=str,
        default=None,
        help='指定生成主题prompt是哪一部分。'
    )

    parser.add_argument(
        '--chat_mode',
        action='store_true',
        default=False,
        help='启用连续对话模式。'
    )

    parser.add_argument(
        '--system_prompt_txt_path',
        type=str,
        default=None,
        help='指定system prompt文件的路径。'
    )

    parser.add_argument(
        '--time_str',
        type=str,
        default=None,
        help='指定时间字符串。'
    )

    parser.add_argument(
        '--query_txt_file',
        type=str,
        default=None,
        help='指定查询文本文件的路径。'
    )

    # --- 配置文件路径 ---
    parser.add_argument(
        '--model_config_path',
        type=str,
        default='configs/model.json',
        help='指定模型配置文件的路径。\n(默认值: configs/model.json)'
    )

    parser.add_argument(
        '--conversation_path',
        type=str,
        default=None,
        help='指定历史对话文件的路径。'
    )

    parser.add_argument(
        '--all_deleted_functions_file',
        type=str,
        default=None,
        help='指定预加载所有待删除函数的文件路径。'
    )

    parser.add_argument(
        '--restart_file',
        type=str,
        default=None,
        help='指定重启文件的路径。'
    )

    parser.add_argument(
        '--cold_start_restart_file',
        type=str,
        default=None,
        help='指定冷启动重启文件的路径。'
    )

    parser.add_argument(
        '--mcp_server_config_path',
        type=str,
        default=None,
        help='指定MCP配置文件的路径。\n(默认值: configs/mcp.json)'
    )

    parser.add_argument(
        '--tool_parts',
        type=str,
        default=None,
        help='指定工具部件'
    )

    parser.add_argument(
        '--mcp_server_config_class',
        type=str,
        default=None,
        help='指定MCP配置类的名称。'
    )

    parser.add_argument(
        '--log_messages_path',
        type=str,
        default=None,
        help='指定日志文件的路径。'
    )

    parser.add_argument(
        '--cold_start_temperature',
        type=float,
        default=0.7,
        help='指定冷启动温度。'
    )

    args = parser.parse_args()
    args.model_name = args.model_name or None
    if args.log_messages_path is None:
        args.log_messages_path = f'logs/{args.time_str}.json'
    args.mcp_server_config_class = args.mcp_server_config_class.split(',') if args.mcp_server_config_class else None
    args.mcp_server_config = None
    args.tool_parts = args.tool_parts.split(',') if args.tool_parts else None

    return args


def is_high_risk_command(command_str: str) -> tuple[bool, str]:
    """
    判断一个shell指令是否为高危指令。

    通过多重检查：
    1. 危险命令关键词
    2. 危险参数组合（正则表达式）
    3. 访问敏感路径
    4. 危险的管道或重定向

    :param command_str: 要检查的shell指令字符串。
    :return: 一个元组 (is_risk, reason)，is_risk为布尔值，reason为判断原因。
    """
    if not isinstance(command_str, str) or not command_str.strip():
        return False, "Empty command"

    # 规范化命令，方便解析
    normalized_command = command_str.strip()

    # --- 1. 定义高危特征 ---

    # 1.1 危险命令关键词 (只要出现就报警)
    # 这些命令本身就很危险，或者经常被用于恶意操作
    DANGEROUS_COMMANDS = {
        # 文件和磁盘操作
        'rm', 'mkfs', 'format', 'dd', 'shred', 'fdisk', 'gdisk',
        # 系统控制
        'shutdown', 'reboot', 'halt', 'poweroff',
        # 用户和权限管理
        'userdel', 'groupdel', 'usermod', 'groupmod', 'chown', 'chmod',
        # 网络和进程
        'kill', 'pkill',
        # 内核和系统配置
        'sysctl', 'insmod', 'modprobe', 'rmmod',
    }

    # 1.2 危险的参数/模式组合 (使用正则表达式匹配)
    DANGEROUS_PATTERNS = [
        # rm -rf / 或类似的操作
        re.compile(r'\brm\s+(-[a-zA-Z]*f[a-zA-Z]*r|-[a-zA-Z]*r[a-zA-Z]*f)\s+/\b'),
        re.compile(r'\brm\s+.*/\s+.*-rf'), # rm / -rf
        # 格式化或写入整个磁盘设备
        re.compile(r'\bdd\s+if=/dev/zero\s+of=/dev/sd[a-z]'),
        re.compile(r'\bmkfs\.[a-z]+\s+/dev/sd[a-z]'),
        # 递归地修改整个系统的权限
        re.compile(r'\bchmod\s+(-R|--recursive)\s+(777|666)\s+/'),
        re.compile(r'\bchown\s+(-R|--recursive)\s+.*\s+/'),
    ]

    # 1.3 敏感的文件/路径 (检查是否出现在命令中)
    SENSITIVE_PATHS = [
        # 系统核心目录
        '/', '/bin', '/sbin', '/etc', '/usr', '/boot', '/dev', '/proc', '/sys',
        # 敏感文件
        '/etc/passwd', '/etc/shadow', '/etc/sudoers',
        '/root',
        '~/.ssh/authorized_keys',
        '~/.bashrc', '~/.profile',
        # 磁盘设备
        '/dev/sd', '/dev/hd', '/dev/nvme',
    ]
    
    # 1.4 危险的管道或重定向
    DANGEROUS_PIPES_REDIRECTIONS = [
        # 下载脚本并直接执行
        re.compile(r'\|\s*(bash|sh|zsh)\b'),
        re.compile(r'\b(curl|wget)\s+.*\s*\|\s*(bash|sh|zsh)\b'),
        # 重定向到关键设备
        re.compile(r'>\s*/dev/sd[a-z]'),
    ]

    # --- 2. 开始检查 ---

    # 2.1 检查危险的管道和重定向
    for pattern in DANGEROUS_PIPES_REDIRECTIONS:
        if pattern.search(normalized_command):
            return True, f"Detected dangerous pipe or redirection: {pattern.pattern}"
            
    # 尝试使用shlex解析命令，对引号等有更好的支持
    try:
        tokens = shlex.split(normalized_command)
    except ValueError:
        # 如果解析失败（比如引号不匹配），就用简单的空格分割作为后备
        tokens = normalized_command.split()

    if not tokens:
        return False, "Empty command after parsing"

    # 获取主命令
    main_command = tokens[0].split('/')[-1] # 处理 /bin/rm 这种情况

    # 2.2 检查主命令是否在危险列表中
    if main_command in DANGEROUS_COMMANDS:
        # 特殊处理chmod, chown, rm: 只有当它们操作敏感路径时才认为是高危
        if main_command in {'chmod', 'chown', 'rm'}:
             for path in SENSITIVE_PATHS:
                 # 确保匹配的是一个完整的路径，而不是子字符串
                 # 例如，避免把 `/home/user/myetc` 误判为 `/etc`
                 if re.search(r'\s' + re.escape(path) + r'(\s|$)', normalized_command):
                     return True, f"Command '{main_command}' is operating on sensitive path '{path}'"
        else:
            return True, f"Command '{main_command}' is in the dangerous command list"

    # 2.3 检查危险的参数组合模式
    for pattern in DANGEROUS_PATTERNS:
        if pattern.search(normalized_command):
            return True, f"Detected dangerous command pattern: {pattern.pattern}"

    return False, "Command appears to be safe"


def is_base64_image(s: dict) -> bool:
    if 'mimeType' not in s or 'data' not in s:
        return False
    cur_s = f"data:{s['mimeType']};base64,{s['data']}"

    try:
        header, encoded_data = cur_s.split(',', 1)
        decoded_data = base64.b64decode(encoded_data)
    except (ValueError, TypeError):
        return False

    try:
        image_data = io.BytesIO(decoded_data)
        img = Image.open(image_data)
        img.verify()
    except Exception:
        return False
        
    return True


def load_tool_experience(tool_list):
    pass