import collections
import logging

import math
import numpy as np
import regex as re
from absl import logging
from shapely.geometry.polygon import Polygon

from veoplace.utils import get_clamped_corners
from veoplace.utils import parse_valid_rectangle


def parse_suggest_random_regions(gemini_response, env, selected_macros):
    regions = {}

    # Create suggestion mask to track used regions
    suggestion_mask = np.zeros((env.grid_size, env.grid_size))

    for macro in selected_macros:
        macro_name = macro['name']
        macro_width = macro['width']
        macro_height = macro['height']

        # Calculate valid range for bottom-left corner
        max_x = env.grid_size - macro_width
        max_y = env.grid_size - macro_height

        # Try to find non-overlapping region (max 10 attempts)
        for attempt in range(3):
            # Sample random bottom-left corner
            x1 = np.random.randint(0, max_x + 1)
            y1 = np.random.randint(0, max_y + 1)

            # Region is exactly macro size (no padding)
            x2 = x1 + macro_width
            y2 = y1 + macro_height

            # Check if region overlaps with existing suggestions
            tmp_mask = np.zeros((env.grid_size, env.grid_size))
            tmp_mask[y1:y2, x1:x2] = 1

            if not np.any(suggestion_mask * tmp_mask):
                # No overlap, add to regions
                points = [
                        (x1, y1),  # bottom-left
                        (x2, y1),  # bottom-right
                        (x2, y2),  # top-right
                        (x1, y2),  # top-left
                ]
                regions[macro_name] = points

                # Update suggestion mask
                suggestion_mask[y1:y2, x1:x2] = 1
                break

    return regions


def extract_coordinates(text, short_name):
    """
    Extract coordinates for a specific macro in the expected format.
    If the format is significantly different, return None to flag it as invalid.
    """
    # Pattern that looks for short_name, followed by anything, then a colon,
    # then anything, then two coordinate pairs without revision markers
    pattern = f"{short_name}.*?:.*?\((\d+),\s*(\d+)\).*?\((\d+),\s*(\d+)\)"

    match = re.search(pattern, text)

    if not match:
        # logging.warning(
        #         f"Could not find valid coordinates for macro {short_name}, format may be incorrect")
        return None

    x1 = int(match.group(1))
    y1 = int(match.group(2))
    x2 = int(match.group(3))
    y2 = int(match.group(4))

    return [(x1, y1), (x2, y2)]


def parse_suggest_variable_regions(gemini_response, env, selected_macros):
    """
    Parses regions for all first macros, taking the last occurrence of coordinates for each macro.
    Returns a dictionary mapping macro names to their regions.
    """
    regions = {}

    suggestion_mask = np.zeros(
            (env.grid_size, env.grid_size))  # or whatever grid size we need

    macro_names = [m['name'] for m in selected_macros]

    # Group macros by color
    color_to_macros = collections.defaultdict(list)
    for macro_name in macro_names:
        color_group = env.color_config[macro_name]
        color_to_macros[color_group].append(macro_name)

    for color, macros in color_to_macros.items():
        color_to_macros[color] = sorted(
                macros,
                key=lambda x: env.node_dimensions[env.node_to_idx[x]][0] *
                              env.node_dimensions[env.node_to_idx[x]][1],
                reverse=True
        )
    # Prioritized macro ordering:
    # 1. First, the largest macro from each color group
    # 2. Then, remaining macros sorted by area
    prioritized_macros = []

    # First, add the largest macro from each color group
    for color, macros in color_to_macros.items():
        prioritized_macros.append(macros[0])

    # Then, add all remaining macros sorted by area
    remaining_macros = []
    for color, macros in color_to_macros.items():
        remaining_macros.extend(
                macros[1:])  # Skip the first one we already added

    remaining_macros = sorted(
            remaining_macros,
            key=lambda x: env.node_dimensions[env.node_to_idx[x]][0] *
                          env.node_dimensions[env.node_to_idx[x]][1],
            reverse=True
    )

    prioritized_macros.extend(remaining_macros)
    for macro_name in prioritized_macros:
        regions[macro_name] = None
        short_name = env.node_name_to_short_name[macro_name]

        points = extract_coordinates(gemini_response,
                                     short_name)

        if points:
            try:
                points = parse_valid_rectangle(points)
                points = get_clamped_corners(points, grid_size=env.grid_size)

                # If the coordinates had to be clamped, throw it out

                # Update the suggestion mask
                min_x, min_y = points[0]
                max_x, max_y = points[-2]

                tmp_suggest_mask = np.zeros((env.grid_size, env.grid_size))
                tmp_suggest_mask[min_y:max_y, min_x:max_x] = 1

                # check for overlap between this region and any previous regions
                if np.any(suggestion_mask * tmp_suggest_mask):
                    # Give good logging about what the suggestion was and why we skipped it
                    # logging.warning(
                    #         'Region for macro %s overlaps with previous suggestions: %s',
                    #         short_name, points)
                    continue

                # Let's also simply check that the length and width are enough to fit the macro
                x = env.node_normalized_dimensions_int[
                    env.node_to_idx[macro_name]][
                    0]

                y = env.node_normalized_dimensions_int[
                    env.node_to_idx[macro_name]][
                    1]
                region_width = max_x - min_x
                region_height = max_y - min_y
                if region_width < x or region_height < y:
                    # logging.warning(
                    #         'Region for macro %s (%s): (%d, %d) is too small for macro (%d, %d). Coordinates: %s',
                    #         short_name, macro_name, region_width, region_height, x,
                    #         y, points)
                    continue

                suggestion_mask[min_y:max_y, min_x:max_x] = 1
                regions[macro_name] = points

            except ValueError as e:
                # logging.warning('Error parsing region for macro %s: %s',
                #                 short_name, e)
                continue

    logging.info("Parsed %d regions for %d macros",
                 len([r for r in regions.values() if r is not None]),
                 len(macro_names))

    return regions


