import numpy as np
import scipy.stats as scs
from functools import lru_cache
from tqdm import tqdm
import sys
from scipy.special import logsumexp
from abc import ABCMeta


def Gaussian_kernel_1D(x,loc,scale=1):
    '''Gaussian 1-d density.'''
    return np.power(2*np.pi*(scale)**2,-0.5)*np.exp(-0.5*np.power((loc-x)/scale,2))

def Gaussian_kernel_multivariate(x,loc,scale=1): 
    '''Gaussian multivariate density with scale=1'''
    return np.exp(scs.norm.logpdf(x,loc=loc,scale=1).sum(axis=1))

def alpha(n,power):
    '''Returns the weigth for the nth itteration.
    power: the power for discounting
    return: 1/(n+1)^power'''
    return 1/((n+1)**power)


def update(a,K,samples,weights,data, new_version = 2):
    '''Performs one step following algorithm 1 in the PRticle paper.
    
    Parameters:
    ---
    
    - a: (scalar) the weight for the recursive mixture function. 

    - K: (funciton that outputs T-dimensional vector)
     the kernel function, returns log_pdf (new_version 2) or pdf (new_version 0,1). 
    
    - samples: (T-dimensional vector)
     the samples from the proposal of the mixing density. 
    
    - weights: (T-dimensional vector) the weights for the proposal of the current mixing density estimate. 
    
    - data: (d-dimensional vector) the current observation. 
    
    - new_version: (int 2,1,0) the version of the update function. 
        If ==0 then it uses clamps on weights,
            if ==1 it uses nothing to control weight values,
                if ==2 (default, most stable) THE KERNEL NEEDS TO GIVE LOGPROBABILITIES and the update function uses logsumexp and log tricks. (int)
    '''
    # K should output a vector of the same length as samples (T ) which gives [...,K(data_point,sample_i),...].
    
    # Mixture: the current mixture distribution of parameters. But it's not used for the weights computation, so might as well leave it out?
    #@lru_cache(maxsize=None)
    #def Mixture_next(u):
    #    N = K(data,samples) * Mixture(u)
    #    D = K(data,samples) * weights /T
    #    return (1-a) * Mixture(u) + a*(N/D)
    

    if new_version == 2: # K is now a logpdf, and I use log tricks
        D_log = logsumexp(K(data,samples) + np.log(weights) -np.log(len(samples)))
        new_weights = np.exp( np.log(weights) + np.log(1+ a*(np.exp(K(data,samples)-D_log) -1) ) )
        if np.isnan(new_weights).any():
            print('NANSSSSSSSSSSSSSSS',a)
        return new_weights

    elif new_version == 1: # no clamps, used for debugging
        D = (K(data,samples) * weights ).mean()
        print(D)
        new_weights = weights * (1+ a*((K(data,samples)/D) -1) )
        if np.isnan(new_weights).any():
            print(D,np.isnan(weights),
                  np.max(np.nan_to_num(weights,nan=0)),a,
                  np.mean(np.nan_to_num(weights,nan=0)))

            print('-----------------')

            print(D,np.isnan(new_weights),
                  np.max(np.nan_to_num(new_weights,nan=0)),a,
                  np.mean(np.nan_to_num(new_weights,nan=0)))
            return ['error']
        return new_weights
    
    else: # with clamps. Works until dimension of data is 5. Then it forces all weights to be equal -> no IS step.
        print('AAAAA')
        D = max((K(data,samples) * weights ).mean(),1e-10)
        new_weights = weights * (1+ a*((K(data,samples)/D) -1) )
        return np.nan_to_num(new_weights,nan=0.0)
    

def PRticle(obs, samples, perm_nb=20, kernel_fn=Gaussian_kernel_1D, update_version = 2, discount_power=1, verbose=False):
    '''
    Performs the PRticle algorithm by iterating the 'update' function over the observations.

    Parameters:
    ---
    - obs: (array-like) the observations to fit the PRticle filter to.
    - samples: (array-like) the samples from the proposal for the mixing density.
    - perm_nb: (integer) the number of permutations to average over.
    - kernel_fn: (function) the kernel function. Should take in two arguments: data, samples.
    - update_version: (integer) the version of the update function. 
        If ==0 then it uses clamps on weights, 
            if ==1 it uses nothing to control weight values, 
                if ==2 (default, most stable) THE KERNEL NEEDS TO GIVE LOGPROBABILITIES and
                the update function uses logsumexp and log tricks.
    - discount_power: (scalar) the power for discounting in the recursive mixture weights' function.
    - verbose: (boolean) if True, prints the progress of the algorithm.

    Returns:
    - (array-like) proposal weights corresponding to Importance Sampling (IS) samples from mixing density targets f_1, f_2, ..., and f_n.
    '''
    # K should output a vector of the same length as samples (T ) which gives [...,K(data_point,sample_i),...].

    final_w = [np.zeros(len(kernel_fn(obs[0],samples))) for i in range(len(obs))]
    
    if verbose and (update_version == 2):
        print('Using logpdf version of update function. Kernel should return a logpdf.')
    
    for j in range(perm_nb): # average over permutations
        w = [np.ones(len(kernel_fn(obs[0],samples)))]
        if perm_nb > 1:
            obs_permute = np.random.permutation(obs)
        else:
            obs_permute = obs

        if verbose: # version with progress bar
            for i in tqdm(range(1,len(obs))):
                new = update(a = alpha(i,power=discount_power),K = kernel_fn,
                             samples = samples,weights = w[i-1] ,
                             data = obs_permute[i-1], new_version = update_version)
                w.append(new)
                final_w[i]+=w[-1]
        else: # version without progress bar
            for i in range(1,len(obs)):
                new = update(a = alpha(i,power=discount_power),K = kernel_fn,
                             samples = samples,weights = w[i-1] ,
                             data = obs_permute[i-1], new_version = update_version)
                w.append(new)
                final_w[i]+=w[-1]

    # average over permutations after all iterations
    for j in range(len(obs)): 
        final_w[j] = final_w[j]/perm_nb

    return final_w


