from typing import List, Dict
import pandas as pd
import recognizers_suite
from recognizers_suite import Culture
import re
import unicodedata
from fuzzywuzzy import fuzz

from utils.extraction_from_sql import *
from utils.sql_stuff import ALL_KEY_WORDS

culture = Culture.English


def maybe_normalize_float(span):
    if span and (re.match(r"^[+-][0-9]+[.]?[0-9]*$", span)
                 or (re.match(r"^[0-9]*[.]?[0-9]*$", span))) and span != '.':
        # FIXME: We did this(instead of try except) to convert a string into a float
        #  since the try catch will lead to an error when using 8 V100 gpus with cuda 11.0,
        #  and we still don't know why that could happen....
        return str(float(span))
    else:
        return span


def finditer(sub_str: str, mother_str: str):
    result = []
    start_index = 0
    while True:
        start_index = mother_str.find(sub_str, start_index, -1)
        if start_index == -1:
            break
        end_idx = start_index + len(sub_str)
        result.append((start_index, end_idx))
        start_index = end_idx
    return result


def normalize_header(header: List):
    """Header normalizer for converting a table into a DataFrame"""
    normalized_header = []
    # Rename duplicated header.
    col_counter = {}
    for idx, col in enumerate(header):
        if col in header[:idx]:
            if col not in col_counter.keys():
                col_counter[col] = 1
            col_number = col_counter[col]
            normalized_header.append("{}_{}".format(col, col_number))
            col_counter[col] += 1
        else:
            normalized_header.append(col)

    return normalized_header


def str_normalize(user_input, recognition_types=None):
    """A string normalizer which recognize and normalize value based on recognizers_suite"""
    user_input = str(user_input)
    user_input = user_input.replace("\\n", "; ")

    def replace_by_idx_pairs(orig_str, strs_to_replace, idx_pairs):
        assert len(strs_to_replace) == len(idx_pairs)
        last_end = 0
        to_concat = []
        for idx_pair, str_to_replace in zip(idx_pairs, strs_to_replace):
            to_concat.append(orig_str[last_end:idx_pair[0]])
            to_concat.append(str_to_replace)
            last_end = idx_pair[1]
        to_concat.append(orig_str[last_end:])
        return ''.join(to_concat)

    if recognition_types is None:
        recognition_types = ["datetime",
                             "number",
                             # "ordinal",
                             # "percentage",
                             # "age",
                             # "currency",
                             # "dimension",
                             # "temperature",
                             ]

    for recognition_type in recognition_types:
        if re.match("\d+/\d+", user_input):
            # avoid calculating str as 1991/92
            continue
        recognized_list = getattr(recognizers_suite, "recognize_{}".format(recognition_type))(user_input,
                                                                                              culture)  # may match multiple parts
        strs_to_replace = []
        idx_pairs = []
        for recognized in recognized_list:
            if not recognition_type == 'datetime':
                recognized_value = recognized.resolution['value']
                if str(recognized_value).startswith("P"):
                    # if the datetime is a period:
                    continue
                else:
                    strs_to_replace.append(recognized_value)
                    idx_pairs.append((recognized.start, recognized.end + 1))
            else:
                if recognized.resolution:  # in some cases, this variable could be none.
                    if len(recognized.resolution['values']) == 1:
                        strs_to_replace.append(
                            recognized.resolution['values'][0]['timex'])  # We use timex as normalization
                        idx_pairs.append((recognized.start, recognized.end + 1))

        if len(strs_to_replace) > 0:
            user_input = replace_by_idx_pairs(user_input, strs_to_replace, idx_pairs)

    if re.match("(.*)-(.*)-(.*) 00:00:00", user_input):
        user_input = user_input[:-len("00:00:00") - 1]
        # '2008-04-13 00:00:00' -> '2008-04-13'
    return user_input


def prepare_df_for_neuraldb_from_table(table: Dict, add_row_id=True, normalize=True, lower_case=True):
    header, rows = table['header'], table['rows']
    if add_row_id and 'row_id' not in header:
        # warning: please add row_id to make neuraldb work normally
        header = ["row_id"] + header
        rows = [["{}".format(i)] + row for i, row in enumerate(rows)]
    if normalize:
        df = convert_df_type(pd.DataFrame(data=rows, columns=header), lower_case=lower_case)
    else:
        df = pd.DataFrame(data=rows, columns=header)

    return df


def get_table_content_in_column(table):
    if isinstance(table, pd.DataFrame):
        header = table.columns.tolist()
        rows = table.values.tolist()
    else:
        # Standard table dict format
        header, rows = table['header'], table['rows']
    all_col_values = []
    for i in range(len(header)):
        one_col_values = []
        for _row in rows:
            one_col_values.append(_row[i])
        all_col_values.append(one_col_values)
    return all_col_values


