#!/usr/bin/env python3

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os

blue2    = '#2e518c'
blue3    = '#5079b3'
blue4    = '#7da7d9'
green1   = '#146614'
green2   = '#2e8c2e'
green3   = '#50b350'
green4   = '#7dd97d'
green5   = '#b3ffb3'
red1     = '#660000'
red2     = '#8c1919'
red3     = '#b33e3e'
red4     = '#d97272'
red5     = '#ffb3b3'
magenta1 = '#581466'
magenta2 = '#762e8c'
magenta3 = '#9650b3'
magenta4 = '#b87dd9'
magenta5 = '#dfb3ff'
orange1  = '#b34e0b'
orange2  = '#c67322'
orange3  = '#d99a3d'
orange4  = '#ecc05c'
orange5  = '#ffe480'
cyan1    = '#146666'
cyan2    = '#2e8c8c'
cyan3    = '#50b3b3'
cyan4    = '#7dd9d9'
cyan5    = '#b3ffff'
gray1    = '#4d4d4d'
gray2    = '#6c6c6c'
gray3    = '#8c8c8c'
gray4    = '#acacac'
gray5    = '#cccccc'


def generate_plot(targets, net_outputs, net_outputs_at, observations):

    # Get sequence length and input dimension from the data
    sl, in_dim_1d, _ = targets.shape
    # Convert the one-dimensional into a two-dimensional input dimension
    in_dim_2d = int(np.sqrt(in_dim_1d))

    # Reshape the data from 1D into 2D at the input dimension
    targets = np.reshape(targets, (sl, in_dim_2d, in_dim_2d))
    net_outputs = np.reshape(net_outputs, (sl, in_dim_2d, in_dim_2d))
    net_outputs_at = np.reshape(net_outputs_at, (sl, in_dim_2d, in_dim_2d))
    observations = np.reshape(observations, (sl, in_dim_2d, in_dim_2d))

    # Put all data into one array
    data = np.array([targets, observations, net_outputs, net_outputs_at])

    # The number of wave images that are displayed horizontally
    numh = 20

    # Determine which 2D wave time steps will be visualized
    start = 2
    stop = start + 2*numh
    stepsize = 2
    wave_idcs = range(start, stop, stepsize)  # [2, 4, 6, ..., 40]

    # Determine the colors that are used for the plot
    colors = ["black", gray2, orange3, blue3]

    # Plot customizations
    tic_font_size = "xx-small"
    label_font_size = "x-small"
    plt.rc("font", family = "sans serif")
    ## for Palatino and other serif fonts use:
    #rc('font',**{'family':'serif','serif':['Palatino']})
    plt.rc("text", usetex = False)
    plt.rc("xtick", labelsize = tic_font_size)
    plt.rc("ytick", labelsize = tic_font_size)
    plt.rc("axes", axisbelow = True)
    matplotlib.rcParams["lines.linewidth"] = 0.8
    matplotlib.rcParams["legend.fancybox"] = True

    # Set up the figure with the corresponding subplots
    fig3 = plt.figure(constrained_layout=False, figsize=(6.75, 3.0))
    gs = fig3.add_gridspec(8, numh)

    # Create the axes for the 2D wave subplots
    axs_matrix = np.zeros((4, numh), dtype=matplotlib.axes.Axes)
    for i in range(4):
        for j in range(numh):
            axs_matrix[i, j] = fig3.add_subplot(gs[i, j])

            # Hide axes
            axs_matrix[i, j].set_xticks([])
            axs_matrix[i, j].set_yticks([])

            # Customize the color and line width of the axis frame
            for side in ["top", "bottom", "right", "left"]:
                axs_matrix[i, j].spines[side].set_edgecolor(colors[i])
                axs_matrix[i, j].spines[side].set_linewidth(1.5)

    # Create the axis for the plot over time
    ax_time = fig3.add_subplot(gs[4:, :])

    # Add the data to the axes
    for i in range(4):
        for j in range(numh):
            # Fill the current axis with the corresponding 2D image
            axs_matrix[i, j].imshow(
                data[i, wave_idcs[j]], cmap="Blues", vmin=-1.0, vmax=1.0
            )

    # Plot the wave activities of a single cell for the different conditions
    # over time
    ax_time.plot(range(sl), targets[:, 8, 8], label="Ground truth", color="black", linewidth=0.5, zorder=10, linestyle="--", dashes=(5, 1))
    ax_time.plot(range(sl), observations[:, 8, 8], label="Observations", color=gray2, zorder=1)
    ax_time.plot(range(sl), net_outputs[:, 8, 8], label="Network output", color=orange3, zorder=2)
    ax_time.plot(range(sl), net_outputs_at[:, 8, 8], label="Network output (AT)", color=blue3, zorder=3)

    # Customize the plot
    ax_time.set_xlabel("Time $t$", fontsize=label_font_size)
    ax_time.tick_params(
        direction = "in", 
        bottom = True, top = True,
        left = True, right = True,
        zorder = 3
    )
    ax_time.grid(linewidth=0.5, linestyle="dashed", zorder=0)
    ax_time.legend(fontsize=tic_font_size, loc="upper right", ncol=4)
    ax_time.label_outer()

    # Adjust the subplots to remove white space between the wave images
    plt.subplots_adjust(
        left=0.03, bottom=0.11, right=0.99, top=1.0, wspace=0.0, hspace=0.1
    )

    plt.savefig("wave-results.pdf", bbox_inches="tight", pad_inches=0)
    plt.show()


