from functools import reduce
import itertools
from operator import add

from sympy.codegen.matrix_nodes import MatrixSolve
from sympy.core.add import Add
from sympy.core.containers import Tuple
from sympy.core.expr import UnevaluatedExpr
from sympy.core.function import Function
from sympy.core.mul import Mul
from sympy.core.power import Pow
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Symbol, symbols)
from sympy.core.sympify import sympify
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.matrices.dense import Matrix
from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose
from sympy.polys.rootoftools import CRootOf
from sympy.series.order import O
from sympy.simplify.cse_main import cse
from sympy.simplify.simplify import signsimp
from sympy.tensor.indexed import (Idx, IndexedBase)

from sympy.core.function import count_ops
from sympy.simplify.cse_opts import sub_pre, sub_post
from sympy.functions.special.hyper import meijerg
from sympy.simplify import cse_main, cse_opts
from sympy.utilities.iterables import subsets
from sympy.testing.pytest import XFAIL, raises
from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix,
        ImmutableDenseMatrix, ImmutableSparseMatrix)
from sympy.matrices.expressions import MatrixSymbol


w, x, y, z = symbols('w,x,y,z')
x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13')


def test_numbered_symbols():
    ns = cse_main.numbered_symbols(prefix='y')
    assert list(itertools.islice(
        ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)]
    ns = cse_main.numbered_symbols(prefix='y')
    assert list(itertools.islice(
        ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)]
    ns = cse_main.numbered_symbols()
    assert list(itertools.islice(
        ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)]

# Dummy "optimization" functions for testing.


def opt1(expr):
    return expr + y


def opt2(expr):
    return expr*z


def test_preprocess_for_cse():
    assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y
    assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x
    assert cse_main.preprocess_for_cse(x, [(None, None)]) == x
    assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y
    assert cse_main.preprocess_for_cse(
        x, [(opt1, None), (opt2, None)]) == (x + y)*z


def test_postprocess_for_cse():
    assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
    assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y
    assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
    assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
    # Note the reverse order of application.
    assert cse_main.postprocess_for_cse(
        x, [(None, opt1), (None, opt2)]) == x*z + y


def test_cse_single():
    # Simple substitution.
    e = Add(Pow(x + y, 2), sqrt(x + y))
    substs, reduced = cse([e])
    assert substs == [(x0, x + y)]
    assert reduced == [sqrt(x0) + x0**2]

    subst42, (red42,) = cse([42])  # issue_15082
    assert len(subst42) == 0 and red42 == 42
    subst_half, (red_half,) = cse([0.5])
    assert len(subst_half) == 0 and red_half == 0.5


def test_cse_single2():
    # Simple substitution, test for being able to pass the expression directly
    e = Add(Pow(x + y, 2), sqrt(x + y))
    substs, reduced = cse(e)
    assert substs == [(x0, x + y)]
    assert reduced == [sqrt(x0) + x0**2]
    substs, reduced = cse(Matrix([[1]]))
    assert isinstance(reduced[0], Matrix)

    subst42, (red42,) = cse(42)  # issue 15082
    assert len(subst42) == 0 and red42 == 42
    subst_half, (red_half,) = cse(0.5)  # issue 15082
    assert len(subst_half) == 0 and red_half == 0.5


def test_cse_not_possible():
    # No substitution possible.
    e = Add(x, y)
    substs, reduced = cse([e])
    assert substs == []
    assert reduced == [x + y]
    # issue 6329
    eq = (meijerg((1, 2), (y, 4), (5,), [], x) +
          meijerg((1, 3), (y, 4), (5,), [], x))
    assert cse(eq) == ([], [eq])


def test_nested_substitution():
    # Substitution within a substitution.
    e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
    substs, reduced = cse([e])
    assert substs == [(x0, w*x + y)]
    assert reduced == [sqrt(x0) + x0**2]


