from . import idnadata
import bisect
import unicodedata
import re
from typing import Union, Optional
from .intranges import intranges_contain

_virama_combining_class = 9
_alabel_prefix = b'xn--'
_unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')

class IDNAError(UnicodeError):
    """ Base exception for all IDNA-encoding related problems """
    pass


class IDNABidiError(IDNAError):
    """ Exception when bidirectional requirements are not satisfied """
    pass


class InvalidCodepoint(IDNAError):
    """ Exception when a disallowed or unallocated codepoint is used """
    pass


class InvalidCodepointContext(IDNAError):
    """ Exception when the codepoint is not valid in the context it is used """
    pass


def _combining_class(cp):
    # type: (int) -> int
    v = unicodedata.combining(chr(cp))
    if v == 0:
        if not unicodedata.name(chr(cp)):
            raise ValueError('Unknown character in unicodedata')
    return v

def _is_script(cp, script):
    # type: (str, str) -> bool
    return intranges_contain(ord(cp), idnadata.scripts[script])

def _punycode(s):
    # type: (str) -> bytes
    return s.encode('punycode')

def _unot(s):
    # type: (int) -> str
    return 'U+{:04X}'.format(s)


def valid_label_length(label):
    # type: (Union[bytes, str]) -> bool
    if len(label) > 63:
        return False
    return True


def valid_string_length(label, trailing_dot):
    # type: (Union[bytes, str], bool) -> bool
    if len(label) > (254 if trailing_dot else 253):
        return False
    return True


def check_bidi(label, check_ltr=False):
    # type: (str, bool) -> bool
    # Bidi rules should only be applied if string contains RTL characters
    bidi_label = False
    for (idx, cp) in enumerate(label, 1):
        direction = unicodedata.bidirectional(cp)
        if direction == '':
            # String likely comes from a newer version of Unicode
            raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx))
        if direction in ['R', 'AL', 'AN']:
            bidi_label = True
    if not bidi_label and not check_ltr:
        return True

    # Bidi rule 1
    direction = unicodedata.bidirectional(label[0])
    if direction in ['R', 'AL']:
        rtl = True
    elif direction == 'L':
        rtl = False
    else:
        raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label)))

    valid_ending = False
    number_type = None  # type: Optional[str]
    for (idx, cp) in enumerate(label, 1):
        direction = unicodedata.bidirectional(cp)

        if rtl:
            # Bidi rule 2
            if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
                raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx))
            # Bidi rule 3
            if direction in ['R', 'AL', 'EN', 'AN']:
                valid_ending = True
            elif direction != 'NSM':
                valid_ending = False
            # Bidi rule 4
            if direction in ['AN', 'EN']:
                if not number_type:
                    number_type = direction
                else:
                    if number_type != direction:
                        raise IDNABidiError('Can not mix numeral types in a right-to-left label')
        else:
            # Bidi rule 5
            if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
                raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx))
            # Bidi rule 6
            if direction in ['L', 'EN']:
                valid_ending = True
            elif direction != 'NSM':
                valid_ending = False

    if not valid_ending:
        raise IDNABidiError('Label ends with illegal codepoint directionality')

    return True


def check_initial_combiner(label):
    # type: (str) -> bool
    if unicodedata.category(label[0])[0] == 'M':
        raise IDNAError('Label begins with an illegal combining character')
    return True


def check_hyphen_ok(label):
    # type: (str) -> bool
    if label[2:4] == '--':
        raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
    if label[0] == '-' or label[-1] == '-':
        raise IDNAError('Label must not start or end with a hyphen')
    return True


def check_nfc(label):
    # type: (str) -> None
    if unicodedata.normalize('NFC', label) != label:
        raise IDNAError('Label must be in Normalization Form C')


def valid_contextj(label, pos):
    # type: (str, int) -> bool
    cp_value = ord(label[pos])

    if cp_value == 0x200c:

        if pos > 0:
            if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
                return True

        ok = False
        for i in range(pos-1, -1, -1):
            joining_type = idnadata.joining_types.get(ord(label[i]))
            if joining_type == ord('T'):
                continue
            if joining_type in [ord('L'), ord('D')]:
                ok = True
                break

        if not ok:
            return False

        ok = False
        for i in range(pos+1, len(label)):
            joining_type = idnadata.joining_types.get(ord(label[i]))
            if joining_type == ord('T'):
                continue
            if joining_type in [ord('R'), ord('D')]:
                ok = True
                break
        return ok

    if cp_value == 0x200d:

        if pos > 0:
            if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
                return True
        return False

    else:

        return False


def valid_contexto(label, pos, exception=False):
    # type: (str, int, bool) -> bool
    cp_value = ord(label[pos])

    if cp_value == 0x00b7:
        if 0 < pos < len(label)-1:
            if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
                return True
        return False

    elif cp_value == 0x0375:
        if pos < len(label)-1 and len(label) > 1:
            return _is_script(label[pos + 1], 'Greek')
        return False

    elif cp_value == 0x05f3 or cp_value == 0x05f4:
        if pos > 0:
            return _is_script(label[pos - 1], 'Hebrew')
        return False

    elif cp_value == 0x30fb:
        for cp in label:
            if cp == '\u30fb':
                continue
            if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
                return True
        return False

    elif 0x660 <= cp_value <= 0x669:
        for cp in label:
            if 0x6f0 <= ord(cp) <= 0x06f9:
                return False
        return True

    elif 0x6f0 <= cp_value <= 0x6f9:
        for cp in label:
            if 0x660 <= ord(cp) <= 0x0669:
                return False
        return True

    return False


