import argparse
import pickle

import torch
from torch import nn
import numpy as np
from collections import defaultdict
from tqdm import tqdm

from easy_tpp.config_factory import Config
from easy_tpp.config_factory import DataSpecConfig
from easy_tpp.preprocess import EventTokenizer
from easy_tpp.preprocess.dataset import TPPDataset, get_data_loader
# from easy_tpp.utils.torch_utils import set_seed
from easy_tpp.utils import set_seed, create_folder, count_model_params, set_device

from easy_tpp.model import TorchRMTPP as RMTPP
from easy_tpp.model import TorchDLHP as DLHP
from easy_tpp.model import TorchNHP as _NHP
from easy_tpp.model import TorchSAHP as SAHP
from easy_tpp.model import TorchTHP as THP
from easy_tpp.model import TorchIntensityFree as IntensityFree
from easy_tpp.model import TorchAttNHP as AttNHP
from easy_tpp.model.torch_model.torch_nhp import ContTimeLSTMCell
from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus
from easy_tpp.utils.hawkes import HawkesModel, SelfCorrectingModel
import os


class NHP(_NHP):
    """Torch implementation of The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process,
       NeurIPS 2017, https://arxiv.org/abs/1612.09328.
    """

    def __init__(self, model_config):
        """Initialize the NHP model.

        Args:
            model_config (EasyTPP.ModelConfig): config of model specs.
        """
        super(NHP, self).__init__(model_config)
        self.beta = model_config.model_specs.get('beta', 1.0)
        self.bias = model_config.model_specs.get('bias', True)
        self.rnn_cell = ContTimeLSTMCell(self.hidden_size)

        self.layer_intensity = nn.Sequential(  # eq. 4a,
            nn.Linear(self.hidden_size, self.num_event_types, self.bias),
            ScaledSoftplus(self.num_event_types))
        self._init_state = nn.Parameter(torch.randn(4*self.hidden_size))

    def get_init_state(self, batch_size):
        c_t, c_bar_t, delta_t, o_t = self._init_state.unsqueeze(0).expand(batch_size, -1).chunk(4, dim=1)
        return c_t, c_bar_t, delta_t, o_t  # Okay to initialize delta to be zero because c==c_bar at the beginning


EXPERIMENT_ID = {
    'rmtpp': 'RMTPP_train',
    'nhp': 'NHP_train',
    'sahp': 'SAHP_train',
    'thp': 'THP_train',
    'dlhp': 'DLHP_train',
    'iftpp': 'IntensityFree_train',
    'attnhp': 'AttNHP_train',
    }

MODELS = {
    'rmtpp': RMTPP,
    'nhp': NHP,
    'sahp': SAHP,
    'thp': THP,
    'dlhp': DLHP,
    'iftpp': IntensityFree,
    'attnhp': AttNHP
}


def get_args():
    parser = argparse.ArgumentParser(description='Arguments for synthetic experiments.')
    parser.add_argument('--seed', nargs='+', type=int, default=[123], help='List of random seeds.')
    parser.add_argument('--model_name', type=str, default='dlhp',
                        help='Model name: rmtpp | nhp | sahp | thp | iftpp | attnhp | dlhp.')
    parser.add_argument('--ground_truth', type=str, default='hawkes',
                        help='Ground truth intensity to generate seqs. from: poisson | hawkes | self_correcting.')
    parser.add_argument('--num_seqs', type=int, default=50000, help='Number of generated sequences.')
    parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs for training.')
    parser.add_argument('--num_marks', default=3, type=int, help='Number of marks.')
    parser.add_argument('--max_time', default=12, type=float, help='Right end of observation window.')
    parser.add_argument('--poisson_rates', nargs="+", default=[1., 1., 1.], type=float,
                        help='Poisson rates for each mark.')
    parser.add_argument('--poisson_windows', nargs="+", default=[[1., 2.], [3., 4.], [5., 6.]],
                        type=float, help='Left and right ends of each mark.')

    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')
    parser.add_argument('--test_pct', default=0.2, type=float,
                        help="Percentage of test sequences compared to number of training sequences.")
    parser.add_argument('--out_dir', default='./synthetic/', type=str,
                        help='Dir for saving generated sequences and trained models.')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size.')
    args = parser.parse_args()
    args.out_dir = args.out_dir.rstrip("/")
    return args



