# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for miniwob environments."""

import collections
import re

import gin
import numpy as np

from a2perf.domains.web_navigation.gwob.CoDE.vocabulary_utils import tokenize
from a2perf.domains.web_navigation.gwob.miniwob_plusplus.python.miniwob import action
from a2perf.domains.web_navigation.gwob.miniwob_plusplus.python.miniwob import environment
from a2perf.domains.web_navigation.gwob.miniwob_plusplus.python.miniwob import reward
from a2perf.domains.web_navigation.gwob.miniwob_plusplus.python.miniwob import state

# See https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/text
INPUT_ELEMENTS = (
    "input_button",
    "input_checkbox",
    "input_color",
    "input_date",
    "input_datetime",
    "input_datetime-local",
    "input_email",
    "input_file",
    "input_hidden",
    "input_image",
    "input_month",
    "input_number",
    "input_password",
    "input_radio",
    "input_range",
    "input_reset",
    "input_search",
    "input_submit",
    "input_tel",
    "input_text",
    "input_time",
    "input_url",
    "input_week",
    "input_password",
)

COMMON_ELEMENTS = ("button", "label", "select")
ADDITIONAL_ELEMENTS = ("a", "img", "div", "span")


def create_environment(subdomain, base_url, random_state, browser_args):
    """Create a miniwob environment with subdomain.

    Args:
      subdomain: Name of the environment.
      base_url: Base url that shows the location of the envionment resources.
      browser_args: Additional arguments that are passed to environment
        configuration.

    Returns:
      A MiniWoBEnvironment instance for interacting with websites.
    """

    env = environment.MiniWoBEnvironment(subdomain)
    env.configure(
        num_instances=1,
        seeds=[random_state.randint(0, 1000)],
        base_url=base_url,
        reward_processor=reward.get_raw_reward,
        **browser_args)
    return env


def create_element_click(element, fail_hard=True):
    """Create dom element click action.

      This allows gin configuration for failure case.

    Args:
      element: A web page element.
      fail_hard: If true, it throw an exception if the click fails.

    Returns:
      A MiniWoBElementClick action.
    """
    return action.MiniWoBElementClick(
        element, fail_hard=fail_hard, handle_select_in_selenium=False)


def create_dom_graph(observation, prune_refs=None):
    """Return dom nodes and graph from a given observation.

      Create a list of (senders, receivers) as the DOM structure representation.

    Args:
      observation: Observation from web environment.
      prune_refs: If given, these will be used to prune elements based on their
        ref attribute. This attribute is miniwob specific.

    Returns:
      A tuple of (dom elements, sender elements, receiver elements) that
      corresponds to nodes and edges of DOM graph.
    """
    dom_elems = observation.dom_elements
    if prune_refs:
        dom_elems = [elem for elem in dom_elems if elem.ref in prune_refs]
    senders, receivers = [], []
    for element in dom_elems:
        senders.append(element.ref)
        for child in element.children:
            receivers.append(child.ref)
    return dom_elems, senders, receivers


@gin.configurable
def order_dom_elements(dom_elements, only_leaf=False, html_id_prefix=None):
    """Order dom elements based on their ref.

      This is used to have a consistent order on DOM elements when selecting an
      element using an index.
    Args:
      dom_elements: A sequence of elements from DOM tree.
      only_leaf: If true, return only leaf nodes.
      html_id_prefix: If given, this will be used to select DOM elements based on
        their html ids. If html id matches the prefix, that DOM element will be
        included.

    Returns:
      A sequence of DOM elements that are sorted based on their ref.
    """
    dom_elems = []
    if html_id_prefix:
        visited = set()
        matching_elements = set()

        def traverse_and_label(element):
            # 'ref' is an attribute generated by MiniWoB. It allows us to point to
            # any element uniquely within an episode. This is not DOM or HTML related
            # feature and only specific to MiniWoB.
            if element.ref in visited:
                return element.ref in matching_elements
            matches = False
            for child in element.children:
                matches = traverse_and_label(child) | matches
            visited.add(element.ref)
            if not matches and element.id and element.id.startswith(html_id_prefix):
                matching_elements.add(element.ref)
                return True
            return False

        for elem in dom_elements:
            matches = traverse_and_label(elem)
            if matches and elem.ref and (elem.is_leaf or not only_leaf):
                dom_elems.append(elem)
    else:
        for elem in dom_elements:
            if elem.ref and (elem.is_leaf or not only_leaf):
                dom_elems.append(elem)
    return sorted(dom_elems, key=lambda x: x.ref)


