from enum import Enum
import json
import os
import time
from loguru import logger
import re
from langchain_core.output_parsers import JsonOutputParser
import typer

# 移除 setup_logger 函数，并配置 loguru
LOG_PATH = f'outputs/logs/{time.strftime("%Y-%m-%d_%H-%M-%S")}.log'
os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)

# 配置 loguru 同时输出到控制台和文件
logger.remove()
logger.add(lambda msg: print(msg, end=''), colorize=True, format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level>: <level>{message}</level>")
logger.add(LOG_PATH, rotation="500 MB", encoding="utf-8", enqueue=True, retention="10 days", level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}")

logger = logger.bind(name="utils")

def json_parser(x):
    if not x:
        logger.warning("Input to json_parser is empty.")
        return None

    # 尝试直接解析
    try:
        parsed_json = JsonOutputParser().parse(x)
        logger.debug("Successfully parsed JSON directly.")
        return parsed_json
    except Exception as e:
        logger.debug(f"Direct JSON parsing failed: {e}")

    # 尝试从 ``` 或 ```json``` 块中解析
    patterns = [r'```(?:json)?\s*([\s\S]*?)\s*```', r'({[\s\S]*?})\s*(?:$|Note:)']
    for pattern in patterns:
        try:
            x_cleaned = x.replace("\n", " ")
            match = re.search(pattern, x_cleaned, re.DOTALL)
            if match:
                json_str = match.group(1)
                parsed_json = JsonOutputParser().parse(json_str)
                logger.debug(f"Successfully parsed JSON with pattern '{pattern}'.")
                return parsed_json
        except Exception as e:
            logger.debug(f"JSON parsing with pattern '{pattern}' failed: {e}")

    logger.error(f"Failed to parse JSON from input: {x[:200]}...") # 打印部分输入以避免过长日志
    return None

def list_parser(x):
    if not x:
        logger.warning("Input to list_parser is empty.")
        return None

    return [a.strip() for a in x.split("\n")]

def check_response(response):
    if not isinstance(response, dict):
        logger.warning(f"Response is not a dictionary: {response}")
        return [], []

    candidate_answers = response.get("candidate_answers", [])
    reasoning_paths = response.get("reasoning_paths", [])

    if not candidate_answers and not reasoning_paths:
        logger.info("No candidate answers or reasoning paths found in response.")

    return candidate_answers, reasoning_paths


def convert_json_to_jsonl(json_file, jsonl_file):
    logger.info(f"Converting {json_file} to {jsonl_file}...")
    with open(json_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    with open(jsonl_file, 'w', encoding='utf-8') as f:
        for item in data:
            _json_str = json.dumps(item, ensure_ascii=False)
            f.write(_json_str + '\n')
    logger.info("Conversion complete.")

def convert_negs_to_neg(negs_file, neg_file):
    logger.info(f"Converting negs in {negs_file} to neg in {neg_file}...")
    with open(negs_file, 'r', encoding='utf-8') as f:
        data = [json.loads(line) for line in f]

    with open(neg_file, 'w', encoding='utf-8') as f:
        for item in data:
            item["neg"] = item["negs"]
            item.pop("negs")
            _json_str = json.dumps(item, ensure_ascii=False)
            f.write(_json_str + '\n')
    logger.info("Conversion complete.")

def cut_max_samples(file, max_samples=30):
    logger.info(f"Cutting max samples to {max_samples} for {file}...")
    with open(file, 'r', encoding='utf-8') as f:
        data = [json.loads(line) for line in f]

    with open(file, 'w', encoding='utf-8') as f:
        for item in data:
            item["neg"] = item["neg"][:max_samples]
            item["pos"] = item["pos"][:max_samples]
            _json_str = json.dumps(item, ensure_ascii=False)
            f.write(_json_str + '\n')
    logger.info("Cutting complete.")


def wait_for_gpu(gpu_id=None, memory_threshold=20000, check_interval=300, consecutive_counts=3):
    """
    等待 GPU 空闲

    Args:
        gpu_id (str): 要监测的GPU ID，如果为None则表示任意显卡空闲都可以使用
        memory_threshold (int): 空闲显存阈值（MB），低于该值认为GPU忙碌
        check_interval (int): 检查间隔时间（秒）
        consecutive_counts (int): 连续检测到空闲的次数，达到该次数认为GPU可用

    Returns:
        str: 使用的GPU ID
    """
    import subprocess
    import time

    def get_gpu_count():
        """获取系统中可用的GPU数量"""
        cmd = "nvidia-smi --query-gpu=count --format=csv,noheader,nounits"
        try:
            output = subprocess.check_output(cmd, shell=True).decode('utf-8').strip()
            return int(output.split("\n")[0])
        except:
            return 0

    def check_gpu_memory(gpu_id):
        """检查指定GPU的空闲显存"""
        cmd = f"nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits -i {gpu_id}"
        try:
            output = subprocess.check_output(cmd, shell=True).decode('utf-8').strip()
            return int(output.split("\n")[0])
        except:
            return 0

    # 检查单个指定GPU
    if gpu_id is not None:
        logger.info(f"开始监测GPU:{gpu_id}，等待显存大于{memory_threshold}MB")

        count = 0
        while count < consecutive_counts:
            try:
                free_memory = check_gpu_memory(gpu_id)

                if free_memory >= memory_threshold:
                    count += 1
                    logger.info(f"GPU:{gpu_id} 空闲显存: {free_memory}MB, 连续检测次数: {count}/{consecutive_counts}")
                else:
                    count = 0
                    logger.info(f"GPU:{gpu_id} 空闲显存不足: {free_memory}MB < {memory_threshold}MB, 等待中...")

                if count < consecutive_counts:
                    time.sleep(check_interval)
            except Exception as e:
                logger.error(f"监测GPU时出错: {e}")
                count = 0
                time.sleep(check_interval)

        # 设置环境变量
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        logger.info(f"GPU:{gpu_id} 已空闲，显存大于{memory_threshold}MB，已设置CUDA_VISIBLE_DEVICES={gpu_id}")
        return gpu_id

    # 检查所有GPU，任意一个满足条件即可
    else:
        gpu_count = get_gpu_count()
        if gpu_count == 0:
            logger.error("未检测到可用的GPU")
            return None

        logger.info(f"开始监测所有GPU，等待任意显卡显存大于{memory_threshold}MB")

        while True:
            for current_gpu_id in range(gpu_count):
                count = 0
                consecutive_fail = 0

                # 连续监测当前GPU
                while count < consecutive_counts and consecutive_fail < 3:
                    try:
                        free_memory = check_gpu_memory(current_gpu_id)

                        if free_memory >= memory_threshold:
                            count += 1
                            logger.info(f"GPU:{current_gpu_id} 空闲显存: {free_memory}MB, 连续检测次数: {count}/{consecutive_counts}")
                        else:
                            count = 0
                            consecutive_fail += 1
                            logger.info(f"GPU:{current_gpu_id} 空闲显存不足: {free_memory}MB < {memory_threshold}MB")
                            break

                        if count < consecutive_counts:
                            time.sleep(check_interval)
                    except Exception as e:
                        logger.error(f"监测GPU:{current_gpu_id}时出错: {e}")
                        count = 0
                        consecutive_fail += 1
                        time.sleep(check_interval)

                # 找到满足条件的GPU
                if count >= consecutive_counts:
                    # 设置环境变量
                    os.environ["CUDA_VISIBLE_DEVICES"] = str(current_gpu_id)
                    logger.info(f"GPU:{current_gpu_id} 已空闲，显存大于{memory_threshold}MB，已设置CUDA_VISIBLE_DEVICES={current_gpu_id}")
                    return str(current_gpu_id)

            logger.info("所有GPU均不满足条件，等待下一轮检测...")
            time.sleep(check_interval)


def pause_to_confirm(message: str = "", pause_time: int = 10):
    """
    暂停程序并等待用户确认，并且支持输入回车键取消等待

    Args:
        message (str): 提示信息
        pause_time (int): 暂停时间（秒）
    """
    import sys
    import time
    import threading
    from select import select

    message = message or "Pause to confirm"
    print(f"{message} will continue in {pause_time} seconds, press [Enter] to continue...")

    # 等待回车键的函数
    def wait_for_enter():
        # 在Windows和Unix系统上使用不同的方法
        if sys.platform == 'win32':
            import msvcrt
            while True:
                if msvcrt.kbhit() and msvcrt.getch() in [b'\r', b'\n']:
                    return True
                time.sleep(0.1)
        else:
            # Unix系统使用select
            rlist, _, _ = select([sys.stdin], [], [], pause_time)
            return bool(rlist)

    # 创建等待回车键的线程
    enter_thread = threading.Thread(target=wait_for_enter)
    enter_thread.daemon = True
    enter_thread.start()

    # 等待线程结束或超时
    enter_thread.join(timeout=pause_time)


class Stage(Enum):
    AFTER_SUBGRAPH = "after_subgraph"
    AFTER_REANK = "after_reank"


def read_jsonl(file_path):
    """返回一个迭代器"""
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            yield json.loads(line)


# 创建 typer 应用程序
app = typer.Typer(help="KGQA 工具集")


@app.command("wait-for-gpu")
def wait_for_gpu_cli(
    gpu_id: int = typer.Option(None, "--gpu-id", "-g", help="要监测的GPU ID，如果不指定则监测所有GPU"),
    memory_threshold: int = typer.Option(20000, "--memory-threshold", "-m", help="空闲显存阈值（MB）"),
    check_interval: int = typer.Option(300, "--check-interval", "-i", help="检查间隔时间（秒）"),
    consecutive_counts: int = typer.Option(3, "--consecutive-counts", "-c", help="连续检测到空闲的次数")
):
    """
    等待 GPU 空闲的命令行工具
    """
    logger.info(f"开始等待GPU空闲，参数: gpu_id={gpu_id}, memory_threshold={memory_threshold}MB, check_interval={check_interval}s, consecutive_counts={consecutive_counts}")

    used_gpu = wait_for_gpu(
        gpu_id=gpu_id,
        memory_threshold=memory_threshold,
        check_interval=check_interval,
        consecutive_counts=consecutive_counts
    )

    if used_gpu:
        logger.info(f"成功获取GPU: {used_gpu}")
        typer.echo(f"GPU {used_gpu} 已可用")
    else:
        logger.error("未能获取到可用的GPU")
        typer.echo("未能获取到可用的GPU", err=True)
        raise typer.Exit(code=1)


@app.command("pause-confirm")
def pause_confirm(
    message: str = typer.Option("", "--message", "-m", help="提示信息"),
    pause_time: int = typer.Option(10, "--pause-time", "-t", help="暂停时间（秒）")
):
    """
    暂停程序并等待用户确认
    """
    pause_to_confirm(message=message, pause_time=pause_time)



if __name__ == "__main__":
    app()


