import logging
from src.turtlegfx.emulate.emulator import Executor, _turtle_worker
from src.turtlegfx.utils.img_utils import compare_images
from src.turtlegfx.utils.base64img import convert_base64_to_img
from PIL import Image, ImageDraw
from src.turtlegfx.utils.norm_lines import normalize_length_of_lines, update_pensize_of_lines, center_position_of_lines

logger = logging.getLogger(__name__)

def get_min_image_size_for_lines(lines, padding=10):
    """
    Get the minimum image size that can fit all the lines.
    """
    x_coords = []
    y_coords = []
    for shape in lines:
        if shape.get('type') == 'line':
            x_coords.append(shape['start'][0])
            x_coords.append(shape['end'][0])
            y_coords.append(shape['start'][1])
            y_coords.append(shape['end'][1])
        elif shape.get('type') == 'fill':
            # Include vertices from fill shapes
            for vertex in shape['vertices']:
                x_coords.append(vertex[0])
                y_coords.append(vertex[1])
            # Include lines inside fill shapes
            if 'lines' in shape:
                for line in shape['lines']:
                    x_coords.append(line['start'][0])
                    x_coords.append(line['end'][0])
                    y_coords.append(line['start'][1])
                    y_coords.append(line['end'][1])

    if not x_coords or not y_coords:
        return (0, 0)  # Return a default size or handle the empty case appropriately

    min_x = min(x_coords) - padding
    max_x = max(x_coords) + padding
    min_y = min(y_coords) - padding
    max_y = max(y_coords) + padding

    width = int(max_x - min_x)
    height = int(max_y - min_y)

    return (width, height)

def generate_turtle_code_from_lines(lines):
    """
    Generate a Python code string that uses Turtle commands to draw the given lines.

    Args:
        lines (list): A list of dictionaries representing lines and shapes.

    Returns:
        str: A string containing Python code that draws the lines using Turtle.
    """
    code_lines = ['def draw(t):']

    for shape in lines:
        if shape.get('type') == 'line':
            start = shape['start']
            end = shape['end']
            pencolor = shape.get('pencolor', 'black')
            pensize = shape.get('pensize', 1)

            code_lines.append(f'    t.pencolor("{pencolor}")')
            code_lines.append(f'    t.pensize({pensize})')
            code_lines.append(f'    t.penup()')
            code_lines.append(f'    t.goto({start[0]}, {start[1]})')
            code_lines.append(f'    t.pendown()')
            code_lines.append(f'    t.goto({end[0]}, {end[1]})')
        elif shape.get('type') == 'fill':
            vertices = shape['vertices']
            fillcolor = shape.get('fillcolor', 'white')

            code_lines.append(f'    t.fillcolor("{fillcolor}")')
            code_lines.append(f'    t.penup()')
            code_lines.append(f'    t.goto({vertices[0][0]}, {vertices[0][1]})')
            code_lines.append(f'    t.pendown()')
            code_lines.append(f'    t.begin_fill()')
            for vertex in vertices[1:]:
                code_lines.append(f'    t.goto({vertex[0]}, {vertex[1]})')
            code_lines.append(f'    t.goto({vertices[0][0]}, {vertices[0][1]})')  # Close the shape
            code_lines.append(f'    t.end_fill()')
        else:
            raise ValueError(f"Unknown shape type: {shape.get('type')}")

    # Combine the code lines into a single string
    code_str = '\n'.join(code_lines)
    return code_str

def draw_lines_to_image_PIL(lines, image_size=(500, 500)):
    """
    Draw lines and filled shapes on an image and return the image.
    """
    img = Image.new('RGB', image_size, 'white')
    draw = ImageDraw.Draw(img)
    
    for shape in lines:
        if shape.get('type') == 'fill':
            # first draw the filled polygon
            vertices = [tuple(map(int, v)) for v in shape['vertices']]
            fillcolor = shape.get('fillcolor', 'white')
            draw.polygon(vertices, fill=fillcolor)

            # then draw the lines that were recorded during fill
            if 'lines' in shape:
                for line in shape['lines']:
                    start = tuple(map(int, line['start']))
                    end = tuple(map(int, line['end']))
                    pencolor = line.get('pencolor', 'black')
                    pensize = line.get('pensize', 1)
                    draw.line([start, end], fill=pencolor, width=pensize)
            
        elif shape.get('type') == 'line':
            start = tuple(map(int, shape['start']))
            end = tuple(map(int, shape['end']))
            pencolor = shape.get('pencolor', 'black')
            pensize = shape.get('pensize', 1)
            draw.line([start, end], fill=pencolor, width=pensize)
        else:
            raise ValueError(f"Unknown shape type: {shape.get('type')}")
    
    return img