#
# class Hawkes():
#     def __init__(self, marks=3, alpha=None):
#         self.marks = marks
#         self.kernel = 'exponential'
#         self.mu = np.random.uniform(low=0.1, high=0.5, size=(self.marks))  # random between 0.1 to 0.5
#         if alpha is None:
#             self.alpha = np.random.uniform(low=0.3, high=0.8, size=(self.marks))
#         else:
#             self.alpha = alpha  # a_ij  refers to the effect of event i on event j
#         self.beta = np.random.uniform(low=0.8, high=1.2, size=(self.marks))
#
#     def exponential_kernel(self, t, k, sampled_times, sampled_marks):
#         # TODO: beta as matrices? change to list?
#
#         assert (len(sampled_times) == len(sampled_marks))  # each event should have a time and a type
#         if not sampled_times:
#             return 0
#
#         kernel_val = 0
#         if len(self.alpha.shape) == 1:  # 1D array
#             for event_time, event_mark in zip(sampled_times, sampled_marks):
#                 if event_time > t:
#                     break
#                 if event_mark == k:
#                     kernel_val += self.alpha[k] * np.exp(-self.beta[k] * (t - event_time))
#
#         elif len(self.alpha.shape) == 2:  # 2D matrix
#             for event_time, event_mark in zip(sampled_times, sampled_marks):
#                 if event_time > t:
#                     break
#                 kernel_val += self.alpha[event_mark][k] * np.exp(-self.beta[k] * (t - event_time))
#         else:
#             raise ValueError('alpha is not valid')
#         return kernel_val
#
#     def intensity(self, t, k, sampled_times, sampled_marks):
#         if self.kernel == 'exponential':
#             kernel_func = self.exponential_kernel
#         else:
#             raise ValueError('kernel function not supported')
#
#         return self.mu[k] + kernel_func(t, k, sampled_times, sampled_marks)
#
#     def total_intensity(self, t, sampled_times, sampled_marks):
#         return sum(self.intensity(t, k, sampled_times, sampled_marks) for k in range(self.marks))
#
#     def sample(self, right_window, left_window=0, length_limit=None):
#         # https://stackoverflow.com/questions/6076270/lambda-function-in-list-comprehensions
#         # may also use (1) functools.partial or (2) nested lambda
#         intensity_funcs = [lambda t, sampled_times, sampled_marks, k=k:  # set k=k for default values
#                            self.intensity(t, k, sampled_times, sampled_marks) for k in range(self.marks)]
#
#         K = self.marks
#         # dom_rate = self.dom_rate
#
#         sampled_times, sampled_marks = [], []
#         t = left_window
#         while (length_limit is None) or (len(sampled_times) < length_limit):  # update <= to be <
#             # upper bound for each interval of sampled time
#             dom_rate = sum([f(t, sampled_times, sampled_marks) for f in intensity_funcs])
#             t = t - np.log(1 - np.random.uniform()) / dom_rate
#             if t > right_window:
#                 break
#             intensity_vals = [f(t, sampled_times, sampled_marks) for f in intensity_funcs]
#             total_intensity = sum(intensity_vals)
#             a = total_intensity / dom_rate
#             if np.random.uniform() < a:
#                 # sampled_times.append(t)
#                 sampled_times.append(float(t))
#                 sampled_marks.append(np.random.choice(K,
#                                                       p=[intensity / total_intensity for intensity in intensity_vals]))
#         return sampled_times, sampled_marks
#


