import contextlib
import itertools
import re
import typing
from enum import Enum
from typing import Callable

import sympy
from sympy import Add, Implies, sqrt
from sympy.core import Mul, Pow
from sympy.core import (S, pi, symbols, Function, Rational, Integer,
                        Symbol, Eq, Ne, Le, Lt, Gt, Ge)
from sympy.functions import Piecewise, exp, sin, cos
from sympy.printing.smtlib import smtlib_code
from sympy.testing.pytest import raises, Failed

x, y, z = symbols('x,y,z')


class _W(Enum):
    DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.I)
    WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.I)
    WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.I)


@contextlib.contextmanager
def _check_warns(expected: typing.Iterable[_W]):
    warns: typing.List[str] = []
    log_warn = warns.append
    yield log_warn

    errors = []
    for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)):
        if not e:
            errors += [f"[{i}] Received unexpected warning `{w}`."]
        elif not w:
            errors += [f"[{i}] Did not receive expected warning `{e.name}`."]
        elif not e.value.match(w):
            errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."]

    if errors: raise Failed('\n'.join(errors))


def test_Integer():
    with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w:
        assert smtlib_code(Integer(67), log_warn=w) == "67"
        assert smtlib_code(Integer(-1), log_warn=w) == "-1"
    with _check_warns([]) as w:
        assert smtlib_code(Integer(67)) == "67"
        assert smtlib_code(Integer(-1)) == "-1"


def test_Rational():
    with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w:
        assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)"
        assert smtlib_code(Rational(18, 9), log_warn=w) == "2"
        assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)"
        assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)"

    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w:
        assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)"
        assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \
                                                              "(* (/ 3 7) x)"


def test_Relational():
    with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
        assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
        assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
        assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
        assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
        assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
        assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"


def test_Function():
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))"

    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            abs(x),
            symbol_table={x: int, y: bool},
            known_types={int: "INTEGER_TYPE"},
            known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"},
            log_warn=w
        ) == "(declare-const x INTEGER_TYPE)\n" \
             "(ABSOLUTE_VALUE_OF x)"

    my_fun1 = Function('f1')
    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            my_fun1(x),
            symbol_table={my_fun1: Callable[[bool], float]},
            log_warn=w
        ) == "(declare-const x Bool)\n" \
             "(declare-fun f1 (Bool) Real)\n" \
             "(f1 x)"

    with _check_warns([]) as w:
        assert smtlib_code(
            my_fun1(x),
            symbol_table={my_fun1: Callable[[bool], bool]},
            log_warn=w
        ) == "(declare-const x Bool)\n" \
             "(declare-fun f1 (Bool) Bool)\n" \
             "(assert (f1 x))"

        assert smtlib_code(
            Eq(my_fun1(x, z), y),
            symbol_table={my_fun1: Callable[[int, bool], bool]},
            log_warn=w
        ) == "(declare-const x Int)\n" \
             "(declare-const y Bool)\n" \
             "(declare-const z Bool)\n" \
             "(declare-fun f1 (Int Bool) Bool)\n" \
             "(assert (= (f1 x z) y))"

        assert smtlib_code(
            Eq(my_fun1(x, z), y),
            symbol_table={my_fun1: Callable[[int, bool], bool]},
            known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
            log_warn=w
        ) == "(declare-const x Int)\n" \
             "(declare-const y Bool)\n" \
             "(declare-const z Bool)\n" \
             "(assert (== (MY_KNOWN_FUN x z) y))"

    with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w:
        assert smtlib_code(
            Eq(my_fun1(x, z), y),
            known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
            log_warn=w
        ) == "(declare-const x Real)\n" \
             "(declare-const y Real)\n" \
             "(declare-const z Real)\n" \
             "(assert (== (MY_KNOWN_FUN x z) y))"


def test_Pow():
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)"
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))"
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))'

        a = Symbol('a', integer=True)
        b = Symbol('b', real=True)
        c = Symbol('c')

        def g(x): return 2 * x

        # if x=1, y=2, then expr=2.333...
        expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b)

    with _check_warns([]) as w:
        assert smtlib_code(
            [
                Eq(a < 2, c),
                Eq(b > a, c),
                c & True,
                Eq(expr, 2 + Rational(1, 3))
            ],
            log_warn=w
        ) == '(declare-const a Int)\n' \
             '(declare-const b Real)\n' \
             '(declare-const c Bool)\n' \
             '(assert (= (< a 2) c))\n' \
             '(assert (= (> b a) c))\n' \
             '(assert c)\n' \
             '(assert (= ' \
             '(* (pow (* 7. a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \
             '(/ 7 3)' \
             '))'

    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False),
            log_warn=w
        ) == '(declare-const b Real)\n' \
             '(declare-const c Real)\n' \
             '(* -2 c (pow (* b b) -1))'


