import json
import os
import random
import time
import traceback
from functools import partial
from pathlib import Path

import networkx as nx
from datasets import load_dataset
from dotenv import load_dotenv
from loguru import logger
from openai import OpenAI
from rich.progress import track

from src.graph_utils import summary, path_to_str


load_dotenv()


def call_openai(prompt):
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"))
    logger.info("Calling gpt-4o with openai...")
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.1
    )
    return response.choices[0].message.content

def call_vllm(prompt, model):
    model_map = {
        "qwen3-32b": "qwen3:32b",
        "llama3-1-8b": "llama3.1:8b",
    }
    model = model_map[model]

    client = OpenAI(api_key="empty", base_url=os.getenv("VLLM_API_BASE"))
    logger.info(f"Calling {model} with vllm...")
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.1,
        extra_body={"chat_template_kwargs": {"enable_thinking": False}},
    )
    return response.choices[0].message.content

def vllm_health_check():
    import requests
    health_uri = os.getenv("VLLM_API_BASE").replace("v1", "health")
    response = requests.get(health_uri)
    if response.status_code != 200:
        logger.error(f"VLLM health check failed ({health_uri}): {response.status_code}")
        return False
    logger.info(f"VLLM health check passed: {response}")
    return True

def call_siliconflow(prompt, model):
    client = OpenAI(api_key=os.getenv("SILICONFLOW_API_KEY"), base_url="https://api.siliconflow.cn/v1")
    logger.info(f"Calling {model} with siliconflow...")
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.1,
        extra_body={"chat_template_kwargs": {"enable_thinking": False}},
    )
    return response.choices[0].message.content

def get_all_paths(q_entity, a_entity, G: nx.DiGraph):
    """获取从q_entity到a_entity的路径，确保每个答案实体都有对应的路径"""
    from collections import defaultdict
    answer_triples = defaultdict(list)

    # 检测 q_entity 和 a_entity 是否在图中
    for e in q_entity:
        if e not in G.nodes():
            logger.warning(f"Q_Entity {e} not found in graph")
            continue

    for a in a_entity:
        if a not in G.nodes():
            logger.warning(f"Answer {a} not found in graph")
            continue

    # 按答案实体收集路径
    for e in q_entity:
        for a in a_entity:
            # 检查是否存在从e到a的路径
            if e not in G.nodes() or a not in G.nodes():
                logger.warning(f"Q_Entity {e} or Answer {a} not found in graph")
                continue

            if not nx.has_path(G, e, a):
                logger.warning(f"No path from {e} to {a}")
                continue

            for path in nx.all_shortest_paths(G, e, a):
                _triple = []
                for i in range(len(path)-1):
                    curr_node = path[i]
                    next_node = path[i+1]
                    relation = G[curr_node][next_node]['relation']
                    assert "-->" in relation or "<--" in relation
                    _triple.append((curr_node, relation, next_node))

                answer_triples[a].append(_triple)

    # 对每个答案的路径单独进行过滤
    filtered_triples = []
    max_paths_per_answer = 100  # 每个答案保留的最大路径数

    for answer, triples in answer_triples.items():
        if len(triples) > max_paths_per_answer:
            # 按路径长度排序，优先保留较短的路径
            triples.sort(key=len)
            filtered_triples.extend(triples[:max_paths_per_answer])
        else:
            filtered_triples.extend(triples)

    return filtered_triples

def parse_path_ids(response: str) -> list[int]:
    """
    使用正则表达式解析模型响应中的路径ID

    Args:
        response: 模型的回答文本

    Returns:
        提取出的路径ID列表，按原始顺序排列
    """
    import re

    # 首先尝试查找标准格式 "Answer: 1, 2, 3"
    pattern = r"Answer:\s*([\d,\s]+)"
    match = re.search(pattern, response)

    if match:
        # 提取并处理数字
        numbers_str = match.group(1)
        ids = [int(num.strip()) for num in numbers_str.split(',') if num.strip().isdigit()]
        return ids

    # 备选格式：尝试查找数字列表或其他格式
    # 例如 "path ids: 1, 2, 3" 或 "I recommend paths 1,2,3"
    alt_pattern = r"path(?:s|[\s\:]+)(?:ids[\s\:]+)?(?:are[\s\:]+)?([\d,\s]+)"
    alt_match = re.search(alt_pattern, response, re.IGNORECASE)

    if alt_match:
        numbers_str = alt_match.group(1)
        ids = [int(num.strip()) for num in numbers_str.split(',') if num.strip().isdigit()]
        return ids

    # 最后尝试直接查找独立的数字
    ids = [int(num) for num in re.findall(r'\b(\d+)\b', response)]

    # 过滤掉可能不是路径ID的数字（太大的数字）
    return [id for id in ids if id < 1000]  # 假设路径ID不会超过1000


