import argparse
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from synthetic_run import make_data_loader, make_model, generate_Poisson_data  # Hawkes
from easy_tpp.utils.hawkes import HawkesModel, SelfCorrectingModel
from easy_tpp.utils import set_seed
from torch.distributions import Categorical
from easy_tpp.model.torch_model.torch_intensity_free import LogNormalMixtureDistribution, clamp_preserve_gradients
from matplotlib.ticker import MultipleLocator
import matplotlib
matplotlib.rc('font', family='serif')
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.usetex'] = True
plt.rcParams["figure.figsize"] = (4, 3)
# plt.rcParams["figure.figsize"] = (4, 1)  # Poisson



# EST_COLOR = [173/255, 201/255, 239/255, 1]  # light blue
# # EST_COLOR = [140/255, 201/255, 239/255, 1]  # my blue
# # EST_COLOR = [48/255, 112/255, 173/255, 1]  # blue
# TRUE_COLOR = [204/255, 169/255, 236/255, 1]  # purple
# # TRUE_COLOR = [150/255, 150/255, 239/255, 1]  # my purple
# # EST_COLOR = [229/255, 170/255, 204/255, 1]
# # EST_COLOR = [225/255, 170/255, 237/255, 1]  # light purple
# # TRUE_COLOR = [218/255, 149/255, 52/255, 1]

EST_COLOR = [68/255, 156/255, 218/255, 0.8]
# TRUE_COLOR = [255/255, 100/255, 100/255, 1]
TRUE_COLOR = 'k'


MODEL_NAMES = {
    'rmtpp': 'RMTPP',
    'nhp': 'NHP',
    'sahp': 'SAHP',
    'thp': 'THP',
    'dlhp': 'DLHP (ours)',
    'iftpp': 'IFTPP',
    'attnhp': 'AttNHP',
    }

GROUND_TRUTH = {
    'poisson': 'Poisson',
    'hawkes': 'Hawkes',
    'self_correcting': 'SC'
}

FONTSIZE = 12


def get_args():
    parser = argparse.ArgumentParser(description='Arguments for synthetic experiments.')
    parser.add_argument('--seed', type=int, default=123, help='Random seed.')
    parser.add_argument('--model_name', type=str, default='dlhp',
                        help='Model name: rmtpp | nhp | sahp | iftpp | dlhp | attnhp | thp.')
    parser.add_argument('--ground_truth', type=str, default='hawkes',
                        help='Ground truth intensity to generate seqs. from: poisson | hawkes | self_correcting.')
    # parser.add_argument('--num_marks', default=3, type=int, help='Number of marks.')
    # parser.add_argument('--num_seqs', type=int, default=5000, help='Number of generated sequences.')
    # parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs for training.')
    parser.add_argument('--num_marks', default=3, type=int, help='Number of marks.')
    parser.add_argument('--max_time', default=12, type=float, help='Right end of observation window.')
    parser.add_argument('--num_seqs', type=int, default=50000, help='Number of generated sequences.')
    parser.add_argument('--num_test_seqs', type=int, default=10000, help='Number of generated sequences.')  # todo: or 1k
    parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs for training.')
    parser.add_argument('--num_mc_sample_per_step', type=int, default=1000,
                        help='Number of MC points per inter-event interval.')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size.')  # 32 for attnhp
    parser.add_argument('--plot_num_seqs', type=int, default=3, help='Number of example sequences to plot.')
    parser.add_argument('--drop_poisson_last_mark', action='store_true',
                        help='If true, evaluate sequences without the last mark and fixed maximum T.')
    parser.add_argument('--generate_empty_poisson_seq', action='store_false',
                        help='TODO.')
    # parser.add_argument('--test_pct', default=0.1, type=float,
    #                     help="Percentage of test sequences compared to number of training sequences.")
    parser.add_argument('--out_dir', default='./synthetic/', type=str,
                        help='Dir for saving generated sequences and trained models.')
    args = parser.parse_args()
    args.out_dir = args.out_dir.rstrip("/")
    return args


