import unittest
from src.turtlegfx.utils.code_modifier import remove_turtle_setup_calls

class TestCodeModifier(unittest.TestCase):
    def test_import_removal(self):
        """Test removal of all import statements."""
        code = '''
import turtle
import math
from turtle import *
import turtle as t
from turtle import Screen, Turtle
import other_module
t.forward(100)
'''
        expected = '''
t.forward(100)
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_multiple_import_styles(self):
        """Test removal of different import statement styles."""
        code = '''
from math import sin, cos
import numpy as np
from turtle import *
import other.module as om
t.forward(100)
'''
        expected = '''
t.forward(100)
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_instance_creation_removal(self):
        """Test removal of turtle instance creation."""
        code = '''
t = turtle.Turtle()
screen = turtle.Screen()
raw_t = turtle.RawTurtle()
t.forward(100)
'''
        expected = '''
t.forward(100)
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_turtle_method_removal(self):
        """Test removal of specific turtle instance methods."""
        code = '''
t.forward(100)
t.hideturtle()
t.ht()
t.showturtle()
t.st()
t.right(90)
t.forward(50)
'''
        expected = '''
t.forward(100)
t.right(90)
t.forward(50)
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_screen_setup_removal(self):
        """Test removal of screen setup functions."""
        code = '''
import turtle
turtle.setup(800, 600)
turtle.bgcolor("white")
turtle.title("My Drawing")
turtle.screensize(1000, 1000)
turtle.mode("logo")
turtle.colormode(255)
turtle.delay(0)
t.circle(50)
'''
        expected = '''
t.circle(50)
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_nested_code_blocks(self):
        """Test handling of nested code blocks with imports."""
        code = '''
import turtle
def draw_shape():
    import math
    t = turtle.Turtle()
    turtle.setup(800, 600)
    if True:
        from turtle import *
        turtle.bgcolor("white")
        for i in range(4):
            turtle.tracer(0)
            t.forward(100)
    t.hideturtle()
    turtle.done()
'''
        expected = '''
def draw_shape():
    if True:
        for i in range(4):
            t.forward(100)
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())
    def test_empty_code(self):
        """Test handling of empty code."""
        self.assertEqual(remove_turtle_setup_calls("").strip(), "")
        self.assertEqual(remove_turtle_setup_calls("\n").strip(), "")

    def test_syntax_error_handling(self):
        """Test handling of invalid Python code."""
        invalid_code = '''
        import turtle
        turtle.setup(800, 600)
        if True
            t.forward(100)
        '''
        with self.assertRaises(SyntaxError):
            remove_turtle_setup_calls(invalid_code)

    def test_real_world_example(self):
        """Test with a real-world example including imports."""
        code = '''
import turtle
import math
from numpy import array
import random as rd

def draw(t):
    # Set screen size and background color
    turtle.setup(800, 600)
    turtle.bgcolor("white")
    
    # Create a turtle object
    t = turtle.Turtle()
    t.speed(0)
    
    # Draw the pattern
    for i in range(8):
        t.begin_fill()
        t.circle(100, 90)
        t.end_fill()
        t.right(45)
    
    # Hide the turtle and finish
    t.hideturtle()
    turtle.done()

# Call the function
draw(turtle.Turtle())
'''
        expected = '''def draw(t):
    for i in range(8):
        t.begin_fill()
        t.circle(100, 90)
        t.end_fill()
        t.right(45)
draw(turtle.Turtle())
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_real_world_example_1(self):
        """Test with a real-world example of basic shape drawing."""
        code = '''
import turtle
import math

def draw_shape(t):
    # Setup
    turtle.setup(800, 600)
    turtle.bgcolor("black")
    t.speed(0)
    t.pensize(2)
    t.color("cyan")
    
    # Draw pattern
    for i in range(36):
        t.forward(100)
        t.right(45)
        t.forward(50)
        t.right(90)
        t.forward(50)
        t.right(45)
    
    t.hideturtle()
    turtle.done()

