from typing import List, Tuple, Dict
from functools import partial
from jax import random
import numpy as np
from collections import deque
from src.datasets.dataset_creator.input_generation.v1_objectness import (
    object_generator,
    generate_half_objects_input,
    generate_half_objects_input_vert,
    noisy_object_generator,
)


def rotate_object(obj: np.ndarray, k: int) -> np.ndarray:
    return np.rot90(obj, k=k)


def rotate_objects_in_grid(input_grid: np.ndarray, extra: Dict, k: int) -> np.ndarray:
    output_grid = np.zeros_like(input_grid)
    objects = extra["objects"]

    for top, left, height, width, color in objects:
        obj = input_grid[top : top + height, left : left + width]
        rotated_obj = rotate_object(obj, k)

        new_height, new_width = rotated_obj.shape
        new_top = top + (height - new_height) // 2
        new_left = left + (width - new_width) // 2

        new_top = max(0, min(new_top, input_grid.shape[0] - new_height))
        new_left = max(0, min(new_left, input_grid.shape[1] - new_width))

        output_grid[new_top : new_top + new_height, new_left : new_left + new_width] = rotated_obj

    return output_grid


def move_objects_in_grid_wrap(input_grid: np.ndarray, extra: Dict, dx: int, dy: int) -> np.ndarray:
    output_grid = np.zeros_like(input_grid)
    rows, cols = input_grid.shape
    objects = extra["objects"]

    for top, left, height, width, color in objects:
        obj = input_grid[top : top + height, left : left + width]
        new_top = (top + dy) % rows
        new_left = (left + dx) % cols

        # Handle wrapping for objects that cross the grid boundary
        for i in range(height):
            for j in range(width):
                output_grid[(new_top + i) % rows, (new_left + j) % cols] = obj[i, j]

    return output_grid


def keep_objects_by_rank(input_grid: np.ndarray, extra: Dict, ranks_to_keep: List[int]) -> np.ndarray:
    output_grid = np.zeros_like(input_grid)
    objects = extra["objects"]
    size_ranking = extra["size_ranking"]

    for i, rank in enumerate(size_ranking):
        if rank in ranks_to_keep:
            top, left, height, width, color = objects[i]
            output_grid[top : top + height, left : left + width] = input_grid[
                top : top + height, left : left + width
            ]

    return output_grid


