from src.datasets.dataset_creator.dsl_0 import *
from src.datasets.dataset_creator.constants import *
from src.datasets.dataset_creator.arc_types import *
import numpy as np
from typing import Dict
from src.datasets.dataset_creator.input_generation.v1_counting import *


def count_most_common_color(inputs, key):
    """This function calculates the number of times the most common color appears
    in the grid, then returns a 1xN array of that color, where N is the count."""

    I, extra = inputs

    I = tuple(tuple(row) for row in I)
    x1 = min(colorcount(I, mostcolor(I)), 30)
    result = canvas(mostcolor(I), (1, x1))

    return np.array(result)


def count_adjacent_same_color(inputs, key) -> np.ndarray:
    """
    Check for adjacent cells of the same color and return a binary grid.

    Args:
        grid (np.ndarray): Input grid with colored cells

    Returns:
        np.ndarray: Binary grid where 1 (white) indicates at least one adjacent
                    same-colored cell, and 0 (black) indicates no adjacent
                    same-colored cells.
    """
    grid, extra = inputs

    rows, cols = grid.shape
    result = np.zeros((rows, cols), dtype=int)

    for i in range(rows):
        for j in range(cols):
            for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1)]:  # Check only 4-connected neighbors
                ni, nj = i + di, j + dj
                if 0 <= ni < rows and 0 <= nj < cols:
                    if grid[ni, nj] == grid[i, j]:
                        result[i, j] = 1
                        break

    return result


def create_color_count_bar_chart(inputs, key) -> np.ndarray:
    """
    Create a bar chart representing the count of each color in the input grid.

    Args:
        grid (np.ndarray): Input grid with colored cells
        extra (dict): Extra information (not used in this function)

    Returns:
        np.ndarray: Bar chart grid where each row represents a color's count
    """
    grid, extra = inputs

    if grid.size == 0:
        return np.zeros((1, 1), dtype=int)  # Return a 1x1 grid for empty input

    unique, counts = np.unique(grid, return_counts=True)

    # Filter out the zero (black) color if present
    non_zero_mask = unique != 0

    unique = unique[non_zero_mask]
    counts = counts[non_zero_mask]

    if unique.size == 0 or counts.size == 0:
        return np.zeros((1, 1), dtype=int)  # Return a 1x1 grid if no colors or counts

    max_count = max(counts)
    num_colors = len(unique)

    bar_chart = np.zeros((num_colors, max_count), dtype=int)
    for i, (color, count) in enumerate(zip(unique, counts)):
        bar_chart[i, :count] = color

    return bar_chart


def count_repeating_pattern(inputs, key) -> np.ndarray:
    """
    Present the count of the repeating pattern as a bar chart.

    Args:
        grid (np.ndarray): Input grid with repeating patterns
        extras (Dict): Extra information about the grid and pattern

    Returns:
        np.ndarray: A 1xN grid representing the count of the repeating pattern
    """
    grid, extras = inputs

    num_placements = extras["num_placements"]

    # Create the output bar chart
    output = np.zeros((1, num_placements), dtype=int)
    output[0, :num_placements] = 1  # Use color 1 to represent the count

    return output


def count_repeating_pattern(inputs, key) -> np.ndarray:
    """
    Count the number of repetitions of the key pattern, excluding the key itself.

    Args:
        grid (np.ndarray): Input grid with patterns
        extras (Dict): Extra information about the grid and patterns

    Returns:
        np.ndarray: A 1xN grid representing the count of the repeating pattern
    """
    grid, extras = inputs

    num_repetitions = extras["num_repetitions"]

    # Create the output bar chart
    output = np.zeros((1, num_repetitions), dtype=int)
    output[0, :num_repetitions] = 1  # Use color 1 to represent the count

    return output


def find_unique_pattern(inputs, key) -> np.ndarray:
    """
    Find and return the unique pattern in the grid.

    Args:
        grid (np.ndarray): Input grid with patterns
        extras (Dict): Extra information about the grid and patterns

    Returns:
        np.ndarray: The unique pattern found in the grid
    """

    grid, extras = inputs

    pattern_size = extras["pattern_size"]
    unique_position = extras["unique_position"]

    if unique_position is None:
        return np.zeros((pattern_size, pattern_size), dtype=int)

    x, y = unique_position
    return grid[x : x + pattern_size, y : y + pattern_size]


def bitmap_or_operation(inputs: Tuple[np.ndarray, Dict], key: Any) -> np.ndarray:
    """
    Perform the OR operation on two bitmaps.

    Args:
        inputs (Tuple[np.ndarray, Dict]): The input grid and extra information
        key (Any): Unused parameter for compatibility

    Returns:
        np.ndarray: The result of the OR operation
    """
    grid, extra = inputs
    N = extra["bitmap_size"]

    bitmap1 = grid[:, :N]
    bitmap2 = grid[:, N + 1 :]

    result = np.logical_or(bitmap1 != 0, bitmap2 != 0).astype(int)
    return result


def bitmap_and_operation(inputs: Tuple[np.ndarray, Dict], key: Any) -> np.ndarray:
    """
    Perform the AND operation on two bitmaps.

    Args:
        inputs (Tuple[np.ndarray, Dict]): The input grid and extra information
        key (Any): Unused parameter for compatibility

    Returns:
        np.ndarray: The result of the AND operation
    """
    grid, extra = inputs
    N = extra["bitmap_size"]

    bitmap1 = grid[:, :N]
    bitmap2 = grid[:, N + 1 :]

    result = np.logical_and(bitmap1 != 0, bitmap2 != 0).astype(int)
    return result


def bitmap_nor_operation(inputs: Tuple[np.ndarray, Dict], key: Any) -> np.ndarray:
    grid, extra = inputs
    N = extra["bitmap_size"]
    bitmap1 = grid[:, :N]
    bitmap2 = grid[:, N + 1 :]
    return np.logical_not(np.logical_or(bitmap1, bitmap2)).astype(int)


TRANSFORMATIONS = [
    {
        "name": "count_most_common_color",
        "input_generator": generate_limited_color_grid,
        "program": count_most_common_color,
        "frequency_weight": 1,
    },
    {
        "name": "adjacent_same_color_counter",
        "input_generator": generate_random_colored_square,
        "program": count_adjacent_same_color,
        "frequency_weight": 1,
    },
    {
        "name": "color_count_bar_chart",
        "input_generator": generate_limited_color_grid,
        "program": create_color_count_bar_chart,
        "frequency_weight": 1,
    },
    {
        "name": "obvious_repeating_pattern_counter",
        "input_generator": generate_grid_with_repeating_pattern,
        "program": count_repeating_pattern,
        "frequency_weight": 1,
    },
    {
        "name": "unique_pattern_finder",
        "input_generator": generate_grid_with_unique_pattern,
        "program": find_unique_pattern,
        "frequency_weight": 1,
    },
    {
        "name": "bitmap_or_operation",
        "input_generator": generate_two_binary_bitmaps,
        "program": bitmap_or_operation,
        "frequency_weight": 1,
    },
    {
        "name": "bitmap_and_operation",
        "input_generator": generate_two_binary_bitmaps,
        "program": bitmap_and_operation,
        "frequency_weight": 1,
    },
    {
        "name": "bitmap_nor_operation",
        "input_generator": generate_two_binary_bitmaps,
        "program": bitmap_nor_operation,
        "frequency_weight": 1,
    },
]