# Create and run
t = turtle.Turtle()
draw_shape(t)
'''
        expected = '''def draw_shape(t):
    t.pensize(2)
    t.color('cyan')
    for i in range(36):
        t.forward(100)
        t.right(45)
        t.forward(50)
        t.right(90)
        t.forward(50)
        t.right(45)
draw_shape(t)'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_real_world_example_2(self):
        """Test with a real-world example of complex pattern drawing."""
        code = '''
import turtle
from math import sin, cos, pi

def create_pattern(t):
    # Screen setup
    screen = turtle.Screen()
    screen.bgcolor("black")
    screen.setup(800, 800)
    screen.title("Geometric Pattern")
    
    # Turtle setup
    t.speed(0)
    t.penup()
    t.goto(-100, 0)
    t.pendown()
    t.color("gold")
    t.pensize(2)
    
    # Draw complex pattern
    for i in range(180):
        t.forward(200)
        t.right(30)
        t.forward(20)
        t.left(60)
        t.forward(50)
        t.right(30)
        
        t.penup()
        t.setposition(0, 0)
        t.pendown()
        
        t.right(2)
    
    t.hideturtle()
    screen.exitonclick()

# Initialize and run
t = turtle.Turtle()
create_pattern(t)
'''
        expected = '''def create_pattern(t):
    t.penup()
    t.goto(-100, 0)
    t.pendown()
    t.color('gold')
    t.pensize(2)
    for i in range(180):
        t.forward(200)
        t.right(30)
        t.forward(20)
        t.left(60)
        t.forward(50)
        t.right(30)
        t.penup()
        t.setposition(0, 0)
        t.pendown()
        t.right(2)
create_pattern(t)'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_real_world_example_3(self):
        """Test with a real-world example of recursive pattern."""
        code = '''
import turtle
import random

def recursive_pattern(t, size):
    if size < 10:
        return
        
    # Setup for this iteration
    colors = ["red", "purple", "blue", "green", "yellow", "orange"]
    t.speed(0)
    t.pensize(2)
    t.color(random.choice(colors))
    
    # Draw the pattern
    for _ in range(4):
        t.forward(size)
        t.right(90)
        recursive_pattern(t, size/1.5)
        
    t.hideturtle()

# Initialize
screen = turtle.Screen()
screen.bgcolor("black")
screen.setup(800, 800)
t = turtle.Turtle()
t.penup()
t.goto(-100, 100)
t.pendown()

# Draw
recursive_pattern(t, 200)
screen.exitonclick()
'''
        expected = '''def recursive_pattern(t, size):
    if size < 10:
        return
    colors = ['red', 'purple', 'blue', 'green', 'yellow', 'orange']
    t.pensize(2)
    t.color(random.choice(colors))
    for _ in range(4):
        t.forward(size)
        t.right(90)
        recursive_pattern(t, size / 1.5)
t.penup()
t.goto(-100, 100)
t.pendown()
recursive_pattern(t, 200)'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

    def test_real_world_example_4(self):
        """Test with a real-world example using multiple turtles."""
        code = '''
import turtle
import math

def create_spiral_pattern():
    # Screen setup
    screen = turtle.Screen()
    screen.setup(800, 800)
    screen.bgcolor("black")
    screen.title("Multi-turtle Spiral")
    
    # Create turtles
    t1 = turtle.Turtle()
    t2 = turtle.Turtle()
    
    # Setup turtles
    for t in [t1, t2]:
        t.speed(0)
        t.pensize(2)
        t.hideturtle()
    
    t1.color("cyan")
    t2.color("magenta")
    
    # Draw pattern
    for i in range(120):
        for t in [t1, t2]:
            t.forward(i * 4)
            t.right(89)
    
    screen.exitonclick()

create_spiral_pattern()
'''
        expected = '''def create_spiral_pattern():
    for t in [t1, t2]:
        t.pensize(2)
    t1.color('cyan')
    t2.color('magenta')
    for i in range(120):
        for t in [t1, t2]:
            t.forward(i * 4)
            t.right(89)
create_spiral_pattern()
'''
        self.assertEqual(remove_turtle_setup_calls(code).strip(), expected.strip())

if __name__ == '__main__':
    unittest.main()