def keep_largest_object(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    return keep_objects_by_rank(input_grid, extra, [1])


def keep_two_largest_objects(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    return keep_objects_by_rank(input_grid, extra, [1, 2])


def keep_smallest_object(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    max_rank = max(extra["size_ranking"])
    return keep_objects_by_rank(input_grid, extra, [max_rank])


def keep_two_smallest_objects(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    size_ranking = extra["size_ranking"]
    max_rank = max(size_ranking)
    second_max_rank = max(rank for rank in size_ranking if rank != max_rank)
    return keep_objects_by_rank(input_grid, extra, [max_rank, second_max_rank])


def reorder_objects_by_size(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    output_grid = np.zeros_like(input_grid)
    objects = extra["objects"]
    size_ranking = extra["size_ranking"]

    # Sort objects by their size ranking (largest to smallest)
    sorted_objects = sorted(zip(objects, size_ranking), key=lambda x: x[1])

    # Create a mask to keep track of occupied areas
    occupied_mask = np.zeros_like(input_grid, dtype=bool)

    def find_next_all_black_row(start_row: int) -> int:
        for row in range(start_row, input_grid.shape[0]):
            if np.all(input_grid[row] == 0):
                return row
        return None

    def find_next_position(height: int, width: int, start_row: int, start_col: int) -> Tuple[int, int]:
        row = start_row
        col = start_col

        while row is not None and row < input_grid.shape[0] - height + 1:
            if col + width <= input_grid.shape[1] and not np.any(
                occupied_mask[row : row + height, col : col + width]
            ):
                return row, col

            col += 1
            if col + width > input_grid.shape[1]:
                row = find_next_all_black_row(row + 1)
                col = 0

        return None, None

    def place_object(obj: np.ndarray, row: int, col: int, height: int, width: int):
        output_grid[row : row + height, col : col + width] = obj
        occupied_mask[row : row + height, col : col + width] = True

    current_row = 0
    current_col = 0
    for (top, left, height, width, color), _ in sorted_objects:
        obj = input_grid[top : top + height, left : left + width]

        next_row, next_col = find_next_position(height, width, current_row, current_col)

        if next_row is None or next_col is None:
            break

        place_object(obj, next_row, next_col, height, width)

        current_row = next_row
        current_col = next_col + width

        if current_col + 1 >= input_grid.shape[1]:
            current_row = find_next_all_black_row(current_row + 1)
            if current_row is None:
                break
            current_col = 0

    # Trim the output grid
    non_empty_rows = np.any(output_grid != 0, axis=1)
    non_empty_cols = np.any(output_grid != 0, axis=0)
    trimmed_output_grid = output_grid[non_empty_rows][:, non_empty_cols]

    return trimmed_output_grid


def fill_object_rectangles(input_grid: np.ndarray, extra: Dict, key: random.PRNGKey) -> np.ndarray:
    output_grid = input_grid.copy()
    objects = extra["objects"]

    # Generate a single random color for all fills (excluding black)
    fill_color = random.randint(key, (), 1, 10).item()  # Colors 1-9

    for top, left, height, width, color in objects:
        # Extract the object's rectangle
        rectangle = output_grid[top : top + height, left : left + width]

        # Fill all black spaces in the rectangle with the chosen color
        rectangle[rectangle == 0] = fill_color

        # Place the modified rectangle back in the output grid
        output_grid[top : top + height, left : left + width] = rectangle

    return output_grid


def draw_shape_outlines(input_grid: np.ndarray, extra: Dict, key: random.PRNGKey) -> np.ndarray:
    output_grid = input_grid.copy()
    objects = extra["objects"]

    # Generate a single random color for all outlines (excluding black)
    outline_color = random.randint(key, (), 1, 10).item()  # Colors 1-9

    def flood_fill_external(obj, shape_color):
        height, width = obj.shape
        visited = np.zeros((height, width), dtype=bool)
        external = np.zeros((height, width), dtype=bool)
        queue = deque()

        # Start from the edges
        for i in range(height):
            queue.extend([(i, 0), (i, width - 1)])
        for j in range(width):
            queue.extend([(0, j), (height - 1, j)])

        while queue:
            x, y = queue.popleft()
            if visited[x, y]:
                continue
            visited[x, y] = True

            if obj[x, y] == 0:  # Black cell
                external[x, y] = True
                for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                    nx, ny = x + dx, y + dy
                    if 0 <= nx < height and 0 <= ny < width and not visited[nx, ny]:
                        queue.append((nx, ny))
            elif obj[x, y] != shape_color:  # Non-shape, non-black cell
                for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                    nx, ny = x + dx, y + dy
                    if 0 <= nx < height and 0 <= ny < width and not visited[nx, ny]:
                        queue.append((nx, ny))

        return external

    for top, left, height, width, color in objects:
        # Extract the object and its immediate surroundings
        obj = output_grid[
            max(0, top - 1) : min(output_grid.shape[0], top + height + 1),
            max(0, left - 1) : min(output_grid.shape[1], left + width + 1),
        ]

        # Identify external black cells
        external_cells = flood_fill_external(obj, color)

        # Color the external black cells
        obj[np.logical_and(obj == 0, external_cells)] = outline_color

        # Place the modified object back in the output grid
        output_grid[
            max(0, top - 1) : min(output_grid.shape[0], top + height + 1),
            max(0, left - 1) : min(output_grid.shape[1], left + width + 1),
        ] = obj

    return output_grid


def simulate_gravity(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    output_grid = np.zeros_like(input_grid)
    objects = extra["objects"]

    # Sort objects from bottom to top
    sorted_objects = sorted(objects, key=lambda obj: obj[0] + obj[2], reverse=True)

    def can_move_down(obj, dy):
        top, left, height, width, color = obj
        new_top = top + dy
        if new_top + height > output_grid.shape[0]:
            return False
        for y in range(new_top, new_top + height):
            for x in range(left, left + width):
                if output_grid[y, x] != 0 and input_grid[top + y - new_top, x] != 0:
                    return False
        return True

    for obj in sorted_objects:
        top, left, height, width, color = obj
        shape = input_grid[top : top + height, left : left + width]

        # Move the shape down until it hits a surface or another shape
        dy = 0
        while can_move_down(obj, dy + 1):
            dy += 1

        # Place the shape in its new position
        new_top = top + dy
        output_grid[new_top : new_top + height, left : left + width] = np.where(
            shape != 0, shape, output_grid[new_top : new_top + height, left : left + width]
        )

    return output_grid


def expand_shapes(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    output_grid = np.zeros_like(input_grid)
    objects = extra["objects"]
    height, width = input_grid.shape

    # First, place all original shapes
    for top, left, obj_height, obj_width, color in objects:
        output_grid[top : top + obj_height, left : left + obj_width] = input_grid[
            top : top + obj_height, left : left + obj_width
        ]

    # Directions for expansion (including diagonals)
    directions = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]

    # Expansion stage
    expansion_grid = np.zeros_like(output_grid)
    for top, left, obj_height, obj_width, color in objects:
        for x in range(top - 1, top + obj_height + 1):
            for y in range(left - 1, left + obj_width + 1):
                if 0 <= x < height and 0 <= y < width:
                    if output_grid[x, y] == color:
                        for dx, dy in directions:
                            nx, ny = x + dx, y + dy
                            if 0 <= nx < height and 0 <= ny < width:
                                if output_grid[nx, ny] == 0:
                                    expansion_grid[nx, ny] = color
                                elif output_grid[nx, ny] != color and expansion_grid[nx, ny] == 0:
                                    # Collision: mark as contested
                                    expansion_grid[nx, ny] = -1

    # Resolve contested pixels
    for x in range(height):
        for y in range(width):
            if expansion_grid[x, y] == -1:
                neighbors = [
                    output_grid[x + dx, y + dy]
                    for dx, dy in directions
                    if 0 <= x + dx < height and 0 <= y + dy < width
                ]
                colors = [c for c in neighbors if c != 0]
                if len(set(colors)) == 1:
                    expansion_grid[x, y] = colors[0]
                else:
                    expansion_grid[x, y] = 0  # Leave as background if multiple colors contest

    # Apply expansion
    for x in range(height):
        for y in range(width):
            if expansion_grid[x, y] > 0:
                output_grid[x, y] = expansion_grid[x, y]

    return output_grid


import numpy as np
from typing import Tuple, Dict, List


def mirror_half_objects(input_grid: np.ndarray, extra: Dict) -> np.ndarray:
    output_grid = input_grid.copy()
    half_objects = extra["half_objects"]
    mirror_axes = extra["mirror_axes"]

    for (top, left, height, width, color), (axis, position) in zip(half_objects, mirror_axes):
        half_object = input_grid[top : top + height, left : left + width]

        if axis == "horizontal":
            mirrored = np.vstack((half_object, np.flipud(half_object)))
            output_grid[top : top + 2 * height, left : left + width] = mirrored
        else:  # vertical
            mirrored = np.hstack((half_object, np.fliplr(half_object)))
            output_grid[top : top + height, left : left + 2 * width] = mirrored

    return output_grid


def denoise_objects(inputs: Tuple[np.ndarray, Dict], key) -> np.ndarray:
    """
    Returns the original unpixelated image.
    """
    noisy_grid, extras = inputs
    return extras["original_grid"]


TRANSFORMATIONS = [
    {
        "name": "rotate_objects_90",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: rotate_objects_in_grid(I[0], I[1], k=1),
        "frequency_weight": 1,
    },
    {
        "name": "rotate_objects_180",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: rotate_objects_in_grid(I[0], I[1], k=2),
        "frequency_weight": 1,
    },
    {
        "name": "rotate_objects_270",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: rotate_objects_in_grid(I[0], I[1], k=3),
        "frequency_weight": 1,
    },
    {
        "name": "move_objects_right",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: move_objects_in_grid_wrap(
            I[0], I[1], dx=random.randint(key, (), 1, 5).item(), dy=0
        ),
        "frequency_weight": 1,
    },
    {
        "name": "move_objects_left",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: move_objects_in_grid_wrap(
            I[0], I[1], dx=-random.randint(key, (), 1, 5).item(), dy=0
        ),
        "frequency_weight": 1,
    },
    {
        "name": "move_objects_down",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: move_objects_in_grid_wrap(
            I[0], I[1], dx=0, dy=random.randint(key, (), 1, 5).item()
        ),
        "frequency_weight": 1,
    },
    {
        "name": "move_objects_up",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: move_objects_in_grid_wrap(
            I[0], I[1], dx=0, dy=-random.randint(key, (), 1, 5).item()
        ),
        "frequency_weight": 1,
    },
    {
        "name": "keep_largest_object",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: keep_largest_object(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "keep_two_largest_objects",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: keep_two_largest_objects(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "keep_smallest_object",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: keep_smallest_object(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "keep_two_smallest_objects",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: keep_two_smallest_objects(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "reorder_objects_by_size",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: reorder_objects_by_size(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "fill_holes_in_shapes",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: fill_object_rectangles(I[0], I[1], key),
        "frequency_weight": 1,
    },
    {
        "name": "draw_shape_boundaries",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: draw_shape_outlines(I[0], I[1], key),
        "frequency_weight": 1,
    },
    {
        "name": "simulate_gravity",
        "input_generator": partial(object_generator, min_distance=1),
        "program": lambda I, key: simulate_gravity(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "expand_shapes",
        "input_generator": partial(object_generator, min_distance=2),
        "program": lambda I, key: expand_shapes(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "mirror_half_objects",
        "input_generator": generate_half_objects_input,
        "program": lambda I, key: mirror_half_objects(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "mirror_half_objects_vert",
        "input_generator": generate_half_objects_input_vert,
        "program": lambda I, key: mirror_half_objects(I[0], I[1]),
        "frequency_weight": 1,
    },
    {
        "name": "denoise_objects",
        "input_generator": noisy_object_generator,
        "program": denoise_objects,
        "frequency_weight": 1,
    },
]
