import logging
import os
import torch
from torch.autograd import Variable
from torch.optim import lr_scheduler
from tensorboardX import SummaryWriter
from timeit import default_timer as timer
import pytrec_eval
import json
import numpy as np

logger = logging.getLogger(__name__)


def gen_time_str(t):
    t = int(t)
    minute = t // 60
    second = t % 60
    return '%2d:%02d' % (minute, second)


def output_value(epoch, mode, step, time, loss, info, end, config):
    try:
        delimiter = config.get("output", "delimiter")
    except Exception as e:
        delimiter = " "
    s = ""
    s = s + str(epoch) + " "
    while len(s) < 7:
        s += " "
    s = s + str(mode) + " "
    while len(s) < 14:
        s += " "
    s = s + str(step) + " "
    while len(s) < 25:
        s += " "
    s += str(time)
    while len(s) < 40:
        s += " "
    s += str(loss)
    while len(s) < 48:
        s += " "
    s += str(info)
    s = s.replace(" ", delimiter)
    if not (end is None):
        print(s, end=end)
    else:
        print(s)


def valid(model, dataset, epoch, writer, config, gpu_list, output_function, mode="valid"):
    model.eval()

    acc_result = None
    total_loss = 0
    cnt = 0
    total_len = len(dataset)
    start_time = timer()
    output_info = ""

    output_time = config.getint("output", "output_time")
    step = -1
    more = ""
    if total_len < 10000:
        more = "\t"

    qrel = {}
    run = {}

    for step, data in enumerate(dataset):
        for key in data.keys():
            if isinstance(data[key], torch.Tensor):
                if len(gpu_list) > 0:
                    data[key] = Variable(data[key].cuda())
                else:
                    data[key] = Variable(data[key])

        results = model(data, config, gpu_list, acc_result, "valid")

        cnt += 1

        qid = results['qid'].cpu().numpy()
        did = results['did'].cpu().numpy()
        score = results['score'].cpu().numpy()
        label = results['label'].cpu().numpy()
        for i in range(len(qid)):
            qid_tmp = str(qid[i])
            did_tmp = str(did[i])
            if qid_tmp not in qrel:
                qrel[qid_tmp] = {}
                run[qid_tmp] = {}
            qrel[qid_tmp][did_tmp] = int(label[i].item())
            run[qid_tmp][did_tmp] = float(score[i].item())

    evaluator = pytrec_eval.RelevanceEvaluator(
                qrel, {'ndcg_cut', 'P'})
    #print(json.dumps(evaluator.evaluate(run), indent=1))
    ndcg = pytrec_eval.compute_aggregated_measure('ndcg_cut_20', [query_measures['ndcg_cut_20'] for query_measures in evaluator.evaluate(run).values()])
    p20 = pytrec_eval.compute_aggregated_measure('P_20', [query_measures['P_20'] for query_measures in evaluator.evaluate(run).values()])
    #ndcgs = [v['ndcg'] for k, v in evaluator.evaluate(run).items()]
    print("NDCG", ndcg)
    print("P", p20)

    model.train()
    return ndcg, p20
