import os
import json
import time
import traceback
from typing import List
import pickle

from arguments import parse_args
from src.utils import wait_for_gpu

args = parse_args()

if not args.nowait:
    wait_for_gpu(gpu_id=args.gpu, check_interval=args.gpu_wait_interval, consecutive_counts=args.gpu_wait_counts, memory_threshold=args.gpu_wait_threshold)
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from src.agent import Agent, embed_model # noqa: E402

import datasets # noqa: E402
import networkx as nx # noqa: E402
from rich.progress import track  # noqa: E402
from src.prompts import build_prompt # noqa: E402
from src.models import select_llm_model, select_rerank_model, log_token_usage # noqa: E402
from src.utils import logger, json_parser, check_response, Stage # noqa: E402
from src.graph_utils import (  # noqa: E402
    summary,
    build_vec_graph,
    fill_subgraph,
    build_subgraph,
    get_relevants,
    path_to_str,
    merge_similar_paths,
) # noqa: E402
from src.scripts.get_valid_paths import get_all_paths # noqa: E402
from src.eval import eval_from_config, ParseError # noqa: E402
from src.config import Config # noqa: E402
from src.memory import Memory # noqa: E402



config_850 = {
    "embed_model": "bge-m3",
    "llm_model": "gpt-4o-mini",
    "rerank_model": "ckpt/bge-reranker-v2-m3-0314-results",
    "reasoning_model": "deepseek-ai/DeepSeek-R1",
    "dataset_name": "webqsp",
    "dataset": os.path.join(os.getenv("DATASETS_DIR", ""), "rmanluo/RoG-webqsp"),
    "use_data": "graphs",  # options: graphs, paths, pre_retrieve
    "log_interval": 100,
    "node_count_threshold": 0.3,
    "edge_count_threshold": 0.3,
    "max_rerank_nodes": 400,
    "rerank_threshold": 0.01,
    "rerank_rate": 0.5,
    "top_threshold": 16000,
    "resume": False,
    "init_prompt_type": "initial_with_samples",
    "sample_topk": 3,
    "sample_topk_agent": 3,
    "final_prompt_type": "final",
    "use_merge_paths": True,
    "use_agent": True,
    "run_name": "850",
    "split": "test",
    "memory_file": "data/memory/cwq.train.valid.llama3-1-8b+webqsp.RoG-webqsp.train.valid.txt",
}