def get_data(args):
    if args.ground_truth == 'poisson':
        with open(f'./synthetic/{args.ground_truth}/{args.model_name}/data_seq{args.num_seqs}.pickle', 'rb') as f:
            # Poisson: {'input_data', 'event_times_dict_by_type', 'rates', 'windows', 'num_marks', 'T'}
            # Hawkes: {'input_data', 'T', 'num_marks'}
            data = pickle.load(f)
    elif args.ground_truth == 'hawkes' or args.ground_truth == 'self_correcting':
        # load test data
        data = {}
        filename = f"{args.out_dir.rstrip('/')}/{args.ground_truth}/train_data_seq{args.num_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle"
        with open(filename, 'rb') as f:
            train_data = pickle.load(f)
        data['num_marks'] = train_data['num_marks']
        data['T'] = train_data['T']

        filename = f"{args.out_dir.rstrip('/')}/{args.ground_truth}/test_data_seq{args.num_test_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle"
        with open(filename, 'rb') as f:
            test_data = pickle.load(f)
        data['input_data'] = test_data
    else:
        raise NotImplementedError
    return data

def get_est_intensity(args, data_loader):
    estimated_intensity = defaultdict(dict)
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            if args.model_name == 'nhp':
                left_hiddens, right_hiddens = model.forward(batch.values())
                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch.values()
                dts_sample_B_Nm1_G = model.make_dtime_loss_samples(dts_BN[:, 1:])
                state_t_sample_B_Nm1_G_H = model.get_states(right_hiddens[..., :-1, :], dts_sample_B_Nm1_G)
                intensity_dts_B_Nm1_G_M = model.layer_intensity(state_t_sample_B_Nm1_G_H)
            elif args.model_name == 'rmtpp':
                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch.values()
                _, right_hiddens_BNH = model.forward(batch.values())  # update for new RMTPP implementation
                right_hiddens_B_Nm1_H = right_hiddens_BNH[..., :-1, :]
                dts_sample_B_Nm1_G = model.make_dtime_loss_samples(dts_BN[:, 1:])
                intensity_dts_B_Nm1_G_M = model.evolve_and_get_intentsity(right_hiddens_B_Nm1_H, dts_sample_B_Nm1_G)
            elif args.model_name == 'dlhp':
                forward_results = model.forward(batch.values())  # N minus 1 comparing with sequence lengths
                right_xs_BNLP, right_us_BNH = forward_results["right_xs_BNLP"], forward_results["right_us_BNH"]
                right_us_BNm1H = [None if right_u_BNH is None else right_u_BNH[:, :-1, :] for right_u_BNH in right_us_BNH]

                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch.values()
                dts_sample_B_Nm1_G = model.make_dtime_loss_samples(dts_BN[:, 1:])

                # evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, MC, M)
                intensity_dts_B_Nm1_G_M = model._evolve_and_get_intensity_at_sampled_dts(
                    right_xs_BNLP[:, :-1],
                    dts_sample_B_Nm1_G,
                    right_us_BNm1H,
                )
            elif args.model_name == 'sahp':
                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, attention_mask = batch.values()
                enc_out = model.forward(ts_BN[:, :-1], dts_BN[:, :-1], marks_BN[:, :-1], attention_mask[:, :-1, :-1])
                dts_sample_B_Nm1_G = model.make_dtime_loss_samples(dts_BN[:, 1:])
                state_t_sample = model.compute_states_at_sample_times(encode_state=enc_out, sample_dtimes=dts_sample_B_Nm1_G)
                intensity_dts_B_Nm1_G_M = model.softplus(state_t_sample)

            elif args.model_name == 'thp':
                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, attention_mask = batch.values()
                enc_out = model.forward(ts_BN[:, :-1], marks_BN[:, :-1], attention_mask[:, :-1, :-1])
                dts_sample_B_Nm1_G = model.make_dtime_loss_samples(dts_BN[:, 1:])
                state_t_sample = model.compute_states_at_sample_times(event_states=enc_out,
                                                                     sample_dtimes=dts_sample_B_Nm1_G)
                intensity_dts_B_Nm1_G_M = model.softplus(state_t_sample)
            elif args.model_name == 'iftpp':
                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch.values()
            elif args.model_name == 'attnhp':
                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, attention_mask = batch.values()
                dts_sample_B_Nm1_G = model.make_dtime_loss_samples(dts_BN[:, 1:])

                # [batch_size, seq_len, num_sample]
                sample_times = dts_sample_B_Nm1_G + ts_BN[:, :-1].unsqueeze(-1)
                intensity_dts_B_Nm1_G_M = model.compute_intensities_at_sample_times(ts_BN[:, :-1],
                                                                           dts_BN[:, :-1],  # not used
                                                                           marks_BN[:, :-1],
                                                                           sample_times,
                                                                           attention_mask=attention_mask[:, :-1, :-1])
            else:
                raise NotImplementedError

            for seq_id in range(args.plot_num_seqs):
                if args.model_name == 'iftpp':
                    event_num_per_seq = sum(batch_non_pad_mask[seq_id, :]) - 1
                    context = model.forward(dts_BN[seq_id, :event_num_per_seq][None, ...], marks_BN[seq_id, :event_num_per_seq][None, ...])

                    # [batch_size, seq_len, 3 * num_mix_components]
                    raw_params = model.linear(context)
                    locs = raw_params[..., :model.num_mix_components]
                    log_scales = raw_params[..., model.num_mix_components: (2 * model.num_mix_components)]
                    log_weights = raw_params[..., (2 * model.num_mix_components):]

                    log_scales = clamp_preserve_gradients(log_scales, -5.0, 3.0)
                    log_weights = torch.log_softmax(log_weights, dim=-1)
                    inter_time_dist = LogNormalMixtureDistribution(
                        locs=locs,
                        log_scales=log_scales,
                        log_weights=log_weights,
                        mean_log_inter_time=model.mean_log_inter_time,
                        std_log_inter_time=model.std_log_inter_time
                    )
                    mark_logits = torch.log_softmax(model.mark_linear(context), dim=-1)
                    mark_dist = Categorical(logits=mark_logits)
                    dts_sample_B_Nm1_G = model.make_dtime_loss_samples(dts_BN[:, 1:])

                dts_sample_Nm1_MC = dts_sample_B_Nm1_G[seq_id, :]
                event_times_N = ts_BN[seq_id, :]  # left end of the interval to start evaluating dts
                event_num_per_seq = sum(batch_non_pad_mask[seq_id, :]) - 1  # exclude last event

                ts_per_seq_NumxMC = torch.reshape(
                    event_times_N[:event_num_per_seq, None] + dts_sample_Nm1_MC[:event_num_per_seq], (-1,))
                if args.model_name == 'iftpp':
                    # ts_per_seq_NumxMC = torch.where(ts_per_seq_NumxMC == 0, torch.finfo(torch.float32).eps, ts_per_seq_NumxMC)
                    # ts_per_seq_Num_MC = event_times_N[:event_num_per_seq, None] + dts_sample_Nm1_MC[:event_num_per_seq]  # torch.Size([15, 1000])
                    mark_prob = mark_dist.probs[:, :, :-1]  # torch.Size([1, 15, 3])

                    # ts_per_seq_MC_Num = torch.transpose(ts_per_seq_Num_MC, 0, 1)
                    # ts_per_seq_MC_Num = torch.where(ts_per_seq_MC_Num == 0, torch.finfo(torch.float32).eps, ts_per_seq_MC_Num)

                    dts_per_seq_Num_MC = dts_sample_Nm1_MC[:event_num_per_seq, :][None, ...]  # torch.Size([1, 15, 1000])
                    dts_per_seq_Num_MC[..., 0] = dts_per_seq_Num_MC[..., 1]
                    pdfs = []  # batch x num x grid_points
                    cdfs = []
                    for g in range(dts_per_seq_Num_MC.shape[-1]):
                        pdfs.append(inter_time_dist.log_prob(dts_per_seq_Num_MC[...,g]).exp())
                        cdfs.append(inter_time_dist.cdf(dts_per_seq_Num_MC[...,g]))
                    pdfs = torch.stack(pdfs, dim=-1)
                    cdfs = torch.stack(cdfs, dim=-1)

                    # pdfs = inter_time_dist.log_prob(dts_per_seq_Num_MC).exp()
                    # cdfs = inter_time_dist.cdf(dts_per_seq_Num_MC)
                    intensity_est_per_seq_1_Num_MC_M = mark_prob[..., None, :] * pdfs[..., None] / (1 - cdfs[..., None])
                    intensity_est_per_seq_NumxMC_M = torch.reshape(intensity_est_per_seq_1_Num_MC_M, shape=(-1, args.num_marks))
                    # intensity_est_per_seq_NumxMC_M = torch.reshape(torch.transpose(intensity_est_per_seq_MC_Num_M, 0, 1), shape=(-1, args.num_marks))
                else:
                    intensity_est_per_seq_NumxMC_M = torch.reshape(
                        intensity_dts_B_Nm1_G_M[seq_id, :event_num_per_seq, ...], shape=(-1, args.num_marks))

                true_times = event_times_N[1:event_num_per_seq + 1].tolist()
                true_marks = marks_BN[seq_id, :][1:event_num_per_seq + 1].tolist()
                estimated_intensity[seq_id]['ts'] = ts_per_seq_NumxMC
                estimated_intensity[seq_id]['intensity'] = intensity_est_per_seq_NumxMC_M
                estimated_intensity[seq_id]['true_times'] = true_times
                estimated_intensity[seq_id]['true_marks'] = true_marks
            break
    return estimated_intensity



