

import matplotlib.pyplot as plt
import numpy as np
import scienceplots
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from matplotlib.figure import Figure

from icpe.utils import scale


def plot_param(Ps: np.ndarray,
               Qs: np.ndarray) -> Figure:
    '''
    param Ps: np.ndarray (attention parameter P of shape (l, 2d+1, 2d+1))
    param Qs: np.ndarray (attention parameter Q of shape (l, 2d+1, 2d+1))
    '''
    assert Ps.ndim == Qs.ndim == 3
    assert Ps.shape == Qs.shape

    l = Ps.shape[0]
    fig, axes = plt.subplots(l, 2, figsize=(12, 5*l))
    axes = np.atleast_2d(axes)
    cmap = 'viridis'
    norm = Normalize(vmin=-1, vmax=1)

    for layer in range(l):
        P, Q = Ps[layer], Qs[layer]
        # normalize the range of the matrices to [-1, 1] and make sure the bottom right of P is always positive
        P = scale(P) * np.sign(P[-1, -1])
        Q = scale(Q) * np.sign(P[-1, -1])
        # Plot the matrix from the first array in the first column
        axes[layer, 0].imshow(P, cmap=cmap, norm=norm, aspect='equal')
        axes[layer, 0].set_title(f"$P_{layer}$", fontsize=20)
        axes[layer, 0].axis('off')  # Hide axes for a cleaner plot

        # Plot the matrix from the second array in the second column
        axes[layer, 1].imshow(Q, cmap=cmap, norm=norm, aspect='equal')
        axes[layer, 1].set_title(f"$Q_{layer}$", fontsize=20)
        axes[layer, 1].axis('off')  # Hide axes for a cleaner plot

    fig.colorbar(ScalarMappable(norm=norm, cmap=cmap),
                 ax=axes[layer, 1], orientation='vertical')

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    return fig


def plot_validation(mean: np.ndarray, ste: np.ndarray,
                    ns: np.ndarray, log_scale: bool) -> Figure:
    '''
    param mean: mean of the validation error
    param ste: standard error of the validation error
    param ns: context lengths
    param log_scale: whether to plot in log scale
    '''
    if log_scale:
        if ste is not None:
            upper = np.log(mean + ste)
            lower = np.log(mean - ste)
        ys = np.log(mean)
    else:
        if ste is not None:
            upper = mean + ste
            lower = mean - ste
        ys = mean

    fig = plt.figure(figsize=(10, 5))
    plt.plot(ns, ys)
    if ste is not None:
        plt.fill_between(ns, lower, upper, alpha=0.2)
    plt.xlabel('context length', fontsize=18)
    plt.ylabel('msve (log scale)' if log_scale else 'msve', fontsize=18)
    plt.grid()
    return fig
