import os
import json
import torch
import numpy as np
from tqdm import tqdm
from utils.metrics import get_metrics, thresh_max_f1
from utils.parser import parse
from utils.loss import MyLoss
from utils.dataloader import get_dataloader
from utils.utils import Timer, EarlyStop, set_random_seed, to_gpu


def evaluate(args, stage, model, loss, loader):
    pred, real, eval_loss = [], [], []
    tqdm_loader = tqdm(loader, ncols=150)
    for i, (u, x, y, p) in enumerate(tqdm_loader):
        model.eval()
        with torch.no_grad():
            x, y, p = to_gpu(x, y, p, device=args.device)
            z = model(x, p, y)
            los = loss(z, p, y)

        pred.append(z[0])
        real.append(p)
        eval_loss.append(los.item())

    eval_loss = np.mean(eval_loss).item()
    pred = torch.sigmoid(torch.cat(pred, dim=0)).cpu().numpy()
    real = torch.cat(real, dim=0).cpu().numpy()

    if stage == 'train':
        args.threshold_value = 0.5
    elif stage == 'val':
        if args.threshold:
            args.threshold_value = float(args.threshold)
        else:
            threshold_value = thresh_max_f1(y_true=real, y_prob=pred)
            args.threshold_value = threshold_value
        print(f"Use threshold {args.threshold_value}")

    scores = get_metrics(pred, real, threshold_value=args.threshold_value)
    return eval_loss, scores, pred, real


if __name__ == '__main__':
    args = parse()
    set_random_seed(args.seed)
    print(args)

    save_folder = os.path.join('./saves', args.dataset + '-' + args.setting, args.model, args.name)
    _, val_loader, test_loader = get_dataloader(args)

    test_scores_multiple_runs = []
    saves = list(filter(lambda f: '.pt' in f, os.listdir(save_folder)))
    for run in range(len(saves)):
        early_stop = EarlyStop(args, model_path=os.path.join(save_folder, f'best-model-{run}.pt'))

        # read model
        model = early_stop.load_best_model()
        print(model)
        print('Number of model parameters is', sum([p.nelement() for p in model.parameters()]))
        loss = MyLoss(args)

        _, _, _, _ = evaluate(args, 'val', model, loss, test_loader)
        # test model
        _, test_scores, pred, tgt = evaluate(args, 'test', model, loss, test_loader)

        test_scores_multiple_runs.append(test_scores)

    # merge results from several runs
    test_scores = {'mean': {}, 'std': {}}
    for k in test_scores_multiple_runs[0].keys():
        test_scores['mean'][k] = np.mean([scores[k] for scores in test_scores_multiple_runs]).item()
        test_scores['std'][k] = np.std([scores[k] for scores in test_scores_multiple_runs]).item()

    print(f"Dataset: {args.dataset}, model: {args.model}, setting: {args.setting}, name: {args.name}")
    print('*' * 30, 'mean', '*' * 30)
    skip_keys = lambda k: '-' in str(k) and int(str(k).split('-')[-1]) not in [5, 10, 15]
    for k in test_scores['mean']:
        if not skip_keys(k):
            print(f"{k}\t", end='')
    print()
    for k in test_scores['mean']:
        if not skip_keys(k):
            print("{:.4f}\t".format(test_scores['mean'][k]), end='')
    print()

    print('*' * 30, 'std', '*' * 30)
    for k in test_scores['std']:
        if not skip_keys(k):
            print(f"{k}\t", end='')
    print()
    for k in test_scores['std']:
        if not skip_keys(k):
            print("{:.4f}\t".format(test_scores['std'][k]), end='')
    print()