def plot_Poisson(args, estimated_intensity, data, true_color=TRUE_COLOR, est_color=EST_COLOR):
    poisson_rate = data['rates']
    poisson_windows = data['windows']
    num_marks = data['num_marks']

    event_times_dict_by_type = data['event_times_dict_by_type']

    for i in estimated_intensity.keys():
        est = estimated_intensity[i]

        # plot samples
        for j, t in enumerate(est['true_times'][:-1]):  # last event is the end of interval
            plt.plot(t, -0.2, marker='X', color=cmap(est['true_marks'][j]), ms=4)

        # plot ground truth intensity
        # for k in range(num_marks):
        #     plt.hlines(poisson_rate[k], poisson_windows[k][0], poisson_windows[k][1],
        #                linestyles=':', color=cmap(k))
        #     plt.vlines(poisson_windows[k][0], 0, poisson_rate[k],
        #                linestyles='--', color='tab:gray')
        #     plt.vlines(poisson_windows[k][1], 0, poisson_rate[k],
        #                linestyles='--', color='tab:gray')

        # for k in range(3):
        #     plt.hlines(poisson_rate[k], poisson_windows[k][0], poisson_windows[k][1],
        #                linestyles=':', color=cmap(0))
        #     plt.vlines(poisson_windows[k][0], 0, poisson_rate[k],
        #                linestyles='--', color='tab:gray')
        #     plt.vlines(poisson_windows[k][1], 0, poisson_rate[k],
        #                linestyles='--', color='tab:gray')
        #
        #     # This added shaded area for empirical distributions of samples
        #     # plt.hist(event_times_dict_by_type[k], density=True, color=cmap(k), alpha=0.3)

        # plot ground truth intensity
        t = 0
        lw = 2
        for k in range(len(poisson_windows)):
            plt.hlines(0, t, poisson_windows[k][0], linestyles=':', color=true_color, lw=lw)
            plt.hlines(poisson_rate[k], poisson_windows[k][0], poisson_windows[k][1],
                       linestyles=':', color=true_color, lw=lw)
            plt.vlines(poisson_windows[k][0], 0, poisson_rate[k],
                       linestyles=':', color=true_color, lw=lw)
            plt.vlines(poisson_windows[k][1], 0, poisson_rate[k],
                       linestyles=':', color=true_color, lw=lw)
            t = poisson_windows[k][1]
        plt.hlines(0, t, data['T'], linestyles=':', color=true_color, lw=lw)

        # plot estimated intensity
        for k in range(num_marks):
            plt.plot(est['ts'], est['intensity'][:, k], '-',
                     # label="k={}".format(k), color=cmap(k), lw=1.5)
                     label="k={}".format(k), color=EST_COLOR, lw=1.5)

        if args.model_name == 'iftpp':
            plt.ylim(bottom=-0.4, top=max(poisson_rate) + 5)
            plt.yticks([0,  2., 4., 6. ])
            plt.yticks([1., 3., 5. ], minor=True)
        else:
            plt.ylim(bottom=-0.2, top= max(poisson_rate) + 0.2)
            plt.yticks([0, 0.5, 1])
        plt.xlim(left=0, right=7.5)
        plt.xticks([0, 1., 2., 3., 4., 5., 6., 7.], fontsize=FONTSIZE)
        plt.yticks(fontsize=FONTSIZE)
        plt.xlabel('Time', fontsize=FONTSIZE)
        plt.ylabel('Intensity', fontsize=FONTSIZE)
        # plt.xlabel(r'$t$', fontsize=FONTSIZE)
        # plt.ylabel(r'$\lambda_t$', fontsize=FONTSIZE)
        # plt.legend()
        plt.title(f'{MODEL_NAMES[args.model_name]} Estimated Intensity', fontsize=FONTSIZE + 2)

        # Save figures
        # if args.generate_empty_poisson_seq:
        #     plt.savefig(f'./synthetic/{args.ground_truth}/{args.model_name}/{args.model_name}_empty_seq_rate1.pdf',
        #                 bbox_inches='tight')
        # elif args.drop_poisson_last_mark:
        #     plt.savefig(f'./synthetic/{args.ground_truth}/{args.model_name}/{args.model_name}_drop_last_mark_rate1.pdf',
        #                 bbox_inches='tight')
        # else:
        #     plt.savefig(f'./synthetic/{args.ground_truth}/{args.model_name}/{args.model_name}_original_seq_rate1.pdf',
        #                 bbox_inches='tight')
        plt.show()