prompt_template = """
Seriously analyze these paths, which of these paths can be used to answer this question:
<Question>{question}</Question>
<Answer_Entity>{answer_entity}</Answer_Entity>

NOTE: Please response with the following format:
Thought: <your thought>
Answer: <path id>

The paths are:
{path_str}

Response:
"""

def get_valid_paths(ins: dict, call_model) -> list[list[str]]:
    ins = summary(ins, with_graph=True)
    truth_paths = get_all_paths(ins["q_entity"], ins["a_entity"], ins["graph"])
    if len(truth_paths) == 0:
        logger.warning(f"Warning: No valid paths found for {ins['id']}, use SHORTEST.")
        return ins["truth_paths"]

    # 定义获取关系模式的函数
    def get_relation_pattern(path):
        return " ".join([r[1] for r in path])

    # 添加轮询机制，当truth_paths长度超过300时，分批请求
    batch_size = 300
    valid_paths = []
    valid_patterns = set()  # 使用集合保存已验证的关系模式
    remaining_paths = truth_paths.copy()

    while remaining_paths:
        # 处理当前批次
        batch_paths = remaining_paths[:batch_size]
        remaining_paths = remaining_paths[batch_size:]

        # 预先过滤掉已有模式的路径
        new_batch_paths = []
        for path in batch_paths:
            pattern = get_relation_pattern(path)
            if pattern in valid_patterns:
                valid_paths.append(path)
            else:
                new_batch_paths.append(path)

        if not new_batch_paths:
            continue

        # 只处理新模式的路径
        path_str = path_to_str(ins['q_entity'][0], new_batch_paths, with_deco=False)
        path_str = "\n".join([f"{j}. {path}" for j, path in enumerate(path_str)])
        prompt = prompt_template.format(
            question=ins["question"],
            answer_entity=ins["a_entity"],
            path_str=path_str
        )
        response = call_model(prompt)
        path_ids = parse_path_ids(response)

        # 添加新的有效路径和模式
        for id in path_ids:
            if id < len(new_batch_paths):
                path = new_batch_paths[id]
                valid_paths.append(path)
                valid_patterns.add(get_relation_pattern(path))

        # 控制一下上限
        if len(valid_paths) >= 300:
            break

    if len(valid_paths) == 0:
        logger.warning(f"Warning: No valid paths found for {ins['id']}, use SHORTEST.")
        return ins["truth_paths"]

    return valid_paths

def convert_paths_to_triples(paths: list[list[str]]) -> list[list[tuple[str, str, str]]]:
    """
    将路径转换为三元组表示

    Args:
        paths: 路径列表，每个路径包含多个(头实体,关系,尾实体)三元组

    Returns:
        转换后的标准三元组列表，每个三元组格式为[头实体,关系,尾实体]
    """
    triples = []
    for path in paths:
        triple = []
        for h, r, t in path:
            # 关系中包含方向指示符
            if "<--" in r:
                # h <--r-- t 表示 t是主语, h是宾语，标准表示为 [t, r, h]
                triple.append([t, r[4:-4], h])
            elif "-->" in r:
                # h --r--> t 表示 h是主语, t是宾语，标准表示为 [h, r, t]
                triple.append([h, r[4:-4], t])
            else:
                raise ValueError(f"无效的关系格式: {r}")
        triples.append(triple)

    return triples

