import json
import re
from pathlib import Path
from typing import Any, TypedDict

from browser_env import ActionParsingError, Trajectory
from browser_env.env_config import URL_MAPPINGS
from browser_env.utils import StateInfo
from llms import lm_config
from llms.tokenizers import Tokenizer
from llms.utils import APIInput


class Instruction(TypedDict):
    """Instruction for constructing prompt"""

    intro: str
    examples: list[tuple[str, str]]
    template: str
    meta_data: dict[str, Any]


class PromptConstructor(object):
    def __init__(
        self,
        instruction_path: str | Path,
        lm_config: lm_config.LMConfig,
        tokenizer: Tokenizer,
    ):
        self.instruction_path = Path(instruction_path)
        self.obs_modality = "text"
        self.lm_config = lm_config
        instruction = json.load(open(self.instruction_path))
        instruction["examples"] = [tuple(e) for e in instruction["examples"]]
        self.instruction: Instruction = instruction
        self.tokenizer = tokenizer

    def get_lm_api_input(
        self, intro: str, examples: list[tuple[str, str]], current: str
    ) -> APIInput:

        """Return the require format for an API"""
        message: list[dict[str, str]] | str
        if "openai" in self.lm_config.provider:
            if self.lm_config.mode == "chat":
                message = [{"role": "system", "content": intro}]
                for (x, y) in examples:
                    message.append(
                        {
                            "role": "system",
                            "name": "example_user",
                            "content": x,
                        }
                    )
                    message.append(
                        {
                            "role": "system",
                            "name": "example_assistant",
                            "content": y,
                        }
                    )
                message.append({"role": "user", "content": current})
                return message
            elif self.lm_config.mode == "completion":
                message = f"{intro}\n\n"
                message += "Here are a few examples:\n"
                for example in examples:
                    message += f"Observation\n:{example[0]}\n\n"
                    message += f"Action: {example[1]}\n\n"
                message += "Now make prediction given the observation\n\n"
                message += f"Observation\n:{current}\n\n"
                message += "Action:"
                return message
            else:
                raise ValueError(
                    f"OpenAI models do not support mode {self.lm_config.mode}"
                )
        elif "huggingface" in self.lm_config.provider:
            # https://huggingface.co/blog/llama2#how-to-prompt-llama-2
            # https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320
            if "Llama-2" in self.lm_config.model:
                if self.lm_config.mode == "chat":
                    B_INST, E_INST = "[INST]", "[/INST]"
                    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
                    BOS, EOS = "<s>", "</s>"
                    # adding the system message to be the starting of the first example
                    examples = [
                        (
                            B_SYS + intro + E_SYS + examples[0][0],
                            examples[0][1],
                        )
                    ] + examples[1:]
                    message = "".join(
                        [
                            f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}"
                            for (x, y) in examples
                        ]
                    )
                    # add the current observation
                    message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}"

                    return message
                else:
                    raise ValueError("Only chat mode is supported for Llama-2")
            else:
                raise ValueError(
                    f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}"
                )
        elif "ours" in self.lm_config.provider:
            message = f"{intro}\n\n"
            message += "Now make prediction given the observation\n\n"
            message += f"Observation\n:{current}\n\n"
            message += "Action:"
            return message
        else:
            raise NotImplementedError(
                f"Provider {self.lm_config.provider} not implemented"
            )

    def construct(
        self,
        trajectory: Trajectory,
        intent: str,
        meta_data: dict[str, Any] = {},
    ) -> APIInput:
        raise NotImplementedError

    def map_url_to_real(self, url: str) -> str:
        """Map the urls to their real world counterparts"""
        for i, j in URL_MAPPINGS.items():
            if i in url:
                url = url.replace(i, j)
        return url

    def map_url_to_local(self, url: str) -> str:
        """Map the urls to their local counterparts"""
        for i, j in URL_MAPPINGS.items():
            if j in url:
                url = url.replace(j, i)
            # https
            if j.replace("http", "https") in url:
                url = url.replace(j.replace("http", "https"), i)
        return url

    def _extract_action(self, response: str) -> str:
        raise NotImplementedError

    def extract_action(self, response: str) -> str:
        response = self._extract_action(response)
        response = self.map_url_to_local(response)
        return response


