import re


####### Code translation

def extract_python_code_from_text(text):
    """
    Extracts Python code from a given text block that contains the code within triple quotes.
    """
    # Match content within triple quotes (```python ... ```)
    matches = re.findall(r'```python([^`]*)```', text, re.DOTALL)
    if matches:
        # Assuming the first match is the target code
        return matches[0].strip()

    # if no match, then match content within triple quotes (``` ... ```)
    matches = re.findall(r'```([^`]*)```', text, re.DOTALL)
    if matches:
        # Assuming the first match is the target code
        return matches[0].strip()
    raise ValueError("No Python code found in the given text.")


def python2codejs(python_code):
    """
    Translates Python code into JSON representation.
    """
    # remove all comments in the code
    python_code = re.sub(r'\s*#.*', '', python_code)

    # remove all empty spaces at the end of the function calls (e.g, "move_forward()  " --> "move_forward()")
    python_code = re.sub(r'\)\s+\n', ')\n', python_code)

    # remove any run() function call
    python_code = re.sub(r'^run\(\).*$', '', python_code, flags=re.MULTILINE)

    # remove all empty lines
    python_code = "\n".join([line for line in python_code.split('\n') if line.strip() != ''])

    # remove 'pass' if it is in a line by itself
    python_code = re.sub(r'^\s*pass\s*$', '', python_code, flags=re.MULTILINE)


    lines = python_code.split('\n')

    # remove empty lines
    lines = [line for line in lines if line.strip() != '']

    commands = []

    def add_command(lines):
        command = {}
        line = lines[0].strip()

        if len(lines) == 1:
            if "move_forward()" in line:
                command['type'] = 'fd'
            elif "move_backward()" in line:
                command['type'] = 'bk'
            elif "turn_left()" in line:
                command['type'] = 'lt'
            elif "turn_right()" in line:
                command['type'] = 'rt'
            elif "setpc" in line:
                command['type'] = 'setpc'
                if '"' in line:
                    command['value'] = line.split('"')[1]
                elif "'" in line:
                    command['value'] = line.split("'")[1]
                else:
                    raise ValueError(f"Invalid command: {line}")
            elif "[MASK]" == line:
                command = {'type': '[MASK]'}
            else:
                raise ValueError(f"Invalid command: {line}")
            return command

        elif "for" in lines[0] and "in range(" in lines[0]:
            command['type'] = 'repeat'
            command['times'] = line.split("range(")[1].split(")")[0]
            if command['times'].isdigit():
                command['times'] = int(command['times'])
            command['body'] = python2codejs('\n'.join(lines[1:]))['run']
            return command
        else:
            raise ValueError(f"Invalid command: {line}")

    body_start, body_end = 0, 0
    while body_end < len(lines):
        if lines[body_start].strip() == 'def run():':
            body_start += 1
            body_end = body_start
            continue

        while body_end + 1 < len(lines) and (lines[body_end + 1].count('  ') - lines[body_start].count('  ') > 0
                                             or lines[body_end + 1].count('\t') - lines[body_start].count('\t') > 0):
            body_end += 1

        result = add_command(lines[body_start:body_end + 1])
        commands.append(result)

        body_start = body_end + 1
        body_end = body_start

    return {'run': commands}


def codejs2python(codejs):
    """
    Translates code from JSON format to either Python or JSON.

    Parameters:
    - source_json (str or dict): The source code in JSON format.
    - target_format (str): The target format for the translation ('python' or 'json').

    Returns:
    - str: The translated code in the specified target format.
    """

    def get_loop_var(nesting_level):
        """
        Returns the loop variable based on the nesting level (i, j, k, ...)
        """
        return chr(ord('i') + nesting_level % 26)  # Use modulo 26 to cycle through alphabet if more than 26 loops

    def translate_command(command, indent_level, nesting_level):
        """
        Translates individual commands based on their type with correct indentation.
        """
        indent = '\t' * indent_level
        if command['type'] == 'fd':
            return f"{indent}move_forward()"
        elif command['type'] == 'bk':
            return f"{indent}move_backward()"
        elif command['type'] == 'lt':
            return f"{indent}turn_left()"
        elif command['type'] == 'rt':
            return f"{indent}turn_right()"
        elif command['type'] == 'setpc':
            return f"{indent}setpc('{command['value']}')"
        elif command['type'] == 'repeat':
            loop_var = get_loop_var(nesting_level)
            loop_header = f"{indent}for {loop_var} in range({command['times']}):"
            body_commands = [translate_command(c, indent_level + 1, nesting_level + 1) for c in command['body']]
            return loop_header + '\n' + '\n'.join(body_commands)

    def translate_commands(commands, indent_level=1, nesting_level=0):
        """
        Translates a list of commands, managing indentation and loop nesting correctly.
        """
        return '\n'.join([translate_command(command, indent_level, nesting_level) for command in commands])

    # Directly pass source_json['run'] assuming it's already a dict and contains 'run' key
    translated_code = translate_commands(codejs['run'])

    return "def run():\n" + translated_code


