from functools import partial

import jax
import numpy as np

import jax.numpy as jnp
from src.datasets.dataset_creator.input_generation.single_point_markers import create_sparse_point_grid
from src.datasets.dataset_creator.dsl_0 import *
from src.datasets.dataset_creator.constants import *
from src.datasets.dataset_creator.arc_types import *
from src.datasets.dataset_creator.input_generation.v1_objectness import (
    object_generator,
    object_generator_with_holes,
    diverse_shapes_generator,
)
from src.datasets.dataset_creator.input_generation.v1_generic import (
    v0_arc_general_input,
    binary_mask_pattern_generator,
)
from src.datasets.dataset_creator.input_generation.v1_patterns import (
    missing_pattern_input_generator,
    generate_partial_pattern,
    generate_beam_extension_input,
    generate_complex_beam_input,
)


import numpy as np
from typing import Tuple, Dict
from jax import random
import jax
import jax.numpy as jnp

import jax
import jax.numpy as jnp


def generate_unique_colors(key, num_colors):
    """
    Generates a specified number of unique colors within the range 1-9.
    """
    colors = jax.random.permutation(key, jnp.arange(1, 10))[:num_colors]
    return colors.tolist()


def create_fractal_pattern(key, n):
    colors = generate_unique_colors(key, 2)
    pattern = jnp.zeros((n, n), dtype=jnp.int32)
    pattern = pattern.at[0::2, 0::2].set(colors[0])
    pattern = pattern.at[1::2, 1::2].set(colors[1])
    return pattern


def create_spiral_pattern(key, n):
    colors = generate_unique_colors(key, 2)  # Changed to 2 colors
    pattern = jnp.full((n, n), colors[0], dtype=jnp.int32)
    directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
    x, y = 0, 0
    dx, dy = 0, 1
    for i in range(n * n):
        pattern = pattern.at[y, x].set(colors[i % 2])  # Alternate between two colors
        next_x, next_y = x + dx, y + dy
        if 0 <= next_x < n and 0 <= next_y < n and pattern[next_y, next_x] == colors[0]:
            x, y = next_x, next_y
        else:
            dx, dy = directions[(directions.index((dx, dy)) + 1) % 4]
            x, y = x + dx, y + dy
    return pattern


def create_concentric_pattern(key, n):
    num_rings = (n + 1) // 2
    colors = generate_unique_colors(key, 2)  # Changed to 2 colors
    pattern = jnp.zeros((n, n), dtype=jnp.int32)
    for i in range(num_rings):
        pattern = pattern.at[i : n - i, i : n - i].set(colors[i % 2])  # Alternate between two colors
    return pattern


def create_diagonal_pattern(key, n):
    colors = generate_unique_colors(key, 2)
    pattern = jnp.zeros((n, n), dtype=jnp.int32)
    for i in range(n):
        pattern = pattern.at[i, : i + 1].set(colors[0])
        pattern = pattern.at[i, i + 1 :].set(colors[1])
    return pattern


def create_cross_pattern(key, n):
    colors = generate_unique_colors(key, 2)
    pattern = jnp.full((n, n), colors[0], dtype=jnp.int32)
    mid = n // 2
    pattern = pattern.at[mid, :].set(colors[1])
    pattern = pattern.at[:, mid].set(colors[1])
    return pattern


def create_checkerboard_pattern(key, n):
    colors = generate_unique_colors(key, 2)
    pattern = jnp.zeros((n, n), dtype=jnp.int32)
    pattern = pattern.at[::2, ::2].set(colors[0])
    pattern = pattern.at[1::2, 1::2].set(colors[0])
    pattern = pattern.at[1::2, ::2].set(colors[1])
    pattern = pattern.at[::2, 1::2].set(colors[1])
    return pattern