def convert_df_type(df: pd.DataFrame, lower_case=True):
    """
    A simple converter of dataframe data type from string to int/float/datetime.
    """
    # Rename empty columns
    new_columns = []
    for idx, header in enumerate(df.columns):
        if header == '':
            new_columns.append('FilledColumnName')  # Fixme: give it a better name when all finished!
        else:
            new_columns.append(header)
    df.columns = new_columns

    # Rename duplicate columns
    new_columns = []
    for idx, header in enumerate(df.columns):
        if header in new_columns:
            new_header, suffix = header, 2
            while new_header in new_columns:
                new_header = header + '_' + str(suffix)
                suffix += 1
            new_columns.append(new_header)
        else:
            new_columns.append(header)
    df.columns = new_columns

    # Recognize null values like "-"
    null_tokens = ['', '-', '/']
    for header in df.columns:
        df[header] = df[header].map(lambda x: str(None) if x in null_tokens else x)

    # Convert the null values in digit column to "NaN"
    all_col_values = get_table_content_in_column(df)
    for col_i, one_col_values in enumerate(all_col_values):
        all_number_flag = True
        for row_i, cell_value in enumerate(one_col_values):
            try:
                float(cell_value)
            except Exception as e:
                if not cell_value in [str(None), str(None).lower()]:
                    # None or none
                    all_number_flag = False
        if all_number_flag:
            _header = df.columns[col_i]
            df[_header] = df[_header].map(lambda x: "NaN" if x in [str(None), str(None).lower()] else x)

    # Normalize cell values.
    for header in df.columns:
        df[header] = df[header].map(lambda x: str_normalize(x))

    # Strip the mis-added "01-01 00:00:00"
    all_col_values = get_table_content_in_column(df)
    for col_i, one_col_values in enumerate(all_col_values):
        all_with_00_00_00 = True
        all_with_01_00_00_00 = True
        all_with_01_01_00_00_00 = True
        for row_i, cell_value in enumerate(one_col_values):
            if not str(cell_value).endswith(" 00:00:00"):
                all_with_00_00_00 = False
            if not str(cell_value).endswith("-01 00:00:00"):
                all_with_01_00_00_00 = False
            if not str(cell_value).endswith("-01-01 00:00:00"):
                all_with_01_01_00_00_00 = False
        if all_with_01_01_00_00_00:
            _header = df.columns[col_i]
            df[_header] = df[_header].map(lambda x: x[:-len("-01-01 00:00:00")])
            continue

        if all_with_01_00_00_00:
            _header = df.columns[col_i]
            df[_header] = df[_header].map(lambda x: x[:-len("-01 00:00:00")])
            continue

        if all_with_00_00_00:
            _header = df.columns[col_i]
            df[_header] = df[_header].map(lambda x: x[:-len(" 00:00:00")])
            continue

    # Do header and cell value lower case
    if lower_case:
        new_columns = []
        for header in df.columns:
            lower_header = str(header).lower()
            if lower_header in new_columns:
                new_header, suffix = lower_header, 2
                while new_header in new_columns:
                    new_header = lower_header + '-' + str(suffix)
                    suffix += 1
                new_columns.append(new_header)
            else:
                new_columns.append(lower_header)
        df.columns = new_columns
        for header in df.columns:
            # df[header] = df[header].map(lambda x: str(x).lower())
            df[header] = df[header].map(lambda x: str(x).lower().strip())

    # Recognize header type
    for header in df.columns:

        float_able = False
        int_able = False
        datetime_able = False

        # Recognize int & float type
        try:
            df[header].astype("float")
            float_able = True
        except:
            pass

        if float_able:
            try:
                if all(df[header].astype("float") == df[header].astype(int)):
                    int_able = True
            except:
                pass

        if float_able:
            if int_able:
                df[header] = df[header].astype(int)
            else:
                df[header] = df[header].astype(float)

        # Recognize datetime type
        try:
            df[header].astype("datetime64")
            datetime_able = True
        except:
            pass

        if datetime_able:
            df[header] = df[header].astype("datetime64")

    return df


def normalize(x):
    """ Normalize string. """
    # Copied from WikiTableQuestions dataset official evaluator.
    if x is None:
        return None
    # Remove diacritics
    x = ''.join(c for c in unicodedata.normalize('NFKD', x)
                if unicodedata.category(c) != 'Mn')
    # Normalize quotes and dashes
    x = re.sub("[‘’´`]", "'", x)
    x = re.sub("[“”]", "\"", x)
    x = re.sub("[‐‑‒–—−]", "-", x)
    while True:
        old_x = x
        # Remove citations
        x = re.sub("((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$", "", x.strip())
        # Remove details in parenthesis
        x = re.sub("(?<!^)( \([^)]*\))*$", "", x.strip())
        # Remove outermost quotation mark
        x = re.sub('^"([^"]*)"$', r'\1', x.strip())
        if x == old_x:
            break
    # Remove final '.'
    if x and x[-1] == '.':
        x = x[:-1]
    # Collapse whitespaces and convert to lower case
    x = re.sub('\s+', ' ', x, flags=re.U).lower().strip()
    return x