def test_subtraction_opt():
    # Make sure subtraction is optimized.
    e = (x - y)*(z - y) + exp((x - y)*(z - y))
    substs, reduced = cse(
        [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
    assert substs == [(x0, (x - y)*(y - z))]
    assert reduced == [-x0 + exp(-x0)]
    e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
    substs, reduced = cse(
        [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
    assert substs == [(x0, (x - y)*(y - z))]
    assert reduced == [x0 + exp(x0)]
    # issue 4077
    n = -1 + 1/x
    e = n/x/(-n)**2 - 1/n/x
    assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
        ([], [0])
    assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \
        ([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3])


def test_multiple_expressions():
    e1 = (x + y)*z
    e2 = (x + y)*w
    substs, reduced = cse([e1, e2])
    assert substs == [(x0, x + y)]
    assert reduced == [x0*z, x0*w]
    l = [w*x*y + z, w*y]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == rsubsts
    assert reduced == [z + x*x0, x0]
    l = [w*x*y, w*x*y + z, w*y]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == rsubsts
    assert reduced == [x1, x1 + z, x0]
    l = [(x - z)*(y - z), x - z, y - z]
    substs, reduced = cse(l)
    rsubsts, _ = cse(reversed(l))
    assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
    assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
    assert reduced == [x1*x2, x1, x2]
    l = [w*y + w + x + y + z, w*x*y]
    assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
    assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
    assert cse([x + y, x + z]) == ([], [x + y, x + z])
    assert cse([x*y, z + x*y, x*y*z + 3]) == \
        ([(x0, x*y)], [x0, z + x0, 3 + x0*z])


@XFAIL # CSE of non-commutative Mul terms is disabled
def test_non_commutative_cse():
    A, B, C = symbols('A B C', commutative=False)
    l = [A*B*C, A*C]
    assert cse(l) == ([], l)
    l = [A*B*C, A*B]
    assert cse(l) == ([(x0, A*B)], [x0*C, x0])


# Test if CSE of non-commutative Mul terms is disabled
def test_bypass_non_commutatives():
    A, B, C = symbols('A B C', commutative=False)
    l = [A*B*C, A*C]
    assert cse(l) == ([], l)
    l = [A*B*C, A*B]
    assert cse(l) == ([], l)
    l = [B*C, A*B*C]
    assert cse(l) == ([], l)


@XFAIL # CSE fails when replacing non-commutative sub-expressions
def test_non_commutative_order():
    A, B, C = symbols('A B C', commutative=False)
    x0 = symbols('x0', commutative=False)
    l = [B+C, A*(B+C)]
    assert cse(l) == ([(x0, B+C)], [x0, A*x0])


@XFAIL # Worked in gh-11232, but was reverted due to performance considerations
def test_issue_10228():
    assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0])
    assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0])
    assert cse((w + 2*x + y + z, w + x + 1)) == (
        [(x0, w + x)], [x0 + x + y + z, x0 + 1])
    assert cse(((w + x + y + z)*(w - x))/(w + x)) == (
        [(x0, w + x)], [(x0 + y + z)*(w - x)/x0])
    a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m')
    exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2)
    assert cse(exprs) == (
        [(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1]
)

@XFAIL
def test_powers():
    assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0])


def test_issue_4498():
    assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
        ([], [(w - z)/(x - y)])


def test_issue_4020():
    assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
        == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])


def test_issue_4203():
    assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])


def test_issue_6263():
    e = Eq(x*(-x + 1) + x*(x - 1), 0)
    assert cse(e, optimizations='basic') == ([], [True])


def test_dont_cse_tuples():
    from sympy.core.function import Subs
    f = Function("f")
    g = Function("g")

    name_val, (expr,) = cse(
        Subs(f(x, y), (x, y), (0, 1))
        + Subs(g(x, y), (x, y), (0, 1)))

    assert name_val == []
    assert expr == (Subs(f(x, y), (x, y), (0, 1))
            + Subs(g(x, y), (x, y), (0, 1)))

    name_val, (expr,) = cse(
        Subs(f(x, y), (x, y), (0, x + y))
        + Subs(g(x, y), (x, y), (0, x + y)))

    assert name_val == [(x0, x + y)]
    assert expr == Subs(f(x, y), (x, y), (0, x0)) + \
        Subs(g(x, y), (x, y), (0, x0))


def test_pow_invpow():
    assert cse(1/x**2 + x**2) == \
        ([(x0, x**2)], [x0 + 1/x0])
    assert cse(x**2 + (1 + 1/x**2)/x**2) == \
        ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
    assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
        ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
    assert cse(cos(1/x**2) + sin(1/x**2)) == \
        ([(x0, x**(-2))], [sin(x0) + cos(x0)])
    assert cse(cos(x**2) + sin(x**2)) == \
        ([(x0, x**2)], [sin(x0) + cos(x0)])
    assert cse(y/(2 + x**2) + z/x**2/y) == \
        ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
    assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
        ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
    assert cse((1 + 1/x**2)/x**2) == \
        ([(x0, x**(-2))], [x0*(x0 + 1)])
    assert cse(x**(2*y) + x**(-2*y)) == \
        ([(x0, x**(2*y))], [x0 + 1/x0])


