""" 
    This file contains the code for calling all LLM APIs. 
    Adapted from MLAgentBench/LLM.py
    Only support OpenAI API and open-source models for now.
"""

import os
import re
import sys
import torch
import backoff
import openai

import json
import time
import requests
import tiktoken
import matplotlib.pyplot as plt
import networkx as nx

from requests.exceptions import Timeout, ConnectionError
from verl.utils.conversation import get_conv_template, Conversation
import envs.MLAgentBench.high_level_actions as high_level_actions
from envs.MLAgentBench.schema import Action


HUMAN_PROMPT = "\n\nHuman:"
AI_PROMPT = "\n\nAssistant:"
FAST_MODEL = "gpt-4o-mini"

initial_prompt = """You are a helpful research assistant. You have access to the following tools:
{tools_prompt}

Research Problem: {task_description}

Always respond in this format exactly:
{format_prompt}
Observation: 
```
the result of the action
```

"""

format_prompt_dict = {
    "Reflection": "What does the observation mean? If there is an error, what caused the error and how to debug?",
    "Research Plan and Status": "The full high level research plan, with current status and confirmed results of each step briefly annotated. It must only include progress that has been made by previous steps. If there is any update, enclose the new update text in double asterisks **like this**. If there is no update, just copy the previous step Research Plan and Status. The high level plan from the previous step should be fully retained, unless it is intentionally revised.",
    "Fact Check": "List all objective statements in the updates to Research Plan and Status one by one and point out whether it is guessed versus directly confirmed by the previous observation directly above. Performance numbers can only be confirmed by running the code and observing the output.",
    "Thought": "What you are currently doing, what actions to perform and why",
    "Action": "the action to take, should be one of the names of the tools",
    "Action Input": "the input to the action as a valid JSON string",
}

def _add_to_set(s, new_stop):
    if not s:
        return
    if isinstance(s, str):
        new_stop.add(s)
    else:
        new_stop.update(s)