def draw_lines_to_image_turtle_deprecated(lines, image_size=(500, 500)):
    """
    Draw lines and filled shapes by generating Turtle code and executing it using _turtle_worker.
    Note: This is **not** used in the current evaluation pipeline.

    Args:
        lines (list): A list of dictionaries representing lines and shapes.
        image_size (tuple): The (width, height) of the image to draw the lines on.
    Returns:
        PIL.Image.Image: The resulting image after drawing.
    """
    # Generate the Turtle code string from the lines data
    code_str = generate_turtle_code_from_lines(lines)

    # Use _turtle_worker to execute the code and get the result
    result = _turtle_worker(code_str, 
                            show_screen=False, 
                            record_turtle_states=False)

    if result['status'] != 'success':
        raise RuntimeError(f"Failed to draw lines: {result['message']}")

    # Convert the base64 image to a PIL Image
    img = convert_base64_to_img(result['image'])

    return img

def get_normalized_images_from_code(code1, code2,
                                   show_screen=False, 
                                   length_invariant=True, 
                                   position_invariant=True, 
                                   pensize_invariant=True):
    """
    Get normalized images from two code strings.

    Args:
        code1 (str): First turtle graphics code
        code2 (str): Second turtle graphics code
        show_screen (bool): Whether to show the turtle screen
        length_invariant (bool): Whether to normalize line lengths
        position_invariant (bool): Whether to center the drawings
        pensize_invariant (bool): Whether to normalize pen sizes

    Returns:
        tuple: (img1, img2) if successful, (None, None) if any errors occur
    """
    # Run the code and get the lines
    turtle_img1, res1 = Executor().run(code=code1, show_screen=show_screen, record_turtle_states=True)
    turtle_img2, res2 = Executor().run(code=code2, show_screen=show_screen, record_turtle_states=True)
    
    # Special Case 1: handle the case where the code is not valid
    if not (res1['status'] == 'success' and res2['status'] == 'success'):
        logger.debug(f"Both codes must be valid, but got {res1['status']} and {res2['status']}; returning None")
        return None, None
    
    lines1 = res1['turtle_states']['states']
    lines2 = res2['turtle_states']['states']

    # Handle empty cases
    if len(lines1) == 0 or len(lines2) == 0:
        logger.debug("One or both line sets are empty")
        return None, None

    # Center lines if position invariance is requested
    if position_invariant:
        lines1 = center_position_of_lines(lines1)
        lines2 = center_position_of_lines(lines2)

    # Normalize lines if requested
    if length_invariant:
        lines1 = normalize_length_of_lines(lines1)
        lines2 = normalize_length_of_lines(lines2)

    if pensize_invariant:
        lines1 = update_pensize_of_lines(lines1, pensize=1)
        lines2 = update_pensize_of_lines(lines2, pensize=1)

    # get the minimum image size that can fit all the lines
    img1_size = get_min_image_size_for_lines(lines1)
    img2_size = get_min_image_size_for_lines(lines2)
    width = max(img1_size[0], img2_size[0])
    height = max(img1_size[1], img2_size[1])

    # Convert lines to images
    img1 = draw_lines_to_image_PIL(lines1, image_size=(width, height))
    img2 = draw_lines_to_image_PIL(lines2, image_size=(width, height))

    if logger.isEnabledFor(logging.DEBUG):
        img1.show()
        img2.show()
    
    return img1, img2

def get_tolerance_from_codes(code1, code2, tolerance_if_fill=5, tolerance_if_nofill=8):
    """
    Determine the image comparison tolerance based on the presence of fill colors in the code.

    Args:
        code1 (str): The first turtle graphics code.
        code2 (str): The second turtle graphics code.

    Returns:
        int: The tolerance value for image comparison. If any of the codes use fill colors, 
             the tolerance is set to 1. Otherwise, it is set to 8.
    """
    def has_fill_colors(code):
        return "begin_fill" in code or "end_fill" in code

    # Set tolerance based on whether fill colors are used
    uses_fill = has_fill_colors(code1) or has_fill_colors(code2)
    if uses_fill:
        tolerance = tolerance_if_fill
    else:
        tolerance = tolerance_if_nofill

    return tolerance