def test_postprocess():
    eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
    assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
        postprocess=cse_main.cse_separate) == \
        [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)],
        [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]]


def test_issue_4499():
    # previously, this gave 16 constants
    from sympy.abc import a, b
    B = Function('B')
    G = Function('G')
    t = Tuple(*
        (a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
        b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
        sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
        sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
        sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
        (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
        sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b,
        -2*a))
    c = cse(t)
    ans = (
        [(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)),
         (x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)),
         (x8, B(b, x4)), (x9, x6*B(x2, x4))],
        [(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9,
          1, 0, S.Half, z/2, -x3, -x1, -x0)])
    assert ans == c


def test_issue_6169():
    r = CRootOf(x**6 - 4*x**5 - 2, 1)
    assert cse(r) == ([], [r])
    # and a check that the right thing is done with the new
    # mechanism
    assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y


def test_cse_Indexed():
    len_y = 5
    y = IndexedBase('y', shape=(len_y,))
    x = IndexedBase('x', shape=(len_y,))
    i = Idx('i', len_y-1)

    expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
    expr2 = 1/(x[i+1]-x[i])
    replacements, reduced_exprs = cse([expr1, expr2])
    assert len(replacements) > 0


def test_cse_MatrixSymbol():
    # MatrixSymbols have non-Basic args, so make sure that works
    A = MatrixSymbol("A", 3, 3)
    assert cse(A) == ([], [A])

    n = symbols('n', integer=True)
    B = MatrixSymbol("B", n, n)
    assert cse(B) == ([], [B])

    assert cse(A[0] * A[0]) == ([], [A[0]*A[0]])

    assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0])

def test_cse_MatrixExpr():
    A = MatrixSymbol('A', 3, 3)
    y = MatrixSymbol('y', 3, 1)

    expr1 = (A.T*A).I * A * y
    expr2 = (A.T*A) * A * y
    replacements, reduced_exprs = cse([expr1, expr2])
    assert len(replacements) > 0

    replacements, reduced_exprs = cse([expr1 + expr2, expr1])
    assert replacements

    replacements, reduced_exprs = cse([A**2, A + A**2])
    assert replacements


def test_Piecewise():
    f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
    ans = cse(f)
    actual_ans = ([(x0, x*y)],
        [Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))])
    assert ans == actual_ans


def test_ignore_order_terms():
    eq = exp(x).series(x,0,3) + sin(y+x**3) - 1
    assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)])


def test_name_conflict():
    z1 = x0 + y
    z2 = x2 + x3
    l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
    substs, reduced = cse(l)
    assert [e.subs(reversed(substs)) for e in reduced] == l


def test_name_conflict_cust_symbols():
    z1 = x0 + y
    z2 = x2 + x3
    l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
    substs, reduced = cse(l, symbols("x:10"))
    assert [e.subs(reversed(substs)) for e in reduced] == l


def test_symbols_exhausted_error():
    l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
    sym = [x, y, z]
    with raises(ValueError):
        cse(l, symbols=sym)


def test_issue_7840():
    # daveknippers' example
    C393 = sympify( \
        'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
        C391 > 2.35), (C392, True)), True))'
    )
    C391 = sympify( \
        'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))'
    )
    C393 = C393.subs('C391',C391)
    # simple substitution
    sub = {}
    sub['C390'] = 0.703451854
    sub['C392'] = 1.01417794
    ss_answer = C393.subs(sub)
    # cse
    substitutions,new_eqn = cse(C393)
    for pair in substitutions:
        sub[pair[0].name] = pair[1].subs(sub)
    cse_answer = new_eqn[0].subs(sub)
    # both methods should be the same
    assert ss_answer == cse_answer

    # GitRay's example
    expr = sympify(
        "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \
        (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \
        Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \
        Symbol('AUTO'))), (Symbol('OFF'), true)), true))"
    )
    substitutions, new_eqn = cse(expr)
    # this Piecewise should be exactly the same
    assert new_eqn[0] == expr
    # there should not be any replacements
    assert len(substitutions) < 1


