import numpy as np
from scipy import stats

def get_scipy_multivar_gauss(T_max, d):
    covar = stats.invwishart(d, np.eye(d)).rvs()
    mean = np.random.uniform(0, T_max, size=d)
    scipy_multivar_norm = stats.multivariate_normal(mean, covar)
    return scipy_multivar_norm

def create_mixture_of_pdf(T_max, d, N, func):
    scipy_mixture = [func(T_max, d) for _ in range(N)]
    return scipy_mixture

def draw_samples(scipy_mixture_gauss, T_max, N_samples):
    weights = np.ones(len(scipy_mixture_gauss))
    samples = list()
    for i in range(N_samples):
        if (i % 100000) == 100000-1:
            print(i+1, 'samples drawn')
        reject_flag = True
        while reject_flag:
            draw = scipy_mixture_gauss[np.where(np.random.multinomial(1,weights/np.sum(weights)))[0][0]].rvs()
            if np.sum((draw < 0) | (draw > T_max)) == 0:
                samples.append(draw)
                reject_flag = False
    samples = np.array(samples)
    return samples

def get_emperical_pdf(samples, T_max, N_samples, index, loc='middle'):
    if loc=='middle':
        x = np.linspace(0, T_max, N_samples)
        x = (x[:-1] + x[1:]) / 2
    else:
        x = np.linspace(0, T_max, N_samples)

    bin_location = (x[1:] + x[:-1])/2
    count = [np.sum((samples[:,index] >= x[i]) & (samples[:,index] < x[i+1])) for i in range(len(x)-1)]
    pdf = count / np.sum(count) * len(count) / T_max
    return pdf, bin_location

def convert_spiketimes_to_discreteincrements(samples, t_init, t_max, dt):
    samples = np.array(samples)
    if len(samples.shape) == 1:
        samples = samples[:,np.newaxis]

    if len(np.unique(samples, axis=1)) != len(samples):
        print('Samples are not unique')

    N_intervals = int((t_max - t_init) / dt) + 2
    
    t_intervals = np.linspace(t_init, t_max + dt, N_intervals)

    X = np.array([np.sum((samples >= t_intervals[n]) & (samples < t_intervals[n+1]), axis=0, dtype=bool) for n in range(N_intervals-1)])
    if np.sum(X) != (len(samples) * samples.shape[1]):
        print('np.sum(X):', np.sum(X))
        print('(len(samples) * samples.shape[1]):', (len(samples) * samples.shape[1]))
        print('dt is not small enough.')
    return X, t_intervals[:-1]