def generate_Poisson_data(num_seqs, K, T, rates, windows, seed=None, drop_last_mark=False):
    # assert K == len(rates) == len(windows)
    if seed:
        np.random.seed(seed)
    input_data = defaultdict(list)
    event_times_dict = defaultdict(list)

    for seq_id in range(num_seqs):
        times, marks = [], []
        # for k in range(K):
        #     T_k = windows[k][1] - windows[k][0]
        #     t, t_list = 0, []
        #     while True:
        #         t = t - 1 / rates[k] * np.log(1 - np.random.uniform())
        #         if t > T_k:
        #             break
        #         t_list.append(float(t) + windows[k][0])
        #
        #     if drop_last_mark and k == K-1:
        #         continue
        #     else:
        #         times.extend(t_list)
        #         marks.extend([k] * len(t_list))
        #         event_times_dict[k].extend(t_list)


        for k in range(len(rates)):
            T_k = windows[k][1] - windows[k][0]
            t, t_list = 0, []
            while True:
                t = t - 1 / rates[k] * np.log(1 - np.random.uniform())
                if t > T_k:
                    break
                t_list.append(float(t) + windows[k][0])

            if drop_last_mark and k == K-1:
                continue
            else:
                times.extend(t_list)
                marks.extend([0] * len(t_list))
                event_times_dict[0].extend(t_list)

        sorted_events = list(zip(*sorted([*zip(times, marks)], key=lambda x: x[0])))
        if len(times) != 0:
            sampled_times = [0] + list(sorted_events[0]) + [T]
            sampled_marks = [-1] + list(sorted_events[1]) + [-1]

        else:  # generate empty sequences
            if (seq_id + 1) % 10 == 0:
                sampled_times = [0, T, T+1]
                sampled_marks = [-1, -1, -1]
            else:
                sampled_times = [0, T]
                sampled_marks = [-1, -1]
        dts = [0] + [sampled_times[i + 1] - sampled_times[i] for i in range(len(sampled_times) - 1)]


        input_data['time_seqs'].append(sampled_times)
        input_data['time_delta_seqs'].append(dts)
        input_data['type_seqs'].append(sampled_marks)
    return input_data, event_times_dict


def generate_Hawkes_data(num_seqs, marks=3, T=10, alpha=None, seed=None):
    '''
        input_data = {'time_seqs': time_seqs,  # list of lists of events
                'type_seqs': type_seqs,
                'time_delta_seqs': time_delta_seqs}
    '''
    if seed:
        np.random.seed(seed)
    # h = Hawkes(marks=marks, alpha=alpha)
    h = HawkesModel(num_marks=marks, int_strength=0.5)  # K=3, T=10; K=10, T=1
    # h = HawkesModel(num_marks=marks, int_strength=0.3)
    # h = HawkesModel(num_marks=marks, int_strength=0.1)
    input_data = defaultdict(list)
    for _ in tqdm(range(num_seqs)):
        # sampled_times, sampled_marks = h.sample(right_window=T, left_window=0)
        # sampled_times = [0] + sampled_times + [T]
        # sampled_marks = [-1] + sampled_marks + [-1]

        sampled_times, sampled_marks = h.sample_points(None, None, T=T, left_window=0)
        sampled_times = [0] + sampled_times.tolist()[0] # + [T]  # Note: IFTPP doesn't support logL with right window
        sampled_marks = [-1] + sampled_marks.tolist()[0] # + [-1]

        dts = [0] + [sampled_times[i+1] - sampled_times[i] for i in range(len(sampled_times) - 1)]
        input_data['time_seqs'].append(sampled_times)
        input_data['time_delta_seqs'].append(dts)
        input_data['type_seqs'].append(sampled_marks)
    return input_data, h


def generate_SC_data(num_seqs, marks=3, T=10, alpha=None, seed=None):
    '''
        input_data = {'time_seqs': time_seqs,  # list of lists of events
                'type_seqs': type_seqs,
                'time_delta_seqs': time_delta_seqs}
    '''
    # TODO: could be combined with Hawkes
    if seed:
        np.random.seed(seed)
    h = SelfCorrectingModel(num_marks=marks, int_strength=0.5)
    input_data = defaultdict(list)
    for _ in tqdm(range(num_seqs)):
        sampled_times, sampled_marks = h.sample_points(None, None, T=T, left_window=0)
        sampled_times = [0] + sampled_times.tolist()[0] # + [T]
        sampled_marks = [-1] + sampled_marks.tolist()[0] # + [-1]

        dts = [0] + [sampled_times[i+1] - sampled_times[i] for i in range(len(sampled_times) - 1)]
        input_data['time_seqs'].append(sampled_times)
        input_data['time_delta_seqs'].append(dts)
        input_data['type_seqs'].append(sampled_marks)
    return input_data, h