def compare_lines_from_code(code1, code2, 
                            show_screen=False, 
                            length_invariant=True, 
                            position_invariant=True,
                            pensize_invariant=True):
    """
    Compare the two sets of lines drawn by the code.

    Args:
        code1 (str): First turtle graphics code
        code2 (str): Second turtle graphics code
        show_screen (bool): Whether to show the turtle screen
        length_invariant (bool): Whether to normalize line lengths
        position_invariant (bool): Whether to center the drawings
        pensize_invariant (bool): Whether to normalize pen sizes
        return_imgs (bool): Whether to return the generated images

    Returns:
        bool or tuple: If return_imgs is False, returns whether images are same.
                      If return_imgs is True, returns (are_same, img1, img2)
    """
    # Get normalized images
    img1, img2 = get_normalized_images_from_code(
        code1, code2,
        show_screen=show_screen,
        length_invariant=length_invariant,
        position_invariant=position_invariant,
        pensize_invariant=pensize_invariant
    )

    # Handle special cases: if one of the images is None, return False
    if img1 is None or img2 is None:
        return False
    
    tolerance = get_tolerance_from_codes(code1, code2, tolerance_if_fill=5, tolerance_if_nofill=8)

    # Compare images
    are_same = compare_images(img1, img2, tolerance=tolerance)
    
    return are_same

if __name__ == "__main__":
    code1 = """def draw(t):
    # t.pencolor('red')
    t.forward(100)
    t.left(90)
    t.forward(100)
    t.left(90)
    t.forward(100)
    t.left(90)
    t.forward(100)
    t.left(90)
    t.forward(100)
    """
    code2 = """def draw(t):
    def draw_square(t):
        for _ in range(4):
            t.forward(100)
            t.left(90)
    draw_square(t)
    """
    code3 = """def draw(t):
    t.setheading(90)
    t.forward(100)
    t.back(100)
    t.rt(90)
    t.forward(100)
    """
    code4 = """def draw(t):
    t.setheading(90)
    t.forward(100)
    t.penup()
    t.goto(0, 0)
    t.pendown()
    t.rt(90)
    t.forward(200)
    """
    code5 = '''def draw(t):
    """
    Draws two crosses intersecting each other on a white background,
    resembling a crosshair used by Rafał Bajer during his fishing trips.
    
    Each arm of the cross has a length of 8 units.
    The arms are black lines without fill color.
    The intersection point is centered on the canvas.
    """
    # Set up window size and position
    turtle.setup(width=400, height=400)
    turtle.setworldcoordinates(-200, -200, 200, 200)
    
    # Move to starting position
    t.penup()
    t.goto(0, 0)
    t.pendown()

    # Draw first cross
    for _ in range(2):
        t.forward(8)
        t.right(90)
        t.forward(8)
        t.right(90)

    # Rotate turtle to start second cross from different angle
    t.right(90)

    # Draw second cross
    for _ in range(2):
        t.forward(8)
        t.right(90)
        t.forward(8)
        t.right(90)
    '''
    code6='''def draw(t):
    t.setheading(90)

    def draw_cross(t):
        for _ in range(4):
            t.forward(100)
            t.backward(100)
            t.right(90)

    draw_cross(t)
    t.penup()
    t.goto(300, 0)
    t.pendown()
    draw_cross(t)
    '''
    code_lines_ordered = """def draw(t):
        for _ in range(4):
            t.forward(100)
            t.back(100)
            t.right(90)
        """
    
    code_lines_unordered = """def draw(t):
        t.forward(100)
        t.back(100)
        t.back(100)
        t.forward(100)
        t.right(90)
        t.forward(100)
        t.back(100)
        t.back(100)
    """
    # is_equal_12 = compare_lines_from_code(code1, code2)
    # is_equal_34 = compare_lines_from_code(code3, code4)
    # is_equal_56 = compare_lines_from_code(code5, code6)
    is_equal_ordered_unordered = compare_lines_from_code(code_lines_ordered, code_lines_unordered)
    print(is_equal_ordered_unordered)
    # print(is_equal_12, is_equal_34, is_equal_56, is_equal_ordered_unordered)
