#!/usr/bin/env python
#  type: ignore


""" Text-based normalizers, used to mitigate simple attacks against watermarking.

This implementation is unlikely to be a complete list of all possible exploits within the unicode standard,
it represents our best effort at the time of writing.

These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would
require messing with the limited rust interface of tokenizers.NormalizedString
"""

import re
import unicodedata
from collections import defaultdict
from functools import cache

import watermark_stealing.watermarks.kgw.homoglyphs as hg


def normalization_strategy_lookup(strategy_name: str) -> object:
    if strategy_name == "unicode":
        return UnicodeSanitizer()
    elif strategy_name == "homoglyphs":
        return HomoglyphCanonizer()
    elif strategy_name == "truecase":
        return TrueCaser()


class HomoglyphCanonizer:
    """Attempts to detect homoglyph attacks and find a consistent canon.

    This function does so on a per-ISO-category level. Language-level would also be possible (see commented code).
    """

    def __init__(self):
        self.homoglyphs = None

    def __call__(self, homoglyphed_str: str) -> str:
        # find canon:
        target_category, all_categories = self._categorize_text(homoglyphed_str)
        homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
        return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)

    def _categorize_text(self, text: str) -> dict:
        iso_categories = defaultdict(int)
        # self.iso_languages = defaultdict(int)

        for char in text:
            iso_categories[hg.Categories.detect(char)] += 1
            # for lang in hg.Languages.detect(char):
            #     self.iso_languages[lang] += 1
        target_category = max(iso_categories, key=iso_categories.get)
        all_categories = tuple(iso_categories)
        return target_category, all_categories

    @cache
    def _select_canon_category_and_load(
        self, target_category: str, all_categories: tuple[str]
    ) -> dict:
        homoglyph_table = hg.Homoglyphs(
            categories=(target_category, "COMMON")
        )  # alphabet loaded here from file

        source_alphabet = hg.Categories.get_alphabet(all_categories)
        restricted_table = homoglyph_table.get_restricted_table(
            source_alphabet, homoglyph_table.alphabet
        )  # table loaded here from file
        return restricted_table

    def _sanitize_text(
        self, target_category: str, homoglyph_table: dict, homoglyphed_str: str
    ) -> str:
        sanitized_text = ""
        for char in homoglyphed_str:
            # langs = hg.Languages.detect(char)
            cat = hg.Categories.detect(char)
            if target_category in cat or "COMMON" in cat or len(cat) == 0:
                sanitized_text += char
            else:
                sanitized_text += list(homoglyph_table[char])[0]
        return sanitized_text


class UnicodeSanitizer:
    """Regex-based unicode sanitzer. Has different levels of granularity.

    * ruleset="whitespaces"    - attempts to remove only whitespace unicode characters
    * ruleset="IDN.blacklist"  - does its best to remove unusual unicode based on  Network.IDN.blacklist characters
    * ruleset="ascii"          - brute-forces all text into ascii

    This is unlikely to be a comprehensive list.

    You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/
    and https://www.unicode.org/faq/security.html
    """

    def __init__(self, ruleset="whitespaces"):
        if ruleset == "whitespaces":
            """Documentation:
            \u00a0: Non-breaking space
            \u1680: Ogham space mark
            \u180e: Mongolian vowel separator
            \u2000-\u200b: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
            \u200c\u200d: Zero-width non-joiner and zero-width joiner
            \u200e,\u200f: Left-to-right-mark, Right-to-left-mark
            \u2060: Word joiner
            \u2063: Invisible separator
            \u202f: Narrow non-breaking space
            \u205f: Medium mathematical space
            \u3000: Ideographic space
            \ufeff: Zero-width non-breaking space
            \uffa0: Halfwidth hangul filler
            \ufff9\ufffa\ufffb: Interlinear annotation characters
            \ufe00-\ufe0f: Variation selectors
            \u202a-\u202f: Embedding characters
            \u3164: Korean hangul filler.

            Note that these characters are not always superfluous whitespace characters!
            """

            self.pattern = re.compile(
                r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
                r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
                r"\u202E\u202F]"
            )
        elif ruleset == "IDN.blacklist":
            """Documentation:
            [\u00a0\u1680\u180e\u2000-\u200b\u202f\u205f\u2060\u2063\ufeff]: Matches any whitespace characters in the Unicode character
                        set that are included in the IDN blacklist.
            \ufff9-\ufffb: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
                        These characters are not allowed in domain names.
            \ud800-\udb7f: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
                        set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
                        and the second part is in the range U+DC00 to U+DFFF.
            \udb80-\udbff][\udc00-\udfff]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
                        to U+DFFF, and is optional.
            [\udb40\udc20-\udb40\udc7f][\udc00-\udfff]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
            """

            self.pattern = re.compile(
                r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
                r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
            )
        else:
            """Documentation:
            This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
            """
            self.pattern = re.compile(r"[^\x00-\x7F]+")

    def __call__(self, text: str) -> str:
        text = unicodedata.normalize("NFC", text)  # canon forms
        text = self.pattern.sub(" ", text)  # pattern match
        text = re.sub(" +", " ", text)  # collapse whitespaces
        text = "".join(
            c for c in text if unicodedata.category(c) != "Cc"
        )  # Remove any remaining non-printable characters
        return text


class TrueCaser:
    """True-casing, is a capitalization normalization that returns text to its original capitalization.

    This defends against attacks that wRIte TeXt lIkE spOngBoB.

    Here, a simple POS-tagger is used.
    """

    uppercase_pos = ["PROPN"]  # Name POS tags that should be upper-cased

    def __init__(self, backend="spacy"):
        if backend == "spacy":
            import spacy

            self.nlp = spacy.load("en_core_web_sm")
            self.normalize_fn = self._spacy_truecasing
        else:
            import nltk
            from nltk import pos_tag, word_tokenize  # noqa

            nltk.download("punkt")
            nltk.download("averaged_perceptron_tagger")
            nltk.download("universal_tagset")
            self.normalize_fn = self._nltk_truecasing

    def __call__(self, random_capitalized_string: str) -> str:
        truecased_str = self.normalize_fn(random_capitalized_string)
        return truecased_str

    def _spacy_truecasing(self, random_capitalized_string: str):
        doc = self.nlp(random_capitalized_string.lower())
        POS = self.uppercase_pos
        truecased_str = "".join([
            w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws
            for w in doc
        ])
        return truecased_str

    def _nltk_truecasing(self, random_capitalized_string: str):
        import nltk
        from nltk import pos_tag, word_tokenize

        nltk.download("punkt")
        nltk.download("averaged_perceptron_tagger")
        nltk.download("universal_tagset")
        POS = ["NNP", "NNPS"]

        tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
        truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
        return truecased_str
