import io
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from PIL import Image


def create_color_plot(ax, data_array: np.ndarray, color_map: dict) -> None:
    """Create a color-coded plot."""
    colors_list = [
        color_map.get(i, "#FFFFFF") for i in range(max(color_map.keys()) + 1)
    ]
    cmap = colors.ListedColormap(colors_list)
    ax.imshow(data_array, cmap=cmap, interpolation="nearest")


def create_number_plot(
    ax, data_array: np.ndarray, font_size: int, font_color: str
) -> None:
    """Create a plot with numbers in each cell."""
    ax.set_facecolor("white")
    ax.set_xticks(np.arange(data_array.shape[1]))
    ax.set_yticks(np.arange(data_array.shape[0]))
    for i in range(data_array.shape[0]):
        for j in range(data_array.shape[1]):
            ax.text(
                j,
                i,
                str(data_array[i, j]),
                ha="center",
                va="center",
                color=font_color,
                fontsize=font_size,
            )


def plot_abstract_reasoning(
    data: list[list[int]],
    output_path: str | None = None,
    dpi: int = 300,
    grid_color: str = "white",
    grid_width: float = 2,
    color_map: dict | None = None,
    title: str | None = None,
    fig_size: tuple[int, int] = (10, 10),
    pil_size: tuple[int, int] = (512, 512),
    return_pil: bool = True,
    font_size: int = 20,
    font_color: str = "black",
    display_mode: str = "color",
) -> Image.Image | None:
    """
    Plot abstract reasoning corpus data with options for color-coding or displaying numbers.

    Args:
    data (list of lists or numpy.ndarray): The input data to plot.
    output_path (str, optional): Path to save the output image. If None and return_pil is False, the plot will be displayed.
    dpi (int): The resolution of the output image (dots per inch).
    grid_color (str): Color of the grid lines.
    grid_width (float): Width of the grid lines.
    color_map (dict, optional): Custom color mapping for values. Required if display_mode is 'color'.
    title (str, optional): Title for the plot.
    fig_size (tuple): Figure size in inches (width, height).
    pil_size (tuple): Size of the output PIL image (width, height).
    return_pil (bool): If True, return a PIL.Image object instead of displaying or saving the plot.
    font_size (int): Size of the font for numbers when display_mode is 'number'.
    font_color (str): Color of the font for numbers when display_mode is 'number'.
    display_mode (str): 'color' for color-coded cells, 'number' for displaying numbers in cells.

    Returns:
    PIL.Image if return_pil is True, otherwise None
    """
    # Convert input data to numpy array if it's not already
    data_array = np.array(data)

    # Create the plot
    fig, ax = plt.subplots(figsize=fig_size)

    # Choose display mode
    if display_mode == "color":
        # Define default color map if not provided
        if color_map is None:
            color_map = {
                0: "#000000",  # black
                1: "#0074D9",  # blue
                2: "#FF4136",  # red
                3: "#2ECC40",  # green
                4: "#FFDC00",  # yellow
                5: "#AAAAAA",  # grey
                6: "#F012BE",  # pink
                7: "#FF851B",  # orange
                8: "#800080",  # purple
                9: "#A52A2A",  # brown
            }
        create_color_plot(ax, data_array, color_map)
    elif display_mode == "number":
        create_number_plot(ax, data_array, font_size, font_color)
    else:
        raise ValueError("display_mode must be either 'color' or 'number'")

    # Add grid lines
    ax.set_xticks(np.arange(-0.5, data_array.shape[1], 1), minor=True)
    ax.set_yticks(np.arange(-0.5, data_array.shape[0], 1), minor=True)
    ax.grid(which="minor", color=grid_color, linestyle="-", linewidth=grid_width)

    # Remove ticks
    ax.set_xticks([])
    ax.set_yticks([])

    # Add title if provided
    if title:
        plt.title(title)

    # Tight layout to maximize the plot size
    plt.tight_layout()

    if return_pil:
        # Save the plot to a buffer
        buf = io.BytesIO()
        plt.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
        plt.close(fig)
        buf.seek(0)

        # Create a PIL Image from the buffer
        pil_image = Image.open(buf)
        pil_image = pil_image.resize(pil_size)
        return pil_image
    elif output_path:
        plt.savefig(output_path, dpi=dpi, bbox_inches="tight")
        plt.close(fig)
    else:
        plt.show()

    return None
