import numpy as np
import tensorflow as tf
from . import eq_kernels_rfm
import time, sys

class k2_hawkes:

    def __init__(self, kernel='gaissian'):

        if kernel=='gaussian':
            self.func  = lambda x: tf.exp(-x**2)
            self.ifunc = lambda x,r: 0.5*tf.sqrt(tf.cast(np.pi,x.dtype))*(tf.math.erf(x-r[0])-tf.math.erf(x-r[1]))
        if kernel=='laplace':
            self.func  = lambda x: tf.exp(-tf.abs(x))
            self.ifunc = lambda x,r: tf.sign(x-r[1])*(tf.exp(-tf.abs(x-r[1]))-1.)-tf.sign(x-r[0])*(tf.exp(-tf.abs(x-r[0]))-1.)
        if kernel=='cauchy':
            self.func  = lambda x: 2./(1.+x**2)
            self.ifunc = lambda x,r: 2.*tf.atan(x-r[0]) - 2.*tf.atan(x-r[1])
    
    def fit(self, spk, T, gamma, b, mu, n_int, support):
        elapse_t0 = time.time()
        
        spk    = tf.ragged.constant(spk)
        d_type = spk.dtype
        T      = tf.cast(T,d_type)
        b      = tf.cast(b,d_type)
        mu     = tf.cast(mu,d_type)
        gamma  = tf.cast(gamma,d_type)
        n_node = spk.shape[0]

        self.ker  = lambda t,s: gamma*self.func(b*(t-s))
        self.iker = lambda t,r: gamma*self.ifunc(b*t,b*r)/b
        self.poi  = tf.constant(np.linspace(0.,T,n_int),d_type)
        self.spk  = spk
        self.T    = T
        self.mu   = mu
        self.sup = support

        # Trapezoidal approximation
        #dT = 0.5*(self.poi[1:]-self.poi[:-1])
        #dT = tf.concat([ [dT[0]], dT[:-1]+dT[1:], [dT[-1]] ],axis=0)
        dT = tf.ones(self.poi.shape[0],dtype=d_type)*T/tf.cast(n_int,d_type)
        self.dT = dT
        
        # (1/gamma * I + V)^-1
        V = tf.zeros((0,n_int*n_node),d_type)
        for spk1 in spk:
            V = tf.concat([V,self.func_V(self.poi,self.poi,spk1,spk,T,dT,support) ],axis=0)
        V += tf.eye(n_int*n_node,dtype=d_type)
        self.lu, self.p = tf.linalg.lu(V)

        return time.time() - elapse_t0

    """
    @tf.function()
    def func_V(self, s, t, spk1, spk2_list, T, dT, sup):

        ss, tt = s[:,None,None], t[None,:,None]
        n = spk2_list.shape[0]
        V_array = tf.TensorArray(dtype=s.dtype, size=n)
        
        for i in tf.range(n):
            v = tf.zeros((s.shape[0],t.shape[0]),dtype=s.dtype)
            s2 = spk2_list[i][None,None,:]
            for s1 in spk1:
                bool1 = tf.cast(tt > s1-s2, s.dtype)
                bool2 = tf.cast(tt <= T-s2, s.dtype)
                v += tf.reduce_sum(self.ker(ss,tt-(s1-s2))*bool1*bool2,2)
            v = tf.matmul(v,tf.linalg.diag(dT))
            V_array = V_array.write(i,v)
        
        V = tf.concat(tf.unstack(V_array.stack(), axis=0), axis=1)

        return V
    """
    
    @tf.function()
    def func_V(self, s, t, spk1, spk2_list, T, dT, sup):

        ss = s[:,None]
        n = spk2_list.shape[0]
        V_array = tf.TensorArray(dtype=s.dtype,
                                 size=spk2_list.shape[0]*t.shape[0])
        j = 0
        for s2 in spk2_list:
            for tt in t:
                v = tf.zeros((s.shape[0],),dtype=s.dtype)
                TT = tf.minimum(sup+s2,T)
                s2_ = tf.boolean_mask(s2, s2+tt < TT)
                for s1 in spk1:
                    mask = (s2_+tt > s1) & (s2_+tt < sup+s1)
                    z = tf.boolean_mask(s2_, mask)
                    v += tf.reduce_sum(self.ker(ss,tt-(s1-z[None,:])),axis=1)
                V_array = V_array.write(j,v[:,tf.newaxis])
                j += 1

        V = tf.concat(tf.unstack(V_array.stack(), axis=0), axis=1)
        V *= tf.tile(dT[None,:],(1,n))

        return V
    '''
    def predict(self, x, edge):

        # edge = [node1, node2], interaction of (node1 <- node2)
        # node = 0, 1, 2 ..., (N_node-1)
        spk1, spk2 = self.spk[edge[0]], self.spk[edge[1]]

        spk_spk = spk1[:,None]-spk2[None,:]
        spk_spk = tf.boolean_mask(spk_spk, spk_spk > 0.)

        poi_spk = self.poi[:,None]-spk2[None,:]
        ddt_spk = self.dT[:,None]-spk2[None,:]*0.0
        ddt_spk = tf.boolean_mask(ddt_spk, poi_spk > 0.)
        poi_spk = tf.boolean_mask(poi_spk, poi_spk > 0.)
                
        ###
        g_a = tf.reduce_sum(self.ker(x[:,None],spk_spk[None,:]),axis=1)
        
        ###
        kk = tf.reduce_sum(self.ker(self.poi[:,None],spk_spk[None,:]),
                           axis=1, keepdims=True)
        kk = tf.tile(kk, multiples=[self.spk.shape[0], 1])
        hh = tf.linalg.lu_solve(self.lu, self.p, kk)
        vv = self.func_V(x,self.poi,spk2,self.spk,self.T,self.dT,self.sup)
        g_b = - tf.matmul(vv,hh)[:,0]

        ###
        #rr = [spk2[None,:],self.T*tf.ones(spk2.shape,dtype=self.T.dtype)[None,:]]
        #g_c = - self.mu[edge[0]] * tf.reduce_sum(self.iker(x[:,None],rr),axis=1)
        g_c = tf.reduce_sum(self.ker(x[:,None],poi_spk[None,:])*ddt_spk[None,:],axis=1)
        g_c = - self.mu[edge[0]] * g_c
        print(g_c)
        sys.exit()
        """
        g_c = tf.zeros(x.shape,dtype=self.T.dtype)
        for i in tf.range(self.poi.shape[0]):
            bool1 = tf.cast(self.poi[i]-spk2[None,:] > 0., self.T.dtype)
            g_c += tf.reduce_sum(self.ker(x[:,None],self.poi[i]-spk2[None,:])
                                 *bool1*self.dT[i],axis=1)
        g_c = - self.mu[edge[0]] * g_c
        """
        
        ###
        #kk = tf.reduce_sum(self.iker(self.poi[:,None],rr),
        #                   axis=1, keepdims=True)
        """
        kk = tf.zeros(self.poi.shape,dtype=self.T.dtype)
        for i in tf.range(self.poi.shape[0]):
            bool1 = tf.cast(self.poi[i]-spk2[None,:] > 0., self.T.dtype)
            kk += tf.reduce_sum(self.ker(self.poi[:,None],self.poi[i]-spk2[None,:])
                                *bool1**self.dT[i],axis=1)
        kk = kk[:,tf.newaxis]
        kk = tf.tile(kk, multiples=[self.spk.shape[0], 1])
        hh = tf.linalg.lu_solve(self.lu, self.p, kk)
        g_d = self.mu[edge[0]] * tf.matmul(vv,hh)[:,0]
        """

        kk = tf.reduce_sum(self.ker(self.poi[:,None],poi_spk[None,:])*ddt_spk[None,:],
                           axis=1, keepdims=True)
        kk = tf.tile(kk, multiples=[self.spk.shape[0], 1])
        hh = tf.linalg.lu_solve(self.lu, self.p, kk)
        vv = self.func_V(x,self.poi,spk2,self.spk,self.T,self.dT,self.sup)
        g_d = self.mu[edge[0]] * tf.matmul(vv,hh)[:,0]
        
        g = g_a + g_b + g_c + g_d
        
        return g
    '''
    
    def predict(self, x, edge):

        # edge = [node1, node2], interaction of (node1 <- node2)
        # node = 0, 1, 2 ..., (N_node-1)
        spk1, spk2 = self.spk[edge[0]], self.spk[edge[1]]
        
        ###
        g_a = tf.zeros(x.shape[0],dtype=self.T.dtype)
        for s1 in spk1:
            s2 = tf.boolean_mask(spk2, (spk2<s1)&(spk2>s1-self.sup))
            g_a += tf.reduce_sum(self.ker(x[:,None],s1-s2[None,:]),axis=1)
        
        ###
        kk = tf.zeros(self.poi.shape[0],dtype=self.T.dtype)
        for s1 in spk1:
            s2 = tf.boolean_mask(spk2, (spk2<s1)&(spk2>s1-self.sup))
            kk += tf.reduce_sum(self.ker(self.poi[:,None],s1-s2[None,:]),axis=1)
        kk = tf.tile(kk[:,tf.newaxis], multiples=[self.spk.shape[0], 1])
        hh = tf.linalg.lu_solve(self.lu, self.p, kk)
        vv = self.func_V(x,self.poi,spk2,self.spk,self.T,self.dT,self.sup)
        g_b = - tf.matmul(vv,hh)[:,0]

        ###
        g_c = tf.zeros(x.shape[0],dtype=self.T.dtype)
        for i in tf.range(self.poi.shape[0]):
            s1, dt = self.poi[i], self.dT[i]
            s2 = tf.boolean_mask(spk2, (spk2<s1)&(spk2>s1-self.sup))
            g_c += tf.reduce_sum(self.ker(x[:,None],s1-s2[None,:]),axis=1)*dt
        g_c = - self.mu[edge[0]] * g_c
        
        ###
        kk = tf.zeros(self.poi.shape[0],dtype=self.T.dtype)
        for i in tf.range(self.poi.shape[0]):
            s1, dt = self.poi[i], self.dT[i]
            s2 = tf.boolean_mask(spk2, (spk2<s1)&(spk2>s1-self.sup))
            kk += tf.reduce_sum(self.ker(self.poi[:,None],s1-s2[None,:]),axis=1)*dt
        kk = tf.tile(kk[:,tf.newaxis], multiples=[self.spk.shape[0], 1])
        hh = tf.linalg.lu_solve(self.lu, self.p, kk)
        vv = self.func_V(x,self.poi,spk2,self.spk,self.T,self.dT,self.sup)
        g_d = self.mu[edge[0]] * tf.matmul(vv,hh)[:,0]
                
        g = g_a + g_b + g_c + g_d
        
        return g
    
    def predict_integral(self, region):

        region = tf.cast(region,self.d_spk.dtype)
        
        return self.eq_ker.integral_reduce_sum(region,self.d_spk)

    def predict_integral_squared(self, region):

        region = tf.cast(region,self.d_spk.dtype)
        
        return self.eq_ker.integral_squared_reduce_sum(region,self.d_spk)
    
