from typing import Collection, Dict, List, Literal, Optional, Set, Tuple, TypeVar
import matplotlib.axes
from matplotlib.image import AxesImage
import matplotlib.legend
import matplotlib.lines
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib
import numpy as np
from itertools import chain, combinations


class ColormapData:
    def __init__(
        self,
        square_numbers: Collection[int],
        colors: List[str],
        label: Optional[str] = None,
    ):
        self.data = square_numbers
        self.colors = colors
        self.label = label


def colormap(
    data: List[ColormapData],
    title: Optional[str] = None,
    legend_title: Optional[str] = None,
    max_squares: Optional[int] = None,
    number_annonations: bool = True,
    extent_color: str = "white",
    extent_label: Optional[str] = None,
    all_common_color: str = "lime",
    all_common_label: Optional[str] = None,
    any_color: str = "pink",
    any_label: Optional[str] = None,
    excluded_data: Optional[ColormapData] = None,
    outlined_data: List[ColormapData] = [],
    legend_anchor: Optional[Tuple[float, float]] = None,
    legend_loc: Optional[str] = None,
    legend_fontsize: int = 10,
    fontsize: int = 11,
    figsize: Tuple[int, int] = (8, 5.2),
    width_ratios: Optional[List[float]] = [2.5, 1],
    orientation: Literal["horizontal", "vertical"] = "horizontal",
):
    disjoint_data = __disjoin_colormap_data(data, all_common_color, all_common_label)

    all_colors = __get_all_colors(
        data,
        excluded_data,
        extent_color,
        all_common_color,
        any_color,
    )

    color_to_label = __get_color_to_label_dict(
        data,
        excluded_data,
        extent_label,
        extent_color,
        all_common_color,
        all_common_label,
        any_color,
        any_label,
    )

    all_data = [data_item for d in disjoint_data for data_item in d.data]
    if excluded_data:
        all_data.extend(excluded_data.data)

    squares_count = max_squares if max_squares else max(all_data)
    plot_data = np.zeros(squares_count)
    cols, rows = __get_largest_rectangle_size(squares_count)
    plot_data = plot_data.reshape((rows, cols))

    cmap = matplotlib.colors.ListedColormap(all_colors)

    fig = plt.figure(
        constrained_layout=True,
        figsize=figsize if orientation == "horizontal" else (figsize[1], figsize[0]),
    )
    gs = fig.add_gridspec(
        2,
        2,
        width_ratios=width_ratios if orientation == "horizontal" else [1, 1],
        height_ratios=[1, 1]
        if orientation == "horizontal"
        else list(reversed(width_ratios)),
    )

    ax = (
        fig.add_subplot(gs[:, 0])
        if orientation == "horizontal"
        else fig.add_subplot(gs[1, :])
    )

    ax_empty = (
        fig.add_subplot(gs[0, 1])
        if orientation == "horizontal"
        else fig.add_subplot(gs[0, 0])
    )
    ax_empty.axis("off")

    ax_histogram = (
        fig.add_subplot(gs[1, 1])
        if orientation == "horizontal"
        else fig.add_subplot(gs[0, 1])
    )

    im = ax.imshow(plot_data, cmap=cmap)

    ax.set_yticks([])
    ax.set_xticks([])

    ax.set_xticks(np.arange(-0.5, cols, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, rows, 1), minor=True)

    if len(outlined_data) == 0:
        ax.grid(which="minor", color="black", linestyle="-", linewidth=1)

    plt.ylabel("")
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
    ax.tick_params(which="minor", bottom=False, left=False)

    if title:
        plt.title(title)

    __draw_colormap_squares(
        disjoint_data, cols, ax, squares_count, excluded_data, len(outlined_data) > 0
    )
    __draw_colormap_outlines(outlined_data, cols, ax)

    if number_annonations:
        __annotate_colormap(
            disjoint_data,
            im,
            squares_count,
            cols=cols,
            extent_color=extent_color,
            fontsize=fontsize,
        )

    __add_colormap_legend(
        ax_empty,
        legend_title,
        all_colors,
        color_to_label,
        legend_anchor,
        legend_loc,
        legend_fontsize,
    )

    add_histogram(
        data,
        ax_histogram,
        excluded_data,
        all_common_color,
        any_color,
        any_label,
        extent_color,
        squares_count,
        legend_fontsize,
    )

    return im


IT = TypeVar("IT")


def __powerset(s: List[IT]) -> List[Set[IT]]:
    return list(
        map(set, chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))
    )