def check_label(label):
    # type: (Union[str, bytes, bytearray]) -> None
    if isinstance(label, (bytes, bytearray)):
        label = label.decode('utf-8')
    if len(label) == 0:
        raise IDNAError('Empty Label')

    check_nfc(label)
    check_hyphen_ok(label)
    check_initial_combiner(label)

    for (pos, cp) in enumerate(label):
        cp_value = ord(cp)
        if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
            continue
        elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
            try:
                if not valid_contextj(label, pos):
                    raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format(
                        _unot(cp_value), pos+1, repr(label)))
            except ValueError:
                raise IDNAError('Unknown codepoint adjacent to joiner {} at position {} in {}'.format(
                    _unot(cp_value), pos+1, repr(label)))
        elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
            if not valid_contexto(label, pos):
                raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label)))
        else:
            raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label)))

    check_bidi(label)


def alabel(label):
    # type: (str) -> bytes
    try:
        label_bytes = label.encode('ascii')
        ulabel(label_bytes)
        if not valid_label_length(label_bytes):
            raise IDNAError('Label too long')
        return label_bytes
    except UnicodeEncodeError:
        pass

    if not label:
        raise IDNAError('No Input')

    label = str(label)
    check_label(label)
    label_bytes = _punycode(label)
    label_bytes = _alabel_prefix + label_bytes

    if not valid_label_length(label_bytes):
        raise IDNAError('Label too long')

    return label_bytes


def ulabel(label):
    # type: (Union[str, bytes, bytearray]) -> str
    if not isinstance(label, (bytes, bytearray)):
        try:
            label_bytes = label.encode('ascii')
        except UnicodeEncodeError:
            check_label(label)
            return label
    else:
        label_bytes = label

    label_bytes = label_bytes.lower()
    if label_bytes.startswith(_alabel_prefix):
        label_bytes = label_bytes[len(_alabel_prefix):]
        if not label_bytes:
            raise IDNAError('Malformed A-label, no Punycode eligible content found')
        if label_bytes.decode('ascii')[-1] == '-':
            raise IDNAError('A-label must not end with a hyphen')
    else:
        check_label(label_bytes)
        return label_bytes.decode('ascii')

    label = label_bytes.decode('punycode')
    check_label(label)
    return label


def uts46_remap(domain, std3_rules=True, transitional=False):
    # type: (str, bool, bool) -> str
    """Re-map the characters in the string according to UTS46 processing."""
    from .uts46data import uts46data
    output = ''

    for pos, char in enumerate(domain):
        code_point = ord(char)
        try:
            uts46row = uts46data[code_point if code_point < 256 else
                bisect.bisect_left(uts46data, (code_point, 'Z')) - 1]
            status = uts46row[1]
            replacement = None  # type: Optional[str]
            if len(uts46row) == 3:
                replacement = uts46row[2]  # type: ignore
            if (status == 'V' or
                    (status == 'D' and not transitional) or
                    (status == '3' and not std3_rules and replacement is None)):
                output += char
            elif replacement is not None and (status == 'M' or
                    (status == '3' and not std3_rules) or
                    (status == 'D' and transitional)):
                output += replacement
            elif status != 'I':
                raise IndexError()
        except IndexError:
            raise InvalidCodepoint(
                'Codepoint {} not allowed at position {} in {}'.format(
                _unot(code_point), pos + 1, repr(domain)))

    return unicodedata.normalize('NFC', output)


def encode(s, strict=False, uts46=False, std3_rules=False, transitional=False):
    # type: (Union[str, bytes, bytearray], bool, bool, bool, bool) -> bytes
    if isinstance(s, (bytes, bytearray)):
        s = s.decode('ascii')
    if uts46:
        s = uts46_remap(s, std3_rules, transitional)
    trailing_dot = False
    result = []
    if strict:
        labels = s.split('.')
    else:
        labels = _unicode_dots_re.split(s)
    if not labels or labels == ['']:
        raise IDNAError('Empty domain')
    if labels[-1] == '':
        del labels[-1]
        trailing_dot = True
    for label in labels:
        s = alabel(label)
        if s:
            result.append(s)
        else:
            raise IDNAError('Empty label')
    if trailing_dot:
        result.append(b'')
    s = b'.'.join(result)
    if not valid_string_length(s, trailing_dot):
        raise IDNAError('Domain too long')
    return s


def decode(s, strict=False, uts46=False, std3_rules=False):
    # type: (Union[str, bytes, bytearray], bool, bool, bool) -> str
    if isinstance(s, (bytes, bytearray)):
        s = s.decode('ascii')
    if uts46:
        s = uts46_remap(s, std3_rules, False)
    trailing_dot = False
    result = []
    if not strict:
        labels = _unicode_dots_re.split(s)
    else:
        labels = s.split('.')
    if not labels or labels == ['']:
        raise IDNAError('Empty domain')
    if not labels[-1]:
        del labels[-1]
        trailing_dot = True
    for label in labels:
        s = ulabel(label)
        if s:
            result.append(s)
        else:
            raise IDNAError('Empty label')
    if trailing_dot:
        result.append('')
    return '.'.join(result)