def plot_Hawkes_or_SC(args, estimated_intensity, data, model):
    # T = data['T']
    # t_range = np.linspace(0, T, int(1000 * T))
    T = args.max_time
    t_range = np.linspace(0, T, int(100 * T))

    for seq_id in range(args.plot_num_seqs):
        plt.figure()
        est = estimated_intensity[seq_id]
        # sampled_times, sampled_marks = est['true_times'][:-1], est['true_marks'][:-1]  # do not plot the end of the interval
        sampled_times, sampled_marks = est['true_times'], est['true_marks']

        # compute intensities
        results = model(
            tgt_marks=torch.tensor(sampled_marks)[None,...],
            tgt_timestamps=torch.tensor(sampled_times)[None,...],
            sample_timestamps=torch.tensor(t_range)[None,...],
        )

        # plot samples
        for i, t in enumerate(sampled_times):
            plt.plot(t, results['intensities']['all_mark_intensities'][:, i, sampled_marks[i]].item(),
                     marker='X', color=cmap(sampled_marks[i]))

        for k in range(args.num_marks):
            # plot ground truth intensity
            intensity_vals = results['sample_intensities']['all_mark_intensities'][...,k].squeeze().detach().numpy()
            plt.plot(t_range, intensity_vals, color=cmap(k), linestyle=':', label=r"True $k$={}".format(k+1))

        for k in range(args.num_marks):  # change the legend order
            # plot estimated intensity
            plt.plot(est['ts'], est['intensity'][:, k], '-', label=r"Est. $k$={}".format(k+1),
                     color=plt.get_cmap("tab10")(k), linewidth=0.9)




        # plot samples
        # for i, t in enumerate(sampled_times):
        #     plt.plot(t, h.intensity(t, sampled_marks[i], sampled_times, sampled_marks),
        #              marker='X', color=cmap(sampled_marks[i]))

        # for k in range(args.num_marks):
        #     # plot ground truth intensity
        #     intensity_vals = [h.intensity(t, k, sampled_times, sampled_marks) for t in t_range]
        #     plt.plot(t_range, intensity_vals, color=cmap(k), linestyle=':')
        #
        #     # plot estimated intensity
        #     plt.plot(est['ts'], est['intensity'][:, k], '-', label="k={} est.".format(k),
        #              color=plt.get_cmap("tab10")(k))

        if args.ground_truth == 'hawkes':
            plt.legend(ncol=2, loc='upper left')
        else:  # sc
            plt.legend(ncol=2, loc='upper right')
        plt.xlim(left=0, right=10)
        plt.ylim(bottom=0, top=3)
        plt.title(f'{MODEL_NAMES[args.model_name]} Estimated Intensity')
        plt.ylabel('Intensity')
        plt.xlabel('Time')
        # plt.savefig(f'./synthetic/{args.ground_truth}/plots/{args.model_name}_seq{seq_id}.pdf',bbox_inches='tight')
        plt.show()