@gin.configurable
def get_dom_elements(observation, prune_refs=None, prune_form_input=False):
    """Return a list of dom elements from a given observation.

    Args:
      observation: Observation from web environment.
      prune_refs: If given, these will be used to prune elements based on their
        refs. Refs are miniwob specific and might not work outside of miniwob.
      prune_form_input: If true, take only input elements such as input_text,
        input_radio or if they have text values that could be used as label of an
        input.

    Returns:
      A list of DOM elements.
    """

    def check_prune_element(elem):
        return elem.tag in INPUT_ELEMENTS + COMMON_ELEMENTS or (
                elem.tag in ADDITIONAL_ELEMENTS and (elem.text or elem.value))

    dom_elements = []

    def _get_elements(element, dom_elements):
        for child in element.children:
            if child.tag == "option":
                continue
            _get_elements(child, dom_elements)
        dom_elements.append(element)

    if not prune_form_input:
        dom_elements = observation.dom_elements
    else:
        _get_elements(observation.dom, dom_elements)
    dom_elements = order_dom_elements(dom_elements)
    if prune_form_input:
        dom_elements = [elem for elem in dom_elements if check_prune_element(elem)]
    if prune_refs:
        return [
            elem for elem in dom_elements if elem.ref and elem.ref in prune_refs
        ]
    return dom_elements


def select_dom_element(dom_elements, index):
    """Select a dom element for a given index.

    Args:
      dom_elements: A list of dom elements.
      index: An integer index pointing to a dom element.

    Returns:
      A dom element that corresponds to the given index.
    """
    # An ordered dictionary to map from element refs to dom elements.
    # This is used to ensure refs are unique and sorted.
    # Sorting aligns with order_dom_elements function.
    dom_map = collections.OrderedDict()
    for elem in order_dom_elements(dom_elements):
        dom_map[elem.ref] = elem
    if index >= len(dom_map):  # index is invalid
        return None
    dom_elem = list(dom_map.items())[index][1]
    return dom_elem


def generate_click_action(dom_elements, index_or_element):
    """Generate a click action using dom element or its index.

    Args:
      dom_elements: A list of dom elements.
      index_or_element: An integer index pointing to a dom element or a DOMElement
        instance.

    Returns:
      A miniwob click action.
    """
    if isinstance(index_or_element, state.DOMElement):
        dom_elem = index_or_element
    else:
        dom_elem = select_dom_element(dom_elements, index_or_element)
    if not dom_elem:
        return None
    click = create_element_click(dom_elem)
    return click


def generate_focus_and_type_action(dom_elements, index_or_element,
                                   type_seq, typed_refs):
    """Generate focus and type action using dom element or its index.

    Args:
      dom_elements: A list of dom elements.
      index_or_element: An integer index pointing to a dom element or a DOMElement
        instance.
      type_seq: A sequence of tokens to use with keyboard action.
      typed_refs: A set of refs that correspond to elements that are already typed
        in the current episode. This is used to check whether a space should be
        inserted before typing or not.

    Returns:
      A MiniWoBFocusAndType action that simulates typing a sequence into an
      element.
    """
    if isinstance(index_or_element, state.DOMElement):
        dom_elem = index_or_element
    else:
        dom_elem = select_dom_element(dom_elements, index_or_element)

    if not dom_elem:
        return None
    if type_seq and dom_elem.ref:
        if dom_elem.ref in typed_refs:
            type_seq = " " + type_seq
        else:
            typed_refs.add(dom_elem.ref)
    return action.MiniWoBFocusAndType(
        dom_elem, type_seq, handle_select_in_selenium=False)


def generate_web_action(dom_elements,
                        action_type,
                        index_or_dom_element,
                        type_seq="",
                        typed_refs=None):
    """Generate a click or keyboard action.

    Args:
      dom_elements: A list of dom elements.
      action_type: Type of the action. Either click (1) or keyboard (0).
      index_or_dom_element: An integer index pointing to a dom element or a
        DOMElement instance.
      type_seq: A sequence of tokens to use with keyboard action.
      typed_refs: A set of refs that correspond to elements that are already typed
        in the current episode. This is used to check whether a space should be
        inserted before typing or not.

    Returns:
      A MiniWoBClick or MiniWoBFocusAndType action that
      simulates clicking on an element or typing a sequence into an element.

    Raises:
      A ValueError if the action type is incorrect.
    """
    if action_type == 0:  # focus and keyboard
        return generate_focus_and_type_action(dom_elements, index_or_dom_element,
                                              type_seq, typed_refs)
    elif action_type == 1:  # click
        return generate_click_action(dom_elements, index_or_dom_element)
    raise ValueError(
        "Web action should be 0 (focus-and-keyboard) or 1 (click) but found {}"
        .format(action_type))