class DirectPromptConstructor(PromptConstructor):
    """The agent will direct predict the action"""

    def __init__(
        self,
        instruction_path: str | Path,
        lm_config: lm_config.LMConfig,
        tokenizer: Tokenizer,
    ):
        super().__init__(instruction_path, lm_config, tokenizer)

    def construct(
        self,
        trajectory: Trajectory,
        intent: str,
        meta_data: dict[str, Any] = {},
    ) -> APIInput:
        """Construct prompt given the trajectory"""
        intro = self.instruction["intro"]
        examples = self.instruction["examples"]
        template = self.instruction["template"]
        keywords = self.instruction["meta_data"]["keywords"]
        state_info: StateInfo = trajectory[-1]  # type: ignore[assignment]

        obs = state_info["observation"][self.obs_modality]
        max_obs_length = self.lm_config.gen_config["max_obs_length"]
        if max_obs_length:
            obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length])  # type: ignore[arg-type]

        page = state_info["info"]["page"]
        url = page.url
        previous_action_str = meta_data["action_history"][-1]

        # input x
        current = template.format(
            objective=intent,
            url=self.map_url_to_real(url),
            observation=obs,
            previous_action=previous_action_str,
        )

        # make sure all keywords are replaced
        assert all(["{k}" not in current for k in keywords])
        prompt = self.get_lm_api_input(intro, examples, current)
        return prompt

    def _extract_action(self, response: str) -> str:
        action_splitter = self.instruction["meta_data"]["action_splitter"]
        pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
        match = re.search(pattern, response)
        if match:
            return match.group(1).strip()
        else:
            raise ActionParsingError(
                f"Cannot parse action from response {response}"
            )


class CoTPromptConstructor(PromptConstructor):
    """The agent will perform step-by-step reasoning before the answer"""

    def __init__(
        self,
        instruction_path: str | Path,
        lm_config: lm_config.LMConfig,
        tokenizer: Tokenizer,
    ):
        super().__init__(instruction_path, lm_config, tokenizer)
        self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]

    def construct(
        self,
        trajectory: Trajectory,
        intent: str,
        meta_data: dict[str, Any] = {},
    ) -> APIInput:
        intro = self.instruction["intro"]
        examples = self.instruction["examples"]
        template = self.instruction["template"]
        keywords = self.instruction["meta_data"]["keywords"]
        state_info: StateInfo = trajectory[-1]  # type: ignore[assignment]

        obs = state_info["observation"][self.obs_modality]
        max_obs_length = self.lm_config.gen_config["max_obs_length"]
        if max_obs_length:
            obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length])  # type: ignore[arg-type]

        page = state_info["info"]["page"]
        url = page.url
        previous_action_str = meta_data["action_history"][-1]
        current = template.format(
            objective=intent,
            url=self.map_url_to_real(url),
            observation=obs,
            previous_action=previous_action_str,
        )

        assert all(["{k}" not in current for k in keywords])

        prompt = self.get_lm_api_input(intro, examples, current)
        return prompt

    def _extract_action(self, response: str) -> str:
        # find the first occurence of action
        action_splitter = self.instruction["meta_data"]["action_splitter"]
        pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
        match = re.search(pattern, response)
        if match:
            return match.group(1).strip()
        else:
            raise ActionParsingError(
                f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
            )

