from pylab import *
import dill, sys, re
import numpy as np
import matplotlib.pyplot as plt
from kernelhawkes import MultivariateKernelHawkes, MultivariateExponentialHawkes, MultivariateBasisHawkes
from myfunc import k2_hawkes, k2_hawkes_rfm
import tensorflow as tf

import os
import warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.simplefThe ilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)
tf.config.set_visible_devices([], 'GPU')

def reform_spk(spk):

    n_node = max([x[1] for x in spk])
    y = [[] for x in range(n_node)]
    for x in spk:
        y[x[1]-1].append(x[0])
    
    return y


dfile = 'data_synthetic/3D_EX_T2000.dill'
data = dill.load(open(dfile,'rb'))
d_spk, k_true, T, mu = data['spk'], data['ker'], data['T'], data['mu']
n_node = len(k_true)

dir_result = 'result/'+re.split(r'[/.]+',dfile)[1]+'/'
if not os.path.exists(dir_result):
    os.makedirs(dir_result)

set_g = [0.1, 0.5, 1.0]
set_b = [0.5, 1.0, 1.5]
set_par = [[x,y] for x in set_g for y in set_b]

support = 5
p_train = 0.8

t_ev = linspace(1.e-9,5,1000)
dt = diff(t_ev)
dt_ev = 0.5*r_[dt[0],(dt[1:]+dt[:-1]),dt[-1]]

ise = {'K2H':[], 'Bonnet':[], 'EXP':[], 'BER':[], 'GAU':[]}
cpu = {'K2H':[], 'Bonnet':[], 'EXP':[], 'BER':[], 'GAU':[]}

models  = ['K2H','Bonnet', 'EXP', 'BER', 'GAU']