@gin.configurable
def dom_attributes(dom_element, num_attributes=-1):
    """Return a list of dom attributes.


    Args:
      dom_element: A dom element.
      num_attributes: If given, use this many attributes that are sorted according
        to their human readability and importance.

    Returns:
      A list of attributes that represents the dom element.
    """
    result = [
        dom_element.tag, dom_element.value, dom_element.text,
        dom_element.placeholder, dom_element.classes
    ]
    if num_attributes <= 0 or num_attributes > len(result):
        return result
    return result[0:num_attributes]


def dom_element_representation(dom_element,
                               local_vocabulary,
                               dom_attribute_sequence_length,
                               num_attributes=-1):
    """Tokenize and indexify a dom element.

    Args:
      dom_element: A dom element instance.
      local_vocabulary: A dictionary to map tokens to ids.
      dom_attribute_sequence_length: Maximum allowed length of an attribute
        sequence.
      num_attributes: If given, use this many attributes that are sorted according
        to their human readability and importance.

    Returns:
      A tuple of (dom, mask) where dom is the ids of attributes for each dom
      element and mask is the masking for attribute sequences.
    """
    dom_element_attributes = dom_attributes(
        dom_element, num_attributes=num_attributes)
    tokenized_attributes = [
        tokenize(dom_element_attr) for dom_element_attr in dom_element_attributes
    ]
    indexified_attributes = [
        indexify(attr, local_vocabulary, dom_attribute_sequence_length)
        for attr in tokenized_attributes
    ]
    mask = np.asarray([[0] * dom_attribute_sequence_length
                       for _ in range(len(tokenized_attributes))],
                      dtype="float")
    for i, dom_element_attr in enumerate(tokenized_attributes):
        mask[i, 0:len(dom_element_attr)] = 1.0
    return indexified_attributes, mask


@gin.configurable
def get_word_to_id(word,
                   local_vocabulary,
                   check_for_numbers=False,
                   token_length=-1):
    """Get id of a word.

      This function explicitly extracts numbers from the word. If there are
      multiple words, it returns the first occurrence.
      If a token is not in the vocabulary, it is added to it as long as the size
      of the vocabulary is less than the allowed amount.

    Args:
      word: A single token.
      local_vocabulary: Mapping from tokens to ids.
      check_for_numbers: If true, check the availability of a number in the token
        and use that.
      token_length: If given, tokens longer than this threshold will not be added
        to the vocabulary.

    Returns:
      A single integer id that corresponds to the word.
    """
    max_vocab_size = local_vocabulary.max_vocabulary_size

    if word in local_vocabulary:
        return local_vocabulary[word]
    elif str(word) in local_vocabulary:
        return local_vocabulary[str(word)]
    elif token_length > 0 and len(word) > token_length:
        return local_vocabulary["OOV"]
    elif check_for_numbers:
        try:
            raw_numbers = re.findall(r"\d+", word)
            if raw_numbers:
                if raw_numbers[0] not in local_vocabulary:
                    if max_vocab_size > 0 and len(local_vocabulary) >= max_vocab_size:
                        return local_vocabulary["OOV"]
                    local_vocabulary.add_to_vocabulary([raw_numbers[0]])
                return local_vocabulary[raw_numbers[0]]
        except TypeError as _:
            pass
        except ValueError as _:
            pass
        except AttributeError as _:
            pass
    # some numbers (ui-id-NUMBER) from autocomplete will be OOV .
    if max_vocab_size != -1 and len(local_vocabulary) >= max_vocab_size:
        return local_vocabulary["OOV"]
    local_vocabulary.add_to_vocabulary([str(word)])
    return get_word_to_id(
        word, local_vocabulary, check_for_numbers=check_for_numbers)


@gin.configurable
def indexify(tokens, local_vocabulary, max_sequence_length):
    """Map tokens to ids and pad with id of NULL if necessary.

    Args:
      tokens: A sequence of tokens. If not, assume a single token.
      local_vocabulary: A dictionary of word to id mapping.
      max_sequence_length: Maximum allowed sequence length.

    Returns:
      A sequence of ids that corresponds to the input sequence of tokens.
    """
    if not isinstance(tokens, list):
        tokens = [tokens]
    indices = [get_word_to_id(s, local_vocabulary) for s in tokens]
    if len(indices) > max_sequence_length:
        return indices[:max_sequence_length]
    pad_length = max_sequence_length - len(indices)
    return indices + [local_vocabulary["NULL"]] * pad_length
