import truststore
truststore.inject_into_ssl()

import argparse
import collections
import torch
import numpy as np
from sklearn.metrics import top_k_accuracy_score
from tqdm import tqdm

from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner

from easy_tpp.utils import set_device
# import pickle

def predict_next_event(data_loader, model_runner, num_marks, device, fullynn_flag=False, num_test_seq=1000):
    all_dtime = []
    # all_expected_dtime_pred = []  # mean of all samples
    all_dtime_pred = []  # raw results


    all_labels = []
    # all_labels_pred = []  # take argmax over event types
    all_labels_score = []

    data_loader = iter(data_loader)
    # for batch in tqdm(data_loader):

    # TODO: we can't evaluate batch with only one seq due to the bug in event_tokenizer
    pbar = tqdm(total=num_test_seq)
    try:
        while num_test_seq > 0:
            batch = next(data_loader)
            batch = batch.to(device).values()
            batch_size = batch[0].shape[0]
            num_test_seq -= batch_size

            label_dtime, label_type = batch[1][:, 1:], batch[2][:, 1:]
            mask = batch[3][:, 1:]
            mask[batch[2][:, 1:] == num_marks] = False  # avoid grading right window events if padded

            # pred_dtime: [batch_size, seq_len, num_sample]
            # pred_type: [batch_size, seq_len, num_marks]
            pred_dtime, pred_type = model_runner.model.predict_one_step_at_every_event(batch=batch,
                                                                                       get_raw_pred_next_time=True,
                                                                                       get_raw_mark_distribution=True)

            # # Could be aggregated later...
            # expected_pred_dtime = torch.mean(pred_dtime, dim=-1)
            # pred_type_max = torch.argmax(pred_type, dim=-1)

            all_dtime.extend(torch.masked_select(label_dtime, mask).cpu().numpy().reshape(-1).tolist())

            num_samples = model_runner.runner_config.model_config.thinning.num_sample

            if fullynn_flag:
                all_dtime_pred.extend(torch.masked_select(pred_dtime, mask[..., None]).cpu().detach().numpy().reshape((-1, num_samples)).tolist())
            else:
                all_dtime_pred.extend(torch.masked_select(pred_dtime, mask[..., None]).cpu().numpy().reshape((-1, num_samples)).tolist())

            # if fullynn_flag:
            #     all_expected_dtime_pred.extend(torch.masked_select(expected_pred_dtime, mask).cpu().detach().numpy().reshape(-1).tolist())
            # else:
            #     all_expected_dtime_pred.extend(torch.masked_select(expected_pred_dtime, mask).cpu().numpy().reshape(-1).tolist())
            # all_dtime_pred.extend(torch.masked_select(pred_dtime, mask).cpu().numpy().reshape(-1).tolist())

            all_labels.extend(torch.masked_select(label_type, mask).cpu().numpy().reshape(-1).tolist())
            # all_labels_pred.extend(torch.masked_select(pred_type_max, mask).cpu().numpy().reshape(-1).tolist())

            if fullynn_flag:
                pred_type = torch.masked_select(pred_type, mask[..., None]).cpu().detach().numpy().reshape((-1, num_marks))
            else:
                pred_type = torch.masked_select(pred_type, mask[..., None]).cpu().numpy().reshape((-1, num_marks))

            # all_labels_score = np.concatenate((all_labels_score, pred_type), axis=0)
            all_labels_score.append(pred_type)

            pbar.update(batch_size)
        pbar.close()
    except StopIteration:  # if there're less number of test sequences in dataloader
        pass

    # return all_dtime, all_expected_dtime_pred, all_dtime_pred, all_labels, all_labels_pred, all_labels_score
    return all_dtime, all_dtime_pred, all_labels, all_labels_score


