"""Utilities related to converting things."""
import math
import re
from typing import Iterable

import numexpr as ne

import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.calculus.util import AccumBounds

from . import constants
from . import sympy_util as sp_util

OPERATOR_ARITIES = constants.OPERATOR_ARITIES
SYMPY_OPERATORS = constants.SYMPY_OPERATORS

###############################################################################


class ValueErrorExpression(Exception):
    pass


class UnknownSymPyOperator(Exception):
    pass


class InvalidPrefixExpression(Exception):

    def __init__(self, data):
        self.data = data

    def __str__(self):
        return repr(self.data)


###############################################################################


# def count_nested_exp(s):
#     """Return the maximum number of nested exponential functions in an infix expression."""
#     stack = []
#     count = 0
#     max_count = 0
#     for v in re.findall('[+-/*//()]|[a-zA-Z0-9]+', s):
#         if v == '(':
#             stack.append(v)
#         elif v == ')':
#             while True:
#                 x = stack.pop()
#                 if x in EXP_OPERATORS:
#                     count -= 1
#                 if x == '(':
#                     break
#         else:
#             stack.append(v)
#             if v in EXP_OPERATORS:
#                 count += 1
#                 max_count = max(max_count, count)
#     assert len(stack) == 0
#     return max_count


# def is_valid_expr(s):
#     """Check that we are able to evaluate an expression (and that it will not blow in SymPy evaluation)."""
#     s = s.replace('Derivative(f(x),x)', '1')
#     s = s.replace('Derivative(1,x)', '1')
#     s = s.replace('(E)', '(exp(1))')
#     s = s.replace('(I)', '(1)')
#     s = s.replace('(pi)', '(1)')
#     s = re.sub(r'(?<![a-z])(f|g|h|Abs|sign|ln|sin|cos|tan|sec|csc|cot|asin|acos|atan|asec|acsc|acot|tanh|sech|csch|coth|asinh|acosh|atanh|asech|acoth|acsch)\(', '(', s)
#     count = count_nested_exp(s)
#     if count >= 4:
#         return False
#     for v in EVAL_VALUES:
#         try:
#             local_dict = {s: (v + 1e-4 * i) for i, s in enumerate(EVAL_SYMBOLS)}
#             value = ne.evaluate(s, local_dict=local_dict).item()
#             if not (math.isnan(value) or math.isinf(value)):
#                 return True
#         except (FloatingPointError, ZeroDivisionError, TypeError, MemoryError):
#             continue
#     return False

###############################################################################


def write_int(val):
    """
    Convert a decimal integer to a representation in the given base.
    The base can be negative.
    In balanced bases (positive), digits range from -(base-1)//2 to (base-1)//2
    """
    base = 10
    balanced = False
    res = []
    max_digit = abs(base)
    if balanced:
        max_digit = (base - 1) // 2
    else:
        if base > 0:
            neg = val < 0
            val = -val if neg else val
    while True:
        rem = val % base
        val = val // base
        if rem < 0 or rem > max_digit:
            rem -= base
            val += 1
        res.append(str(rem))
        if val == 0:
            break
    if base < 0 or balanced:
        res.append('INT')
    else:
        res.append('INT-' if neg else 'INT+')
    return res[::-1]


def parse_int(lst):
    """
    Parse a list that starts with an integer.
    Return the integer value, and the position it ends in the list.
    """
    base = 10
    balanced = False
    val = 0
    if not (balanced and lst[0] == 'INT' or base >= 2 and lst[0] in ['INT+', 'INT-'] or base <= -2 and lst[0] == 'INT'):
        raise InvalidPrefixExpression("Invalid integer in prefix expression")
    i = 0
    for x in lst[1:]:
        if not (x.isdigit() or x[0] == '-' and x[1:].isdigit()):
            break
        val = val * base + int(x)
        i += 1
    if base > 0 and lst[0] == 'INT-':
        val = -val
    return val, i + 1


###############################################################################


