import traceback
import re
from dataclasses import asdict, dataclass, field
from warnings import warn
from langchain.schema import HumanMessage, SystemMessage
from jinja2 import Template

from browsergym.core.action.base import AbstractActionSet
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.utils.obs import flatten_dom_to_str, prune_html
from browsergym.experiments import Agent, AbstractAgentArgs
from benchmark.demo_agent.agents.legacy.utils.llm_utils import image_to_jpg_base64_url

from benchmark.webarenasafe.page_understanding.obs import postprocess_axtree_str, extract_pu_attributes, flatten_axtree_to_str, overlay_som
from benchmark.webarenasafe.page_understanding.observation import extract_dom_elements, extract_visual_hierarchy
from benchmark.webarenasafe.page_understanding.smart_parse import format_policies, remove_send_msg, extract_functions_with_docstrings

from agentS.consts import OPENENDED_TASK, WEBARENA_SAFE_TASK, WEBARENA_TASK, WORKARENA_TASK, DEFAULT_WEB, DEFAULT_ARCHITECTURE

from .utils.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_TEXT_ONLY, INIT_MESSAGE

from webarenasafe.agents.predefined_action_mapping import CODE_ACTIONS
from webarenasafe.agents.safe_legacy import dynamic_prompting
from agentS.utils import process_action_history

action_set = HighLevelActionSet(custom_actions=[], subsets=["bid", "chat"], strict=False,
                                multiaction=True, demo_mode='on')

@dataclass
class WebVoyagerArgs(AbstractAgentArgs):
    """
    This class is meant to store the arguments that define the agent.

    By isolating them in a dataclass, this ensures serialization without storing
    internal states of the agent.
    """

    model_name: str = "gpt-3.5-turbo"
    temperature: int = 1
    text_only: bool = False
    flags: dynamic_prompting.Flags = field(default_factory=lambda: dynamic_prompting.Flags())

    def make_agent(self):
        return WebVoyagerAgent(model_name=self.model_name,
                               temperature=self.temperature,
                               text_only=self.text_only, 
                               flags = self.flags,)