def __disjoin_colormap_data(
    data: List[ColormapData],
    all_common_color: Optional[str] = None,
    all_common_label: Optional[str] = None,
) -> List[ColormapData]:
    labaled_data: List[ColormapData] = []
    not_labaled_data: List[ColormapData] = []

    for subset in __powerset(data):
        subset_sets = [set(d.data) for d in subset]
        current_data = (
            subset_sets[0].intersection(*subset_sets[1:])
            if len(subset_sets) > 0
            else set()
        )

        for other in data:
            if other not in subset:
                current_data.difference_update(other.data)

        if current_data != set():
            data_with_max_luminance: ColormapData = max(
                subset, key=lambda x: __get_luminance_from_colors(x.colors)
            )

            if len(subset) == 1:
                labaled_data.append(
                    ColormapData(
                        current_data,
                        colors=data_with_max_luminance.colors,
                        label=data_with_max_luminance.label,
                    )
                )
            elif len(subset) == len(data) and all_common_color is not None:
                labaled_data.append(
                    ColormapData(
                        current_data,
                        colors=[all_common_color],
                        label=all_common_label,
                    )
                )
            else:
                not_labaled_data.append(
                    ColormapData(
                        current_data,
                        colors=sorted(
                            [color for d in subset for color in d.colors],
                            key=lambda x: -__get_luminance_from_color(x),
                        ),
                    )
                )

    labaled_data.extend(not_labaled_data)

    return labaled_data


def add_histogram(
    data: List[ColormapData],
    ax_histogram: plt.Axes,
    excluded: Optional[ColormapData],
    all_common_color: str,
    any_color: str,
    any_label: Optional[str],
    extent_color: str,
    squares_count: int,
    fontsize: int,
):
    excluded_data = excluded.data if excluded is not None else []
    data_sets = [set(d.data) for d in data]
    all_common_data = (
        data_sets[0].intersection(*data_sets[1:]).difference(excluded_data)
    )
    any_data = data_sets[0].union(*data_sets[1:]).difference(excluded_data)
    extent_data = (
        set(range(1, squares_count + 1)).difference(any_data).difference(excluded_data)
    )
    histogram_data = data.copy()

    if len(all_common_data) > 0:
        histogram_data.append(ColormapData(all_common_data, [all_common_color]))

    if any_label and len(any_data) > 0:
        histogram_data.append(ColormapData(any_data, [any_color]))

    if len(extent_data) > 0:
        histogram_data.append(ColormapData(extent_data, [extent_color]))

    if excluded and len(excluded.data) > 0:
        histogram_data.append(excluded)

    for index, d in enumerate(sorted(histogram_data, key=lambda x: len(x.data))):
        bar = ax_histogram.barh(
            index,
            len(d.data),
            color=d.colors[0],
            edgecolor="black" if d.colors[0] == "white" else None,
            height=0.9,
            label=len(d.data),
        )
        ax_histogram.bar_label(
            bar,
            label_type="edge",
            padding=4,
            fontsize=fontsize,
        )

    ax_histogram.set_xticks([])
    ax_histogram.set_yticks([])
    ax_histogram.set_xlim(
        0,
        max(
            [len(d.data) for d in data]
            + [len(any_data), len(extent_data), len(all_common_data)]
        )
        + 10,
    )


def __get_color_to_label_dict(
    data: List[ColormapData],
    excluded_data: Optional[ColormapData],
    extent_label: str,
    extent_color: str,
    all_common_color: Optional[str],
    all_common_label: Optional[str],
    any_color: Optional[str],
    any_label: Optional[str],
):
    color_to_label = {d.colors[0]: d.label for d in data}
    if extent_label is not None:
        color_to_label.update({extent_color: extent_label})
    if all_common_label is not None:
        color_to_label.update({all_common_color: all_common_label})
    if any_label is not None:
        color_to_label.update({any_color: any_label})
    if excluded_data is not None:
        color_to_label.update({excluded_data.colors[0]: excluded_data.label})
    return color_to_label


def __get_all_colors(
    data: List[ColormapData],
    excluded_data: Optional[ColormapData],
    extent_color: str,
    all_common_color: str,
    any_color: str,
):
    all_colors = [extent_color, all_common_color, any_color]

    if excluded_data is not None:
        all_colors.append(excluded_data.colors[0])

    all_colors.extend([color for d in data for color in d.colors if d is not None])

    return __list_unique(all_colors)