def make_data_loader(input_data, num_marks, batch_size=128):
    config = DataSpecConfig.parse_from_yaml_config({'num_event_types': num_marks, 'batch_size': batch_size,
                                                    'pad_token_id': num_marks})

    dataset = TPPDataset(input_data)
    tokenizer = EventTokenizer(config)
    data_loader = get_data_loader(dataset, 'torch', tokenizer, batch_size=batch_size)
    return data_loader


def make_model(num_marks, model_name, plot=False):
    # TODO: initialize model using dot dict
    config = Config.build_from_yaml_file('configs/exp_config_synthetic.yaml',
                                         experiment_id=EXPERIMENT_ID[model_name])
    # added for IFTPP
    config.model_config.mean_log_inter_time = 0.
    config.model_config.std_log_inter_time = 1.
    if plot:
        config.model_config.use_mc_samples = False  # TODO: check this
    model_config = config.model_config

    # TODO: better ways of initializing these models
    # hack this
    model_config.num_event_types = num_marks
    model_config.num_event_types_pad = num_marks + 1  # TODO: check where this is used
    model_config.pad_token_id = num_marks

    # model_config.hidden_size = 8  # for DLHP?

    model = MODELS[model_name](model_config).to(set_device(config.trainer_config.gpu))
    return model, config



