# Explanation: This version of the code is based on an old architecture, and although the actual running process is the same as described in the paper, the logic is still not clear enough. We will organize and release the official code later

import os
import json
import copy
import time
import tqdm
import random
random.seed(1024)
import argparse

from modules.qwenapi import QwenAPI

from agents.merging_agent import MergingAgent
from agents.analysis_agent import AnalysisAgent
from agents.retrieval_agent import RetrievalAgent
from agents.construction_agent import ConstructionAgent


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--max_iter", type=int, default=1)
    parser.add_argument("--llm_name", type=str, default="qwen")
    parser.add_argument("--dataset_name", type=str, default="loong")
    parser.add_argument("--url", type=str, default="")
    parser.add_argument("--weak_url", type=str, default="")
    parser.add_argument("--weak_agent", type=str, default=None)
    parser.add_argument("--router_url", type=str, default="")
    parser.add_argument("--worker_id", type=int, choices=[0, 1, 2, 3, 4, 5, 6, 7], default=0)
    parser.add_argument("--start_bias", type=int, default=0) # used to manually skip last time error data
    parser.add_argument("--output_path_suffix", type=str, default="")
    args = parser.parse_args()

    for k, v in vars(args).items():
        print(f"{k}: {v}")
    print('\nstart...')

    llm = QwenAPI(url=f"http://{args.url}/v1/chat/completions")
    router = QwenAPI(url=f"http://{args.router_url}/v1/chat/completions")

    eval_data_path = "./Loong/data/loong_process.jsonl"
    eval_datas = [json.loads(l) for l in open(eval_data_path)]
    random.shuffle(eval_datas)
    eval_datas = eval_datas[200*args.worker_id+args.start_bias : 200*(args.worker_id+1)]
    print(f"len eval_datas: {len(eval_datas)}")

    intermediate_results_dir = f"./intermediate_results/{args.llm_name}/{args.dataset_name}{args.output_path_suffix}"
    os.makedirs(intermediate_results_dir) if not os.path.exists(intermediate_results_dir) else None

    chunk_kb_path = f"{intermediate_results_dir}/chunk_kb"
    graph_kb_path = f"{intermediate_results_dir}/graph_kb"
    table_kb_path = f"{intermediate_results_dir}/table_kb"
    algorithm_kb_path = f"{intermediate_results_dir}/algorithm_kb"
    catalogue_kb_path = f"{intermediate_results_dir}/catalogue_kb"
    os.makedirs(chunk_kb_path) if not os.path.exists(chunk_kb_path) else None
    os.makedirs(graph_kb_path) if not os.path.exists(graph_kb_path) else None
    os.makedirs(table_kb_path) if not os.path.exists(table_kb_path) else None
    os.makedirs(algorithm_kb_path) if not os.path.exists(algorithm_kb_path) else None
    os.makedirs(catalogue_kb_path) if not os.path.exists(catalogue_kb_path) else None

    output_dir = f"./eval_results/{args.llm_name}/{args.dataset_name}{args.output_path_suffix}"
    os.makedirs(output_dir) if not os.path.exists(output_dir) else None
    fw = open(f"{output_dir}/final_output_{args.worker_id}.jsonl", "a")
    fw_error = open(f"{output_dir}/final_output_error_{args.worker_id}.jsonl", "a")
    exiting_data = [json.loads(l) for l in open(f"{output_dir}/final_output_{args.worker_id}.jsonl")]
    exiting_data_ids = [d["id"] for d in exiting_data]    

    analysis_agent = AnalysisAgent(llm=llm, router=router)
    construction_agent = ConstructionAgent(llm=llm, chunk_kb_path=chunk_kb_path, graph_kb_path=graph_kb_path, table_kb_path=table_kb_path, algorithm_kb_path=algorithm_kb_path, catalogue_kb_path=catalogue_kb_path)
    retrieval_agent = RetrievalAgent(llm=llm, chunk_kb_path=chunk_kb_path, graph_kb_path=graph_kb_path, table_kb_path=table_kb_path, algorithm_kb_path=algorithm_kb_path, catalogue_kb_path=catalogue_kb_path)
    merging_agent = MergingAgent(llm=llm)

    if args.weak_agent is None:
        print("No weak agent")
    else:
        weak_llm = QwenAPI(url=f"http://{args.weak_url}/v1/chat/completions")
        if args.weak_agent == 'analysis':
            analysis_agent = AnalysisAgent(llm=weak_llm, router=router)
            print("analysis agent is weak")
        elif args.weak_agent == 'construction':
            construction_agent = ConstructionAgent(llm=weak_llm, chunk_kb_path=chunk_kb_path, graph_kb_path=graph_kb_path, table_kb_path=table_kb_path, algorithm_kb_path=algorithm_kb_path, catalogue_kb_path=catalogue_kb_path)
            print("construction agent is weak")
        elif args.weak_agent == 'retrieval':
            retrieval_agent = RetrievalAgent(llm=weak_llm, chunk_kb_path=chunk_kb_path, graph_kb_path=graph_kb_path, table_kb_path=table_kb_path, algorithm_kb_path=algorithm_kb_path, catalogue_kb_path=catalogue_kb_path)
            print("retrieval agent is weak")
        elif args.weak_agent == 'merging':
            merging_agent = MergingAgent(llm=weak_llm)
            print("merging agent is weak")
        else:
            raise ValueError(f"args.weak_agent: {args.weak_agent} is not supported")

    for i, data in enumerate(eval_datas): # data: {"instruction": "", "question": "", "docs": "", "prompt_template": "{},{},{}"}
        if data["id"] in exiting_data_ids:
            print(f"################## Skipping {i}th data existing... ##################")
            continue
        print(f"################## Processing {i}th data... ##################")

        query = data['prompt_template'].format(instruction=data['instruction'], question=data['question'], docs="......")

        try:
            current_time = time.time()
            fw_intermediate = open(f"{intermediate_results_dir}/{data['id']}.jsonl", "w")

            iter_time = 0
            _, titles = construction_agent.split_content_and_tile(data['docs'])
            kb_info = "The titles of the docs are: " + "\n".join(list(set(titles)))
            while True: # loop for analysis-construction
                iter_time += 1
                print(f"analysis-construction iter_time: {iter_time}")

                subqueries, chosen = analysis_agent.analyze(query, kb_info, data_id=data['id'])
                
                fw_intermediate.write(json.dumps({"subqueries": subqueries, "chosen": chosen}, ensure_ascii=False) + "\n")
                fw_intermediate.flush()

                instruction, kb_info = construction_agent.construct(subqueries, chosen, data["docs"], data_id=data['id'])

                fw_intermediate.write(json.dumps({"instruction": instruction, "kb_info": kb_info}, ensure_ascii=False) + "\n")
                fw_intermediate.flush()

                if iter_time >= args.max_iter:
                    break

            iter_time = 0
            extra_instruction = None
            while True: # loop for retrieval-merging
                iter_time += 1
                print(f"retrieval-merging iter_time: {iter_time}")

                subknowledges = retrieval_agent.retrieve(query, subqueries, chosen, data_id=data['id'], extra_instruction=extra_instruction)

                fw_intermediate.write(json.dumps({"subknowledges": subknowledges}, ensure_ascii=False) + "\n")
                fw_intermediate.flush()

                answer, if_continue, extra_instruction = merging_agent.merge(query, subqueries, subknowledges, chosen, data_id=data['id'])

                fw_intermediate.write(json.dumps({"answer": answer, "if_continue": if_continue, "extra_instruction": extra_instruction}, ensure_ascii=False) + "\n")
                fw_intermediate.flush()

                if if_continue != "CONTINUE":
                    break
                
                if iter_time >= args.max_iter:
                    break
            
            used_time = (time.time() - current_time) / 60
            print(f"level:{data['level']},set:{data['set']},type:{data['type']}")
            print(f"used time: {used_time:.2f} min")

            data['generate_response'] = answer
            data['used_time'] = used_time
            fw.write(json.dumps(data, ensure_ascii=False) + "\n")
            fw.flush()

        except Exception as e:
            print(f"(print in main.py) Error: {e}")
            data['generate_response'] = "meet error"
            data['used_time'] = -100
            fw_error.write(json.dumps(data, ensure_ascii=False) + "\n")
            fw_error.flush()

    print("all done")