# coding=utf-8
# Copyright 2022 The Microsoft, The Google and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import enum
import functools
import math
import re

# The following script is adapted from the script of TaPas.
# Original: https://github.com/google-research/tapas/master/wikisql_utils.py
from typing import Any, List


EMPTY_ANSWER = "none"
EMPTY_ANSWER_AGG = "none"


def _split_thousands(delimiter, value):
    split = value.split(delimiter)
    return len(split) > 1 and any((len(x) == 3 for x in split))


def convert_to_float(value):
    """Converts value to a float using a series of increasingly complex heuristics.
    Args:
      value: object that needs to be converted. Allowed types include
        float/int/strings.
    Returns:
      A float interpretation of value.
    Raises:
      ValueError if the float conversion of value fails.
    """
    if isinstance(value, float):
        return value
    if isinstance(value, int):
        return float(value)
    if not isinstance(value, str):
        raise TypeError("Argument value is not a string. Can't parse it as float")
    sanitized = value

    try:
        # Example: 1,000.7
        if "." in sanitized and "," in sanitized:
            return float(sanitized.replace(",", ""))
        # 1,000
        if "," in sanitized and _split_thousands(",", sanitized):
            return float(sanitized.replace(",", ""))
        # 5,5556
        if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(",", sanitized):
            return float(sanitized.replace(",", "."))
        # 0.0.0.1
        if sanitized.count(".") > 1:
            return float(sanitized.replace(".", ""))
        # 0,0,0,1
        if sanitized.count(",") > 1:
            return float(sanitized.replace(",", ""))
        return float(sanitized)
    except ValueError:
        # Avoid adding the sanitized value in the error message.
        raise ValueError("Unable to convert value to float")


def _normalize_float(answer):
    if answer is None:
        return None
    try:
        value = convert_to_float(answer)
        if isinstance(value, float) and math.isnan(value):
            return None
        return value
    except ValueError:
        return answer.lower()


_TYPE_CONVERTER = {
    "text": lambda x: x,
    "real": convert_to_float,
}


class _Aggregation(enum.Enum):
    """Aggregations as defined by WikiSQL. Indexes match the data."""

    NONE = 0
    MAX = 1
    MIN = 2
    COUNT = 3
    SUM = 4
    AVERAGE = 5


class _Operator(enum.Enum):
    """The boolean operators used by WikiSQL. Indexes match the data."""

    EQUALS = 0
    GREATER = 1
    LESSER = 2


@dataclasses.dataclass
class _Condition:
    """Represents an SQL where clauses (e.g A = "a" or B > 5)."""

    column: str
    operator: _Operator
    cmp_value: Any


_TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL)


def _normalize_for_match(x):
    return list(_TOKENIZER.findall(x.lower()))


def _compare(operator, src, tgt):
    if operator == _Operator.EQUALS:
        return src == tgt
    elif operator == _Operator.GREATER:
        return src > tgt
    elif operator == _Operator.LESSER:
        return src < tgt
    raise ValueError(f"Unknown operator: {operator}")


def _parse_value(table, column, cell_value):
    """Convert numeric values to floats and keeps everything else as string."""
    types = table["types"]
    return _TYPE_CONVERTER[types[column]](cell_value)


def _is_string(x):
    return isinstance(x, str)


def _respect_conditions(table, row, conditions):
    """True if 'row' satisfies all 'conditions'."""
    for cond in conditions:
        table_value = row[cond.column]

        cmp_value = _parse_value(table, cond.column, cond.cmp_value)

        if _is_string(table_value) and _is_string(cmp_value):
            table_value = _normalize_for_match(table_value)
            cmp_value = _normalize_for_match(cmp_value)

        if not isinstance(table_value, type(cmp_value)):
            raise TypeError("Type difference {} != {}".format(type(table_value), type(cmp_value)))

        if not _compare(cond.operator, table_value, cmp_value):
            return False
    return True


def _get_float_answer(table, answer_coordinates, aggregation_op):
    """Applies operation to produce reference float answer."""
    if not answer_coordinates:
        if aggregation_op == _Aggregation.COUNT:
            return 0.0
        else:
            return EMPTY_ANSWER_AGG

    # Count can support non numeric answers.
    if aggregation_op == _Aggregation.COUNT:
        return float(len(answer_coordinates))

    # If we have just one answer, if float returns it or try a conversion.
    values = [table["rows"][i][j] for (i, j) in answer_coordinates]
    if len(answer_coordinates) == 1:
        try:
            return convert_to_float(values[0])
        except ValueError as e:
            if aggregation_op != _Aggregation.NONE:
                raise e

    if aggregation_op == _Aggregation.NONE:
        return None

    # Other aggregation only support numeric values. Bail out if we have strings.
    if not all((isinstance(v, (int, float)) for v in values)):
        return None

    if aggregation_op == _Aggregation.SUM:
        return float(sum(values))
    elif aggregation_op == _Aggregation.AVERAGE:
        return sum(values) / len(answer_coordinates)
    else:
        raise ValueError(f"Unknown aggregation: {aggregation_op}")


def _get_answer_coordinates(table, sql_query):
    """Retrieves references coordinates by executing SQL."""
    # MAX and MIN are automatically supported by the model.
    aggregation_op_index = sql_query["agg"]
    if aggregation_op_index >= 3:
        aggregation_op = _Aggregation(aggregation_op_index)
    else:
        aggregation_op = _Aggregation.NONE

    target_column = sql_query["sel"]
    conditions = [
        _Condition(column, _Operator(operator), cmp_value)
        for column, operator, cmp_value in zip(
            sql_query["conds"]["column_index"], sql_query["conds"]["operator_index"], sql_query["conds"]["condition"]
        )
    ]

    indices = []
    for row in range(len(table["rows"])):
        if _respect_conditions(table, table["rows"][row], conditions):
            indices.append((row, target_column))

    if not indices:
        return [], aggregation_op

    if len(indices) == 1:
        return indices, aggregation_op

    # Parsing of MIN/MAX.
    if aggregation_op_index in (1, 2):
        operators = {2: min, 1: max}
        values = [(table["rows"][i][j], index) for index, (i, j) in enumerate(indices)]
        reduced = functools.reduce(operators[sql_query["agg"]], values)

        ret = [indices[reduced[1]]]
        return ret, _Aggregation.NONE

    return indices, aggregation_op


def _get_answer_text(table, answer_coordinates, float_answer):
    if float_answer is not None:
        return [str(float_answer)]
    return [str(table["real_rows"][r][c]) for r, c in answer_coordinates]


def retrieve_wikisql_query_answer_tapas(table, example) -> List:
    answer_coordinates, aggregation_op = _get_answer_coordinates(table, example)
    float_answer = _get_float_answer(table, answer_coordinates, aggregation_op)
    answer_text = _get_answer_text(table, answer_coordinates, float_answer)
    # keep the original data the same with TaPas
    if len(answer_text) == 0:
        answer_text = [EMPTY_ANSWER]
    return answer_text
