import requests
import re
import random

TOP_K = 3
SEARCH_URL = "http://127.0.0.1:8011/retrieve"
MAX_ITERATION = 6

########################################################
#  utils for search
########################################################
def batch_search(query):
    def search_tool(queries):
        payload = {
            "queries": queries,
            "topk": TOP_K,
            "return_scores": True
        }
        return requests.post(SEARCH_URL, json=payload).json()

    def passages2string(retrieval_result):
        format_reference = ''
        for idx, doc_item in enumerate(retrieval_result):   
            content = doc_item['document']['contents']
            title = content.split("\n")[0]
            text = "\n".join(content.split("\n")[1:])
            format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
        return format_reference
    
    results = search_tool([query])['result']
    return [passages2string(result) for result in results][0]


########################################################
#  utils for determining the action
########################################################
def act(response: str):
    if "<search>" in response and "</search>" in response:
        # regex to find the search query
        search_query = re.findall(r'<search>(.*?)</search>', response, re.DOTALL)
        # extract the search query string
        search_query = search_query[0].strip()
        search_results = batch_search(search_query)
        return {"type": "search", "content": search_results, "query": search_query}
    elif "<answer>" in response and "</answer>" in response:
        # regex to find the answer
        answer = re.findall(r'<answer>(.*?)</answer>', response, re.DOTALL)
        # extract the answer string
        answer = answer[0].strip()
        return {"type": "answer", "content": answer}
    else:
        return None

def extract_summary(response: str, summary_tag: str):
                    
    # regex to find the think part
    if f"<{summary_tag}>" in response and f"</{summary_tag}>" in response:
        pattern = f"<{summary_tag}>(.*?)</{summary_tag}>"
        summary = re.findall(pattern, response, re.DOTALL)
        # extract the think string
        summary = summary[0].strip()
        return f"<{summary_tag}>{summary}</{summary_tag}>"
    else:
        return None

def model_estimated_match(answer, golden_answer, question, llm_client):
    prompt = f"""
    Your goal is to determine if a model's answer answers the question based on the golden answer.
    The question is: {question}
    The model's answer is: {answer}
    The golden answer is: {golden_answer}
    Output your answer as 0 or 1, where 0 means the model's answer does not align with the golden answer and 1 means the model's answer aligns with the golden answer. Output only the number, no other text.
    """
    return int(llm_client.generate_response(prompt).strip())

########################################################
#  pipelines
########################################################
from abc import ABC, abstractmethod

class Pipeline(ABC):
    def __init__(self, llm_client):
        self.llm_client = llm_client

    @abstractmethod
    def run_llm_loop(self, prompt):
        pass


class ThinkTwoStepPipeline(Pipeline):
    def __init__(self, llm_client):
        super().__init__(llm_client)
        self.prompt_templates = ["""You will answer complex questions through iterative summary and web search. After each turn, your previous information will be discarded and the think part will be the only information you have to complete the task.
        So please store useful information in the think part.

Your response must include:

<think>
- Repeat previous searche queries made and think about how they lead to information potentially relevant to the question.
- Based on previous search queries and results, reason about how to adjust your search queries to get better results.
- Keep search queries and results here for future reference as everything else will be discarded.
</think>

Then either:
<search>
QUERY (only if you have turns left)
</search>

Or:
<answer>
FINAL ANSWER ONLY (no explanation)
</answer>

Follow this format strictly for your response so that it's either <think>...</think><search>...</search> or <think>...</think><answer>...</answer>.

Question: {question}\n
""", """You will answer complex questions through iterative summary and web search.

Your response must include:

<think>
- Keep information from the current information that is potentially relevant and useful for answering the question.
- The current information will be discarded in the next step and the think part will be the only information you have to complete the task.
- You should also summarize previous searches you have made to avoid repetitive searches.
- You will be told how many turns you have left inside the information given to you after you have made a search. You should keep track of the number of turns you have left.
</think>

Then either:
<search>
QUERY (only if you have turns left)
</search>

Or:
<answer>
FINAL ANSWER ONLY (no explanation)
</answer>

Follow this format strictly for your response so that it's either <think>...</think><search>...</search> or <think>...</think><answer>...</answer>.

Question: {question}\n"""]

    def run_llm_loop(self, prompt, model="openai/gpt-4o-mini"):
        # template = random.choice(self.prompt_templates)
        # prompt = template.format(question=prompt)
        cur_response = ""
        cur_obs = prompt
        iteration_cnt = 0
        # Initialize results tracking dictionary
        results_dict = {"q": prompt}

        while iteration_cnt < MAX_ITERATION:
            # make summary and update the observation
            cur_response = self.llm_client.generate_response(cur_obs, model=model)

            summary = extract_summary(cur_response, summary_tag="summary")
            memory = cur_obs[len(prompt):]
            if self.llm_client.has_memory and memory:
                self.llm_client.memory_system.add_note(memory)
            if summary:
                # Store summary in results dictionary
                results_dict[f"t{iteration_cnt}"] = summary
                cur_obs = prompt + summary
            else:
                results_dict[f"t{iteration_cnt}"] = ""
                cur_obs = prompt
            
            action_dict = act(cur_response)

            num_turns_left = MAX_ITERATION - iteration_cnt - 1
            if num_turns_left > 1:
                hint = f"[HINT]You have {num_turns_left} turns left.[/HINT]"
            else:
                hint = f"[HINT]You have {num_turns_left} turn left. You must answer the question now.[/HINT]"

            if action_dict is None:
                return None, results_dict
            elif action_dict["type"] == "search":
                search_results = action_dict["content"]
                search_results = f"<information>\n{hint}\n{search_results}\n</information>"
                # Store search query in results dictionary
                results_dict[f"r{iteration_cnt}"] = cur_response
                # Store information in results dictionary
                if iteration_cnt == MAX_ITERATION - 1:
                    results_dict[f"i{iteration_cnt}"] = ""
                else:
                    results_dict[f"i{iteration_cnt}"] = search_results
                next_obs = cur_obs + cur_response + search_results
            elif action_dict["type"] == "answer":
                # Store final answer in results dictionary
                results_dict[f"r{iteration_cnt}"] = cur_response
                return action_dict["content"], results_dict
            cur_obs = next_obs

            iteration_cnt += 1
        
        return None, results_dict


class ThinkSummaryOneStepPipeline(Pipeline):
    def __init__(self, llm_client):
        super().__init__(llm_client)
        self.prompt_template = None

    def run_llm_loop(self, prompt):
        pass


class ThinkSummaryTwoStepPipeline(Pipeline):
    def __init__(self, llm_client):
        super().__init__(llm_client)
        self.prompt_template = None

    def run_llm_loop(self, prompt):
        pass