class WebVoyagerAgent(Agent):
    """TO BE FILLED"""

    def __init__(self, model_name, temperature, text_only, flags) -> None:
        super().__init__()
        self.model_name = model_name
        from openai import OpenAI
        self.openai_client = OpenAI()
        self.temperature = temperature
        self.text_only = text_only
        self.pattern = r"<action>(.*?)</action>"
        self.flags = flags if flags is not None else dynamic_prompting.Flags()
        self.action_set = dynamic_prompting._get_action_space(self.flags)

    def add_screenshot(self, prompt):
        if not self.text_only:
            if isinstance(prompt, str):
                prompt = [{"type": "text", "text": prompt}]
            img_url = image_to_jpg_base64_url(self.obs["screenshot_som"])
            prompt.append({"type": "image_url", "image_url": {"url": img_url}})

        return prompt


    def obs_preprocessor(self, obs: dict) -> dict:
        axtree_txt = flatten_axtree_to_str(obs["axtree_object"], extra_properties=obs["extra_element_properties"],
                                            filter_by_type=True, indent_base="  ", filter_with_bid_only=True)
        # axtree_txt = flatten_axtree_to_str(obs["axtree_object"])
        items_raw = None if not obs["nocodeui_pu"] else extract_pu_attributes(obs["nocodeui_pu"])

        return {
            "policies": obs['policies'],
            "goal": obs['goal'],
            "chat_messages": obs['chat_messages'],
            "extra_properties": obs["extra_element_properties"],
            "nocodeui_pu_items": items_raw,
            "axtree_txt": axtree_txt,
            "screenshot_som": overlay_som(obs["screenshot"], 
                                          extra_properties=obs["extra_element_properties"],
                                          ids_to_annotate=sorted(set(re.findall(r'\[(\w+)\]', axtree_txt))),
                                          dynamic=False,
                                          fontsize=20,
                                          keep_ids=True,
                                          linewidth=4,
                                          dash=(6,3)),
        }
    

    def get_action(self, obs: dict) -> tuple[str, dict]:
        #TODO all server
        self.obs = obs

        formatted_chat_messaged = '\n'.join(["{}: {}".format(o['role'], o['message']) for o in obs["chat_messages"]])
        
        action_space = self.action_set.describe(with_long_description=True, with_examples=True)
        # TODO change this to more elegant solution to get also the asnwer function
        # action_space = extract_functions_with_docstrings(CODE_ACTIONS)

        # Format policies
        organizational_policies, personal_policies = format_policies(obs['policies'])

        template = Template(SYSTEM_PROMPT_TEXT_ONLY if self.text_only else SYSTEM_PROMPT)

        # Render the system_msg with your dynamic variables
        system_msg = template.render(
            action_space=action_space,
            organizational_policies=organizational_policies if organizational_policies else None,
            personal_policies=personal_policies if personal_policies else None
        )

        prompt = INIT_MESSAGE.format(task_description=obs['goal'],
                                     chat_history=formatted_chat_messaged,
                                     vision_added_string="" if self.text_only else "and screenshot",
                                     AXTree=obs["axtree_txt"],
                                     )
        
        prompt = self.add_screenshot(prompt)

        # query OpenAI model
        response = self.openai_client.chat.completions.create(
            model=self.model_name.replace("openai/", ""),
            temperature=self.temperature,
            max_tokens=256,
            messages=[
                {"role": "system", "content": system_msg},
                {"role": "user", "content": prompt},
            ],
        ) 

        def extract_content(text):
            # Initialize variables for think and action parts
            thought_part = ""
            action_part = ""

            # Extract the <think> content
            if "<think>" in text and "</think>" in text:
                thought_part = text.split("<think>")[1].split("</think>")[0].strip()

            # Extract the <action> content
            if "<action>" in text and "</action>" in text:
                action_part = text.split("<action>")[1].split("</action>")[0].strip()

            return thought_part, action_part
        
        answer = response.choices[0].message.content
        thought, action = extract_content(answer)
        ans_dict = {'think': thought,
                    'action': action,
                    'n_retry': 0.0,
                    'chat_messages': [m['message'] for m in obs["chat_messages"]],
                    'chat_model_args': {'model_name': self.model_name, 'temperature': self.temperature}}
        print("LLM Output:\n", answer)
        return ans_dict["action"], ans_dict
    
    
    def agent_loop(self, env, obs, info, max_steps):
        print(f"Task goal: {obs['goal']}\n")

        pointer_env = self.get_pointer_env(env)

        # thread = {"configurable": {"thread_id": "1"}}

        loop_data = {
            'steps': [],
            'total_reward': 0,
            'terminated': False,
            'truncated': False,
            'agent_predictions': [],  # New field to store agent predictions
            'violated_policies': []
        }

        done = False
        violated_policies = []

        while not done:
            new_obs = self.obs_preprocessor(obs)
            action, action_description = self.get_action(new_obs)

            action_cleaned_from_send_msg = remove_send_msg(action_description)
            if action_cleaned_from_send_msg:
                env.chat.add_message('assistant', action_cleaned_from_send_msg)

            agent_prediction = {
                    'agent_name': 'Webvoyager',
                    'prediction': action_description
                }

            print("#" * 50)
            print(f"\nExecuting actions:\n\n{action}\n")
            print("#" * 50)

            match = re.match(r'answer\("(.+?)"\)', action)
            if match:
                # Extracts the answer
                answer = match.group(1)

                # id answer is a dict take the first value
                if isinstance(answer, dict):
                    answer = list(answer.values())[0]
                    
                print(f"Final answer:\n\n {answer}\n")
                done = True if self.args.env_id == OPENENDED_TASK else False  # We are done if it is an openended task and we have an answer since for other envs we need to validate the answer with env.step

            # Make step
            obs, reward, terminated, truncated, info = env.step(action)

            if "safety_report" in info:
                for policy_report in info['safety_report']:
                    if policy_report['violated']:
                        violated_policies.append(policy_report)

            # agent_prediction['actions_feedback'] = process_action_history(pointer_env.feedback)

            if len(loop_data['steps']) + 1 >= max_steps:
                truncated = True

            step_data = {
                'action': action,
                'observation': obs,
                'reward': reward,
                'terminated': terminated,
                'truncated': truncated,
                # 'info': info
            }
            loop_data['steps'].append(step_data)
            loop_data['total_reward'] += reward
            loop_data['terminated'] = terminated
            loop_data['truncated'] = truncated
            loop_data['task_info_gt'] = info if not 'task_info' in loop_data else loop_data[
                                                                                        'task_info_gt'] + info
            loop_data['violated_policies'] = violated_policies # Added for postprocssing performance

            done = terminated or truncated

            if done:
                print("Environment terminated or truncated.\n")
                print(f"Reward: {reward}, Terminated: {terminated}, Truncated: {truncated}\n")
                return loop_data

            if done:
                break

            loop_data['agent_predictions'].append(agent_prediction)  # update for each agent prediction
            print("-" * 50)

        print("Violated policies", violated_policies)
        print("Success under policy", len(violated_policies) == 0 * reward == 1.0)

        if done:
            print("Environment terminated or truncated.\n")
            print(f"Reward: {reward}, Terminated: {terminated}, Truncated: {truncated}\n")
            return loop_data

        print("Agent loop completed.\n")
        return loop_data
        

    @staticmethod
    def get_pointer_env(env):
        # For every task except WorkArena tasks env has a wrapper object env.env.env
        if hasattr(env, 'spec'):
            if env.spec.id.split('.')[0] in [OPENENDED_TASK, WEBARENA_TASK, WEBARENA_SAFE_TASK]:
                pointer_env = env.env.env
            else:
                pointer_env = env
        else:
            pointer_env = env

        return pointer_env

