
import argparse
import collections
import torch
import numpy as np
from sklearn.metrics import top_k_accuracy_score
from tqdm import tqdm

from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner

from easy_tpp.utils import set_device, set_seed
from synthetic_run import make_data_loader, make_model, generate_Poisson_data  # Hawkes
from easy_tpp.utils.hawkes import HawkesModel, SelfCorrectingModel
import pickle

def get_data(args):
    # data = {}
    # filename = f"{args.out_dir.rstrip('/')}/hawkes/train_data_seq{args.num_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle"
    # with open(filename, 'rb') as f:
    #     train_data = pickle.load(f)
    # data['num_marks'] = train_data['num_marks']
    # data['T'] = train_data['T']

    # filename = f"{args.out_dir.rstrip('/')}/hawkes/test_data_seq{args.num_test_seqs}_K{args.num_marks}_T{args.max_time}_no_T.pickle"
    # filename = f"{args.out_dir.rstrip('/')}/hawkes/test_data_seq250.pickle"
    filename = f"{args.out_dir.rstrip('/')}/hawkes/test_data_seq1000.pickle"
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    # data['input_data'] = test_data
    return data


def predict_next_event(data_loader, model, num_marks, device, fullynn_flag=False, num_test_seq=1000):
    all_dtime = []
    # all_expected_dtime_pred = []  # mean of all samples
    all_dtime_pred = []  # raw results


    all_labels = []
    # all_labels_pred = []  # take argmax over event types
    all_labels_score = []

    data_loader = iter(data_loader)
    # for batch in tqdm(data_loader):

    total_loss, total_num_event = 0, 0

    # TODO: we can't evaluate batch with only one seq due to the bug in event_tokenizer
    pbar = tqdm(total=num_test_seq)
    try:
        while num_test_seq > 0:
            batch = next(data_loader)
            batch = batch.to(device).values()
            batch_size = batch[0].shape[0]
            num_test_seq -= batch_size

            label_dtime, label_type = batch[1][:, 1:], batch[2][:, 1:]
            mask = batch[3][:, 1:]
            mask[batch[2][:, 1:] == num_marks] = False  # avoid grading right window events if padded

            # batch_loss, batch_num_event, mark_ll, time_ll, _ = model.loglike_loss(batch=batch)
            # total_loss += batch_loss.item()
            # total_num_event += batch_num_event

            # pred_dtime: [batch_size, seq_len, num_sample]
            # pred_type: [batch_size, seq_len, num_marks]
            pred_dtime, pred_type = model.predict_one_step_at_every_event(batch=batch,
                                                                                       get_raw_pred_next_time=True,
                                                                                       get_raw_mark_distribution=True)

            # # Could be aggregated later...
            # expected_pred_dtime = torch.mean(pred_dtime, dim=-1)
            # pred_type_max = torch.argmax(pred_type, dim=-1)

            all_dtime.extend(torch.masked_select(label_dtime, mask).cpu().numpy().reshape(-1).tolist())

            num_samples = 64

            if fullynn_flag:
                all_dtime_pred.extend(torch.masked_select(pred_dtime, mask[..., None]).cpu().detach().numpy().reshape((-1, num_samples)).tolist())
            else:
                all_dtime_pred.extend(torch.masked_select(pred_dtime, mask[..., None]).cpu().numpy().reshape((-1, num_samples)).tolist())

            # if fullynn_flag:
            #     all_expected_dtime_pred.extend(torch.masked_select(expected_pred_dtime, mask).cpu().detach().numpy().reshape(-1).tolist())
            # else:
            #     all_expected_dtime_pred.extend(torch.masked_select(expected_pred_dtime, mask).cpu().numpy().reshape(-1).tolist())
            # all_dtime_pred.extend(torch.masked_select(pred_dtime, mask).cpu().numpy().reshape(-1).tolist())

            all_labels.extend(torch.masked_select(label_type, mask).cpu().numpy().reshape(-1).tolist())
            # all_labels_pred.extend(torch.masked_select(pred_type_max, mask).cpu().numpy().reshape(-1).tolist())

            if fullynn_flag:
                pred_type = torch.masked_select(pred_type, mask[..., None]).cpu().detach().numpy().reshape((-1, num_marks))
            else:
                pred_type = torch.masked_select(pred_type, mask[..., None]).cpu().numpy().reshape((-1, num_marks))

            # all_labels_score = np.concatenate((all_labels_score, pred_type), axis=0)
            all_labels_score.append(pred_type)

            pbar.update(batch_size)
        pbar.close()
    except StopIteration:  # if there're less number of test sequences in dataloader
        pass

    # avg_loss = total_loss / total_num_event
    # print(f'Test logL per event: {np.round(-avg_loss, 4)}')
    # return all_dtime, all_expected_dtime_pred, all_dtime_pred, all_labels, all_labels_pred, all_labels_score
    return all_dtime, all_dtime_pred, all_labels, all_labels_score


