import numpy as np
import tensorflow as tf
from . import kernels_rfm
from scipy import stats
import time, sys

from pylab import *

class k2_hawkes_rfm:

    def __init__(self, kernel='gaissian', n_rand_feature=200, seed=0):
        
        self.ker = kernels_rfm(n_dim=1,kernel=kernel,
                               n_rand_feature=n_rand_feature,seed=seed,qmc=True)

        
    def fit(self, spk, T, gamma, b, mu, support):
        elapse_t0 = time.time()
        
        spk    = tf.ragged.constant(spk)
        d_type = spk.dtype
        T      = tf.cast(T,d_type)
        b      = tf.cast(b,d_type)
        mu     = tf.cast(mu,d_type)
        gamma  = tf.cast(gamma,d_type)
        n_node = spk.shape[0]

        self.spk    = spk
        self.T      = T
        self.mu     = mu
        self.sup    = support
        self.d_type = d_type
        self.rfm    = lambda x: self.ker.rfm(x[:,tf.newaxis],b)
        self.irfm   = lambda T0,T1: self.integral_rfm(T0,T1,b)
        
        V = self.func_XI(b) + 1./gamma*tf.eye(n_node*self.ker.nrf,dtype=d_type)
        self.chol = tf.linalg.cholesky(V)
                
        return time.time() - elapse_t0
        
    @tf.function()
    def func_XI(self, b):

        xi_array = tf.TensorArray(dtype=self.d_type,
                                  size=self.spk.shape[0]**2)
        j = 0
        for spk1 in self.spk:
            for spk2 in self.spk:
                xi = tf.zeros((self.ker.nrf,self.ker.nrf),dtype=self.d_type)
                for s1 in spk1:
                    T0 = tf.maximum(spk2,s1)
                    T1 = tf.minimum(self.T, self.sup + tf.minimum(spk2,s1))
                    mask = T0 < T1 
                    s2 = tf.boolean_mask(spk2, mask)
                    T0, T1 = tf.boolean_mask(T0, mask), tf.boolean_mask(T1, mask)
                    ss1 = s1*tf.ones(tf.shape(s2),dtype=self.d_type)
                    xi += self.integral_rfm_rfm(T0,T1,ss1,s2,b)
                xi_array = xi_array.write(j,xi)
                j += 1

        nn, mm = self.spk.shape[0], self.ker.nrf
        blocks = tf.reshape(xi_array.stack(), [nn,nn,mm,mm])
        blocks = tf.transpose(blocks, perm=[0,2,1,3])
        xi_big = tf.reshape(blocks, [nn*mm,nn*mm])
                
        return xi_big
    
    def integral_rfm_rfm(self, T0, T1, x, y, b):
                
        omega = tf.cast(self.ker.omega,b.dtype)[:,0]
        ww = tf.concat([omega,omega],axis=0)
        d0 = tf.zeros((self.ker.nrf2,),dtype=b.dtype)
        d1 = tf.constant(-0.5*np.pi,dtype=b.dtype) + d0
        dd = tf.concat([d0,d1],axis=0)
        an = 1. / tf.cast(self.ker.nrf,b.dtype)

        T0, T1 = T0[None,None,:], T1[None,None,:]
        x, y = x[None,None,:], y[None,None,:]
        
        bwTd = 0.5*b*(T0+T1)*(ww[:,None,None]+ww[None,:,None]) \
            + dd[:,None,None] + dd[None,:,None] \
            - b*ww[:,None,None]*x - b*ww[None,:,None]*y
        A1 = tf.cos(bwTd)
        bwTd = 0.5*b*(T0+T1)*(ww[:,None,None]-ww[None,:,None]) \
            + dd[:,None,None] - dd[None,:,None] \
            - b*ww[:,None,None]*x + b*ww[None,:,None]*y
        A2 = tf.cos(bwTd)
        bwT = 0.5*b*(T1-T0)*(ww[:,None,None]+ww[None,:,None])
        A1 *= (T1-T0)*self.sinc(bwT)
        bwT = 0.5*b*(T1-T0)*(ww[:,None,None]-ww[None,:,None])
        A2 *= (T1-T0)*self.sinc(bwT)
                
        return tf.reduce_sum(an * (A1 + A2), axis=2)
    
    """
    def integral_rfm_rfm(self, T0, T1, x, y, b):
                
        omega = tf.cast(self.ker.omega,b.dtype)[:,0]
        ww = tf.concat([omega,omega],axis=0)
        d0 = tf.zeros((self.ker.nrf2,),dtype=b.dtype)
        d1 = tf.constant(-0.5*np.pi,dtype=b.dtype) + d0
        dd = tf.concat([d0,d1],axis=0)
        an = 1. / tf.cast(self.ker.nrf,b.dtype)

        T0, T1 = T0[None,:], T1[None,:]
        x, y = x[None,:], y[None,:]

        v_array = tf.TensorArray(dtype=self.d_type,
                                 size=self.ker.nrf)
        for i in tf.range(self.ker.nrf):
            bwTd = 0.5*b*(T0+T1)*(ww[:,None]+ww[i]) \
                + dd[:,None] + dd[i] \
                - ww[:,None]*x - ww[i]*y
            A1 = tf.cos(bwTd)
            bwTd = 0.5*b*(T0+T1)*(ww[:,None]-ww[i]) \
                + dd[:,None] - dd[i] \
                - ww[:,None]*x + ww[i]*y
            A2 = tf.cos(bwTd)
            bwT = 0.5*b*(T1-T0)*(ww[:,None]+ww[i])
            A1 *= (T1-T0)*self.sinc(bwT)
            bwT = 0.5*b*(T1-T0)*(ww[:,None]-ww[i])
            A2 *= (T1-T0)*self.sinc(bwT)
            z = tf.reduce_sum(an * (A1 + A2), axis=1, keepdims=True)
            v_array = v_array.write(i,z)
        v = tf.concat(tf.unstack(v_array.stack(), axis=0), axis=1)
                
        return v
    """

    def integral_rfm(self, T0, T1, b):

        omega = tf.cast(self.ker.omega,b.dtype)[:,0]
        ww = tf.concat([omega,omega],axis=0)
        d0 = tf.zeros((self.ker.nrf2,),dtype=b.dtype)
        d1 = tf.constant(-0.5*np.pi,dtype=b.dtype) + d0
        dd = tf.concat([d0,d1],axis=0)
        an = tf.sqrt(1. / tf.cast(self.ker.nrf2,b.dtype))

        ww, dd = ww[:,None], dd[:,None]
        T0, T1 = T0[None,:], T1[None,:]
        y = tf.reduce_sum((tf.sin(b*ww*T1+dd) - tf.sin(b*ww*T0+dd)) / (b*ww),axis=1,
                          keepdims=True)

        return an * y
        

    def sinc(self, x):
        return tf.where(tf.greater(tf.abs(x), 1.e-7), tf.sin(x)/x, 1.-x**2/6.+x**4/120.)
        
        
    def predict(self, x, edge):

        # edge = [node1, node2], interaction of (node1 <- node2)
        # node = 0, 1, 2 ..., (N_node-1)
        spk1, spk2 = self.spk[edge[0]], self.spk[edge[1]]
        """
        xx = np.linspace(0,10,500)
        yy = np.array([0.])
        cc = tf.tile(self.rfm(yy),(self.spk.shape[0],1))
        cc = tf.linalg.cholesky_solve(self.chol,cc)
        cc = cc[self.ker.nrf*0:self.ker.nrf*(0+1)]
        gg = tf.matmul(self.rfm(xx),cc,transpose_a=True)[:,0]
        plot(xx,gg,'b')

        yy = np.array([0.1])
        cc = tf.tile(self.rfm(yy),(self.spk.shape[0],1))
        cc = tf.linalg.cholesky_solve(self.chol,cc)
        cc = cc[self.ker.nrf*0:self.ker.nrf*(0+1)]
        gg = tf.matmul(self.rfm(xx),cc,transpose_a=True)[:,0]
        plot(xx,gg,'r')

        yy = np.array([0.5])
        cc = tf.tile(self.rfm(yy),(self.spk.shape[0],1))
        cc = tf.linalg.cholesky_solve(self.chol,cc)
        cc = cc[self.ker.nrf*0:self.ker.nrf*(0+1)]
        gg = tf.matmul(self.rfm(xx),cc,transpose_a=True)[:,0]
        plot(xx,gg,'g')

        yy = np.array([1.0])
        cc = tf.tile(self.rfm(yy),(self.spk.shape[0],1))
        cc = tf.linalg.cholesky_solve(self.chol,cc)
        cc = cc[self.ker.nrf*0:self.ker.nrf*(0+1)]
        gg = tf.matmul(self.rfm(xx),cc,transpose_a=True)[:,0]
        plot(xx,gg,'k')
        
        
        show()
        sys.exit()
        """
        
        

        
        """
        x = np.linspace(-10.,10.,100)
        x0 = np.array([0.])
        y = tf.matmul(self.rfm(x),self.rfm(x0),transpose_a=True)[:,0]
        plot(x,y)
        show()
        sys.exit()
        """
        """
        c_2 = tf.zeros((self.ker.nrf,1),dtype=self.d_type)
        for s1 in spk1:
            s2 = tf.boolean_mask(spk2, (spk2<s1)&(spk2>s1-self.sup))
            c_2 += tf.reduce_sum(self.rfm(s1-s2),axis=1,keepdims=True)
        g_a = tf.matmul(self.rfm(x),c_2,transpose_a=True)

        T1 = tf.minimum(self.T-spk2,self.sup)
        T0 = 0.0*T1
        c_2 = self.irfm(T0,T1)
        g_b = tf.matmul(self.rfm(x),c_2,transpose_a=True)
        g_b = - self.mu[edge[0]] * g_b

        
        
        return (g_a+g_b)[:,0] 
        """
        
        ###
        c_2 = tf.zeros((self.ker.nrf,1),dtype=self.d_type)
        for s1 in spk1:
            s2 = tf.boolean_mask(spk2, (spk2<s1)&(spk2>s1-self.sup))
            c_2 += tf.reduce_sum(self.rfm(s1-s2),axis=1,keepdims=True)
        c_2 = tf.tile(c_2,(self.spk.shape[0],1))
        c_2 = tf.linalg.cholesky_solve(self.chol,c_2)
        c_2 = c_2[self.ker.nrf*edge[1]:self.ker.nrf*(edge[1]+1)]
        g_a = tf.matmul(self.rfm(x),c_2,transpose_a=True)

        ###
        T1 = tf.minimum(self.T-spk2,self.sup)
        T0 = 0.0*T1
        c_2 = self.irfm(T0,T1)
        c_2 = tf.tile(c_2,(self.spk.shape[0],1))
        c_2 = tf.linalg.cholesky_solve(self.chol,c_2)
        c_2 = c_2[self.ker.nrf*edge[1]:self.ker.nrf*(edge[1]+1)]
        g_b = tf.matmul(self.rfm(x),c_2,transpose_a=True)
        g_b = - self.mu[edge[0]] * g_b
        
        g = (g_a + g_b)[:,0]
        #g = -g_b[:,0]
        
        return g
        
    def predict_integral(self, region):

        region = tf.cast(region,self.d_spk.dtype)
        
        return self.eq_ker.integral_reduce_sum(region,self.d_spk)

    def predict_integral_squared(self, region):

        region = tf.cast(region,self.d_spk.dtype)
        
        return self.eq_ker.integral_squared_reduce_sum(region,self.d_spk)
    