def main(run_name: str):

    conf = Config(**run_name)
    model = select_llm_model(conf.llm_model)
    reasoning_model = select_llm_model(conf.reasoning_model)
    rerank_model = select_rerank_model(conf.rerank_model)

    # > preprocessing config
    if args.debug:
        logger.warning(">>>> Debug mode is enabled")

    # 加载数据
    data = datasets.load_dataset(conf.dataset, split=conf.split)
    memory = Memory(embed_model=embed_model, dataset_name=conf.dataset_name, memory_file=conf.memory_file)

    if conf.dataset_cutoff:
        data = data.select(range(conf.dataset_cutoff))
        logger.warning(f">>>> Cutting dataset to {conf.dataset_cutoff} instances")


    # 加载 pre_retrieve 数据
    if conf.use_data == "pre_retrieve":
        assert conf.pre_retrieve_path, "Pre retrieve path is not set"
        logger.warning(f"Loading pre_retrieve data from {conf.pre_retrieve_path}")
        pre_retrieve_data = {}
        with open(conf.pre_retrieve_path, "r") as f:
            for line in f.readlines():
                ins = json.loads(line)
                pre_retrieve_data[ins["id"]] = ins

    processed_ids, error_ids = resume_from_file(conf)  # 加载已处理和错误的数据，默认是空
    cache_dir = f"cache/{conf.embed_model}"
    os.makedirs(cache_dir, exist_ok=True)


    # > main loop
    for idx, ins in enumerate(track(data, description="Processing", total=len(data))):
        if conf.resume and ins["id"] in processed_ids:
            continue

        ins_summary = summary(ins)
        q_entity = ins_summary["q_entity"]
        logger.warning(f"Question: [{idx}] [{ins['id']}] {ins_summary['question']}? (Avg Path Len: {ins_summary['mean_path_len']:.2f})")
        logger.warning(f"Answer: {ins_summary['answer'][:30]}")  # 只打印前30个答案

        if conf.exclude_missing and ins_summary.get("answer_not_in_graph"):
            logger.warning(f"Skipping instance {ins['id']} because it has an error")
            continue

        # ONLY for TEST
        # if len(q_entity) == 0:
        #     logger.warning(f"Skipping instance {ins['id']} because it has no entity")
        #     continue

        if ins_summary['mean_path_len'] < 2 and args.debug:
            logger.error(f"Skipping instance {ins['id']} because it has a mean path length of {ins_summary['mean_path_len']}")
            continue

        # if args.debug and ins["id"] != "WebQTrn-567_693feb48c0515cd069014e7ca2846b37":
        #     logger.error(f"Skipping instance {ins['id']} because it has a mean path length of {ins_summary['mean_path_len']}")
        #     continue

        if args.debug:
            if len(q_entity) != 1:
                logger.error(f"Skipping instance {ins['id']} because it has {len(q_entity)} entities")
                continue

            if len(ins_summary["answer"]) > 1:
                logger.error(f"Skipping instance {ins['id']} because it has {len(ins_summary['answer'])} answers")
                continue

        if conf.save_rerank_result:
            result_path = conf.save_rerank_result_dir or f"data/rerank_results/{conf.timestamp}_{conf.dataset_name}_{conf.split}_{conf.rerank_model.split('/')[-1]}"
            _rerank_result = os.path.join(result_path, f"rerank_paths_{ins['id']}.json")
            if os.path.exists(_rerank_result):
                logger.warning(f"Skipping instance {ins['id']} because it has already been processed")
                continue

        try:  # 进入主逻辑
            # > build graph
            if conf.no_kge_cache:  # 不使用 KGE 缓存
                wholeG = build_vec_graph(ins["graph"], embed_model=embed_model, use_agg=not conf.no_agg, agg_method=conf.agg_method)
            else:
                cache_file = os.path.join(cache_dir, f"{ins['id']}.pkl")
                if os.path.exists(cache_file):
                    with open(cache_file, "rb") as f:
                        wholeG = pickle.load(f)
                else:
                    wholeG = build_vec_graph(ins["graph"], embed_model=embed_model, use_agg=not conf.no_agg, agg_method=conf.agg_method)
                    with open(cache_file, "wb") as f:
                        pickle.dump(wholeG, f) # 保存到缓存

            logger.info(f"Graph stats - Nodes: {len(wholeG.nodes)}, Edges: {len(wholeG.edges)}")

            # > init prompt
            if (conf.init_prompt_type == "initial_with_samples" or conf.final_prompt_type == "final_with_samples") and conf.sample_topk > 0:
                samples = memory.retrieve_similar(ins["question"], neg_query=q_entity[0], top_n=conf.sample_topk, return_format="qa")[0]
                samples = "\n---\n".join(samples)
                # logger.error(f"Samples: {samples}")
            else:
                samples = ""

            # > init response
            if conf.use_data == "graphs" or conf.use_data is None:
                prompt_type = conf.init_prompt_type or "initial"
                prompt = build_prompt(
                    type=prompt_type,
                    question=ins["question"],
                    dataset_name=conf.dataset_name or 'webqsp',
                    q_entity=ins_summary["q_entity"],
                    samples=samples,
                )
                response_raw = model.predict(prompt)
                response = json_parser(response_raw.content)
                ins_summary["response_raw"] = response_raw.content
                logger.info(f"Initial Response: {response_raw.content}")
                log_token_usage("initial", prompt, response_raw.content, ins_id=ins["id"])

                # 处理无法解析 response 的情况
                if not response:
                    ins_summary["error_type"] = ParseError.NO_VALID_RESPONSE.name
                    logger.error(ins_summary["error_type"])
                    save_result([ins_summary], conf.results_path)
                    continue

                candidate_answers, reasoning_paths = check_response(response)
                ins_summary["candidate_answers"] = candidate_answers
                ins_summary["reasoning_paths"] = reasoning_paths
                reasoning_paths = reasoning_paths + [ins["question"]]
                candidate_answers = candidate_answers + [ins["question"]]

            elif conf.use_data == "paths":
                candidate_answers = path_to_str(q_entity[0], ins_summary["truth_paths"], with_deco=False) + [ins["question"]]
                reasoning_paths = path_to_str(q_entity[0], ins_summary["truth_paths"], with_deco=False) + [ins["question"]]

            elif conf.use_data == "pre_retrieve":
                logger.warning(f"Pre Retrieve Init Response: {pre_retrieve_data[ins['id']]['response_raw']}")
                assert ins["id"] in pre_retrieve_data
                candidate_answers = pre_retrieve_data[ins["id"]]["candidate_answers"] + [ins["question"]]
                reasoning_paths = pre_retrieve_data[ins["id"]]["reasoning_paths"] + [ins["question"]]

            elif conf.use_data == "no_hyrp":
                candidate_answers = [ins["question"]] + ins_summary["q_entity"]
                reasoning_paths = [ins["question"]] + ins_summary["q_entity"]

            if args.debug:
                ANSWERED = False
                for a in ins_summary["answer"]:
                    if a in candidate_answers:
                        ANSWERED = True
                        break
                if ANSWERED:
                    logger.error(f"Skipping instance {ins['id']} because the initial answer is in the answers")
                    continue

                if len(ins_summary["answer"]) != 1:
                    logger.error(f"Skipping instance {ins['id']} because the initial answer is not a single entity")
                    continue

            # > retrieve subgraph
            node_embeddings = embed_model.encode(candidate_answers, batch_size=256)
            edge_embeddings = embed_model.encode(reasoning_paths, batch_size=256)

            # 构建 subgraph
            relevant_nodes, relevant_edges = get_relevants(
                wholeG,
                nodes=candidate_answers,
                edges=reasoning_paths,
                node_embeds=node_embeddings,
                edge_embeds=edge_embeddings,
                node_count=conf.node_count_threshold,
                edge_count=conf.edge_count_threshold,
                use_bm25=conf.kge_use_bm25,
            )
            subgraph = build_subgraph(wholeG, relevant_nodes, relevant_edges)
            subgraph = fill_subgraph(subgraph, wholeG, ins)
            ins_summary["sub_nodes"] = list(subgraph.nodes)
            logger.info(f"Subgraph stats - Nodes: {len(subgraph.nodes)}, Edges: {len(subgraph.edges)}")

            if conf.stop_at == Stage.AFTER_SUBGRAPH.value:
                raise Exception(ParseError.HUMAN_CHECK)

            # > rerank paths
            paths = []
            cand_nodes_in_graph = [node for node in candidate_answers if node in wholeG]
            leaves = list(subgraph.nodes) + cand_nodes_in_graph
            for leaf in leaves:
                _triples = get_all_paths(q_entity, [leaf], wholeG) # 获取所有路径
                for _triple in _triples:
                    if len(_triple) == 0:
                        continue

                    path_str = path_to_str(q_entity, [_triple], with_deco=False)[0]
                    paths.append({
                        "node": leaf,
                        "path": path_str, # f"Answer: {leaf}\nPath: {path_str}",
                        "triples": _triple,
                        "is_truth": True if leaf in ins_summary["answer"] else False,
                    })
            scores = rerank_model.score_batch(ins["question"], [p["path"] for p in paths], batch_size=64)
            for i in range(len(paths)):
                paths[i]["score"] = scores[i]

            # 保存rerank结果
            if conf.save_rerank_result:
                os.makedirs(os.path.dirname(_rerank_result), exist_ok=True)
                with open(_rerank_result, "w") as f:
                    ins_summary["rerank_paths"] = paths
                    json.dump(ins_summary, f, ensure_ascii=False)

                continue  # 测试，仅保存 rerank 结果

            # > build final paths
            paths = sorted(paths, key=lambda x: x["score"], reverse=True)
            # TODO 下面两个的顺序是不是应该对调一下
            if conf.rerank_dynamic_fix:
                path_top_precent = [p for p in paths if p["score"] > conf.rerank_threshold]
                path_top_precent = path_top_precent[: max(conf.max_rerank_nodes, int(len(path_top_precent) * conf.rerank_rate))]
            else:
                path_top_precent = paths[: max(conf.max_rerank_nodes, int(len(paths) * conf.rerank_rate))]
                path_top_precent = [p for p in path_top_precent if p["score"] > conf.rerank_threshold]

            if len(path_top_precent) == 0:
                logger.warning(f"No valid paths found for instance {ins['id']}, max score: {max(scores)}")
                path_top_precent = paths[:10]

            if conf.top_threshold:  # 如果设置了阈值，则进行截断
                payload = [path_top_precent[0]["path"]]
                while model.count_tokens(" ".join(payload)) < conf.top_threshold and len(payload) < len(path_top_precent):
                    payload.append(path_top_precent[len(payload)]["path"])
                path_top_precent = path_top_precent[: len(payload)]
                if len(payload) < len(path_top_precent):
                    logger.warning(f"Truncated paths for instance {ins['id']}, max score: {max(scores)}")

            # 截断后的路径
            rerank_paths = [p["path"] for p in path_top_precent]
            rerank_nodes = [p["node"] for p in path_top_precent]
            ins_summary["path_top_precent"] = path_top_precent

            if conf.use_merge_paths:
                rerank_triples = [p["triples"] for p in path_top_precent]
                rerank_triples = merge_similar_paths(rerank_triples)
                rerank_paths = path_to_str(q_entity, rerank_triples, with_deco=False)

            logger.debug(f"TOP(NO.1): {path_top_precent[0]}")
            logger.debug(f"LAST(NO.{len(path_top_precent)}): {path_top_precent[-1]}")
            logger.info(f"Reranked nodes stats - Nodes: {len(rerank_nodes)}, Tokens: {model.count_tokens(' '.join(rerank_paths))}")

            # > final response
            if conf.stop_at == Stage.AFTER_REANK.value:
                raise Exception(ParseError.HUMAN_CHECK)

            # 使用截断后的路径生成最终的答案
            if conf.use_agent:
                agent = Agent(model_name=conf.llm_model, agent_model=conf.agent_model)

                top_n = conf.sample_topk_agent if conf.sample_topk_agent else conf.sample_topk
                samples = memory.retrieve_similar(ins["question"], neg_query=q_entity[0], top_n=top_n, return_format="agent")[0]
                samples = "\n---\n".join(samples)
                final_response = agent.invoke_msg(ins["question"], subgraph, rerank_paths, samples)
            elif conf.use_reasoning and reasoning_model is not None:
                final_prompt = build_prompt(type=conf.final_prompt_type, question=ins["question"], reasoning_paths=rerank_paths, samples=samples)
                final_response = reasoning_model.predict(final_prompt)
                log_token_usage("final", final_prompt, final_response.content)
            else:
                final_prompt = build_prompt(type=conf.final_prompt_type, question=ins["question"], reasoning_paths=rerank_paths, samples=samples)
                final_response = model.predict(final_prompt)
                log_token_usage("final", final_prompt, final_response.content)

            ins_summary["final_response_raw"] = final_response.content
            logger.info(f"Final Response: {final_response.content}")
            logger.info(f"Answer: {ins_summary['answer']}")

        except Exception as e:
            if e.args[0] != ParseError.HUMAN_CHECK:
                ins_summary["error_type"] = ParseError.OTHER.name
                logger.error(f"{ins_summary['error_type']}: {e}, {traceback.format_exc()}")
                time.sleep(100)

        # > save result
        save_result([ins_summary], conf.results_path)

        if idx % conf.log_interval == 0  or (idx < 300 and idx % 10 == 0):
            eval_from_config(conf.path, verbose=True, exclude_missing=conf.exclude_missing)

    # > eval
    eval_from_config(conf.path, verbose=True, exclude_missing=conf.exclude_missing)
    print(conf.path)