def main(config, model, data_loader, num_marks=3, top_k_accuracy=1, num_test_seq=1000):
    device = set_device(config.trainer_config.gpu)

    print('Start evaluation...')
    model.eval()
    with (torch.no_grad()):
        # all_dtime, all_expected_dtime_pred, all_dtime_pred, all_labels, all_labels_pred, all_labels_score = predict_next_event(
        #     data_loader, model_runner, num_marks,device)
        all_dtime, all_dtime_pred, all_labels, all_labels_score = predict_next_event(
            data_loader, model, num_marks, device, num_test_seq=num_test_seq)

    all_labels_score = np.concatenate(all_labels_score, axis=0)


    # print('Saving results...')
    # eval_folder_path = '/'.join(config.base_config.specs['saved_log_dir'].split('/')[:-1])
    # # with open(eval_folder_path + '/true_dtime.pkl', 'wb') as f:
    # #     pickle.dump(np.array(all_dtime), f)
    # #
    # # with open(eval_folder_path + '/pred_dtime.pkl', 'wb') as f:
    # #     pickle.dump(np.array(all_dtime_pred), f)
    # #
    # # with open(eval_folder_path + '/true_marks.pkl', 'wb') as f:
    # #     pickle.dump(all_labels, f)
    # #
    # # with open(eval_folder_path + '/pred_marks.pkl', 'wb') as f:
    # #     pickle.dump(all_labels_score, f)


    print('Computing stats...')
    all_expected_dtime_pred = np.mean(all_dtime_pred, axis=-1)
    # rmse = np.sqrt(np.mean((np.array(all_dtime) - np.array(all_expected_dtime_pred)) ** 2))
    # print(f'RMSE: {rmse}')
    mse = np.mean((np.array(all_dtime) - np.array(all_expected_dtime_pred)) ** 2)
    print(f'MSE: {mse}')

    # mrae = np.median(abs(np.array(all_dtime) - np.array(all_dtime_pred))/ (np.array(all_dtime) + np.finfo(np.float32).eps))
    # print(f'Median RAE: {mrae}')

    all_labels_pred = np.argmax(all_labels_score, axis=-1)
    acc = np.mean(np.array(all_labels) == np.array(all_labels_pred))
    print(f'Accuracy: {acc}')

    # ## nan for retweet + thp
    # # acc1 = top_k_accuracy_score(np.array(all_labels), np.concatenate(all_labels_score, axis=0), k=1, labels=np.array(list(range(num_marks))))
    # acc1 = top_k_accuracy_score(np.array(all_labels), all_labels_score, k=top_k_accuracy, labels=np.array(list(range(num_marks))))
    # print(f'Accuracy 1: {acc1}')

    return mse, acc

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=False, default='synthetic',
                        help='Dataset to determine configuration yaml to train and evaluate the model.')
    parser.add_argument('--experiment_id', type=str, required=False, default='AttNHP_eval',
                        help='Experiment id in the config file.')
    parser.add_argument('--num_test_seq', type=int, required=False, default=50)
    parser.add_argument('--num_marks', default=3, type=int, help='Number of marks.')
    parser.add_argument('--batch_size', type=int, default=3, help='Batch size.')  # 16 for attnhp
    parser.add_argument('--out_dir', default='./synthetic/', type=str,
                        help='Dir for saving generated sequences and trained models.')
    args = parser.parse_args()
    args.out_dir = args.out_dir.rstrip("/")
    return args


if __name__ == '__main__':

    model_list = {
        'RMTPP_eval': 'rmtpp',
        'NHP_eval': 'nhp',
        'SAHP_eval': 'sahp',
        'THP_eval': 'thp',
        'AttNHP_eval': 'attnhp',
        'IntensityFree_eval': 'iftpp',
        'DLHP_eval': 'dlhp',
    }

    args = get_args()
    data = get_data(args)
    set_seed(123)

    print('Setting up model...')

    data_path = args.out_dir + f'/hawkes/{model_list[args.experiment_id]}/'
    model_path = f"{data_path.rstrip('/')}/model_seq50000_e300_no_T_0.pt"  # saved for different random seeds

    data_loader = make_data_loader(data, args.num_marks, args.batch_size)
    model, config = make_model(args.num_marks, model_list[args.experiment_id])
    model.load_state_dict(torch.load(model_path))

    res = {}

    # for model_name, model_id in model_list.items():
    #     print(f'Current model: {model_name}')
    config = Config.build_from_yaml_file('configs/next_event_syn.yaml', experiment_id=args.experiment_id)
    rmse, acc = main(config, model, data_loader, top_k_accuracy=1, num_test_seq=args.num_test_seq)
    # res[model_name] = {'rmse': np.round(rmse, 5), f'top 1 acc': np.round(acc, 5)}

    # for model_name, metrics in res.items():
    #     print(model_name)
    #     print(metrics)

