import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display

def report(predictions, _test_sample, label_="", idx_plotted = None):
    if idx_plotted is None:
        idx_plotted = np.arange(_test_sample.shape[1])
    predictions = predictions.reshape(_test_sample.shape)
    plt.figure(1, figsize=(10,2*len(idx_plotted)))
    plt.title = label_
    auxT = _test_sample.shape[0]
    for i, idx in enumerate(idx_plotted):
        plt.subplot(len(idx_plotted), 1, i + 1)
        plt.plot(np.arange(0, auxT), _test_sample[0:auxT,idx], 'b', label='true state' if i == 0 else "")
        plt.plot(np.arange(0, auxT), predictions[0:auxT,idx], 'r--', label=label_ if i == 0 else "")
        if i == 0:  # Add legend only to the first subplot to avoid repetition
            plt.legend(loc = "upper left")
    
    print(f"{label_} RMSE: {np.linalg.norm(predictions-_test_sample,'fro')/np.sqrt(_test_sample.shape[0])}")


def plot_modes(modes_, size = 10, interpolation='bilinear'):
    fig =  plt.figure(figsize=(size,size*modes_.shape[0]))
    for i in range(modes_.shape[0]):
        plt.subplot(1,modes_.shape[0],i + 1)
        if interpolation is None:
            plt.imshow(modes_[i], interpolation=interpolation)
        else:
            plt.imshow(modes_[i])

def show_animation(array, title=None, fig_size=(5, 5)):
    """
    Display animations of multiple arrays side by side with optional titles.

    Parameters:
    array (np.ndarray): one array, a time series of images.
    title (str): Animation title.
    fig_size (tuple): Size of each figure in the animation.
    """
    fig, ax = plt.subplots(1, 1, figsize=fig_size)
    im = ax.imshow(array[0], cmap='jet')
    if title:
        ax.set_title(title, fontsize=16, pad=10)

    def update(frame):
        im.set_array(array[frame])
        return [im]

    ani = animation.FuncAnimation(fig, update, frames=range(array.shape[0]), blit=True, repeat=True)
    plt.close(fig)  # Prevents the figure from displaying as a static plot

    # Display the animation using HTML
    html_template = """
    <div style="display: flex; justify-content: space-around; align-items: center;">
        {}
    </div>
    """
    animation_htmls = [ani.to_jshtml() for ani in [ani]]
    html_content = html_template.format(''.join(f'<div style="flex: 1;">{html}</div>' for html in animation_htmls))
    display(HTML(html_content))

def show_animations_side_by_side(arrays, titles=None, fig_size=(5, 5)):
    """
    Display animations of multiple arrays side by side with optional titles.

    Parameters:
    arrays (list of np.ndarray): List of 3D arrays where each array represents a time series of images.
    titles (list of str): List of titles for each animation.
    fig_size (tuple): Size of each figure in the animation.
    """
    fig, axes = plt.subplots(1, len(arrays), figsize=(fig_size[0] * len(arrays), fig_size[1]))
    ims = []
    
    for i, (ax, frames) in enumerate(zip(axes, arrays)):
        im = ax.imshow(frames[0], cmap='jet')
        ims.append(im)
        ax.axis('off')  # Hide axes for better visualization
        if titles and i < len(titles):
            ax.set_title(titles[i], fontsize=16, pad=10)

    def update(frame):
        for im, frames in zip(ims, arrays):
            im.set_array(frames[frame])
        return ims

    ani = animation.FuncAnimation(fig, update, frames=range(arrays[0].shape[0]), blit=True, repeat=True)
    plt.close(fig)  # Prevents the figure from displaying as a static plot

    # Display the animation using HTML
    html_template = """
    <div style="display: flex; justify-content: space-around; align-items: center;">
        {}
    </div>
    """
    animation_htmls = [ani.to_jshtml() for ani in [ani]]
    html_content = html_template.format(''.join(f'<div style="flex: 1;">{html}</div>' for html in animation_htmls))
    display(HTML(html_content))