from Cython.TestUtils import CythonTest
from Cython.Compiler.TreeFragment import *
from Cython.Compiler.Nodes import *
from Cython.Compiler.UtilNodes import *

class TestTreeFragments(CythonTest):

    def test_basic(self):
        F = self.fragment("x = 4")
        T = F.copy()
        self.assertCode("x = 4", T)

    def test_copy_is_taken(self):
        F = self.fragment("if True: x = 4")
        T1 = F.root
        T2 = F.copy()
        self.assertEqual("x", T2.stats[0].if_clauses[0].body.lhs.name)
        T2.stats[0].if_clauses[0].body.lhs.name = "other"
        self.assertEqual("x", T1.stats[0].if_clauses[0].body.lhs.name)

    def test_substitutions_are_copied(self):
        T = self.fragment("y + y").substitute({"y": NameNode(pos=None, name="x")})
        self.assertEqual("x", T.stats[0].expr.operand1.name)
        self.assertEqual("x", T.stats[0].expr.operand2.name)
        self.assertTrue(T.stats[0].expr.operand1 is not T.stats[0].expr.operand2)

    def test_substitution(self):
        F = self.fragment("x = 4")
        y = NameNode(pos=None, name="y")
        T = F.substitute({"x" : y})
        self.assertCode("y = 4", T)

    def test_exprstat(self):
        F = self.fragment("PASS")
        pass_stat = PassStatNode(pos=None)
        T = F.substitute({"PASS" : pass_stat})
        self.assertTrue(isinstance(T.stats[0], PassStatNode), T)

    def test_pos_is_transferred(self):
        F = self.fragment("""
        x = y
        x = u * v ** w
        """)
        T = F.substitute({"v" : NameNode(pos=None, name="a")})
        v = F.root.stats[1].rhs.operand2.operand1
        a = T.stats[1].rhs.operand2.operand1
        self.assertEqual(v.pos, a.pos)

    def test_temps(self):
        TemplateTransform.temp_name_counter = 0
        F = self.fragment("""
            TMP
            x = TMP
        """)
        T = F.substitute(temps=["TMP"])
        s = T.body.stats
        self.assertTrue(isinstance(s[0].expr, TempRefNode))
        self.assertTrue(isinstance(s[1].rhs, TempRefNode))
        self.assertTrue(s[0].expr.handle is s[1].rhs.handle)

if __name__ == "__main__":
    import unittest
    unittest.main()
