import unittest
from anpl.parser import ANPLParser

class TestParse(unittest.TestCase):

    def setUp(self) :
        super().setUp()
        self.parser = ANPLParser()

    def test_00(self):
        test_str = """
def identify(input_grid):
    \"\"\"
    identify the smallest repeating unit
    \"\"\"
    for unit_length in range(1, input_grid.shape[0]):
       small_grids = `divide the grid into multiple small grid based on the length of the repeating unit`(unit_length)
       if `all the same, or the smaller grid is part of a larger grid`(small_grids):
            return small_grids[0]

def main(input_grid):
    unit = identify(input_grid)
    output_grid = `extend the input grid to 9x3 with unit`(input_grid, unit)
    output_grid = `change all blue pixels to red`(output_grid)
    return output_grid
"""
        anpl = self.parser.parse(test_str)
        print(anpl.funs)

    def test_extract_inline_function_def(self):
        source_code = """
def main(input_grid):
    unit = identify(input_grid)
    output_grid = `extend the input grid to 9x3 with unit`(unit)
    output_grid = `change all blue pixels to red`(output_grid)
    return output_grid
"""
        new_source, funs = self.parser.extract_inline_function_def(source_code)
        self.assertFalse("`" in new_source)
        self.assertEqual(len(funs), 2, "Do not correct extract inline functions")

    def test_recursive(self):
        source_code = """
def flood_fill(grid, x, y):
    if grid[x, y] == black:
        f = array([3, 3, 3])
        grid[x, y] = yellow
        if x > 0:
            flood_fill(grid, x-1, y)
        if x < grid.shape[0] - 1:
            flood_fill(grid, x+1, y)
        if y > 0:
            flood_fill(grid, x, y-1)
        if y < grid.shape[1] - 1:
            flood_fill(grid, x, y+1)
"""
        anpl = self.parser.parse(source_code)
        f = list(anpl.funs.values())
        self.assertEqual(len(f), 1, "Should only one function")
        f = f[0]
        self.assertTrue(f.is_recursive, "Should be a recursive function")
        print(f)


    def test_reverse_define(self):
        source_code = '''
import numpy as np
from scipy.ndimage import label
(black, blue, red, green, yellow, grey, pink, orange, teal, maroon) = range(10)

def get_pattern(input_grid: np.ndarray) -> np.ndarray:
    """
    Get the pattern in the 3x3 grid in the upper left
    """
    return input_grid[:3, :3]

def main(input_grid: np.ndarray) -> np.ndarray:
    """
    The input is a 10x10 grid with a 3x3 grid in the upper left surrounded by gray pixels.
    """
    p = get_pattern(input_grid)
    out = find_and_replace_pattern(input_grid, p)
    return out

def find_and_replace_pattern(input_grid: np.ndarray, pattern: np.ndarray) -> np.ndarray:
    """
    Find a pattern with the same shape as the input pattern but with a different color in the rest of the input grid,
    turn the pattern found into gray, return the changed 10x10 grid
    """
    # Find all unique colors in the input grid
    unique_colors = np.unique(input_grid)

    # Remove gray color from unique colors
    unique_colors = unique_colors[unique_colors != grey]

    # Create a dictionary to store the pattern with different colors
    pattern_dict = {}

    # Loop through all unique colors and create a pattern with the same shape as the input pattern
    for color in unique_colors:
        temp_pattern = np.zeros_like(pattern)
        temp_pattern[input_grid[3:, 3:] == color] = 1
        if np.array_equal(temp_pattern.shape, pattern.shape):
            pattern_dict[color] = temp_pattern

    # Find the pattern with the same shape as the input pattern but with a different color
    for color, temp_pattern in pattern_dict.items():
        if np.array_equal(temp_pattern, pattern):
            new_color = color
            break

    # Replace the pattern found with gray color
    input_grid[:3, :3][pattern == 1] = grey

    # Replace the pattern with the new color
    input_grid[3:, 3:][input_grid[3:, 3:] == new_color] = grey

    return input_grid
'''
        anpl = self.parser.parse(source_code)
        print(anpl.funs)