import matplotlib.pyplot as plt
import matplotlib.patches
import numpy as np
import seaborn as sns
# from trusted_monitoring_calculations import calculate_blue_winrate_trusted_monitoring_from_Game_via_recurrence, \
#     calculate_usefulness_exactly_trusted_monitoring_from_Game
from typing import Union, TypeVar, Tuple, List, Callable

T_blue = TypeVar('T_blue')
T_red = TypeVar('T_red')

PROTOCOL_COLOURS = {
    "never audit": "#e22",
    "upfront auditing": "#28B",
    "trusted model": "#F81",
    "trusted monitoring": "#3A3",
}


def create_heatmap(ax, pddf, x_name: str, y_name: str, z_name: str, cmap='Greens', num_levels=None, xlim=(0, 1),
                   ylim=(0, 1)):
    # Calculate number of points per x-unit:
    ax.set_aspect('equal', adjustable='box')
    num_points_in_ax_width = ax.get_window_extent().width
    num_x_datapoints = len(pddf[x_name].unique())
    single_datapoint_width = num_points_in_ax_width / num_x_datapoints
    num_points_in_ax_height = ax.get_window_extent().height
    num_y_datapoints = len(pddf[y_name].unique())
    single_datapoint_height = num_points_in_ax_height / num_y_datapoints

    if num_levels is None:
        num_levels = min(num_x_datapoints, num_y_datapoints) // 2

    # Choose marker size
    marker_size = (single_datapoint_height * single_datapoint_width) / 4

    sns.scatterplot(data=pddf, x=x_name, y=y_name, hue=z_name, ax=ax, palette=cmap, s=marker_size, marker='s',
                    edgecolor=None)
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax.set_xlabel(x_name)
    ax.set_ylabel(y_name)
    ax.set_title(z_name)

    # Generate a grid of x and y values
    x_min, x_max = pddf[x_name].min(), pddf[x_name].max()
    y_min, y_max = pddf[y_name].min(), pddf[y_name].max()
    x_grid, y_grid = np.meshgrid(np.linspace(x_min, x_max, num_x_datapoints),
                                 np.linspace(y_min, y_max, num_y_datapoints))


    from scipy.interpolate import griddata
    z_values = griddata((pddf[x_name], pddf[y_name]), pddf[z_name], (x_grid, y_grid), method='cubic')
    pctn_plot_that_is_lines = 10
    lw = min(num_points_in_ax_width, num_points_in_ax_height) * (pctn_plot_that_is_lines / 100) / (2 * num_levels)
    ax.contour(x_grid, y_grid, z_values, levels=num_levels, linewidths=lw, colors='w', alpha=0.5)
    ax.get_legend().remove()
    norm = plt.Normalize(pddf[z_name].min(), pddf[z_name].max())
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    ax.figure.colorbar(sm, ax=ax, label=z_name)


def pad_range(min, max):
    pad = (max - min) * 0.05
    return min - pad, max + pad


def plot_with_zoom(
        create_graph_func,
        x_zoom_range: tuple[float, float],
        y_zoom_range: tuple[float, float] = None,
        title: str = None,
        right_fig_is_log_x: bool = False,
        right_fig_is_log_y: bool = False,
        figsize=(12, 6),
        bound_to_0_1: bool = False,
        f_ax1_ax2: tuple[plt.Figure, plt.Axes, plt.Axes] = None,
):
    import warnings
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=FutureWarning)
        if x_zoom_range is not None and f_ax1_ax2 is None:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
            create_graph_func(ax1)
            create_graph_func(ax2)
        elif f_ax1_ax2 is not None:
            fig, ax1, ax2 = f_ax1_ax2
            create_graph_func(ax1)
            create_graph_func(ax2)
        else:
            fig, ax1 = plt.subplots(1, 1, figsize=(figsize[0] / 2, figsize[1]))
            create_graph_func(ax1)

    if bound_to_0_1:
        ax1.set_xlim(*pad_range(0, 1))
        ax1.set_ylim(*pad_range(0, 1))

    if title is not None:
        plt.suptitle(title)

    if x_zoom_range is None:
        # plt.legend(loc=legend_loc)
        return fig

    ax2.set_xlim(*pad_range(*x_zoom_range))

    ax2.set_ylabel("")

    # Draw zoom effect lines
    x_min, x_max = x_zoom_range
    if y_zoom_range is None:
        # Assuming ax is your axis object
        ax1: plt.Axes = ax1
        children = ax1.get_children()

        # To get data from the first line plot
        points = []

        for child in children:
            if isinstance(child, matplotlib.lines.Line2D):
                line = child
                points.extend(line.get_xydata())
            elif isinstance(child, matplotlib.collections.PathCollection):
                points.extend(child.get_offsets().data)
            else:
                pass

        if len(points) == 0:
            raise ValueError("No points found in plot", points, points)

        ys_in_range = [point[1] for point in points if x_min <= point[0] <= x_max]
        if len(ys_in_range) == 0:
            y_min = min([point[1] for point in points])
            y_max = max([point[1] for point in points])
        else:
            y_min = min([point[1] for point in points if x_min <= point[0] <= x_max])
            y_max = max([point[1] for point in points if x_min <= point[0] <= x_max])

        if y_min == y_max:
            y_min -= 0.01
            y_max += 0.01
    else:
        y_min, y_max = y_zoom_range

    ax2.set_ylim(*pad_range(y_min, y_max))

    # Coordinates in the first plot
    coordsA = [(x_min, y_max), (x_max, y_max), (x_max, y_min), (x_min, y_min)]
    # Coordinates in the second plot
    coordsB = [(0, 1), (1, 1), (1, 0), (0, 0)]

    for xyA, xyB in zip(coordsA, coordsB):
        con = matplotlib.patches.ConnectionPatch(xyA=xyA, xyB=xyB, coordsA="data", coordsB="axes fraction",
                                                 axesA=ax1, axesB=ax2, color="black", linestyle="--", alpha=0.2)
        fig.add_artist(con)

    sides = [(coordsA[0], coordsA[1]), (coordsA[1], coordsA[2]), (coordsA[2], coordsA[3]), (coordsA[3], coordsA[0])]
    for corner1, corner2 in sides:
        con = matplotlib.patches.ConnectionPatch(xyA=corner1, xyB=corner2, coordsA="data", coordsB="data",
                                                 axesA=ax1, axesB=ax1, color="black", linestyle="-", alpha=0.2)
        fig.add_artist(con)

    # Also plot borders on the second plot
    sides = [(coordsB[0], coordsB[1]), (coordsB[1], coordsB[2]), (coordsB[2], coordsB[3]), (coordsB[3], coordsB[0])]
    for corner1, corner2 in sides:
        con = matplotlib.patches.ConnectionPatch(xyA=corner1, xyB=corner2, coordsA="axes fraction",
                                                 coordsB="axes fraction",
                                                 axesA=ax2, axesB=ax2, color="black", linestyle="--", alpha=0.2)
        fig.add_artist(con)

    ax2.set_zorder(1)

    if right_fig_is_log_x:
        ax2.set_xscale('log')
    if right_fig_is_log_y:
        ax2.set_yscale('log')

    return fig


if __name__ == '__main__':
    pass