def parse_suggest_all_regions(gemini_response, env, selected_macros):
    """
    Parses regions for all first macros, taking the last occurrence of coordinates for each macro.
    Returns a dictionary mapping macro names to their regions.
    """
    regions = {}
    # for macro_name in env.node_name_list[env.num_macro_placed:]:

    suggestion_mask = np.zeros(
            (env.grid_size, env.grid_size))  # or whatever grid size we need

    macro_names = [m['name'] for m in selected_macros]
    for macro_name in macro_names:
        if macro_name in env.first_macro_of_color_group:
            short_name = env.node_name_to_short_name[macro_name]

            # Simple pattern: short_name + anything + 4 coordinate pairs
            # pattern = f"\*\*{short_name}\*\*.*?\((\d+),\s*(\d+)\).*?\((\d+),\s*(\d+)\).*?\((\d+),\s*(\d+)\).*?\((\d+),\s*(\d+)\)"

            # Simple pattern: short_name + anything + 2 coordinate pairs
            # pattern = f"\*\*{short_name}\*\*.*?\((\d+),\s*(\d+)\).*?\((\d+),\s*(\d+)\)"
            pattern = f"{short_name}:.*?\((\d+),\s*(\d+)\).*?\((\d+),\s*(\d+)\)"

            matches = list(re.finditer(pattern, gemini_response))

            if not matches:
                logging.warning('Could not find region for macro %s: %s',
                                short_name, gemini_response)
                continue

            # Take the last match
            last_match = matches[-1]
            # coords = [int(last_match.group(i)) for i in range(1, 9)]
            # points = [(coords[i], coords[i + 1]) for i in range(0, 8, 2)]
            coords = [int(last_match.group(i)) for i in range(1, 5)]
            points = [(coords[i], coords[i + 1]) for i in range(0, 4, 2)]
            try:
                points = parse_valid_rectangle(points)
                points = get_clamped_corners(points, grid_size=env.grid_size)

                # Update the suggestion mask
                min_x, min_y = points[0]
                max_x, max_y = points[-2]

                tmp_suggest_mask = np.zeros((env.grid_size, env.grid_size))
                tmp_suggest_mask[min_y:max_y, min_x:max_x] = 1

                # check for overlap between this region and any previous regions
                if np.any(suggestion_mask * tmp_suggest_mask):
                    # Give good logging about what the suggestioj was and why we skipped it
                    logging.warning(
                            'Region for macro %s overlaps with previous suggestions: %s',
                            short_name, points)
                    continue

                suggestion_mask[min_y:max_y, min_x:max_x] = 1
                regions[macro_name] = points

            except ValueError as e:
                logging.warning('Error parsing region for macro %s: %s',
                                short_name, e)
                continue

    return regions


def parse_suggest_region(gemini_response):
    """
    Parses the LAST occurrence of four coordinate pairs in the response.
    Validates the region using Shapely to ensure it's a valid rectangle with non-zero area.
    """
    pattern = r'\((\d+),\s*(\d+)\)\s*,\s*\((\d+),\s*(\d+)\)\s*,\s*\((\d+),\s*(\d+)\)\s*,\s*\((\d+),\s*(\d+)\)'
    matches = list(re.finditer(pattern, gemini_response))

    if not matches:
        raise ValueError("Could not find valid coordinate pattern in response")

    last_match = matches[-1]
    coords = [int(last_match.group(i)) for i in range(1, 9)]
    points = [(coords[i], coords[i + 1]) for i in range(0, 8, 2)]

    # Create Polygon and validate
    polygon = Polygon(points)
    if not polygon.is_valid or polygon.area == 0 or not math.isclose(
            polygon.minimum_rotated_rectangle.area, polygon.area):
        raise ValueError(
                "Invalid rectangle: must have non-zero area and be a standard rectangle")

    return points
 

PROMPT_PARSERS = {
        'all_regions': parse_suggest_all_regions,
        'variable_regions': parse_suggest_variable_regions,
        'variable_regions_greedy': parse_suggest_variable_regions,
        'conservative': parse_suggest_variable_regions,
        'exploration': parse_suggest_variable_regions,
        'random': parse_suggest_random_regions,
        'recall_test': parse_suggest_variable_regions,

}
