import numpy as np
import tensorflow.compat.v1 as tf
import math
import os
import time
import matplotlib.pyplot as plt
import sys
from sklearn import metrics


def Split_candidate_ranking(sess, model,mask, test_set, all_item,topN, log_item,mask_valid,isTest=True):
    def partition_arg_topK(matrix, K, axis=0):
        a_part = np.argpartition(matrix, K, axis=axis)
        if axis == 0:
            row_index = np.arange(matrix.shape[1 - axis])
            a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis)
            return a_part[0:K, :][a_sec_argsort_K, row_index]
        else:
            column_index = np.arange(matrix.shape[1 - axis])[:, None]
            a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis)
            return a_part[:, 0:K][column_index, a_sec_argsort_K]

    user_pred = []
    groundTruth = []

    for u in range(len(test_set)):
        groundTruth.append(test_set[u][3])

    batch_size = 512
    num_of_data = len(test_set)
    num_batch = num_of_data//batch_size
    for i in range(num_batch):
        user = [test_set[j][0] for j in range(i*batch_size,(i+1)*batch_size)]
        hist_t = [test_set[j][1] for j in range(i * batch_size, (i + 1) * batch_size)]
        sl_t = [len(test_set[j][1]) for j in range(i*batch_size,(i+1)*batch_size)]
        sl = [max(sl_t) for j in range(len(sl_t))]
        hist = [[0 for j in range(sl[0])]for u in range(batch_size)]
        gr_pos = [test_set[j][3] for j in range(i*batch_size, (i+1)*batch_size)]
        for u in range(batch_size):
            for item in range(len(hist_t[u])):
                hist[u][item] = hist_t[u][item]
        inTrainSet = [[0 for j in range(len(all_item))] for u in range(batch_size)]
        for u in range(batch_size):
            user_id = user[u]
            for j in range(len(mask[user_id])):
                inTrainSet[u][mask[user_id][j]] = -9999
        if isTest:
            for u in range(batch_size):
                user_id = user[u]
                if len(mask_valid)==0:
                    continue
                for j in range(len(mask_valid[user_id])):
                    inTrainSet[u][mask_valid[user_id][j]] = -9999
        prediction = model.run_evaluate_user(sess,user,hist,sl)[0] + inTrainSet
        result = partition_arg_topK(-prediction, topN[-1], axis=1)
        user_pred.extend(result.tolist())
    start = num_batch * batch_size
    end = len(test_set)
    user = [test_set[j][0] for j in range(start, end)]
    hist_t = [test_set[j][1] for j in range(start, end)]
    sl_t = [len(test_set[j][1]) for j in range(start, end)]
    sl = [max(sl_t) for j in range(len(sl_t))]
    hist = [[0 for j in range(sl[0])] for u in range(start, end)]
    gr_pos = [test_set[j][3] for j in range(start, end)]
    for u in range(len(hist_t)):
        for item in range(len(hist_t[u])):
            hist[u][item] = hist_t[u][item]
    inTrainSet = [[0 for j in range(len(all_item))] for u in range(start, end)]
    for u in range(len(user)):
        user_id = user[u]
        for j in range(len(mask[user_id])):
            inTrainSet[u][mask[user_id][j]] = -9999
    if isTest:
        for u in range(len(user)):
            user_id = user[u]
            if len(mask_valid) == 0:
                continue
            for j in range(len(mask_valid[user_id])):
                inTrainSet[u][mask_valid[user_id][j]] = -9999

    prediction = model.run_evaluate_user(sess,user,hist,sl)[0]+ inTrainSet
    result = partition_arg_topK(-prediction, topN[-1], axis=1)
    user_pred.extend(result.tolist())


    # calculate
    precision, recall, NDCG = Split_computeTopNAccuracy(groundTruth, user_pred, topN, log_item)
    return precision, recall, NDCG




def Split_computeTopNAccuracy(GroundTruth, predictedIndices, topN, log_item):
    precision = []
    recall = []
    NDCG = []

    for index in range(len(topN)): 
        for i in range(len(predictedIndices)):  
            if len(GroundTruth[i]) != 0: 
                userHit = 0
                dcg = 0
                idcg = 0
                idcgCount = len(GroundTruth[i])
                ndcg = 0
                for j in range(topN[index]):  
                    item = predictedIndices[i][j]
                    if item in GroundTruth[i]: 
                        dcg += 1.0 / math.log2(j + 2) 
                        userHit += 1
                    if idcgCount > 0:  
                        idcg += 1.0 / math.log2(j + 2)
                        idcgCount = idcgCount - 1
                if (idcg != 0):
                    ndcg += (dcg / idcg)
                precision.append(userHit / topN[index])
                recall.append(userHit / len(GroundTruth[i]))
                NDCG.append(ndcg)
    return precision, recall, NDCG