for iii, spk in enumerate(d_spk):
    
    # True triggering kernel #############################

    fig, ax = subplots(n_node,n_node,figsize=(12,6))
    for ii in range(n_node):
        for jj in range(n_node):
            ax[ii,jj].plot(t_ev,k_true[ii][jj](t_ev),'k--',lw=0.8)
            ax[ii,jj].set_xlim(-0.1,t_ev[-1])
            # Select y-axis display limit 
            # ax[ii,jj].set_ylim(-1.0, 1.1) # Refractory scenario
            # ax[ii,jj].set_ylim(-0.1, 0.6) # Mutually-exciting scenario
    
    # Evaluation points for scross validation ############

    z = array([x[0] for x in spk])
    z = list(set(r_[z-1.e-9,z+1.e-9,linspace(0,T,10000)]))
    t_cr = array(sorted(z))
    
    dt = diff(t_cr)
    dt_cr = 0.5*r_[dt[0],(dt[1:]+dt[:-1]),dt[-1]]
    t_cr_tr = t_cr[where(t_cr<=T*p_train)]
    dt = diff(t_cr_tr)
    dt_cr_tr = 0.5*r_[dt[0],(dt[1:]+dt[:-1]),dt[-1]]

    # Split data into training and test ##################
    
    spk_tr = []
    for x in spk:
        if x[0] <= T*p_train:
            spk_tr.append(x)

    # Format data for score evaluation ###################

    spkk, spkk_tr = reform_spk(spk), reform_spk(spk_tr)
    hist = [array(x)[:,newaxis] for x in spkk]
    hist_tr = [array(x)[:,newaxis] for x in spkk_tr]
    
    # Exponential Hawkes #################################
    ######################################################
    model = 'EXP'
    print('MODEL:',model)
        
    multi_rkhs = MultivariateExponentialHawkes()
    multi_rkhs.fit(spk)
    ker = multi_rkhs.kernel(t_ev)
    for ii in range(n_node):
        for jj in range(n_node):
            ax[ii,jj].plot(t_ev,ker[ii][jj],'g',lw=0.8,alpha=0.8)

    cpu[model].append(multi_rkhs.fit_time[0])
    er1, er2 = 0, 0
    for ii in range(n_node):
        for jj in range(n_node):
            er1 += sum((k_true[ii][jj](t_ev)-ker[ii][jj])**2*dt_ev)
            er2 += sum(abs(k_true[ii][jj](t_ev)-ker[ii][jj])*dt_ev)
    ise[model].append(er1)
    
    # Bernstein Hawkes ###################################
    ######################################################
    model = 'BER'
    print('MODEL:',model)
    num_int = max(1000,2*max([len(x) for x in spkk]))
    
    set_score = []
    for g in set_g:
    
        multi_rkhs = MultivariateBasisHawkes(basis='bernstein', num_basis=50,
                                             link_param=100, support=support,
                                             reg=1./g, num_int=num_int)
        multi_rkhs.fit(spk_tr)
        
        score = np.sum(multi_rkhs.intensity_link(t_cr,hist)*dt_cr[None,:])
        for i, s in enumerate(spkk):
            score -= sum(log(multi_rkhs.intensity_link(s,hist)[i]))
        score_tr = np.sum(multi_rkhs.intensity_link(t_cr_tr,hist_tr)*dt_cr_tr[None,:])
        for i, s in enumerate(spkk_tr):
            score_tr -= sum(log(multi_rkhs.intensity_link(s,hist_tr)[i]))
        
        set_score.append(score-score_tr)
    
    indx = argmin(array(set_score))
    opt_g = array(set_g)[indx]
    
    multi_rkhs = MultivariateBasisHawkes(basis='bernstein', num_basis=50,
                                         link_param=100, support=support,
                                         reg=1./opt_g, num_int=num_int)
    multi_rkhs.fit(spk)
    ker = multi_rkhs.kernel(t_ev)
    for ii in range(n_node):
        for jj in range(n_node):
            ax[ii,jj].plot(t_ev,ker[ii][jj],'m',lw=0.8,alpha=0.8)

    cpu[model].append(multi_rkhs.fit_time[0])
    er1, er2 = 0, 0
    for ii in range(n_node):
        for jj in range(n_node):
            er1 += sum((k_true[ii][jj](t_ev)-ker[ii][jj])**2*dt_ev)
            er2 += sum(abs(k_true[ii][jj](t_ev)-ker[ii][jj])*dt_ev)
    ise[model].append(er1)
    
    # Bernstein Hawkes ###################################
    ######################################################
    model = 'GAU'
    print('MODEL:',model)
    num_int = max(1000,2*max([len(x) for x in spkk]))
    
    set_score = []
    for g in set_g:
    
        multi_rkhs = MultivariateBasisHawkes(basis='gaussian', num_basis=50,
                                             link_param=100, support=support,
                                             reg=1./g, num_int=num_int)
        multi_rkhs.fit(spk_tr)
        
        score = np.sum(multi_rkhs.intensity_link(t_cr,hist)*dt_cr[None,:])
        for i, s in enumerate(spkk):
            score -= sum(log(multi_rkhs.intensity_link(s,hist)[i]))
        score_tr = np.sum(multi_rkhs.intensity_link(t_cr_tr,hist_tr)*dt_cr_tr[None,:])
        for i, s in enumerate(spkk_tr):
            score_tr -= sum(log(multi_rkhs.intensity_link(s,hist_tr)[i]))
        
        set_score.append(score-score_tr)
    
    indx = argmin(array(set_score))
    opt_g = array(set_g)[indx]
    
    multi_rkhs = MultivariateBasisHawkes(basis='gaussian', num_basis=50,
                                         link_param=100, support=support,
                                         reg=1./opt_g, num_int=num_int)
    multi_rkhs.fit(spk)
    ker = multi_rkhs.kernel(t_ev)
    for ii in range(n_node):
        for jj in range(n_node):
            ax[ii,jj].plot(t_ev,ker[ii][jj],'c',lw=0.8,alpha=0.8)

    cpu[model].append(multi_rkhs.fit_time[0])
    er1, er2 = 0, 0
    for ii in range(n_node):
        for jj in range(n_node):
            er1 += sum((k_true[ii][jj](t_ev)-ker[ii][jj])**2*dt_ev)
            er2 += sum(abs(k_true[ii][jj](t_ev)-ker[ii][jj])*dt_ev)
    ise[model].append(er1)
    
    # Bonnet's Method ####################################
    ######################################################
    model = 'Bonnet'
    print('MODEL:',model)
    num_int = max(1000,2*max([len(x) for x in spkk]))
    print('num_int:',num_int)
    
    set_score = []
    for (g,b) in set_par:
    
        multi_rkhs = MultivariateKernelHawkes(link_param=100, support=support,
                                              gamma=b**2, reg=1./g, num_int=num_int)
        multi_rkhs.fit(spk_tr)
        
        score = np.sum(multi_rkhs.intensity_link(t_cr,hist)*dt_cr[None,:])
        for i, s in enumerate(spkk):
            score -= sum(log(multi_rkhs.intensity_link(s,hist)[i]))
        score_tr = np.sum(multi_rkhs.intensity_link(t_cr_tr,hist_tr)*dt_cr_tr[None,:])
        for i, s in enumerate(spkk_tr):
            score_tr -= sum(log(multi_rkhs.intensity_link(s,hist_tr)[i]))
        
        set_score.append(score-score_tr)
        print(g,b,':',set_score[-1])
    
    indx = argmin(array(set_score))
    [opt_g, opt_b] = array(set_par)[indx]
    
    print([opt_g, opt_b])
    
    multi_rkhs = MultivariateKernelHawkes(link_param=100, support=support,
                                          gamma=opt_b**2, reg=1./opt_g, num_int=num_int)
    multi_rkhs.fit(spk)
    ker = multi_rkhs.kernel(t_ev)
    for ii in range(n_node):
        for jj in range(n_node):
            ax[ii,jj].plot(t_ev,ker[ii][jj],'b',lw=0.8,alpha=0.8)

    cpu[model].append(multi_rkhs.fit_time[0])
    er1, er2 = 0, 0
    for ii in range(n_node):
        for jj in range(n_node):
            er1 += sum((k_true[ii][jj](t_ev)-ker[ii][jj])**2*dt_ev)
            er2 += sum(abs(k_true[ii][jj](t_ev)-ker[ii][jj])*dt_ev)
    ise[model].append(er1)
    
    # K2 Hawkes ##########################################
    ######################################################
    model = 'K2H'
    print('MODEL:',model)
    
    k2h = k2_hawkes_rfm(kernel='gaussian', n_rand_feature=100)
    
    set_score = []
    for (g,b) in set_par:
        
        _ = k2h.fit(spk_tr, T*p_train, gamma=g, b=b, support=support)
        rate = lambda x,s: [maximum(y,-99) for y in k2h.intensity(x,s)]
        
        score = np.sum(array(rate([t_cr]*n_node,spk))**2*dt_cr[None,:]) \
            - 2*sum([sum(x) for x in rate(spkk,spk)])
        score_tr = np.sum(array(rate([t_cr_tr]*n_node,spk_tr))**2*dt_cr_tr[None,:]) \
            - 2*sum([sum(x) for x in rate(spkk_tr,spk_tr)])

        set_score.append(score-score_tr)
        print(g,b,':',set_score[-1])
        
    indx = argmin(array(set_score))
    [opt_g, opt_b] = array(set_par)[indx]

    print([opt_g, opt_b])
    
    t = k2h.fit(spk, T, gamma=opt_g, b=opt_b, support=support)
    for ii in range(n_node):
        for jj in range(n_node):
            ax[ii,jj].plot(t_ev,k2h.predict(t_ev,edge=[ii,jj]),'r',lw=0.8,alpha=0.8)

    cpu[model].append(t)
    er1, er2 = 0, 0
    for ii in range(n_node):
        for jj in range(n_node):
            ker_ij = k2h.predict(t_ev,edge=[ii,jj])
            er1 += sum((k_true[ii][jj](t_ev)-ker_ij)**2*dt_ev)
            er2 += sum(abs(k_true[ii][jj](t_ev)-ker_ij)*dt_ev)
    ise[model].append(er1)
    
    savefig(dir_result+str(iii+1).zfill(3)+'.pdf')
    close('all')

    print('')
    print('**ISE**')
    q = ise
    for m in models:
        print(m+':',mean(q[m]),std(q[m]))
    print('**CPU**')
    q = cpu
    for m in models:
        print(m+':',mean(q[m]),std(q[m]))

    dill.dump((ise,cpu),open(dir_result+'perf.dill','wb'))
    


