import json


class AnalysisAgent():
    def __init__(self, llm, router):
        self.llm = llm
        self.router = router
        self.prepared_chosens = json.load(open("/root/multiagent_doc2graph/train_router/output_prediction/weak_model_prediction_ic.json", "r")) 

        if "predicted_decision" in self.prepared_chosens[0]:
            self.id_2_predicted_decision = {item["id"]: item["predicted_decision"] for item in self.prepared_chosens}
            # self.id_2_predicted_decision = {item["id"]: 'chunk' for item in self.prepared_chosens} # _all_chunk
            # print("all_chunk")
            # self.id_2_predicted_decision = {item["id"]: 'graph' for item in self.prepared_chosens} # _all_graph
            # print("all_graph")
            # self.id_2_predicted_decision = {item["id"]: 'table' for item in self.prepared_chosens} # _all_table
            # print("all_table")
            # self.id_2_predicted_decision = {item["id"]: 'algorithm' for item in self.prepared_chosens} # _all_algorithm
            # print("all_algorithm")
            self.id_2_predicted_decision = {item["id"]: 'catalogue' for item in self.prepared_chosens} # _all_catalogue
            print("all_catalogue")
        else:
            self.id_2_predicted_decision = None

    def analyze(self, query, kb_info, data_id):
        print(f"data_id: {data_id}, analyze...")

        if self.id_2_predicted_decision is not None:
            predicted_decision = self.id_2_predicted_decision[data_id]
            if "table" in predicted_decision:
                chosen = "table"
            elif "graph" in predicted_decision:
                chosen = "graph"
            elif "algorithm" in predicted_decision:
                chosen = "algorithm"
            elif "catalogue" in predicted_decision:
                chosen = "catalogue"
            else:   
                chosen = "chunk"
        else:
            chosen = self.do_route(query, kb_info, data_id)
        subqueries = self.do_decompose(query, kb_info, data_id) 

        return subqueries, chosen

    def do_route(self, query, kb_info, data_id):
        print(f"data_id: {data_id}, do_route...") 
        
        raw_prompt = open("agents/prompts/analysis_route.txt", "r").read()

        prompt = raw_prompt.format(
            query=query,
            titles=kb_info
        )
        output = self.llm.response(prompt) 

        if "table" in output.lower():
            chosen = "table"
        elif "graph" in output.lower():
            chosen = "graph"
        else:
            chosen = "chunk"

        return chosen

    def do_decompose(self, query, kb_info, data_id):
        print(f"data_id: {data_id}, do_decompose...")

        raw_prompt = open("agents/prompts/analysis_decompose.txt", "r").read()
        prompt = raw_prompt.format(
            query=query, 
            kb_info=kb_info
        )
        output = self.llm.response(prompt) 
        subqueries = output.split("\n")

        return subqueries