class Converter:

    def __init__(
        self,
        all_operators: Iterable[str],
        all_leaf_symbols: Iterable[str],
        rewrite_functions: Iterable[str],
        variable_names: Iterable[str],
        local_dict,
    ):
        self._all_operators = frozenset(all_operators)
        self._all_leaf_symbols = frozenset(all_leaf_symbols)
        self._rewrite_functions = tuple(rewrite_functions)
        self._variable_names = frozenset(variable_names)
        self._local_dict = local_dict

    ###########################################################################

    def _write_infix(self, token, args):  # noqa: C901
        """Convert prefix expressions to a format that SymPy can parse."""
        if token == 'add':
            return f'({args[0]})+({args[1]})'
        elif token == 'sub':
            return f'({args[0]})-({args[1]})'
        elif token == 'mul':
            return f'({args[0]})*({args[1]})'
        elif token == 'div':
            return f'({args[0]})/({args[1]})'
        elif token == 'pow':
            return f'({args[0]})**({args[1]})'
        elif token == 'rac':
            return f'({args[0]})**(1/({args[1]}))'
        elif token == 'abs':
            return f'Abs({args[0]})'
        elif token == 'inv':
            return f'1/({args[0]})'
        elif token == 'pow2':
            return f'({args[0]})**2'
        elif token == 'pow3':
            return f'({args[0]})**3'
        elif token == 'pow4':
            return f'({args[0]})**4'
        elif token == 'pow5':
            return f'({args[0]})**5'
        elif token in ['sign', 'sqrt', 'exp', 'ln', 'sin', 'cos', 'tan', 'cot', 'sec', 'csc', 'asin', 'acos', 'atan', 'acot', 'asec', 'acsc', 'sinh', 'cosh', 'tanh', 'coth', 'sech', 'csch', 'asinh', 'acosh', 'atanh', 'acoth', 'asech', 'acsch']:
            return f'{token}({args[0]})'
        elif token == 'derivative':
            return f'Derivative({args[0]},{args[1]})'
        elif token == 'f':
            return f'f({args[0]})'
        elif token == 'g':
            return f'g({args[0]},{args[1]})'
        elif token == 'h':
            return f'h({args[0]},{args[1]},{args[2]})'
        elif token.startswith('INT'):
            return f'{token[-1]}{args[0]}'
        else:
            return token
        raise InvalidPrefixExpression(f"Unknown token in prefix expression: {token}, with arguments {args}")

    def _prefix_to_infix(self, expr):
        """
        Parse an expression in prefix mode, and output it in either:
          - infix mode (returns human readable string)
          - develop mode (returns a dictionary with the simplified expression)
        """
        if len(expr) == 0:
            raise InvalidPrefixExpression("Empty prefix list.")
        t = expr[0]
        if t in self._all_operators:
            args = []
            l1 = expr[1:]
            for _ in range(OPERATOR_ARITIES[t]):
                i1, l1 = self._prefix_to_infix(l1)
                args.append(i1)
            return self._write_infix(t, args), l1
        elif t in self._all_leaf_symbols or t == 'I':
            return t, expr[1:]
        else:
            val, i = parse_int(expr)
            return str(val), expr[i:]

    def prefix_to_infix(self, expr):
        """
        Prefix to infix conversion.
        """
        p, r = self._prefix_to_infix(expr)
        if len(r) > 0:
            raise InvalidPrefixExpression(f"Incorrect prefix expression \"{expr}\". \"{r}\" was not parsed.")
        return f'({p})'

    ###########################################################################

    def rewrite_sympy_expr(self, expr):
        """Rewrite a SymPy expression."""
        return sp_util.rewrite_sympy_expr(expr, self._rewrite_functions)

    def infix_to_sympy(self, infix, no_rewrite=False):
        """Convert an infix expression to SymPy."""
        # if not is_valid_expr(infix):
        #     raise ValueErrorExpression

        expr = parse_expr(infix, evaluate=True, local_dict=self._local_dict)
        # expr = parse_expr(infix, evaluate=False, local_dict=self._local_dict)

        if expr.has(sp.I) or expr.has(AccumBounds):
            raise ValueErrorExpression

        if not no_rewrite:
            expr = self.rewrite_sympy_expr(expr)

        return expr

    ###########################################################################

    def _sympy_to_prefix(self, op, expr):
        """Parse a SymPy expression given an initial root operator."""
        n_args = len(expr.args)

        # derivative operator
        if op == 'derivative':
            assert n_args >= 2
            assert all(len(arg) == 2 and str(arg[0]) in self._variable_names and int(arg[1]) >= 1 for arg in expr.args[1:]), expr.args
            parse_list = self.sympy_to_prefix(expr.args[0])
            for var, degree in expr.args[1:]:
                parse_list = ['derivative' for _ in range(int(degree))] + parse_list + [str(var) for _ in range(int(degree))]
            return parse_list

        assert (op == 'add' or op == 'mul') and (n_args >= 2) or (op != 'add' and op != 'mul') and (1 <= n_args <= 2)

        # square root
        if op == 'pow' and isinstance(expr.args[1], sp.Rational) and expr.args[1].p == 1 and expr.args[1].q == 2:
            return ['sqrt'] + self.sympy_to_prefix(expr.args[0])

        # parse children
        parse_list = []
        for i in range(n_args):
            if i == 0 or i < n_args - 1:
                parse_list.append(op)
            parse_list += self.sympy_to_prefix(expr.args[i])

        return parse_list

    def sympy_to_prefix(self, expr):
        """Convert a SymPy expression to a prefix one."""
        if isinstance(expr, sp.Symbol):
            return [str(expr)]
        elif isinstance(expr, sp.Integer):
            return write_int(int(str(expr)))
        elif isinstance(expr, sp.Rational):
            return ['div'] + write_int(int(expr.p)) + write_int(int(expr.q))
        elif expr == sp.E:
            return ['E']
        elif expr == sp.pi:
            return ['pi']
        elif expr == sp.I:
            return ['I']

        # SymPy operator
        for op_type, op_name in SYMPY_OPERATORS.items():
            if isinstance(expr, op_type):
                return self._sympy_to_prefix(op_name, expr)

        # unknown operator
        raise UnknownSymPyOperator(f"Unknown SymPy operator: {expr}")
