"""
Plot, as a function of the grid discretization step, the convergence of final
estimates compared to the estimates using EM in a continuous setting.
"""

from trunc_norm_kernel.metric import negative_log_likelihood
from trunc_norm_kernel.model import TruncNormKernel, Intensity
from trunc_norm_kernel.optim import em_truncated_norm
from trunc_norm_kernel.simu import simulate_data
from raised_torch.utils.utils_plot import plot_hist_params
from raised_torch.utils.utils import grid_projection, check_tensor, get_sparse_from_tt
from raised_torch.kernels import compute_kernels
from raised_torch.solver import initialize, training_loop, compute_loss, optimizer
from raised_torch.model import Model
from raised_torch.simu_pp import simu
from tick.hawkes import HawkesEM
from ast import increment_lineno
import numpy as np
import pandas as pd
from pathlib import Path
import itertools
import json
import torch
from tqdm import tqdm
from joblib import Memory, Parallel, delayed, hash
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from tueplots import bundles


FONTSIZE = 11

plt.rcParams['xtick.labelsize'] = FONTSIZE
plt.rcParams['ytick.labelsize'] = FONTSIZE
plt.rcParams['font.size'] = FONTSIZE
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['text.latex.preamble'] = '\\renewcommand{\\rmdefault}{ptm}\\renewcommand{\\sfdefault}{phv}'

plt.rc('legend', fontsize=FONTSIZE-1)


CACHEDIR = Path('./__cache__')
memory = Memory(CACHEDIR, verbose=0)

SAVE_RESULTS_PATH = Path('./fig1')
if not SAVE_RESULTS_PATH.exists():
    SAVE_RESULTS_PATH.mkdir(parents=True)


@memory.cache(ignore=['driver_tt', 'acti_tt', 'init_params', 'cont_params'])
def compute_discretization_error(T, L, seed, driver_tt, acti_tt, init_params,
                                 true_params, cont_params, poisson_intensity):
    """

    """

    # grid projection of timestamps
    driver_tt_ = grid_projection(
        driver_tt, L, remove_duplicates=False)
    acti_tt_ = grid_projection(acti_tt, L, remove_duplicates=False)

    # learn with EM
    res_params, hist = em_truncated_norm(
        acti_tt_, driver_tt_, lower=lower, upper=upper, T=T, sfreq=L,
        use_dis=True, init_params=init_params, alpha_pos=True,
        n_iter=max_iter, verbose=False, disable_tqdm=True, compute_loss=True)
    baseline_em, alpha_em, m_em, sigma_em = res_params
    params_em = {'baseline': baseline_em, 'alpha': alpha_em,
                 'm': np.array(m_em), 'sigma': np.array(sigma_em)}
    time_em = hist[-1]['time_loop']

    em_err_cont = {k + '_em_err_cont': np.abs(params_em[k] - cont_params[k])
                   for k in params_em.keys()}
    em_err_true = {k + '_em_err_true': np.abs(params_em[k] - true_params[k])
                   for k in params_em.keys()}

    # save results
    this_row = {'T': T, 'L': L, 'seed': seed, 'poisson_intensity': poisson_intensity,
                'time_em': time_em,  # 'time_torch': time_torch,
                # 'params_em': params_em, 'params_torch': params_torch,
                **{k+'_em': v for k, v in params_em.items()},
                **{k+'_true': v for k, v in true_params.items()},
                **{k+'_cont': v for k, v in cont_params.items()},
                **em_err_cont, **em_err_true}
    # **torch_err_cont, **torch_err_true}

    return this_row


def procedure(true_params, poisson_intensity, T, seed):

    # simulation parameters on a continuous line (both driver and activation tt)
    driver_tt, acti_tt, kernel, intensity = simulate_data(
        lower=lower, upper=upper,
        m=true_params['m'], sigma=true_params['sigma'],
        sfreq=None,
        baseline=true_params['baseline'], alpha=true_params['alpha'],
        T=T, n_drivers=n_drivers, seed=seed,
        return_nll=False, verbose=True, poisson_intensity=poisson_intensity)

    # initialize parameters,
    init_params = initialize(driver_tt, acti_tt, T, initializer='smart_start',
                             lower=lower, upper=upper,
                             kernel_name='gaussian')

    # Learn the estimates using EM in a continuous setting
    _, hist = em_truncated_norm(
        acti_tt, driver_tt, lower=lower, upper=upper, T=T, sfreq=None,
        use_dis=False, init_params=init_params, alpha_pos=True,
        n_iter=max_iter, verbose=False, disable_tqdm=False, compute_loss=True)
    # plot learning curves
    # plot_hist_params(pd.DataFrame(hist), true_params=true_params)
    cont_params = pd.DataFrame(hist).iloc[-1].to_dict()

    # fit models and compute error
    # L_list = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500]
    L_list = [i*10**j for j in [1, 2, 3] for i in range(1, 10)]
    L_list.append(10**4)
    rows = []
    for this_L in tqdm(L_list):
        this_row = compute_discretization_error(
            T, this_L, seed, driver_tt, acti_tt, init_params, true_params,
            cont_params, poisson_intensity)
        rows.append(this_row)

    df = pd.DataFrame(rows)

    return df


