import numpy as np
from scipy import stats
import itertools


def thinning_algorithm(lam_t_vec):
    """
    Ogata's thinning algorithm

    Example:
    --------
    T = 100
    t_vec = np.arange(100)
    lam_t_vec = 1 + np.sin(t_vec/10)

    plt.plot(t_vec,lam_t_vec)
    plt.plot(t_list, np.zeros(len(t_list)), '.')

    Parameters:
    -----------
    lam_t_vec: numpy.ndarray
        Descrete samples from a continous function

    Returns:
    --------
    t_list: numpy.ndarray
        Samples of the events. The values in this array
        represent the event time.
    """
    n = 0
    m = 0
    s = 0
    lam_max = np.max(lam_t_vec)

    T = len(lam_t_vec)
    t_list = list()
    while True:
        u = np.random.rand()
        w = - np.log(u) / lam_max
        s += w
        if s > T:
            break
        D = np.random.rand()
        if D <= (lam_t_vec[int(np.floor(s))] / lam_max):
            t_list.append(s)
            n += 1
        m += 1

    t_list = np.array(t_list)
    if t_list[-1] > T:
        t_list = t_list[:-1]

    t_list = np.array(t_list)
    return t_list

def convert_spiketimes_to_discreteincrements(samples, t_init, t_max, dt):
    samples = np.array(samples)
    if len(samples.shape) == 1:
        samples = samples[:,np.newaxis]

    if len(np.unique(samples, axis=1)) != len(samples):
        print('Samples are not unique')

    N_intervals = int((t_max - t_init) / dt) + 2
    
    t_intervals = np.linspace(t_init, t_max + dt, N_intervals)
    # t_intervals = np.array(list(itertools.product(np.linspace(t_init, t_max, N_intervals), repeat=samples.shape[1])))

    X = np.array([np.sum((samples >= t_intervals[n]) & (samples < t_intervals[n+1]), axis=0, dtype=bool) for n in range(N_intervals-1)])
    if np.sum(X) != (len(samples) * samples.shape[1]):
        print('np.sum(X):', np.sum(X))
        print('(len(samples) * samples.shape[1]):', (len(samples) * samples.shape[1]))
        print('dt is not small enough.')
    return X, t_intervals[:-1]

def get_mesh(T_max, dt, d):
    """
    Parameters:
    -----------
    T_max: float
        Max time for the Poisson process
    dt: float
        Time step (width of a bin)
    d: int
        Number of dimensions (Number of processes)

    Returns:
    --------
    pos: numpy.ndarray
        Returns d dimension mesh, where each dimension
        represents a Poisson process.
    """
    mins = np.repeat(0, d)
    maxs = np.repeat(10, d)
    numPoints = int(T_max / dt) + 1
    meshgrid = np.meshgrid(*[np.linspace(i,j,numPoints) for i,j in zip(mins,maxs)], indexing='ij')
    meshgrid = np.array(meshgrid)

    pos = np.empty(meshgrid.shape[1:] + (meshgrid.shape[0],))

    for i in range(meshgrid.shape[0]):
        pos[...,i] = meshgrid[i,...]
    
    return pos

def get_gaussian_pdf(T_max, dt, d):
    """
    Parameters:
    -----------
    T_max: float
        Max time for the Poisson process
    dt: float
        Time step (width of a bin)
    d: int
        Number of dimensions (Number of processes)

    Returns:
    --------
    gauss_pdf: numpy.ndarray
        A single frozen Gaussian pdf
    """
    pos = get_mesh(T_max, dt, d)
    covar = stats.invwishart(d, np.eye(d)).rvs()

    mesh_size = pos.shape[:-1]
    mean = pos[np.where(np.random.multinomial(1, np.ones(mesh_size).flatten() / np.prod(mesh_size)).reshape(mesh_size))]

    gauss_pdf = stats.multivariate_normal(mean[0], covar).pdf(pos)
    gauss_pdf /= np.sum(gauss_pdf)
    return gauss_pdf

def create_mixture_of_gaussian_pdf(T_max, dt, d, N_gauss):
    """
    Parameters:
    -----------
    T_max: float
        Max time for the Poisson process
    dt: float
        Time step (width of a bin)
    d: int
        Number of dimensions (Number of processes)

    Returns:
    --------
    mixture_of_gauss_pdf: numpy.ndarray
        A frozen mixture of Gaussian pdf
    """
    mixture_of_gauss_pdf = np.sum(np.array([get_gaussian_pdf(T_max, dt, d) for _ in range(N_gauss)]), axis=0) / N_gauss
    return mixture_of_gauss_pdf