class Method(metaclass = ABCMeta):
    """
        This abstract base class represents an inference method.

    """

    def __getstate__(self):
        """Cloudpickle is used with the MPIBackend. This function ensures that the backend itself
        is not pickled
        """
        state = self.__dict__.copy()
        del state['backend']
        return state


class PRticle_filter(Method):
    '''
    This class encompases the PRticle filter algorithm and implements it in parallel.

    Parameters:
    ---
        - backend: (object) the backend object to use for parallelization.
        - samples: (array-like) the samples from the proposal for the mixing density.
        - perm_nb: (integer) the number of permutations to average over.
        - kernel_fn: (function) the kernel function. Should take in two arguments: data, samples. THE KERNEL NEEDS TO GIVE LOGPROBABILITIES and
                    the update function uses logsumexp and log tricks.
        - discount_power: (scalar) the power for discounting in the recursive mixture weights' function.
    '''
    
    def __init__(self, backend, samples, perm_nb=20, kernel_fn=Gaussian_kernel_1D, discount_power=1):
        self.backend = backend
        self.samples = samples
        self.perm_nb = perm_nb
        self.kernel_fn = kernel_fn
        self.discount_power = discount_power
        
    def alpha(self,n,power):
        '''
        Returns the weigth for the nth itteration.
        power: the power for discounting
        return: 1/(n+1)^power
        '''
        return 1/((n+1)**power)

    def update(self,a,K,samples,weights,data):
        '''
        Performs one step following algorithm 1 in the PRticle paper.
        
        Parameters:
        ---
        
        - a: (scalar) the weight for the recursive mixture function. 

        - K: (funciton that outputs T-dimensional vector)
        the kernel function, returns log_pdf. 
        
        - samples: (T-dimensional vector)
        the samples from the proposal of the mixing density. 
        
        - weights: (T-dimensional vector) the weights for the proposal of the current mixing density estimate. 
        
        - data: (d-dimensional vector) the current observation. 
        
        '''
        # K should output a vector of the same length as samples (T ) which gives [...,K(data_point,sample_i),...].
        # K is now a logpdf, and I use log tricks

        D_log = logsumexp(K(data,samples) + np.log(weights) -np.log(len(samples)))
        new_weights = np.exp( np.log(weights) + np.log(1+ a*(np.exp(K(data,samples)-D_log) -1) ) )
        
        return new_weights

    def single_PR(self,obs):
        '''
        Performs the PRticle algorithm by iterating the 'update' function over the observations. Runs in parallel over permutations.
        Estimates the mixing density f.

        Parameters:
        ---
        - obs: (array-like) the observations to fit the PRticle filter to.
        - samples: (array-like) the samples from the proposal for the mixing density.
        - perm_nb: (integer) the number of permutations to average over.
        - kernel_fn: (function) the kernel function. Should take in two arguments: data, samples.
        - update_version: (integer) the version of the update function. 
            If ==0 then it uses clamps on weights, 
                if ==1 it uses nothing to control weight values, 
                    if ==2 (default, most stable) THE KERNEL NEEDS TO GIVE LOGPROBABILITIES and
                    the update function uses logsumexp and log tricks.
        - discount_power: (scalar) the power for discounting in the recursive mixture weights' function.
        - verbose: (boolean) if True, prints the progress of the algorithm.

        Returns:
        - (array-like) proposal weights corresponding to Importance Sampling (IS) samples from mixing density targets f_1, f_2, ..., and f_n.
        '''

        # initialise weights
        w = [np.ones(len(self.samples))]
        
        obs_permute = np.random.permutation(obs)

        # Loop over observations to update weights
        for i in range(1,len(obs)):
                new = update(a = alpha(i,power=self.discount_power),K = self.kernel_fn,
                            samples = self.samples,weights = w[i-1] ,
                            data = obs_permute[i-1])
                w.append(new)

        return w    

    def parallel_PR(self,obs):
        '''
        This function is used to parallelize the PRticle algorithm.

        Parameters
        ---
        obs : (array-like) the observations to fit the PRticle filter to.

        Returns
        ---
        (array-like) proposal weights (averaged over permutations) corresponding to Importance Sampling (IS) samples from mixing density targets f_1, f_2, ..., and f_n.
        '''
        obs_parallel = self.backend.parallelize([np.random.permutation(obs) for i in range(self.perm_nb)])
        single_PR_ = lambda x,npc=None: self.single_PR(x)
        weights_parallel = self.backend.map(single_PR_, obs_parallel)
        weights = np.array(self.backend.collect(weights_parallel))

        return weights   

