import argparse
import json
import time

import numpy as np
import torch
from eval_subgraphrag import RetrieverDatasetEval
from tqdm import tqdm
from utils import exact_match

from llm_graph_walk import graph, llm, prompts, sample, sample_sr, text_encoder
from llm_graph_walk.sample_sr import END_REL


def main(args):
    chkpt = torch.load(args.retriever_path)
    config = chkpt["config"]

    t_enc = text_encoder.TextEncoderSR(config["model_name_or_path"], device="cuda")
    print(f"Loading checkpoint at {args.retriever_path}")
    t_enc.load_state_dict(chkpt["model_state_dict"])
    t_enc.eval()

    if args.use_full_wikidata:
        pass
        # kg = graph.KGInterfaceFromWikidata(args.wikidata_server_urls)
    else:
        node_labels = np.load(
            args.wikikg_dir + "/node_labels.npy",
            allow_pickle=True,
        )
        relation_labels = np.load(
            args.wikikg_dir + "/relation_labels.npy",
            allow_pickle=True,
        )
        relation_labels = np.array(relation_labels.tolist() + [END_REL]).astype(
            relation_labels.dtype
        )
        edge_ids = np.load(
            args.wikikg_dir + "/edge_ids.npy",
            allow_pickle=True,
        )
        relation_types = np.load(
            args.wikikg_dir + "/relation_types.npy",
            allow_pickle=True,
        )

        knowledge_graph = graph.Graph(
            edge_ids, relation_types, node_labels, relation_labels
        )

        kg = graph.KGInterfaceFromGraph(knowledge_graph)

    samplefunc = sample_sr.SampleSubgraphSR(kg, t_enc)

    llm_api = llm.LLMAPI("openai", model=args.model)
    tog_sampler = sample.SampleSubgraph(
        llm_api, kg_interface=kg
    )  # only used to produce final prompts

    ds = RetrieverDatasetEval(args.preprocessed_path)
    dl = torch.utils.data.DataLoader(ds, batch_size=None, shuffle=False)

    ds_out = []
    i = -1
    for dp in tqdm(dl):
        i += 1
        # dp_out = {
        #     "id": ds.preprocessed_data[i]["id"],
        #     "question": ds.preprocessed_data[i]["question"],
        #     "qid_topic_entity": ds.preprocessed_data[i]["qid_topic_entity"],
        #     "answers": ds.preprocessed_data[i]["answers"],
        # }
        dp_out = {
            k: ds.preprocessed_data[i][k]
            for k in [
                "question",
                "seed_nodes",
                "seed_nodes_id",
                "answer_node",
                "answer_node_id",
                "sparql_query",
                "all_answers",
                "answer_subgraph",
                "full_subgraph",
                "n_hops",
                "graph_template",
                "answer_triples",
                "redundant",
            ]
        }
        # query = dp_out["question"]
        # answers = dp_out["answers"]
        query = dp_out["question"]
        answers = [kg.get_node_label(dp_out["answer_node_id"])]
        start_time = time.time()
        subgraph, metapaths = samplefunc(
            query=query,
            seed_node_id=dp[-1].numpy().tolist(),
            max_subgraph_size=args.subgraph_size,
        )
        # subgraph = list(zip(subgraph[1].numpy().tolist(), dp[2].numpy().tolist(), dp[3].numpy().tolist()))
        dp_out["graphrag_retrieval_seconds"] = time.time() - start_time
        dp_out["graphrag_subgraph"] = subgraph
        dp_out["graphrag_metapaths"] = metapaths
        # _, subgraph_answer = tog_sampler.is_complete(query, subgraph)
        prompt = tog_sampler.build_answer_prompt(query, subgraph)
        subgraph_answer = tog_sampler.llm_api(prompt)
        dp_out["graphrag_answer"] = subgraph_answer
        # io_out = tog_sampler.llm_api(
        #     prompts.io_prompt() + "Question: " + query + ("\nAnswer: ")
        # )
        # dp_out["io_answer"] = io_out
        # cot_out = tog_sampler.llm_api(
        #     prompts.cot_prompt() + "Question: " + query + ("\nAnswer: ")
        # )
        # dp_out["cot_answer"] = cot_out

        dp_out["graphrag_em"] = exact_match(subgraph_answer, answers)
        # dp_out["io_em"] = exact_match(io_out, answers)
        # dp_out["cot_em"] = exact_match(cot_out, answers)

        ds_out.append(dp_out)

        if len(ds_out) % 10 == 0 or i == len(dl) - 1:
            with open(args.output_file, "w") as f:
                json.dump(ds_out, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--use_full_wikidata",
        action="store_true",
        help="whether to use full Wikidata, instead of WikiKG2",
    )
    parser.add_argument(
        "--wikidata_server_urls",
        type=str,
        default="server_urls_new.txt",
        help="path of txt file with server url addresses",
    )
    parser.add_argument(
        "--wikikg_dir",
        type=str,
        help="directory containing the processed wikikg2",
    )
    parser.add_argument(
        "--preprocessed_path",
        type=str,
        help="path of preprocessed data files",
    )
    parser.add_argument(
        "--output_file", type=str, required=True, help="the output file name"
    )
    parser.add_argument(
        "--model", type=str, default="gpt-4o-mini", help="OpenAI LLM model name"
    )
    parser.add_argument(
        "--retriever_path",
        default="checkpoints/webqsp_sr_shortest_only/checkpoint_ep_10.pth",
        type=str,
        help="path to checkpoint of pretrained SubgraphRAG model",
    )
    parser.add_argument(
        "--subgraph_size",
        type=int,
        required=True,
        help="max number of edges in retrieved subgraph",
    )
    args = parser.parse_args()

    main(args)
