import matplotlib as mpl
import matplotlib.colors as mpc
import matplotlib.pyplot as plt
import numpy as np
from symo.metrics import Metrics


def default_rcparams(dpi: int = 200):
    return {
        "figure.autolayout": True,
        "figure.dpi": dpi,
        "lines.linewidth": 1,
        "scatter.marker": "o",
        "scatter.edgecolors": "face",
        "legend.framelinewidth": 0.1,
        "patch.alpha": 0.6,
        "axes.grid": True,
        "axes.linewidth": 0.3,
        "grid.alpha": 1.0,
        "grid.linewidth": 0.1,
        # "axes.spines.top": False,
        # "axes.spines.right": False,
        "xtick.top": True,
        "ytick.right": True,
        "xtick.direction": "in",
        "ytick.direction": "in",
        "xtick.major.width": 0.3,
        "xtick.minor.width": 0.3,
        "ytick.major.width": 0.3,
        "ytick.minor.width": 0.3,
        # "xtick.major.size": 20,
        # "xtick.minor.size": 10,
        "lines.markersize": 5,
        # "text.usetex": True,
        # "font.family": "serif",
        # "font.serif": ["cmr10", "Computer Modern Roman", "DejaVu Serif"],
        "font.size": 6,
        # "mathtext.fontset": "cm",
        # "mathtext.fontset": "cmu serif",
        # "mathtext.fontset": "stix",
        # "mathtext.rm": "serif",
        "mathtext.fontset": "stix",
        "font.family": "STIXGeneral",
        # "mathtext.fontset": "custom",
        # "mathtext.rm": "Bitstream Vera Sans",
        # "mathtext.it": "Bitstream Vera Sans:italic",
        # "mathtext.bf": "Bitstream Vera Sans:bold",
    }


def blue_white_orange(resolution: int = 256):
    oranges_r = mpl.colormaps["Oranges_r"].resampled(resolution)
    blues = mpl.colormaps["Blues"].resampled(resolution)

    # Check if resolution is odd or even
    is_odd = resolution % 2 == 1

    if is_odd:
        # For odd resolution, we need one exact middle point for white
        # and equal segments on each side
        half_size = resolution // 2

        # Get colors from Blues (up to but not including middle point)
        blues_colors = blues(np.linspace(0, 1, half_size + 1))[:-1]

        # Create white middle point
        white_point = np.array([[1.0, 1.0, 1.0, 1.0]])

        # Get colors from Oranges_r (starting after middle point)
        oranges_colors = oranges_r(np.linspace(0, 1, half_size + 1))[:-1]

        # Stack all colors with white in the middle
        colors = np.vstack((blues_colors, white_point, oranges_colors))

    else:
        # For even resolution, we need to split exactly in half
        half_size = resolution // 2

        # Get colors from Blues but make the last color white
        blues_colors = blues(np.linspace(0, 1, half_size))
        blues_colors[-1] = [1.0, 1.0, 1.0, 1.0]  # Make last color white

        # Get colors from Oranges_r but make the first color white
        oranges_colors = oranges_r(np.linspace(0, 1, half_size))
        oranges_colors[0] = [1.0, 1.0, 1.0, 1.0]  # Make first color white

        # Stack the colors
        colors = np.vstack((blues_colors, oranges_colors))

    # Create a ListedColormap
    cmap = mpc.ListedColormap(colors, name="BlueWhiteOrange")
    return cmap


def orange_blue(resolution: int = 256):
    top = mpl.colormaps["Oranges_r"].resampled(resolution)
    bottom = mpl.colormaps["Blues"].resampled(resolution)

    top_colors = top(np.linspace(0, 1, resolution))
    bottom_colors = bottom(np.linspace(0, 1, resolution))

    top_colors[-1] = [1, 1, 1, 1]
    bottom_colors[0] = [1, 1, 1, 1]

    # Stack the colors
    colors = np.vstack((top_colors, bottom_colors))

    cmap = mpc.ListedColormap(colors, name="OrangeBlue")
    return cmap


def plot_matrix(
    fig,
    ax,
    mat,
    title,
    # show_axes=True,
    # true_inv=None,
    show_bar: bool = True,
    clim_max: float | None = None,
    clim_min: float | None = None,
    cmap: mpc.Colormap | None = None,
):
    if cmap is None:
        cmap = orange_blue().reversed()

    norm = mpc.TwoSlopeNorm(
        vcenter=0,
    )

    im = ax.matshow(mat, cmap=cmap, norm=norm, aspect="equal")
    if show_bar:
        cbar1 = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    # if show_axes:
    #     matshow_named_axes(ax, names_alias_counts)
    # if true_inv is not None:
    #     matshow_norm_diff(ax, names_alias_counts, mat, true_inv)

    rows, cols = mat.shape

    ax.set_title(title, y=1.05)
    ax.set_xticks([])
    ax.set_yticks([])

    if clim_min is None:
        clim_min = np.min(mat)

    if clim_max is None:
        clim_max = np.max(mat)

    im.set_clim(clim_min, clim_max)


def plot_metric_norms(metrics: Metrics, **kwargs):
    keys = list(metrics.param_grad_norm[0].keys())

    grads = {key: [d[key] for d in metrics.param_grad_norm] for key in keys}

    updates = {key: [d[key] for d in metrics.param_update_norm] for key in keys}

    weights = {key: [d[key] for d in metrics.param_weight_norm] for key in keys}

    n = len(grads)
    m = 3
    figsize = (n + m, m)
    fig, axes = plt.subplots(nrows=3, ncols=n + 1, figsize=figsize, sharex=True)

    full_grad_norm = metrics.full_grad_norm
    full_update_norm = metrics.full_update_norm
    full_weight_norm = metrics.full_weight_norm

    axes = axes.T
    axes_full = axes[0]
    axes_full[0].plot(full_grad_norm, **kwargs)
    axes_full[1].plot(full_update_norm, **kwargs)
    axes_full[2].plot(full_weight_norm, **kwargs)
    axes_full[0].set_title("Full")

    axes_full[0].set_ylabel("Grad")
    axes_full[1].set_ylabel("Update")
    axes_full[2].set_ylabel("Weight")

    for i, (ax1, ax2, ax3) in enumerate(axes[1:]):
        k = keys[i]
        g = grads[k]
        u = updates[k]
        w = weights[k]
        ax1.plot(g, **kwargs)
        ax2.plot(u, **kwargs)
        ax3.plot(w, **kwargs)

    return fig, axes


def diag_scale(mat):
    mat_scale = 1 / np.sqrt(mat.diagonal())
    scaled = mat_scale[None, :] * mat * mat_scale[:, None]
    return scaled