if __name__ == '__main__':
    args = get_args()
    data = get_data(args)
    set_seed(123)

    print('Setting up model...')
    args.num_marks = data['num_marks']

    # model_path = f'./synthetic/{args.ground_truth}/{args.model_name}/model_seq{args.num_seqs}_e{args.num_epochs}.pt'
    # model_path = f'./synthetic/{args.ground_truth}/{args.model_name}/model_seq{args.num_seqs}_e{args.num_epochs}_rate1.pt'  # Poisson
    data_path = args.out_dir + f'/{args.ground_truth}/{args.model_name}/'
    model_path = f"{data_path.rstrip('/')}/model_seq{args.num_seqs}_e{args.num_epochs}_no_T_0.pt"  # saved for different random seeds

    if args.ground_truth == 'poisson':
        if args.generate_empty_poisson_seq:
            input_data, _ = generate_Poisson_data(num_seqs=args.num_seqs, K=args.num_marks, T=data['T'],
                                                  rates=[torch.finfo(torch.float32).eps, torch.finfo(torch.float32).eps,
                                                         torch.finfo(torch.float32).eps], windows=data['windows'], drop_last_mark=True)
        elif args.drop_poisson_last_mark:
            input_data, _ = generate_Poisson_data(num_seqs=args.num_seqs, K=args.num_marks, T=data['T'],
                                                  rates=data['rates'], windows=data['windows'], drop_last_mark=True)
        else:
            input_data = data['input_data']
    elif args.ground_truth == 'hawkes' or args.ground_truth == 'self_correcting':  # load ground truth model and test data
        # load ground truth model
        if args.ground_truth == 'hawkes':
            h_model = HawkesModel(num_marks=args.num_marks, int_strength=0.5)
        elif args.ground_truth == 'self_correcting':
            h_model = SelfCorrectingModel(num_marks=args.num_marks, int_strength=0.5)
        else:
            raise NotImplementedError
        h_model.load_state_dict(torch.load(
            f"{args.out_dir.rstrip('/')}/{args.ground_truth}/model_K{args.num_marks}_T{args.max_time}_no_T.pt"))
        input_data = data['input_data']
    else:
        raise NotImplementedError
    data_loader = make_data_loader(input_data, args.num_marks, args.batch_size)
    model, config = make_model(args.num_marks, args.model_name, plot=True)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))


    print('Evaluating model...')
    model.eval()
    if args.model_name == 'iftpp' and (args.generate_empty_poisson_seq==False):
        model.loss_integral_num_sample_per_step = 20  # for IFTPP with events use 20
    # elif args.model_name == 'attnhp' and args.ground_truth == 'hawkes':
    #     model.loss_integral_num_sample_per_step = 200  # couldn't get MC=1000 plotted; could use smaller batch size
    else:
        model.loss_integral_num_sample_per_step = 1000
    estimated_intensity = get_est_intensity(args, data_loader)


    print('Generating plots...')
    cmap = plt.get_cmap("tab10")
    # TODO: generate subplots...
    if args.ground_truth == 'poisson':
        plot_Poisson(args, estimated_intensity, data)
    elif args.ground_truth == 'hawkes' or args.ground_truth == 'self_correcting':
        plot_Hawkes_or_SC(args, estimated_intensity, data, h_model)
    else:
        raise NotImplementedError


    # TODO: check if DLHP level=1 -> Hawkes