def __draw_colormap_squares(
    data: List[ColormapData],
    cols: int,
    ax: plt.Axes,
    squares_count: int,
    excluded_data: Optional[ColormapData] = None,
    with_outlines: bool = False,
):
    data = data + [excluded_data] if excluded_data is not None else data
    drawn_squares = set()

    for d in data:
        for square_number in d.data:
            drawn_squares.add(square_number)
            i = (square_number - 1) % cols
            j = (square_number - 1) // cols

            for index, color in enumerate(d.colors):
                y_pos = index / len(d.colors)

                ax.add_patch(
                    mpatches.Rectangle(
                        (i - 0.5, j - 0.5 + y_pos),
                        1,
                        1 / len(d.colors),
                        fill=True,
                        snap=False,
                        color=None if with_outlines else color,
                        facecolor=color if with_outlines else None,
                        edgecolor="black" if with_outlines else None,
                        linewidth=1,
                    )
                )

    for n in range(1, squares_count + 1):
        if n in drawn_squares:
            continue

        i = (n - 1) % cols
        j = (n - 1) // cols

        ax.add_patch(
            mpatches.Rectangle(
                (i - 0.5, j - 0.5),
                1,
                1,
                fill=False,
                snap=False,
                edgecolor="black",
                linewidth=1,
            )
        )


def __draw_colormap_outlines(
    data: List[ColormapData],
    cols: int,
    ax: plt.Axes,
):
    for d in data:
        for square_number in d.data:
            i = (square_number - 1) % cols
            j = (square_number - 1) // cols

            ax.add_patch(
                mpatches.Rectangle(
                    (i - 0.5, j - 0.5),
                    1,
                    1,
                    fill=False,
                    snap=False,
                    edgecolor=d.colors[0],
                    linewidth=3,
                )
            )


def __add_colormap_legend(
    ax: plt.Axes,
    legend_title: str,
    colors: List[str],
    color_to_label: Dict[str, str],
    legend_anchor: Optional[Tuple[float, float]] = None,
    legend_loc: Optional[str] = None,
    legend_fontsize: int = 12,
) -> matplotlib.legend.Legend:
    patches = [
        mpatches.Patch(
            facecolor=color,
            label=color_to_label[color],
            edgecolor="black",
            linewidth=0.5,
        )
        for color in colors
        if color in color_to_label
    ]

    legend = ax.legend(
        handles=patches,
        borderpad=0,
        bbox_to_anchor=(0, 1) if legend_anchor is None else legend_anchor,
        borderaxespad=0.0,
        title=legend_title,
        loc="upper left" if legend_loc is None else legend_loc,
        fontsize=legend_fontsize,
        title_fontsize=legend_fontsize + 2,
    )

    for text in legend.get_texts():
        text.set_wrap(True)

    legend._legend_box.align = "left"
    return legend


def __list_unique(values: List[int]) -> List[int]:
    return list(dict.fromkeys(values))


def __get_largest_rectangle_size(n: int) -> Tuple[int, int]:
    factors = [(i, n // i) for i in range(1, int(n**0.5) + 1) if n % i == 0]
    rows, cols = min(factors, key=lambda x: abs(x[0] - x[1]))
    return rows, cols


def __annotate_colormap(
    data: List[ColormapData],
    im: AxesImage,
    squares_count: int,
    extent_color: str,
    cols: int,
    fontsize: int,
):
    kw = dict(horizontalalignment="center", verticalalignment="center")
    annonated_squares = set()

    for d in data:
        for square_number in d.data:
            i = (square_number - 1) % cols
            j = (square_number - 1) // cols
            annonated_squares.add(square_number)

            luminance = __get_luminance_from_colors(d.colors)
            text_color = __get_text_color(luminance)

            kw.update(color=text_color)
            im.axes.text(i, j, square_number, fontsize=fontsize, **kw)

    text_color_for_extent = __get_text_color(__get_luminance_from_color(extent_color))

    for n in range(1, squares_count + 1):
        if n in annonated_squares:
            continue

        i = (n - 1) % cols
        j = (n - 1) // cols

        kw.update(color=text_color_for_extent)
        im.axes.text(i, j, n, fontsize=fontsize, **kw)


def __get_luminance_from_colors(colors: List[str]) -> float:
    if len(colors) == 1:
        return __get_luminance_from_color(colors[0])

    middle_index = len(colors) // 2

    if len(colors) % 2 == 0:
        left_luminance = __get_luminance_from_color(colors[middle_index - 1])
        right_luminance = __get_luminance_from_color(colors[middle_index])

        return (
            left_luminance
            if abs(left_luminance - 0.45) > abs(right_luminance - 0.45)
            else right_luminance
        )
    else:
        return __get_luminance_from_color(colors[middle_index])


def __get_luminance_from_color(color: str) -> float:
    return __get_luminance(*matplotlib.colors.to_rgb(color))


def __get_luminance(r: int, g: int, b: int) -> float:
    return 0.299 * r + 0.587 * g + 0.114 * b


def __get_text_color(luminance: float) -> str:
    return "black" if luminance > 0.5 else "white"