def normalize_sql(sql_str, all_headers, table_title=None):
    """SQL normalizer"""

    # fixme: disable this if there are multiple tables
    if table_title:
        sql_str = sql_str.replace("FROM " + table_title, "FROM w")
        sql_str = sql_str.replace("FROM " + table_title.lower(), "FROM w")

    """Case 1: Fix the `` missing. """
    # Remove the null header.
    while '' in all_headers:
        all_headers.remove('')

    # Remove the '\n' in header.
    # This is because the WikiTQ won't actually show the str in two lines,
    # they use '\n' to mean that, and display it in the same line when print.
    sql_str = sql_str.replace("\\n", "\n")
    sql_str = sql_str.replace("\n", "\\n")

    # Add `` in SQL.

    all_headers.sort(key=lambda x: len(x), reverse=True)
    have_matched = [0 for i in range(len(sql_str))]

    # match quotation
    idx_s_single_quotation = [_ for _ in range(1, len(sql_str)) if
                              sql_str[_] in ["\'"] and sql_str[_ - 1] not in ["\'"]]
    idx_s_double_quotation = [_ for _ in range(1, len(sql_str)) if
                              sql_str[_] in ["\""] and sql_str[_ - 1] not in ["\""]]
    for idx_s in [idx_s_single_quotation, idx_s_double_quotation]:
        if len(idx_s) % 2 == 0:
            for idx in range(int(len(idx_s) / 2)):
                start_idx = idx_s[idx * 2]
                end_idx = idx_s[idx * 2 + 1]
                have_matched[start_idx: end_idx] = [2 for _ in range(end_idx - start_idx)]

    # match headers
    for header in all_headers:
        if (header in sql_str) and (header not in ALL_KEY_WORDS):
            all_matched_of_this_header = finditer(header, sql_str)
            for matched_of_this_header in all_matched_of_this_header:
                start_idx, end_idx = matched_of_this_header
                if all(have_matched[start_idx: end_idx]) == 0 and (not sql_str[start_idx - 1] == "`") and (
                        not sql_str[end_idx] == "`"):
                    have_matched[start_idx: end_idx] = [1 for _ in range(end_idx - start_idx)]
                    # a bit ugly, but anyway.

    # re-compose sql from the matched idx.
    start_have_matched = [0] + have_matched
    end_have_matched = have_matched + [0]
    start_idx_s = [idx - 1 for idx in range(1, len(start_have_matched)) if
                   start_have_matched[idx - 1] == 0 and start_have_matched[idx] == 1]
    end_idx_s = [idx for idx in range(len(end_have_matched) - 1) if
                 end_have_matched[idx] == 1 and end_have_matched[idx + 1] == 0]
    assert len(start_idx_s) == len(end_idx_s)
    spans = []
    current_idx = 0
    for start_idx, end_idx in zip(start_idx_s, end_idx_s):
        spans.append(sql_str[current_idx:start_idx])
        spans.append(sql_str[start_idx:end_idx + 1])
        current_idx = end_idx + 1
    spans.append(sql_str[current_idx:])
    sql_str = '`'.join(spans)

    return sql_str