dict_name_latex = {'baseline': r'$\mu$',
                   'alpha': r'$\alpha$',
                   'm': r'$m$',
                   'sigma': r'$\sigma$'}


def plot_fig1(folder_name, L_max=10**3):
    """

    """

    df = pd.read_pickle(folder_name / 'df_convergence_estimates_em.csv')

    pre = 'em'
    cols = [param + '_' + pre + '_err_' + suf
            for param in ['alpha', 'm', 'sigma']
            for suf in ['true', 'cont']]

    params = ['alpha', 'm', 'sigma']
    cols += [p+suf for p in params for suf in ['_cont', '_true', '_em']]

    for col in cols:
        print(col)
        df[col] = df[col].apply(lambda x: x[0])

    for param in ['baseline', 'alpha', 'm', 'sigma']:
        df[param +
            '_cont_err_true'] = np.abs(df[param + '_cont'] - df[param + '_true'])

    df['dt'] = 1 / df['L']

    sub_df = df[df['L'] <= L_max]

    cols_em = [param + '_em_err_true'
               for param in ['baseline', 'alpha', 'm', 'sigma']]
    sub_df_em = sub_df[cols_em + ['T', 'L', 'dt', 'seed']]
    sub_df_em['estimates'] = 'EM'
    sub_df_em.rename(columns={col: col.replace('_em', '') for col in cols_em},
                     inplace=True)

    cols_cont = [param + '_cont_err_true'
                 for param in ['baseline', 'alpha', 'm', 'sigma']]
    sub_df_cont = sub_df[cols_cont + ['T', 'L', 'dt', 'seed']]
    sub_df_cont['estimates'] = 'continuous'
    sub_df_cont.rename(columns={col: col.replace('_cont', '') for col in cols_cont},
                       inplace=True)

    sub_df_final = pd.concat([sub_df_cont, sub_df_em])

    # with plt.rc_context(bundles.iclr2023()):
    #     plt.rcParams.update(figsizes.iclr2023(nrows=2, ncols=2))

    fig, axes = plt.subplots(2, 2, figsize=(5.5, 4), sharex=True)
    axes = axes.reshape(-1)

    palette = [matplotlib.cm.viridis_r(x) for x in np.linspace(0, 1, 5)][1:]
    methods = [("continuous", "--", '/'), ("EM", "o-", None)]

    T = sub_df_final["T"].unique()
    T.sort()

    for i, param in enumerate(['baseline', 'alpha', 'm', 'sigma']):
        ax = axes[i]

        for m, ls, hatch in methods:
            for j, t in enumerate(T):
                this_df = sub_df_final.query("T == @t and estimates == @m")
                curve = this_df.groupby("dt")[f'{param}_err_true'].quantile(
                    [0.25, 0.5, 0.75]).unstack()
                ax.loglog(
                    curve.index, curve[0.5], ls, lw=2, c=palette[j],
                    markersize=5, markevery=2
                )
                ax.fill_between(
                    curve.index, curve[0.25], curve[0.75], alpha=0.2,
                    color=palette[j], hatch=hatch, edgecolor=palette[j] if hatch else None
                )
        ax.set_xlim(1e-1, 1e-3)
        ax.set_title(dict_name_latex[param])
        if (i == 0) or (i == 2):
            ax.set_ylabel(r'$\ell_2$ error')
        if i >= 2:
            ax.set_xlabel(r'$\Delta$')

    bbox_to_anchor = (-0.2, 1.2, 1, 0.01)
    labels_m = ["EM", "Cont. EM"]
    handles_m = [plt.Line2D([], [], c="k", lw=2, marker='o', markersize=5),
                 plt.Line2D([], [], c="k", ls="--", lw=2)]
    axes[1].legend(
        handles_m,
        labels_m,
        ncol=3,
        title="Method",
        bbox_to_anchor=bbox_to_anchor,
        loc="lower left",
    )

    handles_T = [plt.Line2D([], [], c=palette[i], label=t, lw=2)
                 for i, t in enumerate(T)]
    axes[0].legend(
        handles_T,
        [r"$10^{%d}$" % np.log10(t) for t in T],
        ncol=len(T),
        title="$T$",
        bbox_to_anchor=bbox_to_anchor,
        loc="lower right",
    )
    fig.tight_layout()
    plt.savefig(
        folder_name / "fig_convergence_estimates_em-true_cont_true.png",
        bbox_inches='tight')
    plt.savefig(
        folder_name / "fig_convergence_estimates_em-true_cont_true.pdf",
        bbox_inches='tight')
    plt.show()


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