def get_exact_intensity_samples(pdf, n):
    samples_drawn = np.random.multinomial(n, pdf.flatten())
    sample_idx = np.where(samples_drawn != 0)[0]
    sample_idx_unravelled = np.array([np.unravel_index(s, pdf.shape) for s in sample_idx])
    
    return sample_idx_unravelled

# class multidimensional_intensity:
#     """
#     Sample from a multi-dimensional_intensity to create samples which
#     appear to be drawn from a Poisson distribution

#     Example:
#     --------
#     dt = 0.05
#     T_max = 10
#     XX, YY = np.mgrid[0:T_max:dt, 0:T_max:dt]
    
#     pos = np.empty(XX.shape + (2,))
#     pos[:, :, 0] = XX
#     pos[:, :, 1] = YY

#     c1 = 3
#     c2 = 4
#     c3 = 6
#     c4 = 2

#     rv1 = stats.multivariate_normal([1.5, 8.2], [[2.0, 0.7], [0.3, 1.0]])
#     rv2 = stats.multivariate_normal([7.5, 7.2], [[2.0, 1.0], [0.5, 0.5]])
#     rv3 = stats.multivariate_normal([1.5, 2.7], [[1.0, 0.0], [0.5, 1.5]])
#     rv4 = stats.multivariate_normal([7.5, 2.7], [[2.0, 3.0], [2.5, 5.5]])

#     ground_truth_intensity = (c1 * rv1.pdf(pos)) + (c2 * rv2.pdf(pos)) + (c3 * rv3.pdf(pos)) + (c4 * rv4.pdf(pos))
#     plt.pcolormesh(XX, YY, ground_truth_intensity, cmap='plasma')

#     mdi = multidimensional_intensity(XX, YY, ground_truth_intensity)
#     mdi.set_intensity(50)
#     samples = mdi.get_samples()

#     plt.pcolormesh(XX, YY, ground_truth_intensity, cmap='plasma')
#     plt.plot(samples[:,0], samples[:,1], '.')
#     """
#     def __init__(self, XX, YY, intensity_function):
#         """
#         Parameters:
#         -----------
#         XX: np.ndarray
#             Mesh of the X co-ordinates
#         YY: np.ndarray
#             Mesh of the y co-ordinates
#         intensity_function: np.ndarray
#             Input of a two dimensional intensity function

#         Returns:
#         --------
#         None
#         """
#         self.XX = XX
#         self.YY = YY
#         self.intensity_function = intensity_function
#         self.pdf_function = intensity_function / np.sum(intensity_function)

#     def set_intensity(self, intensity):
#         """
#         Changes the intensity of the intensity function

#         Parameters:
#         -----------
#         Intensity: int
#             Intensity of the intensity function

#         Returns:
#         --------
#         None
#         """
#         self.intensity = intensity
#         self.intensity_function = self.intensity * self.pdf_function

#     def get_samples(self):
#         """
#         Parameters:
#         -----------
#         None

#         Returns:
#         --------
#         Samples: numpy.ndarray
#         """
#         n = np.random.poisson(self.intensity)
#         samples_drawn = np.random.multinomial(n, self.pdf_function.flatten())
#         self.sample_idx = np.where(samples_drawn != 0)[0]
#         XX_flattened = XX.flatten()
#         YY_flattened = YY.flatten()
        
        
#         XX_flattened[self.sample_idx]
#         YY_flattened[self.sample_idx]
        
#         samples = np.vstack([XX_flattened[self.sample_idx] , YY_flattened[self.sample_idx]]).T
#         return samples

#     def get_exact_intensity_samples(self):
#         n = self.intensity
#         samples_drawn = np.random.multinomial(n, self.pdf_function.flatten())
#         self.sample_idx = np.where(samples_drawn != 0)[0]
#         XX_flattened = XX.flatten()
#         YY_flattened = YY.flatten()
        
        
#         XX_flattened[self.sample_idx]
#         YY_flattened[self.sample_idx]
        
#         samples = np.vstack([XX_flattened[self.sample_idx] , YY_flattened[self.sample_idx]]).T
#         return samples
    
    
#     def get_marginal(self, dimension):
#         np.sum(self.pdf_function, axis=dimension)