####### Task translation

def taskjs2ascii(source_json):
    """
    Translates a task description from JSON to one of the specified formats: JSON, ASCII, or Python.

    Parameters:
    - source_json (str or dict): The source task description in JSON format.
    - target_format (str): The target format for the translation ('json', 'ascii', or 'python').

    Returns:
    - str: The translated task description in the specified target format.
    """

    def item_taskjs2ascii(taskjs):
        """
        Translate the task to an ASCII representation, including walls, items, and turtle.
        """
        color_map = {
            "red"   : "R",
            "green" : "G",
            "blue"  : "B",
            "yellow": "Y",
            "black" : "K",
            "white" : "W",
            "orange": "O",
            "purple": "U",
            "pink"  : "P"
        }

        name_map = {
            "circle"    : "o",
            "rectangle" : "□",
            "triangle"  : "△",
            "cross"     : "+",
            "strawberry": "S",
            "lemon"     : "L"
        }

        # Determine grid size
        max_x = max(item["x"] for item in taskjs.get("tiles", [])) + 1
        max_y = max(item["y"] for item in taskjs.get("tiles", [])) + 1

        # Initialize the ASCII grid
        grid = [["   " for _ in range(max_x)] for _ in range(max_y)]
        vertical_walls = [[False for _ in range(max_x + 1)] for _ in range(max_y)]
        horizontal_walls = [[False for _ in range(max_x)] for _ in range(max_y + 1)]

        # Handle tiles (including forbidden cells and walls)
        for tile in taskjs.get("tiles", []):
            x, y = tile["x"], tile["y"]
            if not tile.get("allowed", True):
                grid[y][x] = " X "
            for direction in tile.get("walls", {}):
                if direction == "top":
                    horizontal_walls[y][x] = tile.get("walls")['top']
                elif direction == "bottom":
                    horizontal_walls[y + 1][x] = tile.get("walls")['bottom']
                elif direction == "left":
                    vertical_walls[y][x] = tile.get("walls")['left']
                elif direction == "right":
                    vertical_walls[y][x + 1] = tile.get("walls")['right']

        # Handle items
        for item in taskjs.get("items", []):
            x, y = item["x"], item["y"]
            item_repr = f'{item["count"]}{color_map[item["color"].lower()]}{name_map[item["name"].lower()]}'
            grid[y][x] = item_repr

        # Handle turtle
        turtle = taskjs.get("turtle", {})
        if turtle:
            x, y = turtle["x"], turtle["y"]
            direction_symbols = ["^", ">", "v", "<"]
            grid[y][x] = " " + direction_symbols[turtle["direction"]] + " "

        # Construct the ASCII art
        ascii_art = ""
        for y in range(max_y):
            # Top walls
            for x in range(max_x):
                ascii_art += "+" + ("===" if horizontal_walls[y][x] else "---")
            ascii_art += "+\n"

            # Grid cells and vertical walls
            for x in range(max_x):
                ascii_art += ("‖" if vertical_walls[y][x] else "|") + grid[y][x]
            ascii_art += ("‖" if vertical_walls[y][max_x] else "|") + "\n"

        # Bottom walls
        for x in range(max_x):
            ascii_art += "+" + ("===" if horizontal_walls[max_y][x] else "---")
        ascii_art += "+"

        return ascii_art

    def draw_taskjs2ascii(taskjs):
        # Color map for line colors
        color_map = {
            "#D60000": "R",  # Red
            "#009624": "G",  # Green
            "#0D47A1": "B",  # Blue
            "#FFD600": "Y",  # Yellow
            "#000000": "K",  # Black
            "#FFFFFF": "W",  # White
            "#FFA500": "O",  # Orange
            "#800080": "U",  # Purple
            "#FFC0CB": "P"  # Pink
        }

        # Directions for the turtle
        directions = ["^", ">", "v", "<"]

        # Determine grid size and existing tiles
        tiles = {(tile["x"], tile["y"]): tile for tile in taskjs["tiles"]}
        max_x = max(tile["x"] for tile in tiles.values()) + 1
        max_y = max(tile["y"] for tile in tiles.values()) + 1

        # Initialize ASCII matrix
        grid_width = max_x * 4 + 1
        grid_height = max_y * 2 + 1
        grid = [[" " for _ in range(grid_width)] for _ in range(grid_height)]

        # Place corners and edges for existing tiles
        for x, y in tiles:
            grid[y * 2][x * 4] = "+"
            if (x + 1, y) in tiles:
                for i in range(1, 4):
                    grid[y * 2][x * 4 + i] = "-"
                grid[y * 2][(x + 1) * 4] = "+"
            if (x, y + 1) in tiles:
                for i in range(1, 2):
                    grid[y * 2 + i][x * 4] = "|"
                grid[(y + 1) * 2][x * 4] = "+"

        # Draw lines
        for line in taskjs["lines"]:
            color = color_map[line["color"].upper()]
            if line["x1"] == line["x2"]:  # Vertical line
                y_start = min(line["y1"], line["y2"])
                y_end = max(line["y1"], line["y2"])
                x_pos = line["x1"] * 4
                for y in range(y_start * 2 + 1, y_end * 2, 2):
                    grid[y][x_pos] = color
            else:  # Horizontal line
                x_start = min(line["x1"], line["x2"])
                x_end = max(line["x1"], line["x2"])
                y_pos = line["y1"] * 2
                for x in range(x_start * 4 + 1, x_end * 4, 4):
                    grid[y_pos][x] = color
                    grid[y_pos][x + 1] = color
                    grid[y_pos][x + 2] = color

        # Place the turtle
        turtle = taskjs["turtle"]
        turtle_direction = directions[turtle["direction"] % 4]
        grid[turtle["y"] * 2][turtle["x"] * 4] = turtle_direction

        # Remove unused parts of the grid
        for y in range(grid_height):
            for x in range(grid_width):
                if y % 2 == 0 and x % 4 == 0 and grid[y][x] == "+":
                    if all(grid[y + d][x] == " " for d in (-1, 1) if 0 <= y + d < grid_height) and \
                            all(grid[y][x + d] == " " for d in (-2, -1, 1, 2, 3) if 0 <= x + d < grid_width):
                        grid[y][x] = " "

        # Convert grid to string, trimming empty rows and columns
        ascii_art = "\n".join("".join(row).rstrip() for row in grid).strip("\n")

        return ascii_art

    # Implement translation logic here
    if 'lines' in source_json.keys() and len(source_json['lines']) > 0:
        ascii = draw_taskjs2ascii(source_json)
    else:
        ascii = item_taskjs2ascii(source_json)
    return ascii


