# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# This file was copied and modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/byte_utils.py

import re
import unicodedata

WHITESPACE_NORMALIZER = re.compile(r"\s+")
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
BPE_UNK = chr(8263)

PRINTABLE_BASE_CHARS = [
    256,
    257,
    258,
    259,
    260,
    261,
    262,
    263,
    264,
    265,
    266,
    267,
    268,
    269,
    270,
    271,
    272,
    273,
    274,
    275,
    276,
    277,
    278,
    279,
    280,
    281,
    282,
    283,
    284,
    285,
    286,
    287,
    32,
    33,
    34,
    35,
    36,
    37,
    38,
    39,
    40,
    41,
    42,
    43,
    44,
    45,
    46,
    47,
    48,
    49,
    50,
    51,
    52,
    53,
    54,
    55,
    56,
    57,
    58,
    59,
    60,
    61,
    62,
    63,
    64,
    65,
    66,
    67,
    68,
    69,
    70,
    71,
    72,
    73,
    74,
    75,
    76,
    77,
    78,
    79,
    80,
    81,
    82,
    83,
    84,
    85,
    86,
    87,
    88,
    89,
    90,
    91,
    92,
    93,
    94,
    95,
    96,
    97,
    98,
    99,
    100,
    101,
    102,
    103,
    104,
    105,
    106,
    107,
    108,
    109,
    110,
    111,
    112,
    113,
    114,
    115,
    116,
    117,
    118,
    119,
    120,
    121,
    122,
    123,
    124,
    125,
    126,
    288,
    289,
    290,
    291,
    292,
    293,
    294,
    295,
    296,
    297,
    298,
    299,
    300,
    301,
    302,
    303,
    304,
    305,
    308,
    309,
    310,
    311,
    312,
    313,
    314,
    315,
    316,
    317,
    318,
    321,
    322,
    323,
    324,
    325,
    326,
    327,
    328,
    330,
    331,
    332,
    333,
    334,
    335,
    336,
    337,
    338,
    339,
    340,
    341,
    342,
    343,
    344,
    345,
    346,
    347,
    348,
    349,
    350,
    351,
    352,
    353,
    354,
    355,
    356,
    357,
    358,
    359,
    360,
    361,
    362,
    363,
    364,
    365,
    366,
    367,
    368,
    369,
    370,
    371,
    372,
    373,
    374,
    375,
    376,
    377,
    378,
    379,
    380,
    381,
    382,
    384,
    385,
    386,
    387,
    388,
    389,
    390,
    391,
    392,
    393,
    394,
    395,
    396,
    397,
    398,
    399,
    400,
    401,
    402,
    403,
    404,
    405,
    406,
    407,
    408,
    409,
    410,
    411,
    412,
    413,
    414,
    415,
    416,
    417,
    418,
    419,
    420,
    421,
    422,
]

for c in PRINTABLE_BASE_CHARS:
    assert unicodedata.normalize("NFKC", chr(c)) == chr(c), c

BYTE_TO_BCHAR = {b: chr(PRINTABLE_BASE_CHARS[b]) for b in range(256)}
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
BCHAR_TO_BYTE[BPE_UNK] = 32  # map unk to space


def byte_encode(x: str) -> str:
    normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
    return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])


def byte_decode(x: str) -> str:
    try:
        return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
    except ValueError:
        return ""


def smart_byte_decode(x: str) -> str:
    output = byte_decode(x)
    if output == "":
        # DP the best recovery (max valid chars) if it's broken
        n_bytes = len(x)
        f = [0 for _ in range(n_bytes + 1)]
        pt = [0 for _ in range(n_bytes + 1)]
        for i in range(1, n_bytes + 1):
            f[i], pt[i] = f[i - 1], i - 1
            for j in range(1, min(4, i) + 1):
                if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
                    f[i], pt[i] = f[i - j] + 1, i - j
        cur_pt = n_bytes
        while cur_pt > 0:
            if f[cur_pt] == f[pt[cur_pt]] + 1:
                output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
            cur_pt = pt[cur_pt]
    return output