def create_random_blob_pattern(key, n):
    blob_key, count_key = jax.random.split(key)
    colors = generate_unique_colors(blob_key, 2)  # Changed to 2 colors
    pattern = jnp.full((n, n), colors[0], dtype=jnp.int32)
    count = jax.random.randint(count_key, (), n, n * 2).item()
    indices = jax.random.choice(key, n * n, (count,), replace=False)
    pattern = pattern.at[indices // n, indices % n].set(colors[1])
    return pattern


def create_arrow_pattern(key, n):
    colors = generate_unique_colors(key, 2)
    pattern = jnp.full((n, n), colors[0], dtype=jnp.int32)
    direction = jax.random.choice(key, 4).item()
    mid = n // 2
    if direction == 0:  # up
        pattern = pattern.at[1:, mid].set(colors[1])
        pattern = pattern.at[0, mid - 1 : mid + 2].set(colors[1])
    elif direction == 1:  # right
        pattern = pattern.at[mid, : n - 1].set(colors[1])
        pattern = pattern.at[mid - 1 : mid + 2, n - 1].set(colors[1])
    elif direction == 2:  # down
        pattern = pattern.at[: n - 1, mid].set(colors[1])
        pattern = pattern.at[n - 1, mid - 1 : mid + 2].set(colors[1])
    else:  # left
        pattern = pattern.at[mid, 1:].set(colors[1])
        pattern = pattern.at[mid - 1 : mid + 2, 0].set(colors[1])
    return pattern


def generate_interesting_pattern(key, n):
    """
    Generates a random nxn interesting pattern using JAX random.
    """
    pattern_options = [
        create_fractal_pattern,
        create_spiral_pattern,
        create_concentric_pattern,
        create_diagonal_pattern,
        create_cross_pattern,
        create_checkerboard_pattern,
        create_random_blob_pattern,
        create_arrow_pattern,
    ]

    option_key, pattern_key = jax.random.split(key)
    chosen_pattern = pattern_options[jax.random.choice(option_key, len(pattern_options)).item()]

    return chosen_pattern(pattern_key, n)


def draw_interesting_patterns(inputs, key, pattern_size):
    """
    Draws the same nxn interesting pattern at each non-black point in the input grid.
    The original point becomes the top-left corner of the pattern.
    Uses JAX for reproducible randomness based on the provided key.

    :param I: Input grid (list of lists or numpy array)
    :param key: JAX random key for random number generation
    :param pattern_size: Size of the pattern (n for nxn patterns)
    :return: Transformed grid with consistent interesting patterns
    """

    input_grid, _ = inputs

    I = tuple(tuple(row) for row in input_grid)

    I = np.array(I)
    height, width = I.shape
    O = np.zeros((height, width), dtype=np.int32)

    # Generate a single pattern for all non-black points
    pattern = generate_interesting_pattern(key, pattern_size)

    for i in range(height):
        for j in range(width):
            if I[i, j] != 0:
                for di in range(pattern_size):
                    for dj in range(pattern_size):
                        if i + di < height and j + dj < width:
                            O[i + di, j + dj] = pattern[di, dj]

    return np.array(O.tolist())


def color_swap(I):
    """
    This function swaps the least common color with the most common color in the input grid.
    It identifies the least and most common colors, then switches them throughout the grid.
    """

    I = tuple(tuple(row) for row in I)

    x1 = leastcolor(I)
    x2 = mostcolor(I)
    O = switch(I, x1, x2)
    return np.array(O)


def isolate_color(I):
    """
    This function isolates the least common color in the input grid.
    It identifies the least common color, finds all cells with this color,
    and then fills these cells with their original color while setting all other cells to 0.
    """

    I = tuple(tuple(row) for row in I)

    x1 = leastcolor(I)
    x2 = ofcolor(I, x1)
    O = fill(I, 0, x2)
    return np.array(O)


def diagonal_fill(I):
    """
    This function fills a diagonal line from the top-left corner to the bottom-right corner
    of the input grid with the least common color found in the original grid.
    """
    I = tuple(tuple(row) for row in I)
    x1 = leastcolor(I)
    x2 = shoot(ORIGIN, UNITY)
    O = fill(I, x1, x2)
    return np.array(O)


def rotate_90(I):
    """
    This function rotates the input grid 90 degrees clockwise.
    """
    I = tuple(tuple(row) for row in I)
    O = rot90(I)
    return jnp.array(O)


def rotate_180(I):
    """
    This function rotates the input grid 180 degrees.
    """
    I = tuple(tuple(row) for row in I)
    O = rot180(I)
    return jnp.array(O)


def rotate_270(I):
    """
    This function rotates the input grid 270 degrees clockwise (or 90 degrees counterclockwise).
    """
    I = tuple(tuple(row) for row in I)
    O = rot270(I)
    return jnp.array(O)


def tuple_to_numpy(tuple_grid):
    return np.array(tuple_grid)


def numpy_to_tuple(np_grid):
    return tuple(tuple(row) for row in np_grid)


def checkerboard_overlay(I):
    I_tuple = numpy_to_tuple(I)
    overlay_color = leastcolor(I_tuple)
    O = tuple(
        tuple(overlay_color if (i + j) % 2 == 0 else cell for j, cell in enumerate(row))
        for i, row in enumerate(I_tuple)
    )
    return tuple_to_numpy(O)


import jax
import jax.numpy as jnp


def alternate_quadrants(inputs, key):
    I, _ = inputs

    # Convert I to a JAX array
    I = jnp.array(I)

    # Get the dimensions of the input
    height, width = I.shape

    # Calculate the size of each quadrant
    quad_h, quad_w = height // 2, width // 2

    # Trim the grid if necessary
    I = I[: quad_h * 2, : quad_w * 2]

    # Split the grid into quadrants
    top_left = I[:quad_h, :quad_w]
    top_right = I[:quad_h, quad_w:]
    bottom_left = I[quad_h:, :quad_w]
    bottom_right = I[quad_h:, quad_w:]

    # Create a list of quadrants
    quadrants = [top_left, top_right, bottom_left, bottom_right]

    # Generate all non-identity permutations
    all_permutations = jnp.array(
        [
            [1, 0, 3, 2],
            [2, 3, 0, 1],
            [3, 2, 1, 0],  # 2 swaps
            [1, 3, 0, 2],
            [2, 0, 3, 1],
            [3, 1, 2, 0],  # 3 rotations
            [1, 2, 3, 0],
            [2, 3, 1, 0],
            [3, 0, 2, 1],  # 4-cycles
            [1, 3, 2, 0],
            [2, 1, 3, 0],
            [3, 2, 0, 1],  # other 3 swaps
        ]
    )

    # Use the provided key to choose a random non-identity permutation
    perm_idx = jax.random.randint(key, (), 0, len(all_permutations))
    new_order = all_permutations[perm_idx]

    # Reorder the quadrants
    reordered = [quadrants[i] for i in new_order]

    # Reconstruct the grid
    top = jnp.hstack((reordered[0], reordered[1]))
    bottom = jnp.hstack((reordered[2], reordered[3]))
    result = jnp.vstack((top, bottom))

    return result


def fill_shape_holes(inputs, key: random.PRNGKey) -> np.ndarray:

    input_grid, extra = inputs

    output_grid = input_grid.copy()
    objects = extra["objects"]
    object_holes = extra["object_holes"]

    # Generate a single random color for all fills (excluding black and existing colors)
    # existing_colors = set(obj[4] for obj in objects)
    available_colors = list(set(range(1, 10)))
    fill_color = random.choice(key, np.array(available_colors))

    for obj_index, holes in object_holes.items():
        top, left, height, width, color = objects[obj_index]
        for hole_top, hole_left, hole_height, hole_width in holes:
            absolute_top = top + hole_top
            absolute_left = left + hole_left
            output_grid[
                absolute_top : absolute_top + hole_height, absolute_left : absolute_left + hole_width
            ] = fill_color

    return output_grid


def tile_interesting_patterns(inputs, key) -> np.ndarray:
    mask, extra = inputs
    mask_size = extra["mask_size"]
    mask_color = extra["mask_color"]

    # Generate the interesting pattern
    pattern = generate_interesting_pattern(key, jax.random.randint(key, (), 3, 4))
    pattern_size = pattern.shape[0]  # Determine pattern size from the generated pattern

    # Checking that the pattern is not a single color
    assert (pattern == pattern[0]).all() != True

    # print(pattern)

    # Calculate the size of the output grid
    output_size = mask_size * pattern_size

    # Create the output grid
    output = np.zeros((output_size, output_size), dtype=int)

    # Tile the pattern according to the mask
    for i in range(mask_size):
        for j in range(mask_size):
            if mask[i, j] == mask_color:
                output[
                    i * pattern_size : (i + 1) * pattern_size, j * pattern_size : (j + 1) * pattern_size
                ] = pattern

    return output


def tile_mask(inputs, key) -> np.ndarray:
    mask, extra = inputs
    mask_size = extra["mask_size"]
    mask_color = extra["mask_color"]

    # Generate the interesting pattern
    pattern = mask
    pattern_size = pattern.shape[0]  # Determine pattern size from the generated pattern

    # print(pattern)

    # Calculate the size of the output grid
    output_size = mask_size * pattern_size

    # Create the output grid
    output = np.zeros((output_size, output_size), dtype=int)

    # Tile the pattern according to the mask
    for i in range(mask_size):
        for j in range(mask_size):
            if mask[i, j] == mask_color:
                output[
                    i * pattern_size : (i + 1) * pattern_size, j * pattern_size : (j + 1) * pattern_size
                ] = pattern

    return output


def complete_pattern(inputs, key) -> np.ndarray:
    input_pattern, extra = inputs
    original_pattern = extra["original_pattern"]
    mask = extra["mask"]

    # Simply replace the black cells with the original pattern
    output_pattern = np.where(mask, input_pattern, original_pattern)

    return output_pattern


def complete_pattern2(inputs, key) -> np.ndarray:
    partial_pattern, extra = inputs
    n = partial_pattern.shape[0]
    completed_pattern = np.zeros((n, n), dtype=int)

    if extra["pattern_type"] == "stripe":
        colors = extra["colors"]
        stripe_width = extra["stripe_width"]
        for i in range(n):
            completed_pattern[i] = colors[(i // stripe_width) % 2]

    elif extra["pattern_type"] == "checkerboard":
        colors = extra["colors"]
        square_size = extra["square_size"]
        for i in range(n):
            for j in range(n):
                completed_pattern[i, j] = colors[((i // square_size) + (j // square_size)) % 2]

    elif extra["pattern_type"] == "wave":
        colors = extra["colors"]
        wave_length = extra["wave_length"]
        amplitude = extra["amplitude"]
        for i in range(n):
            wave = int(amplitude * np.sin(2 * np.pi * i / wave_length))
            for j in range(n):
                if j <= n // 2 + wave:
                    completed_pattern[i, j] = colors[0]
                else:
                    completed_pattern[i, j] = colors[1]

    elif extra["pattern_type"] == "zigzag":
        colors = extra["colors"]
        zigzag_width = extra["zigzag_width"]
        for i in range(n):
            for j in range(n):
                if (i // zigzag_width) % 2 == 0:
                    completed_pattern[i, j] = colors[j % 2]
                else:
                    completed_pattern[i, j] = colors[(j + 1) % 2]

    elif extra["pattern_type"] == "diamond":
        colors = extra["colors"]
        diamond_size = extra["diamond_size"]
        for i in range(n):
            for j in range(n):
                if (i % diamond_size) + (j % diamond_size) < diamond_size:
                    completed_pattern[i, j] = colors[0]
                else:
                    completed_pattern[i, j] = colors[1]

    return completed_pattern


def extend_beams(inputs, key) -> np.ndarray:
    input_grid, extra = inputs
    n = input_grid.shape[0]
    output_grid = np.zeros((n, n), dtype=int)

    # Iterate from right to left
    for x in range(n - 1, -1, -1):
        # Iterate from bottom to top within each column
        for y in range(n - 1, -1, -1):
            if input_grid[y, x] != 0:
                color = input_grid[y, x]
                # Extend beam horizontally (full row)
                output_grid[y, :] = color
                # Extend beam vertically (full column)
                output_grid[:, x] = color

    return output_grid


def extend_beams2(inputs, key) -> np.ndarray:
    input_grid, extra = inputs
    n = input_grid.shape[0]
    output_grid = np.zeros((n, n), dtype=int)
    beam_color = extra["beam_color"]

    # Iterate from right to left
    for x in range(n - 1, -1, -1):
        # Iterate from bottom to top within each column
        for y in range(n - 1, -1, -1):
            if input_grid[y, x] == beam_color:
                # Extend beam horizontally (full row)
                output_grid[y, :] = beam_color
                # Extend beam vertically (full column)
                output_grid[:, x] = beam_color

    return output_grid


TRANSFORMATIONS = [
    {
        "name": "pattern_3",
        "input_generator": partial(create_sparse_point_grid, min_distance_points=3, min_distance_right=3),
        "program": partial(draw_interesting_patterns, pattern_size=3),
        "frequency_weight": 1 / 7,
    },
    {
        "name": "pattern_4",
        "input_generator": partial(create_sparse_point_grid, min_distance_points=4, min_distance_right=4),
        "program": partial(draw_interesting_patterns, pattern_size=4),
        "frequency_weight": 1 / 7,
    },
    {
        "name": "pattern_5",
        "input_generator": partial(create_sparse_point_grid, min_distance_points=5, min_distance_right=5),
        "program": partial(draw_interesting_patterns, pattern_size=5),
        "frequency_weight": 1 / 7,
    },
    {
        "name": "pattern_6",
        "input_generator": partial(create_sparse_point_grid, min_distance_points=6, min_distance_right=6),
        "program": partial(draw_interesting_patterns, pattern_size=6),
        "frequency_weight": 1 / 7,
    },
    {
        "name": "pattern_10",
        "input_generator": partial(create_sparse_point_grid, min_distance_points=10, min_distance_right=10),
        "program": partial(draw_interesting_patterns, pattern_size=10),
        "frequency_weight": 1 / 7,
    },
    {
        "name": "pattern_11",
        "input_generator": partial(create_sparse_point_grid, min_distance_points=11, min_distance_right=11),
        "program": partial(draw_interesting_patterns, pattern_size=11),
        "frequency_weight": 1 / 7,
    },
    {
        "name": "pattern_30",
        "input_generator": partial(create_sparse_point_grid, min_distance_points=30, min_distance_right=30),
        "program": partial(draw_interesting_patterns, pattern_size=30),
        "frequency_weight": 1 / 7,
    },
    {
        "name": "color_swap",
        "input_generator": v0_arc_general_input,
        "program": lambda x, key: color_swap(x[0]),
        "frequency_weight": 1,
    },
    {
        "name": "isolate_color",
        "input_generator": v0_arc_general_input,
        "program": lambda x, key: isolate_color(x[0]),
        "frequency_weight": 1,
    },
    {
        "name": "diagonal_fill",
        "input_generator": v0_arc_general_input,
        "program": lambda x, key: diagonal_fill(x[0]),
        "frequency_weight": 1,
    },
    {
        "name": "rotate_90",
        "input_generator": v0_arc_general_input,
        "program": lambda x, key: rotate_90(x[0]),
        "frequency_weight": 1 / 3,
    },
    {
        "name": "rotate_180",
        "input_generator": v0_arc_general_input,
        "program": lambda x, key: rotate_180(x[0]),
        "frequency_weight": 1 / 3,
    },
    {
        "name": "rotate_270",
        "input_generator": v0_arc_general_input,
        "program": lambda x, key: rotate_270(x[0]),
        "frequency_weight": 1 / 3,
    },
    {
        "name": "checkerboard_overlay",
        "input_generator": v0_arc_general_input,
        "program": lambda x, key: checkerboard_overlay(x[0]),
        "frequency_weight": 1,
    },
    {
        "name": "alternate_quadrants",
        "input_generator": v0_arc_general_input,
        "program": alternate_quadrants,
        "frequency_weight": 1,
    },
    {
        "name": "object_holes",
        "input_generator": partial(object_generator_with_holes, min_distance=2),
        "program": fill_shape_holes,
        "frequency_weight": 1,
    },
    {
        "name": "object_holes_interesting",
        "input_generator": diverse_shapes_generator,
        "program": fill_shape_holes,
        "frequency_weight": 1,
    },
    {
        "name": "tiled_interesting_patterns",
        "input_generator": binary_mask_pattern_generator,
        "program": tile_interesting_patterns,
        "frequency_weight": 1,
    },
    {
        "name": "tiled_mask",
        "input_generator": binary_mask_pattern_generator,
        "program": tile_mask,
        "frequency_weight": 1,
    },
    {
        "name": "pattern_completion",
        "input_generator": missing_pattern_input_generator,
        "program": complete_pattern,
        "frequency_weight": 1,
    },
    {
        "name": "extend_repeating_pattern",
        "input_generator": generate_partial_pattern,
        "program": complete_pattern2,
        "frequency_weight": 1,
    },
    {
        "name": "extend_color_beams",
        "input_generator": generate_beam_extension_input,
        "program": extend_beams,
        "frequency_weight": 1,
    },
    {
        "name": "extend_beams_complex_input",
        "input_generator": generate_complex_beam_input,
        "program": extend_beams2,
        "frequency_weight": 1,
    },
]