def get_folder_name(true_params, poisson_intensity, save_json=True):
    """

    """

    folder_name = SAVE_RESULTS_PATH / hash([true_params, poisson_intensity])
    if not folder_name.exists():
        folder_name.mkdir(parents=True)

    if save_json:
        experiment_param = {'poisson_intensity': poisson_intensity,
                            **true_params}
        with open(folder_name / 'experiment_param.json', 'w', encoding='utf-8') as f:
            json.dump(experiment_param, f, ensure_ascii=False,
                      indent=4, cls=NumpyEncoder)

    return folder_name


def experiment(true_params, poisson_intensity):
    """

    """
    folder_name = get_folder_name(true_params, poisson_intensity)

    df = pd.DataFrame()
    list_seed = list(range(50))
    for this_T in [1_000, 10_000]:
        new_dfs = Parallel(n_jobs=min(50, len(list_seed)), verbose=1)(
            delayed(procedure)(true_params,
                               poisson_intensity, this_T, this_seed)
            for this_seed in list_seed)
        new_dfs.append(df)
        df = pd.concat(new_dfs)
        df.to_pickle(folder_name / 'df_convergence_estimates_em.csv')

    plot_fig1(folder_name)

# %%


# define exeperiment parameters
n_drivers = 1
lower, upper = 0, 1
max_iter = 50


%matplotlib inline
%pylab inline

true_params = {
    'baseline': 3,
    'alpha': np.array([1]),
    'm': np.array([0.2]),
    'sigma': np.array([0.1])
}
poisson_intensity = 0.5
folder_name = get_folder_name(true_params, poisson_intensity)
fig_name = folder_name / "fig_convergence_estimates_em-true_cont_true.png"


if fig_name.exists():
    plot_fig1(folder_name, L_max=10**3)
else:
    experiment(true_params, poisson_intensity)

# %%

df = pd.read_pickle(folder_name / 'df_convergence_estimates_em.csv')

pre = 'em'
cols = [param + '_' + pre + '_err_' + suf
        for param in ['alpha', 'm', 'sigma']
        for suf in ['true', 'cont']]

params = ['alpha', 'm', 'sigma']
cols += [p+suf for p in params for suf in ['_cont', '_true', '_em']]

for col in cols:
    df[col] = df[col].apply(lambda x: x[0])

for param in ['baseline', 'alpha', 'm', 'sigma']:
    df[param +
        '_cont_err_true'] = np.abs(df[param + '_cont'] - df[param + '_true'])

df['dt'] = 1 / df['L']


def compute_norm2_error(s, pre='em'):

    cols = [param + '_' + pre + '_err_true'
            for param in ['baseline', 'alpha', 'm', 'sigma']]

    return np.sqrt(np.array([s[this_col]**2 for this_col in cols]).sum())


df['em_err_norm2'] = df.apply(
    lambda x: compute_norm2_error(x, pre='em'), axis=1)
df['cont_err_norm2'] = df.apply(
    lambda x: compute_norm2_error(x, pre='cont'), axis=1)

sub_df = df[df['L'] <= 1e3]

sub_df_em = sub_df[['T', 'L', 'dt', 'seed', 'em_err_norm2']]
sub_df_em['estimates'] = 'EM'
sub_df_em.rename(columns={'em_err_norm2': 'err_norm2'},
                 inplace=True)

sub_df_cont = sub_df[['T', 'L', 'dt', 'seed', 'cont_err_norm2']]
sub_df_cont['estimates'] = 'continuous'
sub_df_cont.rename(columns={'cont_err_norm2': 'err_norm2'},
                   inplace=True)

sub_df_final = pd.concat([sub_df_cont, sub_df_em])
sub_df_final.to_csv('error_discrete_EM.csv', index=False)

