import pickle
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from tqdm import tqdm
from scipy.ndimage import gaussian_filter

# np.random.seed(42)

def plot_one(all_hiddens, layer_id=5, num_plot=16): # N, 12, 768, L
    # extract certain number of waveforms
    hiddens = all_hiddens[np.random.choice(np.arange(len(all_hiddens)), 1)[0]]
    wfs = hiddens[layer_id][np.random.choice(np.arange(len(hiddens[layer_id])), num_plot, replace=False)] # choose 64 waveforms
    sqrtn = int(np.ceil(np.sqrt(num_plot)))
    # print(wfs.shape)

    # print(np.min(wfs), np.max(wfs))
    # return

    # plot
    plt.clf()
    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    # background color
    plt.style.use("seaborn-dark")
    for param in ['figure.facecolor', 'axes.facecolor', 'savefig.facecolor']:
        plt.rcParams[param] = '#212946'  # bluish dark grey
    plt.rcParams['figure.facecolor'] = 'k'
    for param in ['text.color', 'axes.labelcolor', 'xtick.color', 'ytick.color']:
        plt.rcParams[param] = '0.9'  # very light grey

    # plot each waveform
    for i, wf in tqdm(enumerate(wfs)):
        # wf = (wf - np.min(wf)) / (np.max(wf) - np.min(wf))
        wf = gaussian_filter(wf.astype(float), sigma=8)
        ax = plt.subplot(gs[i])
        plt.axis('equal')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('auto')

        ax.grid(color='#2A3459')  # bluish dark grey, but slightly lighter 

        # plot line
        line_color = '#08F7FE'
        plt.plot(wf, color=line_color)

        # glow
        n_lines = 10
        diff_linewidth = 1.05
        alpha_value = 0.03
        for n in range(1, n_lines+1):
            plt.plot(wf,
                    linewidth=2+(diff_linewidth*n),
                    alpha=alpha_value,
                    # legend=False,
                    # ax=ax,
                    color=line_color)

    plt.show()

if __name__ == '__main__':
    with open("data/processed_files/processed_samples_for_nld_mae.pkl", 'rb') as f:
        all_outputs = pickle.load(f)
    # order: ppg, ecg, gsr, eeg_g, acc_x

    # plot ppg
    plot_one(all_outputs[2][0], layer_id=5, num_plot=64)