def postprosess_sql(sql_str, df, verbose=False):
    """
    Post-process SQL by fuzzy matching value with table contents.
    """
    def _get_matched_cells(value_str, df, fuzz_threshold=70):
        """
        Get matched table cells with value token.
        """
        matched_cells = []
        for row_id, row in df.iterrows():
            for cell in row:
                cell = str(cell)
                fuzz_score = fuzz.ratio(value_str, cell)
                if fuzz_score == 100:
                    matched_cells = [(cell, fuzz_score)]
                    return matched_cells
                if fuzz_score >= fuzz_threshold:
                    matched_cells.append((cell, fuzz_score))

        matched_cells = sorted(matched_cells, key=lambda x: x[1], reverse=True)
        return matched_cells

    def _check_valid_fuzzy_match(value_str, matched_cell):
        """
        Check if the fuzzy match is valid, now considering:
        1. The number/date should not be disturbed, but adding new number or deleting number is valid.
        """
        number_pattern = "[+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?"
        numbers_in_value = re.findall(number_pattern, value_str)
        numbers_in_matched_cell = re.findall(number_pattern, matched_cell)
        try:
            numbers_in_value = [float(num.replace(',', '')) for num in numbers_in_value]
        except:
            print(f"Can't convert number string {numbers_in_value} into float in _check_valid_fuzzy_match().")
        try:
            numbers_in_matched_cell = [float(num.replace(',', '')) for num in numbers_in_matched_cell]
        except:
            print(f"Can't convert number string {numbers_in_matched_cell} into float in _check_valid_fuzzy_match().")
        numbers_in_value = set(numbers_in_value)
        numbers_in_matched_cell = set(numbers_in_matched_cell)

        if numbers_in_value.issubset(numbers_in_matched_cell) or numbers_in_matched_cell.issubset(numbers_in_value):
            return True
        else:
            return False

    # Drop trailing '\n```', a pattern that may appear in Codex SQL generation
    sql_str = sql_str.rstrip('```').rstrip('\n')

    # Replace QA module with placeholder
    qa_pattern = "QA\(.+?;.*?`.+?`.*?\)"
    qas = re.findall(qa_pattern, sql_str)
    for idx, qa in enumerate(qas):
        sql_str = sql_str.replace(qa, f"placeholder{idx}")

    # Parse and replace SQL value with table contents
    sql_tokens = tokenize(sql_str)
    sql_template_tokens = extract_partial_template_from_sql(sql_str)
    # Fix 'between' keyword bug in parsing templates
    fixed_sql_template_tokens = []
    sql_tok_bias = 0
    for idx, sql_templ_tok in enumerate(sql_template_tokens):
        sql_tok = sql_tokens[idx + sql_tok_bias]
        if sql_tok == 'between' and sql_templ_tok == '[WHERE_OP]':
            fixed_sql_template_tokens.extend(['[WHERE_OP]', '[VALUE]', 'and'])
            sql_tok_bias += 2  # pass '[VALUE]', 'and'
        else:
            fixed_sql_template_tokens.append(sql_templ_tok)
    sql_template_tokens = fixed_sql_template_tokens
    for idx, tok in enumerate(sql_tokens):
        if tok in ALL_KEY_WORDS:
            sql_tokens[idx] = tok.upper()

    if verbose:
        print(sql_tokens)
        print(sql_template_tokens)

    assert len(sql_tokens) == len(sql_template_tokens)
    value_indices = [idx for idx in range(len(sql_template_tokens)) if sql_template_tokens[idx] == '[VALUE]']
    for value_idx in value_indices:
        # Skip the value if the where condition column is QA module
        if value_idx >= 2 and sql_tokens[value_idx-2].startswith('placeholder'):
            continue
        value_str = sql_tokens[value_idx]
        # Drop \"\" for fuzzy match
        is_string = False
        if value_str[0] == "\"" and value_str[-1] == "\"":
            value_str = value_str[1:-1]
            is_string = True
        # If already fuzzy match, skip
        if value_str[0] == '%' or value_str[-1] == '%':
            continue
        value_str = value_str.lower()
        # Fuzzy Match
        matched_cells = _get_matched_cells(value_str, df)

        if verbose:
            print(matched_cells)

        new_value_str = value_str
        if matched_cells:
            # new_value_str = matched_cells[0][0]
            for matched_cell, fuzz_score in matched_cells:
                if _check_valid_fuzzy_match(value_str, matched_cell):
                    new_value_str = matched_cell
                    if verbose and new_value_str != value_str:
                        print("\tfuzzy match replacing!", value_str, '->', matched_cell, f'fuzz_score:{fuzz_score}')
                    break
        if is_string:
            new_value_str = f"\"{new_value_str}\""
        sql_tokens[value_idx] = new_value_str
    # Compose new sql string
    # Clean column name in SQL since columns may have been tokenized in the postprocessing, e.g., (ppp) -> ( ppp )
    new_sql_str = ' '.join(sql_tokens)
    sql_columns = re.findall('`\s(.*?)\s`', new_sql_str)
    for sql_col in sql_columns:
        matched_columns = []
        for col in df.columns:
            score = fuzz.ratio(sql_col.lower(), col)
            if score == 100:
                matched_columns = [(col, score)]
                break
            if score >= 80:
                matched_columns.append((col, score))
        matched_columns = sorted(matched_columns, key=lambda x: x[1], reverse=True)
        if matched_columns:
            matched_col = matched_columns[0][0]
            new_sql_str = new_sql_str.replace(f"` {sql_col} `", f"`{matched_col}`")
        else:
            new_sql_str = new_sql_str.replace(f"` {sql_col} `", f"`{sql_col}`")

    # Restore QA modules
    for idx, qa in enumerate(qas):
        new_sql_str = new_sql_str.replace(f"placeholder{idx}", qa)

    # Fix '<>' when composing the new sql
    new_sql_str = new_sql_str.replace('< >', '<>')

    return new_sql_str
