R"""Tests for matching.py

Note that these are meant mostly to see if basic functionality works. Full
production quality tests would require more test cases.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

python3 em/datasets/antiderivative/expression_metadata/test_matching.py


"""
import unittest

import sympy as sp
from sympy.parsing.sympy_parser import parse_expr


from em.datasets.antiderivative.expression_metadata import matching as m


x = sp.Symbol('x')


class TestLiteral(unittest.TestCase):

    def test_match_str(self):
        self.assertTrue(m.Literal('x').match(x, x))
        self.assertTrue(m.Literal('4 * x - 6').match(4 * x - 6, x))

        self.assertFalse(m.Literal('4 * x - 6').match(x**3, x))

    def test_match_expr(self):
        self.assertTrue(m.Literal(x).match(x, x))
        self.assertTrue(m.Literal(4 * x - 6).match(4 * x - 6, x))

        self.assertFalse(m.Literal(4 * x - 6).match(x**3, x))

    def test_match_exprFromParseExpr(self):
        self.assertTrue(m.Literal('4 * x - 6').match(parse_expr('4 * x - 6', evaluate=True, local_dict={'x': x}), x))

    def test_match_multipleLiterals(self):
        self.assertTrue(m.Literal(['x', 'cos(y)']).match(x, x))
        self.assertFalse(m.Literal(['x**2', 'cos(y)']).match(x, x))


class TestFunction(unittest.TestCase):

    def test_match_noArgs(self):
        self.assertTrue(m.Function(sp.cos).match(sp.sympify('cos(x)'), x))
        self.assertTrue(m.Function(sp.cos).match(sp.sympify('cos(x**2 - 1)'), x))

        self.assertFalse(m.Function(sp.cos).match(sp.sympify('x'), x))
        self.assertFalse(m.Function(sp.cos).match(sp.sympify('sin(x**2 - 1)'), x))

    def test_match_withArgs(self):
        self.assertTrue(m.Function(sp.cos, m.Literal('x')).match(sp.sympify('cos(x)'), x))
        self.assertFalse(m.Function(sp.cos, m.Literal('y')).match(sp.sympify('cos(x)'), x))

        self.assertTrue(m.Function(sp.cos, m.Function(sp.cos, m.Literal('y'))).match(sp.sympify('cos(cos(y))'), x))
        self.assertFalse(m.Function(sp.cos, m.Function(sp.cos, m.Literal('y'))).match(sp.sympify('cos(y)'), x))

    def test_match_multipleFns(self):
        self.assertTrue(m.Function([sp.cos, sp.tan]).match(sp.sympify('cos(x)'), x))
        self.assertFalse(m.Function([sp.cos, sp.tan]).match(sp.sympify('sin(x)'), x))


class TestAdd(unittest.TestCase):
    
    def test_match(self):
        self.assertTrue(m.Add([m.Literal('5'), m.Literal('x')]).match(sp.sympify('5 + x'), x))
        self.assertFalse(m.Add([m.Literal('6'), m.Literal('x')]).match(sp.sympify('5 + x'), x))

    def test_match_orderDoesNotMatter(self):
        self.assertTrue(m.Add([m.Literal('5'), m.Literal('x')]).match(sp.sympify('x + 5'), x))

    def test_match_needsOnlySubset(self):
        self.assertTrue(m.Add([m.Literal('5'), m.Literal('x')]).match(sp.sympify('z**2 + x + y + 5'), x))
        self.assertTrue(m.Add([m.Literal('5'), m.Literal('x')]).match(sp.sympify('5 + z**2 + x + y'), x))


class TestRingOpsUntil(unittest.TestCase):

    def test_match_literal(self):
        self.assertTrue(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('x'), x))

        self.assertFalse(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('y'), x))
        self.assertFalse(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('x**2'), x))

    def test_match_withinOps(self):
        self.assertTrue(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('5 * x'), x))
        self.assertTrue(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('5 + x'), x))
        self.assertTrue(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('(5 + x * y**2) * (z**3 - b)'), x))
        self.assertTrue(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('(5 - x * y**2) * (z**3 - b)'), x))

        self.assertFalse(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('(5 - z * y**2) * (z**3 - b)'), x))
        self.assertFalse(m.RingOpsUntil(m.Literal('x')).match(sp.sympify('(5 - x**2 * y**2) * (z**3 - b)'), x))


class TestConstant(unittest.TestCase):

    def test_match(self):
        self.assertTrue(m.Constant().match(sp.sympify('5'), x))
        self.assertTrue(m.Constant().match(sp.sympify('tan(cos(1/5**2)) + 4*log(55)'), x))

        self.assertFalse(m.Constant().match(sp.sympify('5 * x'), x))


class TestPolynomialIn(unittest.TestCase):

    def test_match(self):
        self.assertTrue(m.PolynomialIn([x, sp.cos(x)]).match(sp.sympify('cos(x)'), x))
        self.assertTrue(m.PolynomialIn([x, sp.cos(x)]).match(sp.sympify('x**2 * (cos(x))**2 + 55 * cos(x) + x**67 + 7'), x))

        self.assertFalse(m.PolynomialIn([x, sp.cos(x)]).match(sp.sympify('x**2 * (cos(x))**2 + 55 * cos(x) + tan(x) + 7'), x))


class TestPow(unittest.TestCase):

    def test_match(self):
        self.assertTrue(m.Pow(x).match(sp.sympify('x**2'), x))
        self.assertTrue(m.Pow(x, 2).match(sp.sympify('x**2'), x))
        self.assertTrue(m.Pow(x, m.Function(sp.cos)).match(sp.sympify('x**cos(x)'), x))
        self.assertTrue(m.Pow(m.Function(sp.log), m.Function(sp.cos)).match(sp.sympify('log(x)**cos(x)'), x))

        self.assertFalse(m.Pow(x).match(sp.sympify('cos(x)**2'), x))
        self.assertFalse(m.Pow(x, 2).match(sp.sympify('x**cos(x)'), x))
        self.assertFalse(m.Pow(m.Function(sp.log), m.Function(sp.cos)).match(sp.sympify('cos(x)**log(x)'), x))


if __name__ == '__main__':
    unittest.main()
