import json


class RetrievalAgent():
    def __init__(self, llm, chunk_kb_path, graph_kb_path, table_kb_path, algorithm_kb_path, catalogue_kb_path):
        self.llm = llm
        self.chunk_kb_path = chunk_kb_path
        self.graph_kb_path = graph_kb_path
        self.table_kb_path = table_kb_path
        self.algorithm_kb_path = algorithm_kb_path
        self.catalogue_kb_path = catalogue_kb_path

    def retrieve(self, query, subqueries, chosen, data_id, extra_instruction):
        print(f"data_id: {data_id}, retrieve...")

        if extra_instruction != None:
            subqueries = [subquery + extra_instruction for subquery in subqueries]
        
        if chosen == "chunk":
            subknowledges = self.do_retrieve_chunk(query, subqueries, data_id)
        elif chosen == "table":
            subknowledges = self.do_retrieve_table(query, subqueries, data_id)
        elif chosen == "graph":
            subknowledges = self.do_retrieve_graph(query, subqueries, data_id)
        elif chosen == "algorithm":
            subknowledges = self.do_retrieve_algorithm(query, subqueries, data_id)
        elif chosen == "catalogue":
            subknowledges = self.do_retrieve_catalogue(query, subqueries, data_id)
        else:
            raise ValueError("chosen should be in ['chunk', 'table', 'graph']")

        return subknowledges

    def do_retrieve_chunk(self, query, subqueries, data_id):
        chunks = json.load(open(f"{self.chunk_kb_path}/data_{data_id}.json"))

        composed_query = "\n".join(subqueries) 

        subknowledges = []
        for c, chunk in enumerate(chunks):
            print(f"retrieve chunk {c}/{len(chunks)} in chunks ..")

            prompt = f"Query:\n{composed_query}\n\nDocument:\n{chunk}\n\nOutput:"
            tmp_output = self.llm.response(prompt)
            title = chunk.split(":")[0]
            subknowledges.append(f"{title}{tmp_output}")

        return subknowledges   

    def do_retrieve_table(self, query, subqueries, data_id):
        print(f"data_id: {data_id}, do_retrieve_table...")

        tables = json.load(open(f"{self.table_kb_path}/data_{data_id}.json"))
        # tables_content = "\n\n".join(tables)
        tables_content = ""
        for t, table in enumerate(tables):
            tables_content += f"Table {t+1}:\n{table}\n\n"

        subknowledges = []
        for s, subquery in enumerate(subqueries):
            print(f"data_id: {data_id}, do_retrieve_table... in subquery {s}/{len(subqueries)} in subqueries ..")
            retrieval = self.llm.response(prompt)
            subknowledges.append(retrieval)

        return subknowledges
    
    def do_retrieve_graph(self, query, subqueries, data_id):
        print(f"data_id: {data_id}, do_retrieve_graph...")

        graphs = json.load(open(f"{self.graph_kb_path}/data_{data_id}.json"))
        graphs_content = "\n\n".join(graphs)

        subknowledges = []
        for s, subquery in enumerate(subqueries):
            print(f"data_id: {data_id}, do_retrieve_graph... in subquery {s}/{len(subqueries)} in subqueries ..")
            retrieval = self.llm.response(prompt)
            subknowledges.append(retrieval)

        return subknowledges

    def do_retrieve_algorithm(self, query, subqueries, data_id):
        print(f"data_id: {data_id}, do_retrieve_algorithm...")

        algorithms = json.load(open(f"{self.algorithm_kb_path}/data_{data_id}.json"))
        algorithms_content = "\n\n".join(algorithms)

        subknowledges = []
        for s, subquery in enumerate(subqueries):
            print(f"data_id: {data_id}, do_retrieve_algorithm... in subquery {s}/{len(subqueries)} in subqueries ..")
            retrieval = self.llm.response(prompt)
            subknowledges.append(retrieval)

        return subknowledges

    def do_retrieve_catalogue(self, query, subqueries, data_id):
        print(f"data_id: {data_id}, do_retrieve_catalogue...")

        catalogues = json.load(open(f"{self.catalogue_kb_path}/data_{data_id}.json"))
        catalogues_content = "\n\n".join(catalogues)

        subknowledges = []
        for s, subquery in enumerate(subqueries):
            print(f"data_id: {data_id}, do_retrieve_catalogue... in subquery {s}/{len(subqueries)} in subqueries ..")
            retrieval = self.llm.response(prompt)
            subknowledges.append(retrieval)

        return subknowledges