class Agent:
    """This agent is a test agent, which does nothing. (return empty string for each action)"""

    def __init__(self, args, env) -> None:
        self.args = args
        self.log_dir = os.path.join(args.log_dir, "agent_log")
        os.makedirs(self.log_dir, exist_ok=True)
        self.action_infos = env.action_infos
        self.valid_format_entries = args.valid_format_entries # Do not work as the sft
        # self.valid_format_entries =self.valid_format_entries = ["Reflection", "Research Plan and Status", "Fact Check", "Thought", "Action", "Action Input"]

        tool_names = list(env.action_infos.keys())
        actions_remove_from_prompt = ["Reflection", "Undo Edit Script", "Read File", "Write File", "Append File", "Retrieval from Research Log", "Append Summary to Research Log", "Python REPL", "Edit Script Segment (AI)"]
        for t in actions_remove_from_prompt:
            # remove tool name but in case of missing tool name, don't crash
            try:
                tool_names.remove(t)
            except:
                pass
        self.prompt_tool_names = tool_names
        high_level_actions.EDIT_SCRIPT_MODEL = args.edit_script_llm_name
        high_level_actions.EDIT_SCRIPT_MAX_TOKENS = args.edit_script_llm_max_tokens
        global FAST_MODEL
        FAST_MODEL = args.edit_script_llm_name
        self.tools_prompt = self.construct_tools_prompt(tool_names, env.action_infos)

        self.initial_prompt = initial_prompt.format(
            tools_prompt=self.tools_prompt, tool_names=self.prompt_tool_names, task_description=env.research_problem, 
            format_prompt="\n".join([f"{k}: {format_prompt_dict[k]}" for k in self.valid_format_entries])
        )       
        self.template_name = args.template_name


    def set_stop(self, stop_str, stop_token_id, stop_words):
        self.stop_str = stop_str
        self.stop_token_ids = stop_token_id
        self.stop_words = stop_words
    
    def run_no_reward(self, env, sampler):
        assert self.args.n_beam == 1 and self.args.num_sampling_sequences == 1
        with open(os.path.join(self.log_dir , "main_log"), "a", 1) as f:
            f.write(self.initial_prompt + "\n")
        
        prompt = self.initial_prompt
        conv = get_conv_template(self.template_name)
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)

        self.set_stop(conv.stop_str, conv.stop_token_ids, ["Observation:"]) # Following MLAgentBench stop words
        
        while not env.check_final():
            with open(os.path.join(self.log_dir , "main_log"), "a", 1) as f:
                f.write(f"==================== Step {env.current_step} ====================")
            print(f"==================== Step {env.current_step} ====================")
            # ===== generate single valid action ===== #
            completion = sampler._sample_step(
                conv=conv,
                agent=self,
                max_new_tokens=2048,  # maximal length for a single step
            )

            env.current_step += 1

            # ===== parse LLM output to env actions ===== #
            entries = self.parse_entries(completion, self.valid_format_entries)
            action = entries["Action"].strip()
            assert action in self.prompt_tool_names
            raw_action_input = entries["Action Input"]

            try:
                action_input = self.parse_action_input(raw_action_input, self.action_infos[action])
            except:
                action_input = raw_action_input
            
            # Add format related warnings
            notice, _ = self.get_format_reward(completion, self.valid_format_entries, return_notice=True)
            observation = env.execute(Action(action, action_input))

            conv.update_last_message(completion)
            conv.append_message(conv.roles[0], "Observation: " + observation + notice)
            conv.append_message(conv.roles[1], None)
            
            with open(os.path.join(self.log_dir , "main_log"), "a", 1) as f:
                f.write(f"\n{completion}\nObservation: {observation}\n")
        
        print("********************* Finished! *********************")
    
    @torch.no_grad()
    def generate(self, actor, conv, max_new_tokens) -> str:
        generate_args = {
            "input_ids": actor.tokenizer(conv.get_prompt(), return_tensors="pt").input_ids.to(actor.model.device),
            "top_k": None,
            "top_p": 1,
            "do_sample": True,
            "temperature": 1,
            "use_cache": True,
            "eos_token_id": conv.stop_token_ids[-1],
            "pad_token_id": actor.tokenizer.pad_token_id,
            "min_new_tokens": 1,
            "max_new_tokens": max_new_tokens,
        }

        output_token = actor.model.generate(**generate_args)[0][len(generate_args["input_ids"][0]):-1] 
        output_decoded = actor.tokenizer.decode(output_token, skip_special_tokens=False)
        output_decoded = output_decoded.split(self.stop_words[0])[0]
        return output_decoded

    def __call__(self, messages, max_new_tokens) -> str:
        controller_addr = self.args.controller_address
        new_stop = set()
        _add_to_set(self.stop_words, new_stop)
        _add_to_set(self.stop_str, new_stop)
        print(f">> New stop: {new_stop}")
        for i in range(3):
            try:
                if self.args.generator_model_name_or_path.startswith("gpt"):
                    client = OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url="")
                elif self.args.generator_model_name_or_path.startswith("deepseek"):
                    client = OpenAI(base_url="", api_key="empty")
                else: # local deployed model
                    client = OpenAI(base_url=f"{controller_addr}/v1", api_key="empty")
                
                if "deepseek-r1" in self.args.generator_model_name_or_path.lower():
                    # stream
                    completion = client.chat.completions.create(
                        model=self.args.generator_model_name_or_path,
                        messages=messages,
                        temperature=0.7,
                        max_tokens=max_new_tokens,
                        stream=True,
                    )
                    text = ""
                    for chunk in completion:
                        if hasattr(chunk.choices[0], "delta"):
                            if chunk.choices[0].delta.content is not None:
                                content = chunk.choices[0].delta.content
                                text += content
                        else:
                            content = chunk.choices[0].text
                            text += content
                        
                        if "</think>" in text and "Observation:" in text:
                            text_tmp = text.split("</think>")[1]
                            if "Observation:" in text_tmp:
                                text = text_tmp.split("Observation:")[0]
                                return text
                else:
                    completion = client.chat.completions.create(
                        model=self.args.generator_model_name_or_path,
                        messages=messages,
                        temperature=0.7,
                        max_tokens=max_new_tokens,
                        stop=list(new_stop)
                    )
                    text = completion.choices[0].message.content
                return text
            # if timeout or connection error, retry
            except Timeout:
                print("Timeout, retrying...")
            except ConnectionError:
                print("Connection error, retrying...")
            time.sleep(5)
        else:
            raise Exception("Timeout after 3 retries.")

    def tree_plot(self, num_beam, num_sqs):
        # Load the JSONL data
        data_path = os.path.join(self.args.log_dir, "agent_log/beam_tree_log.jsonl")
        jsonl_data = []
        with open(data_path, 'r') as f:
            for line in f:
                jsonl_data.append(json.loads(line))

        # Create a directed graph
        G = nx.DiGraph()
        choice_colors = ["#cccccc", "#fdcdac"]

        # Add nodes and edges based on JSONL
        for step_data in jsonl_data:
            step = step_data["step"]
            outputs = step_data["outputs"]
            verifier_scores = step_data["verifier_scores"]
            for i, output in enumerate(outputs):
                entries = self.parse_entries(output, self.valid_format_entries)
                node_name = f"Step {step}: Output {i + 1}"
                G.add_node(node_name, label=entries["Action"]+f"\n{verifier_scores[i]:.2f}", action=entries["Action"], choice=(i in step_data["choices"]))

                if step > 1:  # Add edges from previous step's chosen outputs
                    parent_idx = choices[i // num_sqs]  # Each parent has two children
                    parent_node_name = f"Step {step - 1}: Output {parent_idx + 1}"
                    G.add_edge(parent_node_name, node_name, weight=2)
            choices = step_data["choices"]


        # Calculate vertical positions for nodes
        layered_nodes = {}
        for node in G.nodes:
            step, output = map(int, re.findall(r'\d+', node))
            if step not in layered_nodes:
                layered_nodes[step] = []
            layered_nodes[step].append(node)

        # Assign positions for a vertical layout
        pos = {}
        x_spacing = 1
        y_spacing = 2
        for step, nodes in layered_nodes.items():
            for i, node in enumerate(nodes):
                pos[node] = (i * x_spacing, -step * y_spacing)

        # Get node colors
        node_colors = [choice_colors[int((G.nodes[node]["choice"]))] for node in G.nodes]

        # Draw the graph
        plt.figure(figsize=(5, 11))
        nx.draw(
            G,
            pos,
            with_labels=True,
            labels=nx.get_node_attributes(G, "label"),
            node_size=2000,
            font_size=10,
            node_color=node_colors,
            edge_color="gray",
            arrows=True
        )

        plt.title("Beam Search Tree", fontsize=16, weight="bold")
        plt.savefig(f"{self.args.log_dir}/beam_search_tree.png", bbox_inches="tight")
        plt.close()

    # ############ Helper Functions #############
    def check_repeated_entries(self, s, entries):
        """Helper function to check if any entry appears multiple times in the string."""
        repeated_entries = []
        for entry in entries:
            entry_pattern = r"(?=" + re.escape(entry) + r":)"
            matches = re.findall(entry_pattern, s)
            if len(matches) > 1:
                repeated_entries.append(entry)
        if repeated_entries:
            return True, repeated_entries
        return False, None

    def get_format_reward(self, completion, valid_format_entries, return_notice=False):
        """
            The response should not too long and the entries can be uniquely identified.
        """
        notice = ""
        correct_format = False
        # too many repeated entries
        has_repeated, repeated_entry_list = self.check_repeated_entries(completion, valid_format_entries)
        if has_repeated and return_notice:
            repeated_entry_str = ", ".join(repeated_entry_list)
            notice += f"Note that your response contains duplicate {repeated_entry_str}. This may lead to unexpected or incorrect execution of the action. Please ensure each entry is unique.\n"
        # incorrect action and action input format
        try:
            parsed_entries = self.parse_entries(completion, valid_format_entries)
            action = parsed_entries["Action"].strip()
            assert action in self.prompt_tool_names
            raw_action_input = parsed_entries["Action Input"]
            action_input = self.parse_action_input(raw_action_input, self.action_infos[action])
            correct_format = True
        except ParseEntryError as e:
            if return_notice:
                entry_usage = "\n".join([f"{k}: {format_prompt_dict[k]}" for k in self.valid_format_entries])
                notice += f"Your response format is incorrect. Please ensure that the response follows the correct format.\n{entry_usage}"
        except AssertionError:
            if return_notice:
                notice += f"The action {action} is not recognized. Please use one of the following actions: {', '.join(self.prompt_tool_names)}.\n"
        except ParseActionInputError as e:
            if return_notice:
                usage = ",\n            ".join([f"{k}: [{v}]" for k, v in self.action_infos[action].usage.items()])
                notice += f"The action input for {action} needs to be a valid json with proper entries. You may have missed the comma between entries or used triple quotes (json does not recognizes triple quotes). Please use the correct format and try again:\n{usage}"
        except Exception as e:
            if return_notice:
                notice += f"An error occurred while parsing the response. Please ensure that the response follows the correct format.\n"
        
        format_reward = 0 if correct_format else -1
        if return_notice:
            return notice, format_reward
        return format_reward
    


    @staticmethod
    def construct_tool_prompt(tool_name, action_info):
        """ Construct the prompt for a single tool."""
        tool = action_info
        usage = ",\n            ".join([f"\"{k}\": [{v}]" for k, v in tool.usage.items()])

        tools_prompt = f"""{tool.description}
        Usage:
        ```
        Action: {tool_name}
        Action Input: {{
            {usage}
        }}
        Observation: [{tool.return_value}]
        ```
            """.strip() + "\n\n"
        return tools_prompt

    @classmethod
    def construct_tools_prompt(cls, tool_names, action_infos):
        """ Construct the prompt for all tools."""
        tools_prompt = ""
        for tool_name in tool_names:
            tools_prompt += f"""- {tool_name}:
        """
            tools_prompt += cls.construct_tool_prompt(tool_name, action_infos[tool_name])
        return tools_prompt

    @staticmethod
    def parse_entries(s, entries):
        """ Parse the entries from the string generated by LLM using regex."""
        entries = [ e.strip() for e in entries]
        pattern = ""

        all_entries_lookahead = "(?=" + "|".join([ "^" + e.replace("[", "\[").replace("]", "\]") + ":" for e in entries]) + "|\\Z)"

        for idx, e in enumerate(entries):
            e_escaped = e.replace("[", "\[").replace("]", "\]")
            pattern += f"{e_escaped}:([\s\S]*?){all_entries_lookahead}"
        
        result = re.search(pattern, s, re.MULTILINE)
        if result is None:
            raise ParseEntryError("Invalid: " + s)

        parsed = [r for r in result.groups()]
        return {e: parsed[idx]  for idx, e in enumerate(entries)}
    
    @staticmethod
    def parse_entries_old(s, entries):
        """ Parse the entries from the string generated by LLM using regex."""
        entries = [ e.strip() for e in entries]
        pattern = ""
        for e in entries:
            e = e.replace("[", "\[").replace("]", "\]")
            pattern += f"{e}:([\s\S]*)"
        result = re.search(pattern, s, re.MULTILINE)
        if result is None:
            raise Exception("Invalid: " + s)

        parsed = [r for r in result.groups()]
        return {e: parsed[idx]  for idx, e in enumerate(entries)}
    
    @staticmethod
    def sanitize_json_string(s):
        """ Try to sanitize a string to be a valid JSON string."""
        s = s.strip("```json").strip("```").strip()
        s = s.replace('\\', '\\\\')  # Escape backslashes first
        s = s.replace('/', '\\/')  # Escape forward slashes
        s = s.replace('\b', '\\b')  # Escape backspaces
        s = s.replace('\f', '\\f')  # Escape form feeds
        s = s.replace('\r', '\\r')  # Escape carriage returns
        s = s.replace('\t', '\\t')  # Escape horizontal tabs
        # triple quotes are a problem
        return re.sub(r'"([^"]*)"', lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\"', '\\"') + '"', s)

    @classmethod
    def parse_action_input(cls, s, action_info):
        """ Parse the action input from a string to a dictionary using different methods."""
        try:
            try:
                d = json.loads(s)
            except:
                # try to sanitize the string
                s = cls.sanitize_json_string(s)
                d = json.loads(s)
            
            required_keys = set(action_info.usage.keys())
            if set(d.keys()) != required_keys:
                missing = required_keys - set(d.keys())
                extra = set(d.keys()) - required_keys
                error_msg = f"Key mismatch. Missing: {missing}, Extra: {extra}"
                raise ParseActionInputError(error_msg)
            return d
        except Exception as e:
            try:
                # as a fallback, try to match the string with regex
                return cls.parse_action_input_by_matching(s, action_info)
            except:
                raise ParseActionInputError("Invalid Format")

    @staticmethod
    def parse_action_input_by_matching(s, action_info):
        """ Parse the action input from a string to a dictionary using regex."""
        entries = list(action_info.usage.keys())
        index = s.find('{')
        s = s[index + 1:]
        index = s.rfind('}')
        s = s[:index]
        pattern = ""
        for e in entries:
            pattern += f'"{e}":([\s\S]*),\s*'
        pattern = pattern[:-4]
        result = re.search(pattern, s, re.MULTILINE)

        if result is None:
            raise Exception("Invalid Format")
        result = { e: r.strip().strip('\"') for e, r in zip(entries, result.groups())}
        return result


def log_to_file(log_file, prompt, completion, model, max_tokens_to_sample):
    """ Log the prompt and completion to a file."""
    try:
        enc = tiktoken.encoding_for_model(model)
    except:
        enc = tiktoken.encoding_for_model("gpt-4o-mini") # for local model, this is not correct.

    with open(log_file, "a") as f:
        f.write("\n===================prompt=====================\n")
        f.write(f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}")
        num_prompt_tokens = len(enc.encode(f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}"))
        f.write(f"\n==================={model} response ({max_tokens_to_sample})=====================\n")
        f.write(completion)
        num_sample_tokens = len(enc.encode(completion))
        f.write("\n===================tokens=====================\n")
        f.write(f"Number of prompt tokens: {num_prompt_tokens}\n")
        f.write(f"Number of sampled tokens: {num_sample_tokens}\n")
        f.write("\n\n")

@backoff.on_exception(backoff.expo, openai.RateLimitError)
def complete_text_openai(prompt, stop_sequences=[], model="gpt-3.5-turbo", max_tokens_to_sample=2000, temperature=0.2, log_file=None, **kwargs):
    """ Call the OpenAI API to complete a prompt."""
    raw_request = {
          "model": model,
          "temperature": temperature,
          "max_tokens": max_tokens_to_sample,
          "stop": stop_sequences or None,  # API doesn't like empty list
          **kwargs
    }
    try:
        messages = [{"role": "user", "content": prompt}]
        if openai.__version__ < '1.0.0':
            response = openai.ChatCompletion.create(**{"messages": messages,**raw_request})
            completion = response["choices"][0]["message"]["content"]
        else:
            response = client.chat.completions.create(**{"messages": messages,**raw_request})
            completion = response.choices[0].message.content
    except Exception as e:
        print(e)

    if log_file is not None:
        log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
    return completion

def complete_text_local(prompt, stop_sequences=[], model="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens_to_sample=2000, temperature=0.2, log_file=None, **kwargs):
    """ Call the local model to complete a prompt."""
    client = OpenAI(base_url="", api_key="empty")
    raw_request = {
          "model": FAST_MODEL,
          "temperature": temperature,
          "max_tokens": max_tokens_to_sample,
          "stop": stop_sequences or None,  # API doesn't like empty list
          **kwargs
    }
    try:
        messages = [{"role": "user", "content": prompt}]
        response = client.chat.completions.create(**{"messages": messages,**raw_request})
        completion = response.choices[0].message.content
    except Exception as e:
        print(e)

    if log_file is not None:
        log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
    return completion

def complete_text(prompt, log_file, model, **kwargs):
    """ Complete text using the specified model with appropriate API. """
    
    if "gpt" in model.lower():
        # use OpenAI API
        completion = complete_text_openai(prompt, stop_sequences=["Observation:"], log_file=log_file, model=model, **kwargs)
    elif "Qwen2.5-Coder-32B-Instruct" in model:
        print("========================== Use Qwen ==========================")
        # use local model qwen
        completion = complete_text_local(prompt, stop_sequences=["Observation:", "<|im_end|>"], log_file=log_file, model=model, **kwargs)
    elif "deepseek" in model:
        completion = complete_text_openai(prompt, stop_sequences=["Observation:"], log_file=log_file, model=model, **kwargs)
    else: # not supported
        raise NotImplementedError
    return completion

def complete_text_fast(prompt, **kwargs):
    return complete_text(prompt = prompt, model = FAST_MODEL, temperature =0.2, **kwargs)


class ParseEntryError(Exception):
    def __init__(self, error_info):
        super().__init__(error_info)
        self.error_info = error_info

    def __str__(self):
        return str(self.error_info)

class ParseActionInputError(Exception):
    def __init__(self, error_info):
        super().__init__(error_info)
        self.error_info = error_info

    def __str__(self):
        return str(self.error_info)