def process_dataset(dataset_name: str,
                   data_path: str,
                   split: str = "train",
                   output_dir: str = "./",
                   max_paths: int = 300,
                   model: str = "llama-3-1-8b") -> None:
    """
    处理指定数据集并生成有效路径

    Args:
        dataset_name: 数据集名称
        data_path: 数据集根目录
        split: 数据集分割（train/test/validation）
        output_dir: 输出目录
        max_paths: 最大路径数量
        model: 使用的模型名称
    """
    output_file = Path(output_dir) / f"{split}.valid.{model}.jsonl"
    dataset_path = Path(data_path)
    logger.info(f"Output file: {output_file}")

    from src.utils import pause_to_confirm
    pause_to_confirm()

    # 创建输出目录
    output_file.parent.mkdir(parents=True, exist_ok=True)

    # 加载已有的处理结果
    existing_data = {}
    if output_file.exists():
        with open(output_file, "r", encoding="utf-8") as f:
            lines = f.readlines()

        for line in track(lines, description="Loading existing data"):
            try:
                data = json.loads(line.strip())
                if "valid_paths" in data and len(data["valid_paths"]) > 0:
                    existing_data[data["id"]] = data
            except json.JSONDecodeError:
                logger.warning(f"Warning: Skip invalid JSON line: {line[:50]}...")

    # 选择模型调用函数
    call_model = None
    if model in ["llama-3-1-8b", "qwen3-32b"] and vllm_health_check():
        call_model = partial(call_vllm, model=model)
    elif model in ["gpt-4o", "gpt4o"]:
        call_model = call_openai
    elif model in ["qwen3-32b"]:
        call_model = partial(call_siliconflow, model="Qwen/Qwen3-32B")
    else:
        raise ValueError(f"Invalid model: {model}")

    # 将已经处理过的数据写入到 processed_ids 并保存到 output_file 中
    processed_ids = set()
    with open(output_file, "w", encoding="utf-8") as f:
        for data in track(existing_data.values(), description="Writing processed data"):
            processed_ids.add(data["id"])
            f.write(json.dumps(data, ensure_ascii=False) + "\n")

    # 加载并处理数据集
    ds = load_dataset(dataset_path.as_posix(), split=split)
    logger.info(f"Processing {dataset_name}-{split}... Total: {len(ds)}, Already processed: {len(processed_ids)}")

    # 使用追加模式保存结果，便于断点续传
    with open(output_file, "a", encoding="utf-8") as f:
        for idx, ins in enumerate(track(ds, description="Processing dataset")):
            logger.info(f"\nProcessing {ins['id']} ({idx+1}/{len(ds)}) Question: {ins['question']}, answer: {ins['a_entity']}")
            if ins["id"] in processed_ids:
                logger.info(f"Skip {ins['id']}, already processed.")
                continue

            try:
                valid_paths = get_valid_paths(ins, call_model)
                triples = convert_paths_to_triples(valid_paths)
                logger.info(f"#Valid paths: {len(triples)}, #Avg.Per Answer: {len(triples)/len(ins['a_entity']):.2f};")
                logger.info(f"Paths: {random.sample(valid_paths, min(3, len(triples)))}")

                # 写入文件
                result = ins.copy()
                result["valid_paths"] = triples
                f.write(json.dumps(result, ensure_ascii=False) + "\n")
                f.flush()  # 立即写入文件，避免意外中断导致数据丢失

            except Exception as e:
                logger.error(f"Error processing {ins['id']}: {e}, {traceback.format_exc()}")
                triples = ins.get("truth_paths", [])
                # time.sleep(100)

def check_valid_paths(valid_path):
    """判断该文件中是否存在 valid_paths 字段，以及 valid_paths 字段是否为空"""
    count = 0
    with open(valid_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        logger.info(f"Checking {len(lines)} lines")
        for line in track(lines, description="Checking valid paths"):
            ins = json.loads(line)
            if "valid_paths" not in ins or not ins["valid_paths"] or len(ins["valid_paths"]) == 0:
                logger.warning(f"Warning: {ins['id']} has no valid paths")
                count += 1
    logger.info(f"Total: {len(lines)}, Invalid: {count}")

def main():
    import argparse
    parser = argparse.ArgumentParser(description="处理数据集生成有效路径")
    parser.add_argument("--dataset", type=str, required=True, help="数据集名称")
    parser.add_argument("--data_path", type=str, required=True, help="数据集根目录")
    parser.add_argument("--split", type=str, default="train", help="数据集分割")
    parser.add_argument("--output_dir", type=str, default="./", help="输出目录")
    parser.add_argument("--max_paths", type=int, default=300, help="最大路径数量")
    parser.add_argument("--model", type=str, default="llama-3-1-8b", help="模型", choices=["llama-3-1-8b", "gpt-4o", "qwen3-32b"])

    args = parser.parse_args()
    process_dataset(
        dataset_name=args.dataset,
        data_path=args.data_path,
        split=args.split,
        output_dir=args.output_dir,
        max_paths=args.max_paths,
        model=args.model
    )

if __name__ == "__main__":
    main()

    # logger.debug(f"VLLM health check {os.getenv('VLLM_API_BASE')}...")
    # vllm_health_check()
    # check_valid_paths("data/cwq/train.valid.llama3-1-8b.jsonl")