import argparse
import collections
import torch
import numpy as np
from tqdm import tqdm

from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner

from easy_tpp.utils import set_device
import pickle


def compute_predictions(data_loader, model_runner, num_marks, device, is_if_model):
    dt_cdf, true_mark, mark_pred, mark_conf = [], [], [], []
    for batch in data_loader:
        batch = batch.to(device).values()
        times_BN, dtimes_BN, marks_BN, batch_non_pad_mask, _ = batch
        _, _, _, _, res_dict = model_runner.model.loglike_loss(batch, return_raw_ll=True)

        if is_if_model:  # intensity free model
            dt_cdf_B_Nm1 = res_dict['dt_cdf']
            mark_probs = res_dict['mark_probs']
            # # need to renormalize because of the padding event probability
            # mark_probs_normalized = mark_probs / mark_probs.sum(dim=-1, keepdim=True)
            batch_mark_conf, batch_mark_pred = torch.max(mark_probs, -1)
        else:
            non_event_ll_B_Nm1 = res_dict['non_event_ll']
            dt_cdf_B_Nm1 = 1 - torch.exp(-non_event_ll_B_Nm1)

            mark_intensity_B_Nm1_M = res_dict['mark_intensity']
            intensities_normalized = mark_intensity_B_Nm1_M / mark_intensity_B_Nm1_M.sum(dim=-1, keepdim=True)
            batch_mark_conf, batch_mark_pred = torch.max(intensities_normalized, -1)



        event_mask = torch.logical_and(batch_non_pad_mask[:, 1:], marks_BN[:, 1:] != num_marks)
        dt_cdf.extend(torch.masked_select(dt_cdf_B_Nm1, event_mask).cpu().numpy().tolist())
        mark_pred.extend(torch.masked_select(batch_mark_pred, event_mask).cpu().numpy().tolist())
        mark_conf.extend(torch.masked_select(batch_mark_conf, event_mask).cpu().numpy().tolist())
        true_mark.extend(torch.masked_select(marks_BN[:, 1:], event_mask).cpu().numpy().tolist())
    return dt_cdf, true_mark, mark_pred, mark_conf


def main(config, use_test_data=True, pce_bin=50, ece_bin=20, is_if_model=False):
    device = set_device(config.trainer_config.gpu)
    config.trainer_config.max_epoch = 1
    config.model_config.loss_integral_num_sample_per_step = 50


    model_runner = Runner.build_from_config(config)

    if use_test_data:
        data_loader = model_runner._data_loader.test_loader()
    else:
        data_loader = model_runner._data_loader.valid_loader()

    num_marks = model_runner.runner_config.data_config.data_specs.num_event_types
    # ignore FullyNN
    model_runner.model.eval()
    with (torch.no_grad()):
        dt_cdf, true_mark, mark_pred, mark_conf = compute_predictions(data_loader, model_runner, num_marks, device, is_if_model)

    num_events = len(dt_cdf)
    # compute PCE, Eq.61
    pm = torch.linspace(0, 1, pce_bin + 1)[1:][None, ...]

    indicator_eval = (torch.tensor(dt_cdf)[..., None] <= pm).int().sum(dim=0) / num_events
    pce = torch.mean(abs(indicator_eval - pm))
    print(f'PCE: {np.round(pce, 5)}')

    # compute ECE, weighted by number of predictions in each bin
    prob_bins = np.linspace(0, 1, ece_bin + 1)
    true_mark_np = np.array(true_mark)
    mark_pred_np = np.array(mark_pred)
    mark_conf_np = np.array(mark_conf)
    ece = 0
    for i in range(ece_bin):
        mark_mask_i = (prob_bins[i] <= mark_conf_np) & (mark_conf_np < prob_bins[i+1])
        true_mark_i = true_mark_np[mark_mask_i]
        mark_pred_i = mark_pred_np[mark_mask_i]
        mark_conf_i = mark_conf_np[mark_mask_i]
        if len(true_mark_i):
            ece += abs(np.mean(true_mark_i == mark_pred_i) - np.mean(mark_conf_i)) * len(true_mark_i) / num_events
    print(f'ECE: {np.round(ece, 5)}')

    return dt_cdf, true_mark, mark_pred, mark_conf, pce, ece


if __name__ == '__main__':
    model_list = {
        'RMTPP': 'RMTPP_eval',
        'NHP': 'NHP_eval',
        'SAHP': 'SAHP_eval',
        'THP': 'THP_eval',
        'IntensityFree': 'IntensityFree_eval',
        'DLHP': 'DLHP_eval',
        'AttNHP': 'AttNHP_eval',
        # 'FullyNN': 'FullyNN_eval',
    }

    dataset_config_path = {
        'taxi': '../configs/next_event_taxi.yaml',
        'taobao': '../configs/next_event_tb.yaml',
        'stackoverflow': '../configs/next_event_so.yaml',
        'amazon': '../configs/next_event_amazon.yaml',
        'retweet': '../configs/next_event_rt.yaml',
        'ehrshot': '../configs/next_event_ehr.yaml',
        'nlb1rep': '../configs/next_event_nlb.yaml',
        'lastfm': '../configs/next_event_lastfm.yaml',
    }

    res = {}

    for model, model_id in model_list.items():
        print(f'Current model: {model}')
        parser = argparse.ArgumentParser()
        parser.add_argument('--dataset', type=str, required=False, default='lastfm',
                            help='Dataset to determine configuration yaml to train and evaluate the model.')
        parser.add_argument('--experiment_id', type=str, required=False, default=model_id,
                            help='Experiment id in the config file.')
        args = parser.parse_args()

        print(dataset_config_path[args.dataset])

        config = Config.build_from_yaml_file(dataset_config_path[args.dataset], experiment_id=args.experiment_id)
        dt_cdf, true_mark, mark_pred, mark_conf, pce, ece = main(config, use_test_data=True,
                                                                 is_if_model=(model=='IntensityFree'))


        res[model] = {'PCE': pce, 'ECE': np.round(ece, 5)}

        save_dir = config.base_config.base_dir + f'{args.dataset}/{model}/'
        with open(save_dir + 'dt_cdf.pkl', 'wb') as f:
            pickle.dump(dt_cdf, f)

        with open(save_dir + 'true_mark.pkl', 'wb') as f:
            pickle.dump(true_mark, f)

        with open(save_dir + 'mark_pred.pkl', 'wb') as f:
            pickle.dump(mark_pred, f)

        with open(save_dir + 'mark_conf.pkl', 'wb') as f:
            pickle.dump(mark_conf, f)

    for model, metrics in res.items():
        print(model)
        print(metrics)