def test_basic_ops():
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)"

    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)"

    # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w:
    # todo: implement re-write, currently does '(+ x (* -1 y))' instead
    # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)"

    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)"


def test_quantifier_extensions():
    from sympy.logic.boolalg import Boolean
    from sympy import Interval, Tuple, sympify

    # start For-all quantifier class example
    class ForAll(Boolean):
        def _smtlib(self, printer):
            bound_symbol_declarations = [
                printer._s_expr(sym.name, [
                    printer._known_types[printer.symbol_table[sym]],
                    Interval(start, end)
                ]) for sym, start, end in self.limits
            ]
            return printer._s_expr('forall', [
                printer._s_expr('', bound_symbol_declarations),
                self.function
            ])

        @property
        def bound_symbols(self):
            return {s for s, _, _ in self.limits}

        @property
        def free_symbols(self):
            bound_symbol_names = {s.name for s in self.bound_symbols}
            return {
                s for s in self.function.free_symbols
                if s.name not in bound_symbol_names
            }

        def __new__(cls, *args):
            limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))]
            function = [sympify(a) for a in args if isinstance(a, Boolean)]
            assert len(limits) + len(function) == len(args)
            assert len(function) == 1
            function = function[0]

            if isinstance(function, ForAll): return ForAll.__new__(
                ForAll, *(limits + function.limits), function.function
            )
            inst = Boolean.__new__(cls)
            inst._args = tuple(limits + [function])
            inst.limits = limits
            inst.function = function
            return inst

    # end For-All Quantifier class example

    f = Function('f')
    with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
        assert smtlib_code(
            ForAll((x, -42, +21), Eq(f(x), f(x))),
            symbol_table={f: Callable[[float], float]},
            log_warn=w
        ) == '(assert (forall ( (x Real [-42, 21])) true))'

    with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w:
        assert smtlib_code(
            ForAll(
                (x, -42, +21), (y, -100, 3),
                Implies(Eq(x, y), Eq(f(x), f(y)))
            ),
            symbol_table={f: Callable[[float], float]},
            log_warn=w
        ) == '(declare-fun f (Real) Real)\n' \
             '(assert (' \
             'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \
             '(=> (= x y) (= (f x) (f y)))' \
             '))'

    a = Symbol('a', integer=True)
    b = Symbol('b', real=True)
    c = Symbol('c')

    with _check_warns([]) as w:
        assert smtlib_code(
            ForAll(
                (a, 2, 100), ForAll(
                    (b, 2, 100),
                    Implies(a < b, sqrt(a) < b) | c
                )),
            log_warn=w
        ) == '(declare-const c Bool)\n' \
             '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \
             '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \
             '))'


def test_mix_number_mult_symbols():
    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            1 / pi,
            known_constants={pi: "MY_PI"},
            log_warn=w
        ) == '(pow MY_PI -1)'

    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            [
                Eq(pi, 3.14, evaluate=False),
                1 / pi,
            ],
            known_constants={pi: "MY_PI"},
            log_warn=w
        ) == '(assert (= MY_PI 3.14))\n' \
             '(pow MY_PI -1)'

    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            Add(S.Zero, S.One, S.NegativeOne, S.Half,
                S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
            known_constants={
                S.Pi: 'p', S.GoldenRatio: 'g',
                S.Exp1: 'e'
            },
            known_functions={
                Add: 'plus',
                exp: 'exp'
            },
            precision=3,
            log_warn=w
        ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)'

    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            Add(S.Zero, S.One, S.NegativeOne, S.Half,
                S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
            known_constants={
                S.Pi: 'p'
            },
            known_functions={
                Add: 'plus',
                exp: 'exp'
            },
            precision=3,
            log_warn=w
        ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)'

    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            Add(S.Zero, S.One, S.NegativeOne, S.Half,
                S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
            known_functions={Add: 'plus'},
            precision=3,
            log_warn=w
        ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)'

    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            Add(S.Zero, S.One, S.NegativeOne, S.Half,
                S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
            known_constants={S.Exp1: 'e'},
            known_functions={Add: 'plus'},
            precision=3,
            log_warn=w
        ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)'


