import matplotlib.pyplot as plt
import matplotlib
# matplotlib.use('Agg')
import numpy as np
import brainpy as bp

from utils.mathematical_analysis import firing_time

def draw(spikes, name = 'Simulator'):
    plt.subplot(2,2,1)
    u, s, vh = np.linalg.svd(spikes)
    plt.pcolormesh(u, cmap="coolwarm")
    plt.xticks([])
    plt.yticks([])
    # plt.savefig("img/" + name + "_u.pdf", dpi = 500, bbox_inches = 'tight')
    # plt.cla()
    
    plt.subplot(2,2,2)
    plt.pcolormesh(vh, cmap="coolwarm")
    plt.xticks([])
    plt.yticks([])
    # plt.savefig("img/" + name + "_v.pdf", dpi = 500, bbox_inches = 'tight')
    #plt.cla()

    plt.subplot(2,2,3)
    Mat = spikes[spikes.shape[0] - spikes.shape[1]:, ]
    plt.pcolormesh(Mat, cmap="twilight")
    plt.xticks([])
    plt.yticks([])
    # plt.savefig("img/" + name + "_matrix.pdf", dpi = 500, bbox_inches = 'tight')
    # plt.cla()

    plt.subplot(2,2,4)
    q, r = np.linalg.qr(spikes)
    plt.pcolormesh(q)
    plt.xticks([])
    plt.yticks([])
    # plt.savefig("img/" + name + "_q.pdf", dpi = 500, bbox_inches = 'tight')
    # plt.cla()

    plt.show()

    # # plt.title(name + ": R matrix of QR decomposition", pad=20)
    # plt.pcolormesh(r)
    # plt.xticks([])
    # plt.yticks([])
    # # plt.savefig("img/" + name + "_r.pdf", dpi = 500, bbox_inches = 'tight')
    # plt.cla()


def visual(real_spikes, spikes):
    draw(real_spikes)
    draw(spikes)

    _, real = firing_time(real_spikes)
    _, simula = firing_time(spikes)

    fig, ax = plt.subplots()
    ax.boxplot([real, simula], labels=["Organoid", "NOSF"], showfliers=True, positions=[0.65,1],  patch_artist=False)#, boxprops=boxprops) 
    ax.tick_params(axis='x', pad=21)
    #plt.savefig("img/box.pdf", dpi = 500, bbox_inches = 'tight')
    plt.show()
    

    bp.visualize.raster_plot(np.arange(real_spikes.shape[0]) / 10, real_spikes, show=False, xlabel='Time (s)', ylabel="MEA index", label = "Organoid")
    bp.visualize.raster_plot(np.arange(spikes.shape[0]) / 10, spikes, xlabel='Time (s)', ylabel="O/S node index", show=False, color='r', label="NOSF")
    # plt.savefig("img/spiking.pdf", dpi = 500, bbox_inches = 'tight')
    plt.show()