def taskjs2nl(task_json):
    COLORS = {"#000000": "black",
              "#0D47A1": "blue",
              "#009624": "green",
              "#FFD600": "yellow",
              "#D60000": "red",
              "#FFFFFF": "white"}

    # Helper to translate the direction number to a name and correct grammar for wall descriptions
    directions = ["north", "east", "south", "west"]
    wall_directions = {"top"  : "at the top edge", "bottom": "at the bottom edge", "left": "on the left side",
                       "right": "on the right side"}

    # Initialize narrative with the game's ID and objective
    narrative = ""

    # Calculate grid size
    m = max(tile['y'] for tile in task_json['tiles']) + 1
    n = max(tile['x'] for tile in task_json['tiles']) + 1
    turtle = task_json['turtle']

    # Describe the initial setup
    narrative += (
        f"A {m}x{n} grid. The turtle starts at ({turtle['x']},{turtle['y']}) facing {directions[turtle['direction']]}.\n\n")

    # Describe grid cells with accessible and forbidden areas, and walls
    accessible_cells = []
    forbidden_cells = []
    walls_info = []
    for tile in task_json['tiles']:
        cell_desc = f"({tile['x']},{tile['y']})"
        if tile['allowed']:
            accessible_cells.append(cell_desc)
        else:
            forbidden_cells.append(cell_desc)
        if tile.get('walls'):
            for direction, exists in tile['walls'].items():
                if exists:
                    walls_info.append(f"{cell_desc} has a wall {wall_directions[direction]}.")

    # Compile descriptions
    if accessible_cells:
        narrative += f"Accessible cells: {', '.join(accessible_cells)}.\n"
    if forbidden_cells:
        narrative += f"Forbidden cells: {', '.join(forbidden_cells)}.\n"
    narrative += "\n".join(walls_info) + "\n"

    # Describe items, taking into account singular and plural forms
    if task_json['items']:
        narrative += "\nItems in the grid:\n"
        for item in task_json['items']:
            # consider the case: strawberry -> strawberries
            if item['count'] == 1:
                item_name = item['name']
            else:
                if item['name'][-1] == "y":
                    item_name = item['name'][:-1] + "ies"
                else:
                    item_name = item['name'] + "s"
            item_desc = f"- {item['count']} {item['color']} {item_name} at ({item['x']},{item['y']})."
            narrative += item_desc + "\n"

    # Describe lines
    if task_json['lines']:
        narrative += "\nLines in the grid:\n"
        for line in task_json['lines']:
            narrative += f"- A {COLORS[line['color']]} line from ({line['x1']},{line['y1']}) to ({line['x2']},{line['y2']}).\n"

    return narrative.strip()