def test_boolean():
    with _check_warns([]) as w:
        assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \
                                                 '(declare-const y Bool)\n' \
                                                 '(assert (and x y))'
        assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \
                                                 '(declare-const y Bool)\n' \
                                                 '(assert (or x y))'
        assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \
                                              '(assert (not x))'
        assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \
                                                     '(declare-const y Bool)\n' \
                                                     '(declare-const z Bool)\n' \
                                                     '(assert (and x y z))'

    with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
        assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \
                                                              '(declare-const y Bool)\n' \
                                                              '(declare-const z Real)\n' \
                                                              '(assert (or (> z 3) (and x (not y))))'

    f = Function('f')
    g = Function('g')
    h = Function('h')
    with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
        assert smtlib_code(
            [Gt(f(x), y),
             Lt(y, g(z))],
            symbol_table={
                f: Callable[[bool], int], g: Callable[[bool], int],
            }, log_warn=w
        ) == '(declare-const x Bool)\n' \
             '(declare-const y Real)\n' \
             '(declare-const z Bool)\n' \
             '(declare-fun f (Bool) Int)\n' \
             '(declare-fun g (Bool) Int)\n' \
             '(assert (> (f x) y))\n' \
             '(assert (< y (g z)))'

    with _check_warns([]) as w:
        assert smtlib_code(
            [Eq(f(x), y),
             Lt(y, g(z))],
            symbol_table={
                f: Callable[[bool], int], g: Callable[[bool], int],
            }, log_warn=w
        ) == '(declare-const x Bool)\n' \
             '(declare-const y Int)\n' \
             '(declare-const z Bool)\n' \
             '(declare-fun f (Bool) Int)\n' \
             '(declare-fun g (Bool) Int)\n' \
             '(assert (= (f x) y))\n' \
             '(assert (< y (g z)))'

    with _check_warns([]) as w:
        assert smtlib_code(
            [Eq(f(x), y),
             Eq(g(f(x)), z),
             Eq(h(g(f(x))), x)],
            symbol_table={
                f: Callable[[float], int],
                g: Callable[[int], bool],
                h: Callable[[bool], float]
            },
            log_warn=w
        ) == '(declare-const x Real)\n' \
             '(declare-const y Int)\n' \
             '(declare-const z Bool)\n' \
             '(declare-fun f (Real) Int)\n' \
             '(declare-fun g (Int) Bool)\n' \
             '(declare-fun h (Bool) Real)\n' \
             '(assert (= (f x) y))\n' \
             '(assert (= (g (f x)) z))\n' \
             '(assert (= (h (g (f x))) x))'


# todo: make smtlib_code support arrays
# def test_containers():
#     assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
#            "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
#     assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
#     assert julia_code([1]) == "Any[1]"
#     assert julia_code((1,)) == "(1,)"
#     assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
#     assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))"
#     # scalar, matrix, empty matrix and empty list
#     assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"

def test_smtlib_piecewise():
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            Piecewise((x, x < 1),
                      (x ** 2, True)),
            auto_declare=False,
            log_warn=w
        ) == '(ite (< x 1) x (pow x 2))'

    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(
            Piecewise((x ** 2, x < 1),
                      (x ** 3, x < 2),
                      (x ** 4, x < 3),
                      (x ** 5, True)),
            auto_declare=False,
            log_warn=w
        ) == '(ite (< x 1) (pow x 2) ' \
             '(ite (< x 2) (pow x 3) ' \
             '(ite (< x 3) (pow x 4) ' \
             '(pow x 5))))'

    # Check that Piecewise without a True (default) condition error
    expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        raises(AssertionError, lambda: smtlib_code(expr, log_warn=w))


def test_smtlib_piecewise_times_const():
    pw = Piecewise((x, x < 1), (x ** 2, True))
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))'
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))'
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))'
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))'


# todo: make smtlib_code support arrays / matrices ?
# def test_smtlib_matrix_assign_to():
#     A = Matrix([[1, 2, 3]])
#     assert smtlib_code(A, assign_to='a') == "a = [1 2 3]"
#     A = Matrix([[1, 2], [3, 4]])
#     assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]"

# def test_julia_matrix_1x1():
#     A = Matrix([[3]])
#     B = MatrixSymbol('B', 1, 1)
#     C = MatrixSymbol('C', 1, 2)
#     assert julia_code(A, assign_to=B) == "B = [3]"
#     raises(ValueError, lambda: julia_code(A, assign_to=C))

# def test_julia_matrix_elements():
#     A = Matrix([[x, 2, x * y]])
#     assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
#     A = MatrixSymbol('AA', 1, 3)
#     assert julia_code(A) == "AA"
#     assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \
#            "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
#     assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"

def test_smtlib_boolean():
    with _check_warns([]) as w:
        assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true'
        assert smtlib_code(True, log_warn=w) == '(assert true)'
        assert smtlib_code(S.true, log_warn=w) == '(assert true)'
        assert smtlib_code(S.false, log_warn=w) == '(assert false)'
        assert smtlib_code(False, log_warn=w) == '(assert false)'
        assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false'


def test_not_supported():
    f = Function('f')
    with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
        raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w))
    with _check_warns([_W.WILL_NOT_ASSERT]) as w:
        raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w))