def test_issue_8891():
    for cls in (MutableDenseMatrix, MutableSparseMatrix,
            ImmutableDenseMatrix, ImmutableSparseMatrix):
        m = cls(2, 2, [x + y, 0, 0, 0])
        res = cse([x + y, m])
        ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])])
        assert res == ans
        assert isinstance(res[1][-1], cls)


def test_issue_11230():
    # a specific test that always failed
    a, b, f, k, l, i = symbols('a b f k l i')
    p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l]
    R, C = cse(p)
    assert not any(i.is_Mul for a in C for i in a.args)

    # random tests for the issue
    from sympy.core.random import choice
    from sympy.core.function import expand_mul
    s = symbols('a:m')
    # 35 Mul tests, none of which should ever fail
    ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)]
    for p in subsets(ex, 3):
        p = list(p)
        R, C = cse(p)
        assert not any(i.is_Mul for a in C for i in a.args)
        for ri in reversed(R):
            for i in range(len(C)):
                C[i] = C[i].subs(*ri)
        assert p == C
    # 35 Add tests, none of which should ever fail
    ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)]
    for p in subsets(ex, 3):
        p = list(p)
        R, C = cse(p)
        assert not any(i.is_Add for a in C for i in a.args)
        for ri in reversed(R):
            for i in range(len(C)):
                C[i] = C[i].subs(*ri)
        # use expand_mul to handle cases like this:
        # p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g]
        # x0 = 2*(b + e) is identified giving a rebuilt p that
        # is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]`
        assert p == [expand_mul(i) for i in C]


@XFAIL
def test_issue_11577():
    def check(eq):
        r, c = cse(eq)
        assert eq.count_ops() >= \
            len(r) + sum([i[1].count_ops() for i in r]) + \
            count_ops(c)

    eq = x**5*y**2 + x**5*y + x**5
    assert cse(eq) == (
        [(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1])
        # ([(x0, x**5*y)], [x0*y + x0 + x**5]) or
        # ([(x0, x**5)], [x0*y**2 + x0*y + x0])
    check(eq)

    eq = x**2/(y + 1)**2 + x/(y + 1)
    assert cse(eq) == (
        [(x0, y + 1)], [x**2/x0**2 + x/x0])
        # ([(x0, x/(y + 1))], [x0**2 + x0])
    check(eq)


def test_hollow_rejection():
    eq = [x + 3, x + 4]
    assert cse(eq) == ([], eq)


def test_cse_ignore():
    exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))]
    subst1, red1 = cse(exprs)
    assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y"

    subst2, red2 = cse(exprs, ignore=(y,))  # y is not allowed in substitutions
    assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored"
    assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression"

def test_cse_ignore_issue_15002():
    l = [
        w*exp(x)*exp(-z),
        exp(y)*exp(x)*exp(-z)
    ]
    substs, reduced = cse(l, ignore=(x,))
    rl = [e.subs(reversed(substs)) for e in reduced]
    assert rl == l


def test_cse_unevaluated():
    xp1 = UnevaluatedExpr(x + 1)
    # This used to cause RecursionError
    [(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)])
    if ue == xp1:
        assert red == (-1 - x0) / (1 - x0)
    elif ue == -xp1:
        assert red == (-1 + x0) / (1 + x0)
    else:
        msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}'
        assert False, msg


def test_cse__performance():
    nexprs, nterms = 3, 20
    x = symbols('x:%d' % nterms)
    exprs = [
        reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)])
        for i in range(nexprs)
    ]
    assert (exprs[0] + exprs[1]).simplify() == 0
    subst, red = cse(exprs)
    assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE"
    for i, e in enumerate(red):
        assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0


def test_issue_12070():
    exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z]
    subst, red = cse(exprs)
    assert 6 >= (len(subst) + sum([v.count_ops() for k, v in subst]) +
                 count_ops(red))


def test_issue_13000():
    eq = x/(-4*x**2 + y**2)
    cse_eq = cse(eq)[1][0]
    assert cse_eq == eq


def test_issue_18203():
    eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1)
    assert cse(eq) == ([], [eq])


def test_unevaluated_mul():
    eq = Mul(x + y, x + y, evaluate=False)
    assert cse(eq) == ([(x0, x + y)], [x0**2])