def resume_from_file(conf):
    results_path = conf.results_path
    processed_ids = set()
    error_ids = set()

    if not conf.resume:
        return processed_ids, error_ids

    assert os.path.exists(results_path), f"Results file {results_path} does not exist"

    with open(results_path, "r") as f:
        lines = f.readlines()
        all_cnt = len(lines)
        for line in lines:
            ins_summary = json.loads(line)

            try:
                final_response = json_parser(ins_summary["final_response_raw"])
                assert final_response, f"No valid answer found for instance {ins_summary['id']}"
                assert final_response.get("most_possible_answer"), f"No valid answer found for instance {ins_summary['id']}"
                assert not ins_summary.get("error_type"), f"Error type found for instance {ins_summary['id']}"
            except Exception as e:
                logger.error(f"Error parsing instance {ins_summary['id']}: {e}")
                error_ids.add(ins_summary["id"])
                lines.remove(line)
                continue

            processed_ids.add(ins_summary["id"])

    # 重写文件,删除错误行
    with open(results_path, "w") as f:
        f.writelines(lines)

    logger.warning(f"Resuming from {results_path}, {len(processed_ids)}/{all_cnt} instances processed")

    return processed_ids, error_ids


def save_result(ins: List[dict], path: str, mode: str = "a+"):
    with open(path, mode) as f:
        for i in ins:
            f.write(json.dumps(i, ensure_ascii=False) + "\n")

def get_path_with_edges(graph, start, end):
    path = nx.shortest_path(graph, start, end)
    full_path = [start]
    for i in range(len(path) - 1):
        node1, node2 = path[i], path[i + 1]
        edge_data = graph[node1][node2]
        edge_label = edge_data.get("relation", "")  # 假设边的标签存储在 'label' 属性中
        full_path.extend([f"--{edge_label}-->", node2])
    return full_path

if __name__ == "__main__":
    run = eval(args.config) if args.config else config_850
    run["resume"] = args.resume  # 默认是 False
    main(run)