class MyPromptConstructor(PromptConstructor):
    """The agent will perform step-by-step reasoning before the answer"""
    operation = [
        r"#?(Click)#?\s*([0-9]+)",
        r"#?(Type)#?\s*([0-9]+)\s+[\'\"]{0,1}([\s\S]+)[\'\"]{0,1}",
        r"#?(Select)#?\s*([0-9]+)\s+[\'\"]{0,1}(.+)[\'\"]{0,1}",
        r"#?(Scroll_up)#?",
        r"#?(Scroll_down)#?",
        r"#?(Goto)#?\s*(https?:\/\/[-a-z0-9]+(?:\.[-a-z0-9]+)*\.(?:com|cn|edu|uk)(?:\/[-a-z0-9_:@&?=+,.!/~*'%$]*)?)",
        r"#?(Go_backward)#?",
        r"#?(Go_forward)#?",
        r"#?(Hover)#?\s*([0-9]+)",
        r"#?(Answer)#?\s+(.+)",
        r"#?(Login)#?",
        r"#?(Verify)#?",
        r"#?(Exit)#?",
        r"#?(Record)#?\s+[\'\"]{0,1}(.+)[\'\"]{0,1}",
    ]
    
    translate = [
        "click",
        "type",
        "select",
        "scroll [up]",
        "scroll [down]",
        "goto",
        "go_back",
        "go_forward",
        "hover",
        "stop",
        "stop",
        "stop",
        "stop",
        "record",
    ]

    def __init__(
        self,
        instruction_path: str | Path,
        lm_config: lm_config.LMConfig,
        tokenizer: Tokenizer,
    ):
        super().__init__(instruction_path, lm_config, tokenizer)
        self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
        self.state = {}

    def construct(
        self,
        trajectory: Trajectory,
        intent: str,
        meta_data: dict[str, Any] = {},
    ) -> APIInput:
        intro = self.instruction["intro"]
        examples = self.instruction["examples"]
        template = self.instruction["template"]
        keywords = self.instruction["meta_data"]["keywords"]
        finale = self.instruction["finale"]
        state_info: StateInfo = trajectory[-1]  # type: ignore[assignment]

        obs = state_info["observation"][self.obs_modality]
        max_obs_length = self.lm_config.gen_config["max_obs_length"]
        if max_obs_length:
            obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length])  # type: ignore[arg-type]

        info = state_info["info"]
        obs_metadata = info["observation_metadata"]["text"]
        nodes = obs_metadata["obs_nodes_info"]
        position_info = obs_metadata["position_info"]
        html_parser = obs_metadata["html_parser"]
        self.nodes = nodes
        
        page = info["page"]
        url = self.map_url_to_real(page.url)
        position_bar = self._get_position_bar(position_info)
        
        history = [f"{ix}. {his}" for his in meta_data["action_history"]]
        if len(history) == 1:
            previous_action_str = "None"
        else:
            previous_action_str = '\n'.join(history[1:])
            
        self.state.update({
            "url": url,
            "html": obs,
            "html_parser": html_parser,
            "segment": "None",
            "operation": "None",
        })
        
        current = template.format(
            objective=intent,
            url=url,
            html=obs,
            position=position_bar,
            previous_action=previous_action_str,
        )

        assert all(["{k}" not in current for k in keywords])

        # prompt = self.get_lm_api_input(intro, examples, current)
        prompt = current + finale
        
        return prompt

    def _extract_action(self, response: str) -> str:
        # find the first occurence of action
        self.state["intention"] = self._extract_intention(response)
        
        for regex, act in zip(self.operation, self.translate):
            match = re.search(regex, response)

            if match:
                m = match.groups()
                if isinstance(m, tuple):
                    exact_act = m[0]
                    param = m[1:]
                else:
                    exact_act = m
                    param = []
                
                param = list(param)
                if act in ['click', 'hover', 'type', 'select']:
                    if len(param) == 0:
                        continue
                    
                    for node_id, node in self.nodes.items():
                        if node['label'] == param[0]:
                            label = param[0]
                            hp = self.state["html_parser"]
                            bid = hp.id_label_converter(label)
                            segment = hp.get_segment(bid)
                            
                            print('[Label]', label, bid, segment)
                            self.state["segment"] = segment
                            #self._extract_segment(self.state["html"], label)
                            if act not in ['select']:
                                param[0] = node_id
                            break
                
                
                if act in ['stop', 'select', 'record']:
                    if len(param) > 0:
                        param[-1] = param[-1].strip("\'\"")
                        
                if act in ['type']:
                    print('In prompt constructer', param[-1])
                    if len(param) > 0:
                        param[-1] = param[-1].strip("\'\"")
                        print(param[-1])
                        if param[-1].endswith('\n'):
                            param[-1] = param[-1][:-1]
                            param.append('1')
                        else:
                            param.append('0')
                    
                command = act
                for p in param:
                    command += f" [{p}]"
                
                print(command)
                return command
            
        raise ActionParsingError(
            f'Cannot find the answer phrase in "{response}"'
        )
    
    @staticmethod
    def _get_position_bar(data):
        position = data.get("position", 0.0)
        page_height = data.get("page_height", 1.0)
        left_bar = '-' * int(position)
        right_bar = '-' * int(max(1, page_height - position))
        return f'[0{left_bar}|{round(position, 1)}{right_bar}{round(page_height, 1)}]'
    
    @staticmethod
    def _extract_intention(response, lang='en'):
        if lang == 'en':
            matches = re.findall(r"#Thinking Process:\s*(.+)\s*#Operation:", response)
            print('[Try to match]', matches)
        else:
            matches = re.findall(r"#思考过程: (.+)", response)

        if matches:
            return matches[-1]
        else:
            return None
    
    @staticmethod
    def _extract_segment(html: str, tag: str):
        tag = f'[{tag}]'
        has_content = False

        def _left(html, start):
            nonlocal has_content
            left_cnt, right_cnt = 0, 0
            for i in range(start, -1, -1):
                if html[i] == '<':
                    left_cnt += 1
                elif html[i] == '>':
                    if html[i - 2] != '|' and html[i - 2] != '>':
                        has_content = True
                    right_cnt += 1
                elif html[i] == '|':
                    if html[i + 2] != '<' and html[i + 2] != '>':
                        has_content = True
                if left_cnt == right_cnt + 1:
                    return i
            return -1
        
        def _right(html, start):
            nonlocal has_content
            left_cnt, right_cnt = 0, 0
            for i in range(start, len(html), 1):
                if html[i] == '<':
                    left_cnt += 1
                elif html[i] == '>':
                    if html[i - 2] != '|' and html[i - 2] != '>':
                        has_content = True
                    right_cnt += 1
                elif html[i] == '|':
                    if html[i + 2] != '<' and html[i + 2] != '>':
                        has_content = True
                if left_cnt + 1 == right_cnt:
                    return i + 1
            return -1
        
        tag_start = html.find(tag)

        if tag_start == -1:
            return None
        
        left_bound, right_bound = _left(html, tag_start), _right(html, tag_start)
        while True:
            if left_bound == -1 or right_bound == -1:
                return None

            if has_content:
                break

            else:
                lb, rb = _left(html, left_bound - 1), _right(html, right_bound + 1)
                if lb == -1 or rb == -1:
                    break
                if rb - lb > 150:
                    break
                else:
                    left_bound, right_bound = lb, rb

        segment = html[left_bound:right_bound]

        if len(segment) > 150:
            return segment[:150] + '...>'
        
        return segment
    
