import numpy as np

def firing_time(spikes):
    spike_time = np.nonzero(spikes)
    time_ = spike_time[0]
    loc_ = spike_time[1]
    reco = []
    for i in range(spikes.shape[1]):
        reco.append([0])
    for t, s in zip(time_, loc_):
        reco[s].append(t)

    tot = []
    for i in range(spikes.shape[1]):
        tot = tot + reco[i]
    return reco, np.array(tot)

def inter(reco, spikes):
    max_inter = np.zeros(spikes.shape[1])
    min_inter = np.zeros(spikes.shape[1])
    for i in range(spikes.shape[1]):
        rec_i = np.array(reco[i])
        rec_i = np.unique(rec_i)
        if rec_i.shape[0] == 1:
            max_inter[i] = min_inter[i] = 0
            continue
        temp = rec_i[1:] -rec_i[:-1]
        max_inter[i] = np.max(temp)
        min_inter[i] = np.min(temp)
    return np.sum(max_inter) / spikes.shape[0], np.sum(min_inter) / spikes.shape[0]

def math_analysis(real_spikes, spikes):

    _, real_s, __ = np.linalg.svd(real_spikes)
    _, s, __ = np.linalg.svd(spikes)
    print("Spectral norm: {}  \t  {}".format(np.max(real_s), np.max(s)))

    real_tot = np.sum(real_spikes) / real_spikes.shape[0]
    tot = np.sum(spikes) / spikes.shape[0]
    print("fireRate: {} \t  {}".format(real_tot, tot))


    real_reco, real_tot = firing_time(real_spikes)
    reco, tot = firing_time(spikes)

    real_mean= [np.mean(real_reco[i]) for i in range(real_spikes.shape[1])]
    mean= [np.mean(reco[i]) for i in range(spikes.shape[1])]
    real_mean = np.sum(real_mean) / real_spikes.shape[0]
    mean = np.sum(mean) / spikes.shape[0]
    print("mean: {} \t  {}".format(real_mean, mean))

    real_std = [np.std(real_reco[i]) for i in range(real_spikes.shape[1])]
    std = [np.std(reco[i]) for i in range(spikes.shape[1])]
    real_std = np.sum(real_std) / real_spikes.shape[0]
    std = np.sum(std) / spikes.shape[0]
    print("std: {} \t  {}".format(real_std, std))


    real_max_inter, real_min_inter = inter(real_reco, real_spikes)
    max_inter, min_inter = inter(reco, spikes)
    print("max_interval: {} \t  {}".format(real_max_inter, max_inter))
    print("min_interval: {} \t  {}".format(real_min_inter, min_inter))