def test_cse_release_variables():
    from sympy.simplify.cse_main import cse_release_variables
    _0, _1, _2, _3, _4 = symbols('_:5')
    eqs = [(x + y - 1)**2, x,
        x + y, (x + y)/(2*x + 1) + (x + y - 1)**2,
        (2*x + 1)**(x + y)]
    r, e = cse(eqs, postprocess=cse_release_variables)
    # this can change in keeping with the intention of the function
    assert r, e == ([
    (x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1),
    (_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1),
    (x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4))
    r.reverse()
    r = [(s, v) for s, v in r if v is not None]
    assert eqs == [i.subs(r) for i in e]

def test_cse_list():
    _cse = lambda x: cse(x, list=False)
    assert _cse(x) == ([], x)
    assert _cse('x') == ([], 'x')
    it = [x]
    for c in (list, tuple, set):
        assert _cse(c(it)) == ([], c(it))
    #Tuple works different from tuple:
    assert _cse(Tuple(*it)) == ([], Tuple(*it))
    d = {x: 1}
    assert _cse(d) == ([], d)

def test_issue_18991():
    A = MatrixSymbol('A', 2, 2)
    assert signsimp(-A * A - A) == -A * A - A


def test_unevaluated_Mul():
    m = [Mul(1, 2, evaluate=False)]
    assert cse(m) == ([], m)


def test_cse_matrix_expression_inverse():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    x = Inverse(A)
    cse_expr = cse(x)
    assert cse_expr == ([], [Inverse(A)])


def test_cse_matrix_expression_matmul_inverse():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    b = ImmutableDenseMatrix(symbols('b:2'))
    x = MatMul(Inverse(A), b)
    cse_expr = cse(x)
    assert cse_expr == ([], [x])


def test_cse_matrix_negate_matrix():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    x = MatMul(S.NegativeOne, A)
    cse_expr = cse(x)
    assert cse_expr == ([], [x])


def test_cse_matrix_negate_matmul_not_extracted():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
    x = MatMul(S.NegativeOne, A, B)
    cse_expr = cse(x)
    assert cse_expr == ([], [x])


@XFAIL  # No simplification rule for nested associative operations
def test_cse_matrix_nested_matmul_collapsed():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
    x = MatMul(S.NegativeOne, MatMul(A, B))
    cse_expr = cse(x)
    assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)])


def test_cse_matrix_optimize_out_single_argument_mul():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    x = MatMul(MatMul(MatMul(A)))
    cse_expr = cse(x)
    assert cse_expr == ([], [A])


@XFAIL  # Multiple simplification passed not supported in CSE
def test_cse_matrix_optimize_out_single_argument_mul_combined():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A)
    cse_expr = cse(x)
    assert cse_expr == ([], [MatMul(4, A)])


def test_cse_matrix_optimize_out_single_argument_add():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    x = MatAdd(MatAdd(MatAdd(MatAdd(A))))
    cse_expr = cse(x)
    assert cse_expr == ([], [A])


@XFAIL  # Multiple simplification passed not supported in CSE
def test_cse_matrix_optimize_out_single_argument_add_combined():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A)
    cse_expr = cse(x)
    assert cse_expr == ([], [MatMul(4, A)])


def test_cse_matrix_expression_matrix_solve():
    A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
    b = ImmutableDenseMatrix(symbols('b:2'))
    x = MatrixSolve(A, b)
    cse_expr = cse(x)
    assert cse_expr == ([], [x])


def test_cse_matrix_matrix_expression():
    X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2)
    y = ImmutableDenseMatrix(symbols('y:2'))
    b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y)
    cse_expr = cse(b)
    x0 = MatrixSymbol('x0', 2, 2)
    reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y)
    assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected])


def test_cse_matrix_kalman_filter():
    """Kalman Filter example from Matthew Rocklin's SciPy 2013 talk.

    Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation"

    Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html

    Notes
    =====

    Equations are:

    new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data)
           = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
    new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma
              = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma))

    """
    N = 2
    mu = ImmutableDenseMatrix(symbols(f'mu:{N}'))
    Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N)
    H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N)
    R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N)
    data = ImmutableDenseMatrix(symbols(f'data:{N}'))
    new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
    new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma))
    cse_expr = cse([new_mu, new_Sigma])
    x0 = MatrixSymbol('x0', N, N)
    x1 = MatrixSymbol('x1', N, N)
    replacements_expected = [
        (x0, Transpose(H)),
        (x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))),
    ]
    reduced_exprs_expected = [
        MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))),
        MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)),
    ]
    assert cse_expr == (replacements_expected, reduced_exprs_expected)