if __name__ == '__main__':
    # TODO: save config in yaml and load it, set device
    args = get_args()

    res = defaultdict(list)
    for seed_i, seed in enumerate(args.seed):
        print(f'Current seed: {seed}')
        set_seed(seed)
        data_path = args.out_dir + f'/{args.ground_truth}/{args.model_name}/'
        create_folder(data_path)

        print(f'Generating {args.num_seqs} sequences from {args.ground_truth} processes for training...')

        # if data exists load it
        filename = f"{args.out_dir.rstrip('/')}/{args.ground_truth}/train_data_seq{args.num_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle"
        if os.path.exists(filename):
            with open(filename, 'rb') as f:
                data = pickle.load(f)
                input_data = data['input_data']
            if args.ground_truth == 'hawkes':
                h_model = HawkesModel(num_marks=args.num_marks, int_strength=0.5)
            elif args.ground_truth == 'self_correcting':
                h_model = SelfCorrectingModel(num_marks=args.num_marks, int_strength=0.5)
            else:
                raise NotImplementedError
            h_model.load_state_dict(torch.load(f"{args.out_dir.rstrip('/')}/{args.ground_truth}/model_K{args.num_marks}_T{args.max_time}_no_T.pt"))
            # h_model.load_state_dict(torch.load(f"{data_path.rstrip('/')}/model_K{args.num_marks}_T{args.max_time}.pt"))
            print("Training data loaded successfully.")
        else:
            print("Training data does not exist.")

            # generate data if not exist
            if args.ground_truth == 'poisson':
                input_data, event_times_dict_by_type = generate_Poisson_data(args.num_seqs, args.num_marks, args.max_time,
                                                                             args.poisson_rates, args.poisson_windows, seed,
                                                                             drop_last_mark=False)
                data = {'input_data': input_data, 'event_times_dict_by_type': event_times_dict_by_type, 'T': args.max_time,
                        'rates': args.poisson_rates, 'windows': args.poisson_windows, 'num_marks': args.num_marks}
            elif args.ground_truth == 'hawkes':
                input_data, h_model = generate_Hawkes_data(args.num_seqs, args.num_marks, args.max_time, seed=seed)
                torch.save(h_model.state_dict(), f"{args.out_dir.rstrip('/')}/{args.ground_truth}/model_K{args.num_marks}_T{args.max_time}_no_T.pt")
                data = {'input_data': input_data, 'T': args.max_time, 'num_marks': args.num_marks}
            elif args.ground_truth == 'self_correcting':
                input_data, h_model = generate_SC_data(args.num_seqs, args.num_marks, args.max_time, seed=seed)
                torch.save(h_model.state_dict(),f"{args.out_dir.rstrip('/')}/{args.ground_truth}/model_K{args.num_marks}_T{args.max_time}_no_T.pt")
                data = {'input_data': input_data, 'T': args.max_time, 'num_marks': args.num_marks}
            else:
                raise NotImplementedError


        print('Saving generated sequences...')
        # with open(f"{data_path.rstrip('/')}/train_data_seq{args.num_seqs}_K{args.num_marks}_T{args.max_time}.pickle", 'wb') as f:
        #     pickle.dump(data, f)
        with open(f"{args.out_dir.rstrip('/')}/{args.ground_truth}/train_data_seq{args.num_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle", 'wb') as f:
            pickle.dump(data, f)
        print('Generated sequences saved.')

        set_seed(seed)
        # generate additional test sequences
        num_test_seqs = int(args.num_seqs * args.test_pct)
        print(f'Generating {num_test_seqs} test sequences from {args.ground_truth} processes...')
        # filename = f"{data_path.rstrip('/')}/test_data_seq{num_test_seqs}_K{args.num_marks}_T{args.max_time}.pickle"
        filename = f"{args.out_dir.rstrip('/')}/{args.ground_truth}/test_data_seq{num_test_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle"
        if os.path.exists(filename):
            with open(filename, 'rb') as f:
                test_data = pickle.load(f)
            print("Test data loaded successfully.")
        else:
            print("Test data does not exist.")

            if args.ground_truth == 'poisson':
                # valid_data, event_times_dict_by_type = generate_Poisson_data(int(args.num_seqs * args.test_pct), args.num_marks,
                #                                                             args.max_time, args.poisson_rates,
                #                                                             args.poisson_windows, seed,
                #                                                             drop_last_mark=False)
                test_data, event_times_dict_by_type = generate_Poisson_data(num_test_seqs, args.num_marks,
                                                                            args.max_time, args.poisson_rates,
                                                                            args.poisson_windows, seed,
                                                                            drop_last_mark=False)
            elif args.ground_truth == 'hawkes':
                # valid_data, h_model = generate_Hawkes_data(int(args.num_seqs * args.test_pct), args.num_marks, args.max_time)
                test_data, h_model = generate_Hawkes_data(num_test_seqs, args.num_marks, args.max_time, seed=seed)
            elif args.ground_truth == 'self_correcting':
                # valid_data, h_model = generate_Hawkes_data(int(args.num_seqs * args.test_pct), args.num_marks, args.max_time)
                test_data, h_model = generate_SC_data(num_test_seqs, args.num_marks, args.max_time, seed=seed)
            else:
                raise NotImplementedError

        print('Saving generated sequences...')
        # with open(f"{data_path.rstrip('/')}/test_data_seq{num_test_seqs}_K{args.num_marks}_T{args.max_time}.pickle", 'wb') as f:
        #     pickle.dump(test_data, f)
        with open(f"{args.out_dir.rstrip('/')}/{args.ground_truth}/test_data_seq{num_test_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle", 'wb') as f:
            pickle.dump(test_data, f)
        print('Generated sequences saved.')

        set_seed(seed)
        print('Setting up model...')
        train_loader = make_data_loader(input_data, args.num_marks, args.batch_size)  # for training
        # valid_loader = make_data_loader(valid_data, args.num_marks, args.batch_size)
        test_loader = make_data_loader(test_data, args.num_marks, args.batch_size)

        model, config = make_model(args.num_marks, args.model_name)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        print(f'Model parameter count: {count_model_params(model)}')


        print(f'Training model {args.model_name}...')
        for i in tqdm(range(args.num_epochs)):
            total_loss = 0
            total_num_event = 0
            # model.train()
            for batch in train_loader:  # or data_loader.train_loader()
                batch.to(set_device(config.trainer_config.gpu))
                with torch.set_grad_enabled(True):
                    batch_loss, batch_num_event, _, _, _ = model.loglike_loss(batch=batch.values())
                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()

                total_loss += batch_loss
                total_num_event += batch_num_event

            avg_loss = total_loss / total_num_event
            print(f'epochs {i}: train loss {avg_loss}')

            if i % 100 == 99:
                # TODO: save model with best validation logL instead of the last one; add validation loss!
                print('Save model...')
                # torch.save(model.state_dict(), f"{data_path.rstrip('/')}/model_seq{args.num_seqs}_e{args.num_epochs}_rate1.pt")
                torch.save(model.state_dict(), f"{data_path.rstrip('/')}/model_seq{args.num_seqs}_e{i+1}_no_T_{seed_i}.pt")
                print('Model saved.')


        print(f'Evaluating model {args.model_name}...')
        model.eval()
        total_loss, total_num_event = 0, 0
        total_mark_ll, total_time_ll = 0, 0

        true_ll = 0
        with torch.no_grad():
            for batch in tqdm(test_loader):
                batch.to(set_device(config.trainer_config.gpu))
                batch = batch.values()
                ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch

                batch_loss, batch_num_event, mark_ll, time_ll, _ = model.loglike_loss(batch=batch)

                if args.ground_truth == 'hawkes' or args.ground_truth == 'self_correcting':
                    h_model.to(set_device(config.trainer_config.gpu))
                    batch_size = marks_BN.shape[0]
                    for i in range(batch_size):
                        # tgt_marks = marks_BN[i][batch_non_pad_mask[i]][1:][:-1][None,...]
                        # tgt_times = ts_BN[i][batch_non_pad_mask[i]][1:][:-1][None,...]
                        # # T = ts_BN[i][batch_non_pad_mask[i]][-1]
                        # T = ts_BN[i][batch_non_pad_mask[i]][-2]  # Note: last event as right window? but not consistent with model

                        tgt_marks = marks_BN[i][batch_non_pad_mask[i]][1:][None,...]
                        tgt_times = ts_BN[i][batch_non_pad_mask[i]][1:][None,...]
                        T = ts_BN[i][batch_non_pad_mask[i]][-1]

                        sample_timestamps = torch.rand(
                            tgt_times.shape[0],
                            10000,
                            dtype=tgt_times.dtype,
                            device=tgt_times.device
                        ).clamp(min=1e-8) * T  # ~ U(0, T)
                        sample_timestamps.to(set_device(config.trainer_config.gpu))
                        return_dict = h_model.forward(tgt_marks, tgt_times, sample_timestamps)
                        seq_ll = h_model.log_likelihood(return_dict, T)['log_likelihood']
                        true_ll += seq_ll.item()

                total_loss += batch_loss.item()
                total_mark_ll += mark_ll.item()
                total_time_ll += time_ll.item()
                total_num_event += batch_num_event
            avg_loss = total_loss / total_num_event
            print(f'Test logL per event: marginal: {np.round(true_ll/total_num_event, 4)}, total: {np.round(-avg_loss, 4)}, mark: {np.round(total_mark_ll/total_num_event, 4)}, time: {np.round(total_time_ll/total_num_event, 4)}')
        res['marginal'].append(true_ll/total_num_event)
        res['total'].append(-avg_loss)
        res['time'].append(total_time_ll/total_num_event)
        res['mark'].append(total_mark_ll/total_num_event)
        print(res)

    assert (len(res['marginal']) == len(res['total']) == len(res['time']) == len(res['mark']))
    print()
    print("     AGGREGATED RESULTS:")
    print(f'Results for {args.model_name} averaged over {len(res['marginal'])} random seeds.')
    if len(res['marginal']) > 1:
        print(f'Marginal logL: mean {np.round(np.mean(res['marginal']), 3)}; std {np.round(np.std(res['marginal'], ddof=1), 3)}')
        print(f'Total logL: mean {np.round(np.mean(res['total']), 3)}; std {np.round(np.std(res['total'], ddof=1), 3)}')
        print(f'Time logL: mean {np.round(np.mean(res['time']), 3)}; std {np.round(np.std(res['time'], ddof=1), 3)}')
        print(f'Mark logL : mean {np.round(np.mean(res['mark']), 3)}; std {np.round(np.std(res['mark'], ddof=1), 3)}')
    print('Experiments done.')