def main(config, use_test_data=True, top_k_accuracy=1, num_test_seq=1000):
    device = set_device(config.trainer_config.gpu)

    model_runner = Runner.build_from_config(config)

    if use_test_data:
        data_loader = model_runner._data_loader.test_loader()
    else:
        data_loader = model_runner._data_loader.valid_loader()

    num_marks = model_runner.runner_config.data_config.data_specs.num_event_types

    if args.experiment_id == 'FullyNN_eval':
        model_runner.model.train()  # gradient info needed
        # all_dtime, all_expected_dtime_pred, all_dtime_pred, all_labels, all_labels_pred, all_labels_score = predict_next_event(
        #     data_loader, model_runner, num_marks, device, fullynn_flag=True)
        all_dtime, all_dtime_pred, all_labels, all_labels_score = predict_next_event(
            data_loader, model_runner, num_marks, device, fullynn_flag=True)
    else:
        model_runner.model.eval()
        with (torch.no_grad()):
            # all_dtime, all_expected_dtime_pred, all_dtime_pred, all_labels, all_labels_pred, all_labels_score = predict_next_event(
            #     data_loader, model_runner, num_marks,device)
            all_dtime, all_dtime_pred, all_labels, all_labels_score = predict_next_event(
                data_loader, model_runner, num_marks, device, num_test_seq=num_test_seq)

    all_labels_score = np.concatenate(all_labels_score, axis=0)


    print('Saving results...')
    eval_folder_path = '/'.join(config.base_config.specs['saved_log_dir'].split('/')[:-1])
    # with open(eval_folder_path + '/true_dtime.pkl', 'wb') as f:
    #     pickle.dump(np.array(all_dtime), f)
    #
    # with open(eval_folder_path + '/pred_dtime.pkl', 'wb') as f:
    #     pickle.dump(np.array(all_dtime_pred), f)
    #
    # with open(eval_folder_path + '/true_marks.pkl', 'wb') as f:
    #     pickle.dump(all_labels, f)
    #
    # with open(eval_folder_path + '/pred_marks.pkl', 'wb') as f:
    #     pickle.dump(all_labels_score, f)


    print('Computing stats...')
    all_expected_dtime_pred = np.mean(all_dtime_pred, axis=-1)
    rmse = np.sqrt(np.mean((np.array(all_dtime) - np.array(all_expected_dtime_pred)) ** 2))
    print(f'RMSE: {rmse}')

    # mrae = np.median(abs(np.array(all_dtime) - np.array(all_dtime_pred))/ (np.array(all_dtime) + np.finfo(np.float32).eps))
    # print(f'Median RAE: {mrae}')

    all_labels_pred = np.argmax(all_labels_score, axis=-1)
    acc = np.mean(np.array(all_labels) == np.array(all_labels_pred))
    print(f'Accuracy: {acc}')

    ## nan for retweet + thp
    # acc1 = top_k_accuracy_score(np.array(all_labels), np.concatenate(all_labels_score, axis=0), k=1, labels=np.array(list(range(num_marks))))
    acc1 = top_k_accuracy_score(np.array(all_labels), all_labels_score, k=top_k_accuracy, labels=np.array(list(range(num_marks))))
    print(f'Accuracy 1: {acc1}')
    return rmse, acc



if __name__ == '__main__':

    model_list = {
        'RMTPP': 'RMTPP_eval',
        'NHP': 'NHP_eval',
        'SAHP': 'SAHP_eval',
        'THP': 'THP_eval',
        'AttNHP': 'AttNHP_eval',
        'IntensityFree': 'IntensityFree_eval',
        'DLHP': 'DLHP_eval',
        # 'FullyNN': 'FullyNN_eval',
    }

    dataset_config_path = {
        'taxi': 'configs/next_event_taxi.yaml',
        'tb': 'configs/next_event_tb.yaml',
        'so': 'configs/next_event_so.yaml',
        'amazon': 'configs/next_event_amazon.yaml',
        'rt': 'configs/next_event_rt.yaml',
        # 'ehr': 'configs/next_event_ehr.yaml',
    }

    res = {}

    for model, model_id in model_list.items():
        print(f'Current model: {model}')
        parser = argparse.ArgumentParser()
        parser.add_argument('--dataset', type=str, required=False, default='taxi',
                            help='Dataset to determine configuration yaml to train and evaluate the model.')
        parser.add_argument('--experiment_id', type=str, required=False, default=model_id,
                            help='Experiment id in the config file.')
        parser.add_argument('--num_test_seq', type=int, required=False, default=1000)
        args = parser.parse_args()
        config = Config.build_from_yaml_file(dataset_config_path[args.dataset], experiment_id=args.experiment_id)
        k = 10 if args.dataset == 'ehrshot' else 1
        rmse, acc = main(config, use_test_data=True, top_k_accuracy=k, num_test_seq=args.num_test_seq)
        res[model] = {'rmse': np.round(rmse, 5), f'top {k} acc': np.round(acc, 5)}

        for model, metrics in res.items():
            print(model)
            print(metrics)

