import numpy as np
import tensorflow.compat.v1 as tf
import math
import os
import time
import matplotlib.pyplot as plt
import sys
from sklearn import metrics


def candidate_ranking(sess, model, mask, test_set, all_item, topN, log_item,mask_valid,isTest=False):
    def partition_arg_topK(matrix, K, axis=0):
        a_part = np.argpartition(matrix, K, axis=axis)
        if axis == 0:
            row_index = np.arange(matrix.shape[1 - axis])
            a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis)
            return a_part[0:K, :][a_sec_argsort_K, row_index]
        else:
            column_index = np.arange(matrix.shape[1 - axis])[:, None]
            a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis)
            return a_part[:, 0:K][column_index, a_sec_argsort_K]

    user_pred = []
    groundTruth_all_click = []
    groundTruth_pos = []
    auc = []
    for u in range(len(test_set)):
        groundTruth_all_click.append(test_set[u][2])
        groundTruth_pos.append(test_set[u][3])

    batch_size = 512
    num_of_data = len(test_set)
    num_batch = num_of_data//batch_size

    for i in range(num_batch):
        user = [test_set[j][0] for j in range(i*batch_size,(i+1)*batch_size)]
        hist_t = [test_set[j][1] for j in range(i * batch_size, (i + 1) * batch_size)]
        sl_t = [len(test_set[j][1]) for j in range(i*batch_size,(i+1)*batch_size)]
        sl = [max(sl_t) for j in range(len(sl_t))]
        hist = [[0 for j in range(sl[0])]for u in range(batch_size)]
        gr_pos = [test_set[j][3] for j in range(i*batch_size, (i+1)*batch_size)]
        for u in range(batch_size):
            for item in range(len(hist_t[u])):
                hist[u][item] = hist_t[u][item]
        inTrainSet = [[0 for j in range(len(all_item))] for u in range(batch_size)]
        for u in range(batch_size):
            user_id = user[u]
            for j in range(len(mask[user_id])):
                inTrainSet[u][mask[user_id][j]] = -9999
        if isTest:
            for u in range(batch_size):
                user_id = user[u]
                if len(mask_valid)==0:
                    continue
                for j in range(len(mask_valid[user_id])):
                    inTrainSet[u][mask_valid[user_id][j]] = -9999
        prediction = model.run_evaluate_user(sess,user,hist,sl)[0] + inTrainSet 
        for uu in range(len(user)):
            pre_list = []
            gr_list = []
            for item in range(len(all_item)):
                if prediction[uu][item]<-9998:
                    continue
                else:
                    pre_list.append(prediction[uu][item])
                    label_for_gr = 1 if item in gr_pos[uu] else 0
                    gr_list.append(label_for_gr)
            if sum(gr_list)==0 or sum(gr_list)==3327:
                continue
            else:
                batch_auc = metrics.roc_auc_score(gr_list, pre_list)
                auc.append(batch_auc)
        result = partition_arg_topK(-prediction, topN[-1], axis=1)
        user_pred.extend(result.tolist())
   
    start = num_batch * batch_size
    end = len(test_set)
    user = [test_set[j][0] for j in range(start, end)]
    hist_t = [test_set[j][1] for j in range(start, end)]
    sl_t = [len(test_set[j][1]) for j in range(start, end)]
    sl = [max(sl_t) for j in range(len(sl_t))]
    hist = [[0 for j in range(sl[0])] for u in range(start, end)]
    gr_pos = [test_set[j][3] for j in range(start, end)]
    for u in range(len(hist_t)):
        for item in range(len(hist_t[u])):
            hist[u][item] = hist_t[u][item]
    inTrainSet = [[0 for j in range(len(all_item))] for u in range(start, end)]
    for u in range(len(user)):
        user_id = user[u]
        for j in range(len(mask[user_id])):
            inTrainSet[u][mask[user_id][j]] = -9999
    if isTest:
        for u in range(len(user)):
            user_id = user[u]
            if len(mask_valid) == 0:
                continue
            for j in range(len(mask_valid[user_id])):
                inTrainSet[u][mask_valid[user_id][j]] = -9999

    prediction = model.run_evaluate_user(sess,user,hist,sl)[0]+ inTrainSet
    for uu in range(len(user)):
        pre_list = []
        gr_list = []
        for item in range(len(all_item)):
            if prediction[uu][item] < -9998:
                continue
            else:
                pre_list.append(prediction[uu][item])
                label_for_gr = 1 if item in gr_pos[uu] else 0
                gr_list.append(label_for_gr)
        if sum(gr_list) == 0 or sum(gr_list) == 3327:
            continue
        else:
            batch_auc = metrics.roc_auc_score(gr_list, pre_list)
            auc.append(batch_auc)
    result = partition_arg_topK(-prediction, topN[-1], axis=1)
    user_pred.extend(result.tolist())

    # calculate
    uauc = (sum(auc) * 10) / (len(auc) * 10)
    precision, recall, NDCG, test_popularity = computeTopNAccuracy(groundTruth_all_click, user_pred, topN, log_item)
    pos_precision, pos_recall, pos_NDCG, pos_test_popularity = computeTopNAccuracy(groundTruth_pos, user_pred, topN, log_item)
    return [precision, recall, NDCG, test_popularity], [pos_precision, pos_recall, pos_NDCG, pos_test_popularity], user_pred,uauc




def computeTopNAccuracy(GroundTruth, predictedIndices, topN, log_item):
    precision = []
    recall = []
    NDCG = []
    test_popularity = []


    for index in range(len(topN)):
        sumForPrecision = 0
        sumForRecall = 0
        sumForNdcg = 0
        test_popularity_top = {}
        temp_pool_ratio = {}
        temp_count_pool_ratio = {}
        for i in range(len(predictedIndices)): 
            if len(GroundTruth[i]) != 0:  
                userHit = 0
                dcg = 0
                idcg = 0
                idcgCount = len(GroundTruth[i])
                ndcg = 0
                for j in range(topN[index]):
                    item = predictedIndices[i][j]
                    if item not in log_item.keys():
                        key = 0.5
                    else:
                        key = log_item[item]

                    if key not in test_popularity_top.keys():
                        test_popularity_top[key] = 1
                    else:
                        test_popularity_top[key] += 1
                    if item in GroundTruth[i]:  
                        dcg += 1.0 / math.log2(j + 2)  
                        userHit += 1
                    if idcgCount > 0:  
                        idcg += 1.0 / math.log2(j + 2)
                        idcgCount = idcgCount - 1
                if (idcg != 0):
                    ndcg += (dcg / idcg)
                sumForPrecision += userHit / topN[index]
                sumForRecall += userHit / len(GroundTruth[i])
                sumForNdcg += ndcg

        for key in temp_pool_ratio.keys():
            temp_pool_ratio[key] /= temp_count_pool_ratio[key]

        precision.append(round(sumForPrecision / len(predictedIndices), 4))
        recall.append(round(sumForRecall / len(predictedIndices), 4))
        NDCG.append(round(sumForNdcg / len(predictedIndices), 4))
        test_popularity.append(test_popularity_top)

    return precision, recall, NDCG, test_popularity