def ttestt_ascii():
    from src.xlogomini.utils.load_data import load_task_json
    from src.xlogomini.utils.enums import XLOGO_TASK_IDS

    for task_id in XLOGO_TASK_IDS:
        task_json = load_task_json(task_id)
        print(f"{task_json['id']}: {task_json['description']}")
        print(taskjs2ascii(task_json))
        print("-----")


def ttestt_codejs2python():
    from src.xlogomini.utils.load_data import load_code_json
    from src.xlogomini.utils.enums import XLOGO_TASK_IDS

    for task_id in XLOGO_TASK_IDS:
        print(task_id)

        codejs = load_code_json(task_id)

        # convert codejs to python
        python_code = codejs2python(codejs)
        # convert it back to codejs
        codejs_ = python2codejs(python_code)
        # check if they are the same, if not print the difference
        if codejs != codejs_:
            print(codejs)
            print(codejs_)
        else:
            print("Success")
        print("-----")


if __name__ == '__main__':
    # python_code = 'def run():\n    # Start at (1,2) facing north\n    # First, turn left to face west and move to (0,2)\n    turn_left()\n    move_forward()\n    \n    # Collect 3 strawberries at (0,1)\n    turn_left()  # Face south\n    move_forward()\n    \n    # Collect 1 strawberry at (0,0)\n    move_forward()\n    \n    # Turn around to head back to (0,1) and then to (0,2)\n    turn_left()\n    turn_left()\n    move_forward()\n    move_forward()\n    \n    # Turn right to face east and move to (1,2)\n    turn_right()\n    move_forward()\n    move_forward()\n    \n    # Move to (2,2) and then to (2,1) to collect 4 strawberries\n    move_forward()\n    turn_right()  # Face south\n    move_forward()\n    \n    # Move to (2,0) to collect 2 strawberries\n    move_forward()\n    \n    # Total strawberries collected: 3 + 1 + 4 + 2 = 10\n    # Since we need exactly 9, we adjust the path to collect only 3 at (2,1) instead of 4\n    # Adjustments can be made in the problem understanding or setup if exact counts are critical\n\n# Note: The code assumes that the turtle can count or recognize the number of strawberries collected,\n# which is not explicitly provided in the function list. This is a conceptual solution.'
    python_code = """def run():
    # Start from (1, 2), facing East
    # Setting the pen color to black to draw lines
    setpc('black')

    # We need to draw a series of lines from:
    # (0,0) to (1,0), (0,0) to (0,1), (1,0) to (2,0), (2,0) to (2,1), and so on...
    # Define movements to draw the black outline

    # Assume the turtle trails a line behind it as it moves.

    # First, align the turtle and move it to the beginning of the path at (0,0)
    turn_left()       # West
    move_forward()    # Now at (0, 2)
    move_forward()    # Now at (0, 3)
    turn_left()       # South
    move_backward()   # Now at (0, 2)
    move_backward()   # Now at (0, 1)
    move_backward()   # Now at (0, 0) start drawing

    # Begin drawing the outline
    move_forward()    # Draw to (1, 0)
    move_forward()    # Draw to (2, 0)
    turn_right()      # Now facing South
    move_forward()    # Draw to (2, 1)
    move_forward()    # Draw to (2, 2)
    move_forward()    # Draw to (2, 3)
    turn_right()      # Now facing West
    move_forward()    # Draw to (1, 3)
    move_forward()    # Draw to (0, 3)
    turn_right()      # Now facing North
    move_forward()    # Draw to (0, 2)
    move_forward()    # Draw to (0, 1)
    move_forward()    # Back at (0, 0), outline complete

    # Terminate the function
    pass

# Assume subsequent calls or event triggers 'run()'
    """
    python2codejs(python_code)

