from agents.base_agent import BaseAgent_openai
from prompts.rag_agent_prompts import ARXIV_SYSTEM_PROMPT, IDEAS_SYSTEM_PROMPT, CALLING_FUNCTION_PROMPT
from utils.extractors import *
import random


class RagAgentAPI(BaseAgent_openai):
    def __init__(self, config, agent_name, debug_logger,
                 main_logger, sub_logger,
                 checker=None, rag_agent=None):
        super().__init__(config, agent_name, debug_logger, main_logger, sub_logger, checker, rag_agent)

        self.pool_ideas = []
        self.pool_size = config['number_rag_ideas']

    def get_arxiv_queries(self, n=1):
        self.clear_context()
        with open(self.config['background_data_path'], 'r') as f:
            background_info = '\n'.join(f.readlines())

        example = ";".join([f"query{i + 1}" for i in range(n)])
        query_suffix = "y" if n == 1 else "ies"

        self.instructions = ARXIV_SYSTEM_PROMPT.format(
            n,
            query_suffix,
            example,
            background_info
        )

        response = self.generate_response('')
        if response:
            try:
                response = response.split("<start>")[1].split("<end>")[0]
            except IndexError:
                pass
            return response.split(";")
        else:
            return response

    def generate_pool_ideas(self, raw_texts):
        self.clear_context()
        self.instructions = IDEAS_SYSTEM_PROMPT
        pool_ideas = []
        for raw_text in raw_texts:
            full_prompt = f"{self.instructions.strip()}\n\nInput:\n{raw_text.strip()}"
            response = self.generate_response(full_prompt)

            ideas = []
            for chunk in response.split("<start>"):
                chunk = chunk.strip()
                if "<end>" in chunk:
                    idea = chunk.split("<end>")[0].strip()
                    if idea:
                        ideas.append(idea)
            pool_ideas.extend(ideas)
        return pool_ideas

    def retrieve_rag_ideas(self, number_ideas=5):
        number_ideas = min(self.pool_size, number_ideas)
        if number_ideas == 0:
            return ''

        ideas = random.sample(self.pool_ideas, number_ideas)
        return '\n'.join(ideas)

    def function_calling(self, background_data, current_task,
                         previous_ideas):
        self.clear_context()
        self.instructions = CALLING_FUNCTION_PROMPT.format(
            background_data=background_data,
            current_task=current_task,
            previous_ideas=previous_ideas
        )

        response = self.generate_response('')
        rag_necessary = extract_json(response)

        try:
            if rag_necessary is None or rag_necessary['function_calling'] == False:
                return False
            else:
                return True
        except Exception as err:
            print(err)
            return False