def plot_kernel_activity(ax, wave_act, label, color, t, time_window_disp,
                         linestyle="solid", linewidth=1.5, make_tf_line=False):

    if make_tf_line:
        yticks = [-1.0, 1.0]
        ax.plot(np.ones(len(yticks)) * 30, yticks,
                color='white', linestyle='dotted', linewidth=2.0)

    ax.set_xlabel('Time $t$')
    ax.set_ylabel('Wave amplitude')
    ax.set_ylim(-0.8, 0.8)
    ax.set_xlim(-1, len(wave_act) + 1)
    ax.set_xlim(max(0, t - time_window_disp), max(time_window_disp, t))

    ax.plot(range(len(wave_act)), wave_act,
            label=label, color=color, linestyle=linestyle, linewidth=linewidth)


def plot_2d_wave(ax, _data, _title, rectangle=None, rectangle_color=None):
    """
    This function plots one time step of the spatio-temporally expanding wave
    """

    ax.imshow(_data, cmap='Blues', vmin=-0.8, vmax=0.8)
    ax.set_title(_title)
    ax.axis("off")

    if rectangle is not None:
        ax.plot(rectangle[0], rectangle[1], linewidth=1.2, linestyle="solid",
                color=rectangle_color)


def create_video_frames(targets, net_outputs, net_outputs_at, observations):

    plt.style.use('dark_background')

    # Get sequence length and input dimension from the data
    sl, in_dim_1d, _ = targets.shape
    # Convert the one-dimensional into a two-dimensional input dimension
    in_dim_2d = int(np.sqrt(in_dim_1d))

    # Reshape the data from 1D into 2D at the input dimension
    targets = np.reshape(targets, (sl, in_dim_2d, in_dim_2d))
    net_outputs = np.reshape(net_outputs, (sl, in_dim_2d, in_dim_2d))
    net_outputs_at = np.reshape(net_outputs_at, (sl, in_dim_2d, in_dim_2d))
    observations = np.reshape(observations, (sl, in_dim_2d, in_dim_2d))

    # Put all data into one array
    data = np.array([targets, observations, net_outputs, net_outputs_at])

    # Check if the directory for saving the frames already exists
    os.makedirs("frames/", exist_ok=True)

    for t in range(0, len(targets)):

        print("Creating image " + str(t) + "/" + str(len(targets)))

        # Initialize the figure and axes for the subplots
        fig = plt.figure(1, figsize=[18, 7.2])
        ax0 = plt.subplot2grid((3, 4), (0, 0), colspan=1, rowspan=2)
        ax1 = plt.subplot2grid((3, 4), (0, 1), colspan=1, rowspan=2)
        ax2 = plt.subplot2grid((3, 4), (0, 2), colspan=1, rowspan=2)
        ax3 = plt.subplot2grid((3, 4), (0, 3), colspan=1, rowspan=2)
        ax4 = plt.subplot2grid((3, 4), (2, 0), colspan=4, rowspan=1)

        plot_2d_wave(ax=ax0, _data=targets[t], _title="Ground truth",
                     rectangle=[[9.5, 9.5, 10.5, 10.5, 9.5],
                                [9.5, 10.5, 10.5, 9.5, 9.5]],
                     rectangle_color=gray1)
        plot_2d_wave(ax=ax1, _data=observations[t], _title="Observations",
                     rectangle=[[9.5, 9.5, 10.5, 10.5, 9.5],
                                [9.5, 10.5, 10.5, 9.5, 9.5]],
                     rectangle_color=gray2)
        plot_2d_wave(ax=ax2, _data=net_outputs[t], _title="Network output",
                     rectangle=[[9.5, 9.5, 10.5, 10.5, 9.5],
                                [9.5, 10.5, 10.5, 9.5, 9.5]],
                     rectangle_color=orange3)
        plot_2d_wave(ax=ax3, _data=net_outputs_at[t], _title="Network output (AT)",
                     rectangle=[[9.5, 9.5, 10.5, 10.5, 9.5],
                                [9.5, 10.5, 10.5, 9.5, 9.5]],
                     rectangle_color=blue3)

        plot_kernel_activity(ax=ax4, wave_act=targets[:t, 9, 9],
                             label="Ground truth", color=gray1, t=t,
                             time_window_disp=150, linestyle="solid",
                             linewidth=1.5)
        plot_kernel_activity(ax=ax4, wave_act=observations[:t, 9, 9],
                             label="Observations", color=gray2, t=t,
                             time_window_disp=150, linestyle="--",
                             linewidth=1.5)
        plot_kernel_activity(ax=ax4, wave_act=net_outputs[:t, 9, 9],
                             label="Network output", color=orange3, t=t,
                             time_window_disp=150, linestyle="--",
                             linewidth=1.5)
        plot_kernel_activity(ax=ax4, wave_act=net_outputs_at[:t, 9, 9],
                             label="Network output (AT)", color=blue3, t=t,
                             time_window_disp=150, linestyle="--",
                             linewidth=1.5)

        ax4.legend(loc="upper left", ncol=4)

        plt.savefig("frames/" + str(t) + ".png")
        plt.close("all")