class NewASPromptConstructor(PromptConstructor):
    """The agent will perform step-by-step reasoning before the answer"""
    operation = [
        r"(click)\(\s*[\'\"]([A-Z]{1,3})[\'\"]\s*\)",
        r"(type_string)\(\s*[\'\"]([A-Z]{1,3})[\'\"]\s*,\s*[\'\"]([\s\S]+)[\'\"]\s*,\s*(True|False)\s*\)",
        r"(select)\(\s*[\'\"]([A-Z]{1,3})[\'\"]\s*,\s*[\'\"]([\s\S]+)[\'\"]\s*\)",
        r"(scroll_page)\(\s*[\'\"]up[\'\"]\s*\)",
        r"(scroll_page)\(\s*[\'\"]down[\'\"]\s*\)",
        r"(jump_to)\(\s*[\'\"](.+)[\'\"]\s*,\s*(True|False)\s*\)",
        r"(go)\(\s*[\'\"]backward[\'\"]\s*\)",
        r"(go)\(\s*[\'\"]forward[\'\"]\s*\)",
        r"(hover)\(\s*[\'\"]([A-Z]{1,3})[\'\"]\s*\)",
        r"(finish)\(\s*\)",
        r"(finish)\(\s*(.+)\s*\)",
        r"(record)\(\s*[\'\"](.+)[\'\"]\s*\)",
        r"(switch_tab)\([\d]+\)"
    ]
    
    translate = [
        "click",
        "type",
        "select",
        "scroll [up]",
        "scroll [down]",
        "goto",
        "go_back",
        "go_forward",
        "hover",
        "stop",
        "stop",
        "record",
        "page_focus",
    ]

    def __init__(
        self,
        instruction_path: str | Path,
        lm_config: lm_config.LMConfig,
        tokenizer: Tokenizer,
    ):
        super().__init__(instruction_path, lm_config, tokenizer)
        self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
        self.state = {}

    def construct(
        self,
        trajectory: Trajectory,
        intent: str,
        meta_data: dict[str, Any] = {},
    ) -> APIInput:
        intro = self.instruction["intro"]
        examples = self.instruction["examples"]
        template = self.instruction["template"]
        keywords = self.instruction["meta_data"]["keywords"]
        finale = self.instruction["finale"]
        state_info: StateInfo = trajectory[-1]  # type: ignore[assignment]

        obs = state_info["observation"][self.obs_modality]
        max_obs_length = self.lm_config.gen_config["max_obs_length"]
        if max_obs_length:
            obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length])  # type: ignore[arg-type]

        info = state_info["info"]
        obs_metadata = info["observation_metadata"]["text"]
        nodes = obs_metadata["obs_nodes_info"]
        position_info = obs_metadata["position_info"]
        html_parser = obs_metadata["html_parser"]
        tabs_str = obs_metadata["tab_title"]
        self.nodes = nodes
        
        page = info["page"]
        url = self.map_url_to_real(page.url)
        position_bar = self._get_position_bar(position_info)
        
        history = meta_data["action_history"]
        if len(history) == 1:
            previous_action_str = "None"
        else:
            previous_action_str = '\n'.join(history[1:])
            
        self.state.update({
            "url": url,
            "html": obs,
            "html_parser": html_parser,
            "segment": "None",
            "operation": "None",
        })
        
        current = template.format(
            objective=intent,
            url=url,
            html=obs,
            position=position_bar,
            previous_action=previous_action_str,
            tabs=tabs_str,
        )

        assert all(["{k}" not in current for k in keywords])

        # prompt = self.get_lm_api_input(intro, examples, current)
        prompt = current + finale
        
        return prompt

    def _extract_action(self, response: str) -> str:
        # find the first occurence of action
        # self.state["intention"] = self._extract_intention(response)
        
        for regex, act in zip(self.operation, self.translate):
            match = re.search(regex, response)
            if match:
                m = match.groups()
                if isinstance(m, tuple):
                    exact_act = m[0]
                    param = m[1:]
                else:
                    exact_act = m
                    param = []
                
                print(exact_act, param)
                param = list(param)
                if act in ['click', 'hover', 'type', 'select']:
                    if len(param) == 0:
                        continue
                    
                    for node_id, node in self.nodes.items():
                        if node['label'] == param[0]:
                            label = param[0]
                            hp = self.state["html_parser"]
                            bid = hp.id_label_converter(label)
                            segment = hp.get_segment(bid)
                            
                            print('[Label]', label, bid, segment)
                            self.state["segment"] = segment
                            #self._extract_segment(self.state["html"], label)
                            if act not in ['select']:
                                param[0] = node_id
                            break
                
                if len(param) > 0:
                    if act in ['stop', 'select', 'record']:
                        param[-1] = param[-1].strip("\'\"")
                    if act in ['type', 'goto']:
                        param[-1] = '1' if param[-1] == 'True' else '0'
                    
                command = act
                for p in param:
                    command += f" [{p}]"
                
                print(command)
                return command
            
        raise ActionParsingError(
            f'Cannot find the answer phrase in "{response}"'
        )
    
    @staticmethod
    def _get_position_bar(data):
        position = data.get("position", 0.0)
        page_height = data.get("page_height", 1.0)
        return f"{round(position, 1)} / {round(page_height, 1)}"